openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use http::{header, Method, StatusCode};
use openauth_core::api::{
    create_auth_endpoint, json_openapi_response, parse_request_body, AsyncAuthEndpoint,
    AuthEndpointOptions, BodyField, BodySchema, JsonSchemaType, OpenApiOperation,
};
use openauth_core::db::{Create, DbValue};
use serde_json::{json, Value};
use time::OffsetDateTime;

use super::shared::{
    adapter, current_session, json_response, oauth_error, random_token, with_cors,
};
use super::{
    McpClientIdGenerator, McpClientSecretGenerator, ResolvedMcpOptions, TokenEndpointAuthMethod,
};

pub fn register_endpoint(
    _options: ResolvedMcpOptions,
    client_id_generator: Option<McpClientIdGenerator>,
    client_secret_generator: Option<McpClientSecretGenerator>,
) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/mcp/register",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("registerMcpClient")
            .allowed_media_types(["application/json"])
            .body_schema(BodySchema::object([
                BodyField::optional("redirect_uris", JsonSchemaType::Array),
                BodyField::optional("client_name", JsonSchemaType::String),
                BodyField::optional("grant_types", JsonSchemaType::Array),
                BodyField::optional("response_types", JsonSchemaType::Array),
                BodyField::optional("token_endpoint_auth_method", JsonSchemaType::String),
                BodyField::optional("scope", JsonSchemaType::String),
            ]))
            .openapi(
                OpenApiOperation::new("registerMcpClient")
                    .summary("Register MCP client")
                    .description("Register an OAuth client for MCP authorization")
                    .tag("MCP")
                    .response(
                        "201",
                        json_openapi_response(
                            "Created MCP OAuth client",
                            json!({
                                "type": "object",
                                "properties": {
                                    "client_id": { "type": "string" },
                                    "client_secret": { "type": "string" },
                                    "redirect_uris": {
                                        "type": "array",
                                        "items": { "type": "string" }
                                    },
                                    "token_endpoint_auth_method": { "type": "string" },
                                },
                                "required": ["client_id", "redirect_uris", "token_endpoint_auth_method"]
                            }),
                        ),
                    ),
            ),
        move |context, request| {
            let client_id_generator = client_id_generator.clone();
            let client_secret_generator = client_secret_generator.clone();
            Box::pin(async move {
                let adapter = adapter(context)?;
                let body: Value = parse_request_body(&request)?;
                let grant_types = string_array(&body, "grant_types")
                    .unwrap_or_else(|| vec!["authorization_code".to_owned()]);
                let response_types = string_array(&body, "response_types")
                    .unwrap_or_else(|| vec!["code".to_owned()]);
                let redirect_uris = string_array(&body, "redirect_uris").unwrap_or_default();
                if let Some(invalid) = redirect_uris
                    .iter()
                    .find(|redirect_uri| !is_valid_redirect_uri(redirect_uri))
                {
                    return oauth_error(
                        StatusCode::BAD_REQUEST,
                        "invalid_redirect_uri",
                        &format!("Invalid redirect URI: {invalid}"),
                    );
                }
                if let Some(invalid) = grant_types.iter().find(|grant| !is_allowed_grant(grant)) {
                    return oauth_error(
                        StatusCode::BAD_REQUEST,
                        "invalid_client_metadata",
                        &format!("Unsupported grant type: {invalid}"),
                    );
                }
                if let Some(invalid) = response_types
                    .iter()
                    .find(|response| !is_allowed_response_type(response))
                {
                    return oauth_error(
                        StatusCode::BAD_REQUEST,
                        "invalid_client_metadata",
                        &format!("Unsupported response type: {invalid}"),
                    );
                }
                if let Some(method) = string_value(&body, "token_endpoint_auth_method") {
                    if token_endpoint_auth_method(&method).is_none() {
                        return oauth_error(
                            StatusCode::BAD_REQUEST,
                            "invalid_client_metadata",
                            &format!("Unsupported token endpoint auth method: {method}"),
                        );
                    }
                }

                if requires_redirect_uri(&grant_types) && redirect_uris.is_empty() {
                    return oauth_error(
                        StatusCode::BAD_REQUEST,
                        "invalid_redirect_uri",
                        "Redirect URIs are required for authorization_code and implicit grant types",
                    );
                }
                if grant_types
                    .iter()
                    .any(|grant| grant == "authorization_code")
                    && !response_types.iter().any(|response| response == "code")
                {
                    return oauth_error(
                        StatusCode::BAD_REQUEST,
                        "invalid_client_metadata",
                        "When 'authorization_code' grant type is used, 'code' response type must be included",
                    );
                }
                if grant_types.iter().any(|grant| grant == "implicit")
                    && !response_types.iter().any(|response| response == "token")
                {
                    return oauth_error(
                        StatusCode::BAD_REQUEST,
                        "invalid_client_metadata",
                        "When 'implicit' grant type is used, 'token' response type must be included",
                    );
                }

                let session = current_session(adapter.as_ref(), context, &request).await?;
                let auth_method = auth_method(&body);
                let client_type = if auth_method == TokenEndpointAuthMethod::None {
                    "public"
                } else {
                    "web"
                };
                let client_id = client_id_generator
                    .as_ref()
                    .map(|generator| generator())
                    .unwrap_or_else(random_token);
                let client_secret = (client_type != "public").then(|| {
                    client_secret_generator
                        .as_ref()
                        .map(|generator| generator())
                        .unwrap_or_else(random_token)
                });
                let now = OffsetDateTime::now_utc();

                adapter
                    .create(
                        Create::new("oauthApplication")
                            .data(
                                "name",
                                DbValue::String(
                                    string_value(&body, "client_name")
                                        .unwrap_or_else(|| client_id.clone()),
                                ),
                            )
                            .data("icon", nullable_string(&body, "logo_uri"))
                            .data("metadata", metadata_string(&body))
                            .data("clientId", DbValue::String(client_id.clone()))
                            .data(
                                "clientSecret",
                                client_secret
                                    .clone()
                                    .map(DbValue::String)
                                    .unwrap_or(DbValue::Null),
                            )
                            .data("redirectUrls", DbValue::String(redirect_uris.join(",")))
                            .data("type", DbValue::String(client_type.to_owned()))
                            .data(
                                "authenticationScheme",
                                DbValue::String(auth_method.as_str().to_owned()),
                            )
                            .data("disabled", DbValue::Boolean(false))
                            .data(
                                "userId",
                                session
                                    .map(|session| DbValue::String(session.user_id))
                                    .unwrap_or(DbValue::Null),
                            )
                            .data("createdAt", DbValue::Timestamp(now))
                            .data("updatedAt", DbValue::Timestamp(now)),
                    )
                    .await?;

                let mut response = json!({
                    "client_id": client_id,
                    "client_id_issued_at": now.unix_timestamp(),
                    "redirect_uris": redirect_uris,
                    "token_endpoint_auth_method": auth_method.as_str(),
                    "grant_types": grant_types,
                    "response_types": response_types,
                    "client_name": string_value(&body, "client_name"),
                    "client_uri": string_value(&body, "client_uri"),
                    "logo_uri": string_value(&body, "logo_uri"),
                    "scope": string_value(&body, "scope"),
                    "contacts": string_array(&body, "contacts"),
                    "tos_uri": string_value(&body, "tos_uri"),
                    "policy_uri": string_value(&body, "policy_uri"),
                    "jwks_uri": string_value(&body, "jwks_uri"),
                    "jwks": body.get("jwks").cloned(),
                    "software_id": string_value(&body, "software_id"),
                    "software_version": string_value(&body, "software_version"),
                    "software_statement": string_value(&body, "software_statement"),
                    "metadata": body.get("metadata").cloned(),
                });
                if let Some(secret) = client_secret {
                    response["client_secret"] = Value::String(secret);
                    response["client_secret_expires_at"] = json!(0);
                }
                let mut response = json_response(StatusCode::CREATED, &response)?;
                response.headers_mut().insert(
                    header::CACHE_CONTROL,
                    http::HeaderValue::from_static("no-store"),
                );
                response
                    .headers_mut()
                    .insert(header::PRAGMA, http::HeaderValue::from_static("no-cache"));
                with_cors(response)
            })
        },
    )
}

