openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use std::sync::Arc;

use http::{header, Method, StatusCode};
use openauth_core::api::{create_auth_endpoint, ApiRequest, ApiResponse, AuthEndpointOptions};
use openauth_core::auth::session::{GetSessionInput, SessionAuth};
use openauth_core::context::request_state::{set_current_session, set_current_session_user};
use openauth_core::context::AuthContext;
use openauth_core::error::OpenAuthError;
use serde::Deserialize;
use serde::Serialize;
use serde_json::{json, Value};

use super::claims::JwtClaims;
use super::keys::{public_jwk_value, Jwks};
use super::{adapter, keys, sign, verify, JwtOptions, JwtSessionContext};

pub(crate) fn jwks_endpoint(options: Arc<JwtOptions>) -> openauth_core::api::AsyncAuthEndpoint {
    let path = options.jwks.jwks_path.clone();
    create_auth_endpoint(
        path,
        Method::GET,
        AuthEndpointOptions::new().operation_id("getJSONWebKeySet"),
        move |context, _request| {
            let options = Arc::clone(&options);
            Box::pin(async move {
                if options.jwks.remote_url.is_some() {
                    return json_response(StatusCode::NOT_FOUND, &json!({"code": "NOT_FOUND"}));
                }
                let jwks = get_or_create_public_jwks(context, &options).await?;
                json_response(StatusCode::OK, &jwks)
            })
        },
    )
}

pub(crate) fn token_endpoint(options: Arc<JwtOptions>) -> openauth_core::api::AsyncAuthEndpoint {
    create_auth_endpoint(
        "/token",
        Method::GET,
        AuthEndpointOptions::new().operation_id("getJSONWebToken"),
        move |context, request| {
            let options = Arc::clone(&options);
            Box::pin(async move {
                let Some((session, user)) = session_user(context, &request).await? else {
                    return json_response(
                        StatusCode::UNAUTHORIZED,
                        &json!({"code": "UNAUTHORIZED"}),
                    );
                };
                let token = sign_session_token(context, &options, session, user).await?;
                json_response(StatusCode::OK, &json!({ "token": token }))
            })
        },
    )
}

pub(crate) fn sign_jwt_endpoint(options: Arc<JwtOptions>) -> openauth_core::api::AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sign-jwt",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("signJWT")
            .server_only(),
        move |context, request| {
            let options = Arc::clone(&options);
            Box::pin(async move {
                let body: SignJwtBody = parse_json(&request)?;
                let token = sign::sign_jwt_with_options(context, body.payload, &options).await?;
                json_response(StatusCode::OK, &json!({ "token": token }))
            })
        },
    )
}

pub(crate) fn verify_jwt_endpoint(
    options: Arc<JwtOptions>,
) -> openauth_core::api::AsyncAuthEndpoint {
    create_auth_endpoint(
        "/verify-jwt",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("verifyJWT")
            .server_only(),
        move |context, request| {
            let options = Arc::clone(&options);
            Box::pin(async move {
                let body: VerifyJwtBody = parse_json(&request)?;
                let payload = verify::verify_jwt_with_options(
                    context,
                    &body.token,
                    &options,
                    body.issuer.as_deref(),
                )
                .await?;
                json_response(StatusCode::OK, &json!({ "payload": payload }))
            })
        },
    )
}

async fn get_or_create_public_jwks(
    context: &AuthContext,
    options: &JwtOptions,
) -> Result<Jwks, OpenAuthError> {
    let mut keys = adapter::get_all_keys(context, options).await?;
    if keys.is_empty() {
        let key = super::crypto::encrypt_private_key(
            context,
            keys::generate_jwk(options)?,
            options.jwks.disable_private_key_encryption,
        )?;
        adapter::create_jwk(context, options, key).await?;
        keys = adapter::get_all_keys(context, options).await?;
    }
    let now = time::OffsetDateTime::now_utc();
    let grace = time::Duration::seconds(options.jwks.grace_period);
    let public = keys
        .into_iter()
        .filter(|key| match key.expires_at {
            Some(expires_at) => expires_at + grace > now,
            None => true,
        })
        .map(|key| public_jwk_value(&key, options))
        .collect::<Result<Vec<_>, _>>()?;
    Ok(Jwks { keys: public })
}

pub(crate) async fn sign_session_token(
    context: &AuthContext,
    options: &JwtOptions,
    session: openauth_core::db::Session,
    user: openauth_core::db::User,
) -> Result<String, OpenAuthError> {
    let session_context = JwtSessionContext {
        session,
        user: user.clone(),
    };
    let mut claims = if let Some(define_payload) = &options.jwt.define_payload {
        define_payload(&session_context).await?
    } else {
        let Value::Object(map) =
            serde_json::to_value(&user).map_err(|error| OpenAuthError::Api(error.to_string()))?
        else {
            return Err(OpenAuthError::Api(
                "user must serialize to object".to_owned(),
            ));
        };
        map
    };
    let subject = if let Some(get_subject) = &options.jwt.get_subject {
        get_subject(&session_context).await?
    } else {
        user.id
    };
    claims.insert("sub".to_owned(), Value::String(subject));
    sign::sign_jwt_with_options(context, claims, options).await
}

async fn session_user(
    context: &AuthContext,
    request: &ApiRequest,
) -> Result<Option<(openauth_core::db::Session, openauth_core::db::User)>, OpenAuthError> {
    let Some(adapter) = context.adapter() else {
        return Ok(None);
    };
    let Some(result) = SessionAuth::new(adapter.as_ref(), context)
        .get_session(GetSessionInput::new(
            cookie_header(request).unwrap_or_default(),
        ))
        .await?
    else {
        return Ok(None);
    };
    let Some(session) = result.session else {
        return Ok(None);
    };
    let Some(user) = result.user else {
        return Ok(None);
    };
    if openauth_core::context::request_state::has_request_state() {
        set_current_session(session.clone(), user.clone())?;
        set_current_session_user(
            serde_json::to_value(&user).map_err(|error| OpenAuthError::Api(error.to_string()))?,
        )?;
    }
    Ok(Some((session, user)))
}

fn cookie_header(request: &ApiRequest) -> Option<String> {
    request
        .headers()
        .get(header::COOKIE)
        .and_then(|value| value.to_str().ok())
        .map(str::to_owned)
}

fn parse_json<T>(request: &ApiRequest) -> Result<T, OpenAuthError>
where
    T: for<'de> Deserialize<'de>,
{
    serde_json::from_slice(request.body()).map_err(|error| OpenAuthError::Api(error.to_string()))
}

fn json_response<T>(status: StatusCode, body: &T) -> Result<ApiResponse, OpenAuthError>
where
    T: Serialize,
{
    let body = serde_json::to_vec(body).map_err(|error| OpenAuthError::Api(error.to_string()))?;
    http::Response::builder()
        .status(status)
        .header(header::CONTENT_TYPE, "application/json")
        .body(body)
        .map_err(|error| OpenAuthError::Api(error.to_string()))
}

#[derive(Debug, Deserialize)]
struct SignJwtBody {
    payload: JwtClaims,
}

#[derive(Debug, Deserialize)]
struct VerifyJwtBody {
    token: String,
    issuer: Option<String>,
}