openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use http::{header, HeaderValue};
use openauth_core::api::{
    parse_request_body, redirect_with_error_response, ApiRequest, ApiResponse,
};
use openauth_core::context::{AuthContext, SecretMaterial};
use openauth_core::crypto::{symmetric_decrypt, symmetric_encrypt};
use openauth_core::error::OpenAuthError;
use serde_json::{Map, Value};
use url::Url;

use super::options::OAuthProxyOptions;

pub(crate) fn strip_trailing_slash(value: &str) -> &str {
    value.trim_end_matches('/')
}

pub(crate) fn current_url(
    context: &AuthContext,
    request: &ApiRequest,
    options: &OAuthProxyOptions,
) -> Option<Url> {
    if let Some(value) = &options.current_url {
        return Url::parse(value).ok();
    }
    if let Some(value) = request_origin(request) {
        return Url::parse(&value).ok();
    }
    if let Some(value) = vendor_current_url() {
        return Url::parse(&value).ok();
    }
    (!context.base_url.is_empty())
        .then(|| Url::parse(&context.base_url).ok())
        .flatten()
}

pub(crate) fn production_url(context: &AuthContext, options: &OAuthProxyOptions) -> Option<Url> {
    let value = options
        .production_url
        .clone()
        .or_else(|| std::env::var("OPENAUTH_URL").ok())
        .or_else(|| (!context.base_url.is_empty()).then(|| context.base_url.clone()))?;
    Url::parse(&value).ok()
}

pub(crate) fn should_skip_proxy(
    context: &AuthContext,
    request: &ApiRequest,
    options: &OAuthProxyOptions,
) -> bool {
    if request.headers().contains_key("x-skip-oauth-proxy") {
        return true;
    }
    let Some(current) = current_url(context, request, options) else {
        return false;
    };
    let Some(production) = production_url(context, options) else {
        return false;
    };
    current.origin() == production.origin()
}

pub(crate) fn proxy_callback_url(
    context: &AuthContext,
    request: &ApiRequest,
    options: &OAuthProxyOptions,
    original_callback_url: &str,
) -> Option<String> {
    let current = current_url(context, request, options)?;
    Some(format!(
        "{}{}/oauth-proxy-callback?callbackURL={}",
        strip_trailing_slash(current.origin().ascii_serialization().as_str()),
        context.base_path,
        percent_encode(original_callback_url)
    ))
}

pub(crate) fn production_base_url(context: &AuthContext, options: &OAuthProxyOptions) -> String {
    production_url(context, options)
        .map(|url| {
            format!(
                "{}{}",
                strip_trailing_slash(url.origin().ascii_serialization().as_str()),
                context.base_path
            )
        })
        .unwrap_or_else(|| context.base_url.clone())
}

pub(crate) fn encrypt(
    context: &AuthContext,
    options: &OAuthProxyOptions,
    data: &str,
) -> Result<String, OpenAuthError> {
    if let Some(secret) = &options.secret {
        return symmetric_encrypt(secret.as_str(), data);
    }
    match &context.secret_config {
        SecretMaterial::Single(secret) => symmetric_encrypt(secret.as_str(), data),
        SecretMaterial::Rotating(config) => symmetric_encrypt(config, data),
    }
}

pub(crate) fn decrypt(
    context: &AuthContext,
    options: &OAuthProxyOptions,
    data: &str,
) -> Result<String, OpenAuthError> {
    if let Some(secret) = &options.secret {
        return symmetric_decrypt(secret.as_str(), data);
    }
    match &context.secret_config {
        SecretMaterial::Single(secret) => symmetric_decrypt(secret.as_str(), data),
        SecretMaterial::Rotating(config) => symmetric_decrypt(config, data),
    }
}

pub(crate) fn query_param(request: &ApiRequest, name: &str) -> Option<String> {
    request.uri().query().and_then(|query| {
        query.split('&').find_map(|pair| {
            let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
            (key == name).then(|| percent_decode(value))
        })
    })
}