fn requires_redirect_uri(grant_types: &[String]) -> bool {
    grant_types.is_empty()
        || grant_types
            .iter()
            .any(|grant| grant == "authorization_code" || grant == "implicit")
}

fn auth_method(body: &Value) -> TokenEndpointAuthMethod {
    string_value(body, "token_endpoint_auth_method")
        .and_then(|method| token_endpoint_auth_method(&method))
        .unwrap_or(TokenEndpointAuthMethod::ClientSecretBasic)
}

fn token_endpoint_auth_method(method: &str) -> Option<TokenEndpointAuthMethod> {
    match method {
        "none" => Some(TokenEndpointAuthMethod::None),
        "client_secret_basic" => Some(TokenEndpointAuthMethod::ClientSecretBasic),
        "client_secret_post" => Some(TokenEndpointAuthMethod::ClientSecretPost),
        _ => None,
    }
}

fn is_allowed_grant(grant: &str) -> bool {
    matches!(
        grant,
        "authorization_code"
            | "implicit"
            | "password"
            | "client_credentials"
            | "refresh_token"
            | "urn:ietf:params:oauth:grant-type:jwt-bearer"
            | "urn:ietf:params:oauth:grant-type:saml2-bearer"
    )
}

fn is_allowed_response_type(response_type: &str) -> bool {
    matches!(response_type, "code" | "token")
}

fn is_valid_redirect_uri(redirect_uri: &str) -> bool {
    if redirect_uri.trim().is_empty() {
        return false;
    }
    let Ok(url) = url::Url::parse(redirect_uri) else {
        return false;
    };
    if url.has_authority() && url.fragment().is_none() {
        match url.scheme() {
            "https" => true,
            "http" => url
                .host_str()
                .is_some_and(|host| matches!(host, "localhost" | "127.0.0.1" | "::1")),
            _ => false,
        }
    } else {
        false
    }
}

fn string_value(body: &Value, field: &str) -> Option<String> {
    body.get(field).and_then(Value::as_str).map(str::to_owned)
}

fn nullable_string(body: &Value, field: &str) -> DbValue {
    string_value(body, field)
        .map(DbValue::String)
        .unwrap_or(DbValue::Null)
}

fn metadata_string(body: &Value) -> DbValue {
    body.get("metadata")
        .map(|metadata| DbValue::String(metadata.to_string()))
        .unwrap_or(DbValue::Null)
}

fn string_array(body: &Value, field: &str) -> Option<Vec<String>> {
    body.get(field)?.as_array().map(|values| {
        values
            .iter()
            .filter_map(Value::as_str)
            .map(str::to_owned)
            .collect()
    })
}