openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use http::{header, Method, StatusCode};
use openauth_core::api::{create_auth_endpoint, AsyncAuthEndpoint, AuthEndpointOptions};
use openauth_core::db::{DbValue, FindOne, Where};
use serde_json::{json, Value};
use time::OffsetDateTime;

use super::claims::user_claims;
use super::shared::{
    adapter, find_user, json_response, oauth_error, optional_timestamp, required_string,
    OAUTH_TOKEN_MODEL,
};
use super::McpAdditionalIdTokenClaims;

pub fn userinfo_endpoint(
    additional_claims: Option<McpAdditionalIdTokenClaims>,
) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/mcp/userinfo",
        Method::GET,
        AuthEndpointOptions::new().operation_id("mcpUserInfo"),
        move |context, request| {
            let additional_claims = additional_claims.clone();
            Box::pin(async move {
                let Some(token) = request
                    .headers()
                    .get(header::AUTHORIZATION)
                    .and_then(|value| value.to_str().ok())
                    .and_then(|value| value.strip_prefix("Bearer "))
                    .map(str::to_owned)
                else {
                    return oauth_error(
                        StatusCode::UNAUTHORIZED,
                        "invalid_request",
                        "authorization header not found",
                    );
                };
                let adapter = adapter(context)?;
                let Some(record) = adapter
                    .find_one(
                        FindOne::new(OAUTH_TOKEN_MODEL)
                            .where_clause(Where::new("accessToken", DbValue::String(token))),
                    )
                    .await?
                else {
                    return oauth_error(
                        StatusCode::UNAUTHORIZED,
                        "invalid_token",
                        "invalid access token",
                    );
                };
                if optional_timestamp(&record, "accessTokenExpiresAt")?
                    .is_some_and(|expires_at| expires_at <= OffsetDateTime::now_utc())
                {
                    return oauth_error(
                        StatusCode::UNAUTHORIZED,
                        "invalid_token",
                        "The Access Token expired",
                    );
                }
                let user_id = required_string(&record, "userId")?;
                let Some(user) = find_user(adapter.as_ref(), &user_id).await? else {
                    return oauth_error(
                        StatusCode::UNAUTHORIZED,
                        "invalid_token",
                        "user not found",
                    );
                };
                let scopes = required_string(&record, "scopes")?
                    .split_whitespace()
                    .map(str::to_owned)
                    .collect::<Vec<_>>();
                let mut claims = user_claims(&user, &scopes);
                if let Some(callback) = additional_claims {
                    for (key, value) in callback(&user, &scopes)? {
                        claims.insert(key, value);
                    }
                }
                json_response(StatusCode::OK, &Value::Object(claims))
            })
        },
    )
}

pub fn jwks_endpoint() -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/mcp/jwks",
        Method::GET,
        AuthEndpointOptions::new().operation_id("mcpJwks"),
        move |_context, _request| {
            Box::pin(async move {
                let mut response = json_response(StatusCode::OK, &json!({ "keys": [] }))?;
                response.headers_mut().insert(
                    header::CACHE_CONTROL,
                    http::HeaderValue::from_static("no-store"),
                );
                Ok(response)
            })
        },
    )
}