pub(crate) fn query_or_body_param(request: &ApiRequest, name: &str) -> Option<String> {
    query_param(request, name).or_else(|| {
        if request.body().is_empty() {
            return None;
        }
        parse_request_body::<Value>(request)
            .ok()
            .and_then(|body| body_string(&body, name))
    })
}

pub(crate) fn rewrite_callback_body(
    request: &mut ApiRequest,
    callback_url: &str,
) -> Result<(), OpenAuthError> {
    let content_type = request
        .headers()
        .get(header::CONTENT_TYPE)
        .and_then(|value| value.to_str().ok())
        .unwrap_or("application/json");
    let mut body = if request.body().is_empty() {
        Value::Object(Map::new())
    } else {
        parse_request_body::<Value>(request)?
    };
    let object = body
        .as_object_mut()
        .ok_or_else(|| OpenAuthError::Api("OAuth sign-in body must be an object".to_owned()))?;
    object.insert(
        "callbackURL".to_owned(),
        Value::String(callback_url.to_owned()),
    );
    if content_type.starts_with("application/x-www-form-urlencoded") {
        let mut encoded = url::form_urlencoded::Serializer::new(String::new());
        for (key, value) in object {
            encoded.append_pair(key, &body_string_value(value));
        }
        *request.body_mut() = encoded.finish().into_bytes();
    } else {
        *request.body_mut() =
            serde_json::to_vec(&body).map_err(|error| OpenAuthError::Api(error.to_string()))?;
        request.headers_mut().insert(
            header::CONTENT_TYPE,
            HeaderValue::from_static("application/json"),
        );
    }
    Ok(())
}

pub(crate) fn redirect_error(error_url: &str, error: &str) -> Result<ApiResponse, OpenAuthError> {
    redirect_with_error_response(error_url, error)
}

pub(crate) fn is_trusted_callback_url(
    context: &AuthContext,
    request: &ApiRequest,
    callback_url: &str,
) -> Result<bool, OpenAuthError> {
    if callback_url.starts_with('/') && !callback_url.starts_with("//") {
        return Ok(true);
    }
    let Ok(url) = Url::parse(callback_url) else {
        return Ok(false);
    };
    context.is_trusted_origin_for_request(url.as_str(), None, Some(request))
}

pub(crate) fn percent_encode(value: &str) -> String {
    url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
}

fn request_origin(request: &ApiRequest) -> Option<String> {
    let uri = request.uri();
    let scheme = uri.scheme_str()?;
    let authority = uri.authority()?;
    Some(format!("{scheme}://{authority}"))
}

fn vendor_current_url() -> Option<String> {
    [
        "VERCEL_URL",
        "NETLIFY_URL",
        "RENDER_URL",
        "AWS_LAMBDA_FUNCTION_URL",
        "AWS_LAMBDA_FUNCTION_NAME",
        "AWS_FUNCTION_URL",
        "GOOGLE_CLOUD_FUNCTION_URL",
        "GOOGLE_CLOUD_FUNCTION_NAME",
        "AZURE_FUNCTION_URL",
        "AZURE_FUNCTION_NAME",
        "FUNCTIONS_CUSTOMHANDLER_PORT",
    ]
    .into_iter()
    .find_map(|key| std::env::var(key).ok())
    .map(|value| {
        if value.starts_with("http://") || value.starts_with("https://") {
            value
        } else {
            format!("https://{value}")
        }
    })
}

fn body_string(body: &Value, key: &str) -> Option<String> {
    body.get(key).and_then(|value| match value {
        Value::String(value) => Some(value.clone()),
        Value::Number(value) => Some(value.to_string()),
        Value::Bool(value) => Some(value.to_string()),
        _ => None,
    })
}

fn body_string_value(value: &Value) -> String {
    match value {
        Value::String(value) => value.clone(),
        Value::Number(value) => value.to_string(),
        Value::Bool(value) => value.to_string(),
        Value::Null => String::new(),
        other => other.to_string(),
    }
}

fn percent_decode(value: &str) -> String {
    url::form_urlencoded::parse(format!("x={value}").as_bytes())
        .next()
        .map(|(_, value)| value.into_owned())
        .unwrap_or_default()
}