rustauth-plugins 0.2.0

Official RustAuth plugin modules.
Documentation
use std::sync::Arc;

use http::{header, Method, StatusCode};
use rustauth_core::api::{create_auth_endpoint, ApiRequest, ApiResponse, AuthEndpointOptions};
use rustauth_core::auth::session::{GetSessionInput, SessionAuth};
use rustauth_core::context::request_state::{set_current_session, set_current_session_user};
use rustauth_core::context::AuthContext;
use rustauth_core::error::RustAuthError;
use serde::Deserialize;
use serde::Serialize;
use serde_json::{json, Value};

use super::claims::JwtClaims;
use super::keys::{public_jwk_value, Jwks};
use super::{
    adapter, keys, sign, verify, JwkAlgorithm, JwtJwksOptions, JwtOptions, JwtSessionContext,
    JwtSigningOptions, TimeInput,
};

pub(crate) fn jwks_endpoint(options: Arc<JwtOptions>) -> rustauth_core::api::AsyncAuthEndpoint {
    let path = options.jwks.jwks_path.clone();
    create_auth_endpoint(
        path,
        Method::GET,
        AuthEndpointOptions::new().operation_id("getJSONWebKeySet"),
        move |context, _request| {
            let options = Arc::clone(&options);
            async move {
                if options.jwks.remote_url.is_some() {
                    return json_response(StatusCode::NOT_FOUND, &json!({"code": "NOT_FOUND"}));
                }
                let jwks = get_or_create_public_jwks(&context, &options).await?;
                json_response(StatusCode::OK, &jwks)
            }
        },
    )
}

pub(crate) fn token_endpoint(options: Arc<JwtOptions>) -> rustauth_core::api::AsyncAuthEndpoint {
    create_auth_endpoint(
        "/token",
        Method::GET,
        AuthEndpointOptions::new().operation_id("getJSONWebToken"),
        move |context, request| {
            let options = Arc::clone(&options);
            async move {
                let Some((session, user)) = session_user(&context, &request).await? else {
                    return json_response(
                        StatusCode::UNAUTHORIZED,
                        &json!({"code": "UNAUTHORIZED"}),
                    );
                };
                let token = sign_session_token(&context, &options, session, user).await?;
                json_response(StatusCode::OK, &json!({ "token": token }))
            }
        },
    )
}

pub(crate) fn sign_jwt_endpoint(options: Arc<JwtOptions>) -> rustauth_core::api::AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sign-jwt",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("signJWT")
            .server_only(),
        move |context, request| {
            let options = Arc::clone(&options);
            async move {
                let body: SignJwtBody = parse_json(&request)?;
                let mut effective_options = options.as_ref().clone();
                if let Some(overrides) = body.override_options {
                    overrides.apply(&mut effective_options);
                    effective_options.validate()?;
                }
                let token =
                    sign::sign_jwt_with_options(&context, body.payload, &effective_options).await?;
                json_response(StatusCode::OK, &json!({ "token": token }))
            }
        },
    )
}

pub(crate) fn verify_jwt_endpoint(
    options: Arc<JwtOptions>,
) -> rustauth_core::api::AsyncAuthEndpoint {
    create_auth_endpoint(
        "/verify-jwt",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("verifyJWT")
            .server_only(),
        move |context, request| {
            let options = Arc::clone(&options);
            async move {
                let body: VerifyJwtBody = parse_json(&request)?;
                let payload = verify::verify_jwt_with_options(
                    &context,
                    &body.token,
                    &options,
                    body.issuer.as_deref(),
                )
                .await?;
                json_response(StatusCode::OK, &json!({ "payload": payload }))
            }
        },
    )
}

async fn get_or_create_public_jwks(
    context: &AuthContext,
    options: &JwtOptions,
) -> Result<Jwks, RustAuthError> {
    let mut keys = adapter::get_all_keys(context, options).await?;
    if keys.is_empty() {
        let key = super::crypto::encrypt_private_key(
            context,
            keys::generate_jwk(options)?,
            options.jwks.disable_private_key_encryption,
        )?;
        adapter::create_jwk(context, options, key).await?;
        keys = adapter::get_all_keys(context, options).await?;
    }
    let now = time::OffsetDateTime::now_utc();
    let grace = time::Duration::seconds(options.jwks.grace_period);
    let public = keys
        .into_iter()
        .filter(|key| match key.expires_at {
            Some(expires_at) => expires_at + grace > now,
            None => true,
        })
        .map(|key| public_jwk_value(&key, options))
        .collect::<Result<Vec<_>, _>>()?;
    Ok(Jwks { keys: public })
}

pub(crate) async fn sign_session_token(
    context: &AuthContext,
    options: &JwtOptions,
    session: rustauth_core::db::Session,
    user: rustauth_core::db::User,
) -> Result<String, RustAuthError> {
    let session_context = JwtSessionContext {
        session,
        user: user.clone(),
    };
    let mut claims = if let Some(define_payload) = &options.jwt.define_payload {
        define_payload(&session_context).await?
    } else {
        let Value::Object(map) =
            serde_json::to_value(&user).map_err(|error| RustAuthError::Api(error.to_string()))?
        else {
            return Err(RustAuthError::Api(
                "user must serialize to object".to_owned(),
            ));
        };
        map
    };
    let subject = if let Some(get_subject) = &options.jwt.get_subject {
        get_subject(&session_context).await?
    } else {
        user.id
    };
    claims.insert("sub".to_owned(), Value::String(subject));
    sign::sign_jwt_with_options(context, claims, options).await
}

async fn session_user(
    context: &AuthContext,
    request: &ApiRequest,
) -> Result<Option<(rustauth_core::db::Session, rustauth_core::db::User)>, RustAuthError> {
    let Some(_adapter) = context.adapter() else {
        return Ok(None);
    };
    let Some(result) = SessionAuth::new(context)?
        .get_session(GetSessionInput::new(
            cookie_header(request).unwrap_or_default(),
        ))
        .await?
    else {
        return Ok(None);
    };
    let Some(session) = result.session else {
        return Ok(None);
    };
    let Some(user) = result.user else {
        return Ok(None);
    };
    if rustauth_core::context::request_state::has_request_state() {
        set_current_session(session.clone(), user.clone())?;
        set_current_session_user(
            serde_json::to_value(&user).map_err(|error| RustAuthError::Api(error.to_string()))?,
        )?;
    }
    Ok(Some((session, user)))
}

fn cookie_header(request: &ApiRequest) -> Option<String> {
    request
        .headers()
        .get(header::COOKIE)
        .and_then(|value| value.to_str().ok())
        .map(str::to_owned)
}

fn parse_json<T>(request: &ApiRequest) -> Result<T, RustAuthError>
where
    T: for<'de> Deserialize<'de>,
{
    serde_json::from_slice(request.body()).map_err(|error| RustAuthError::Api(error.to_string()))
}

fn json_response<T>(status: StatusCode, body: &T) -> Result<ApiResponse, RustAuthError>
where
    T: Serialize,
{
    let body = serde_json::to_vec(body).map_err(|error| RustAuthError::Api(error.to_string()))?;
    http::Response::builder()
        .status(status)
        .header(header::CONTENT_TYPE, "application/json")
        .body(body)
        .map_err(|error| RustAuthError::Api(error.to_string()))
}

#[derive(Debug, Deserialize)]
struct SignJwtBody {
    payload: JwtClaims,
    #[serde(default, rename = "overrideOptions")]
    override_options: Option<JwtEndpointOverrideOptions>,
}

#[derive(Debug, Deserialize)]
struct VerifyJwtBody {
    token: String,
    issuer: Option<String>,
}

#[derive(Debug, Deserialize)]
struct JwtEndpointOverrideOptions {
    #[serde(default)]
    jwt: Option<JwtEndpointSigningOverride>,
    #[serde(default)]
    jwks: Option<JwtEndpointJwksOverride>,
}

impl JwtEndpointOverrideOptions {
    fn apply(self, options: &mut JwtOptions) {
        if let Some(jwt) = self.jwt {
            jwt.apply(&mut options.jwt);
        }
        if let Some(jwks) = self.jwks {
            jwks.apply(&mut options.jwks);
        }
    }
}

#[derive(Debug, Deserialize)]
struct JwtEndpointSigningOverride {
    #[serde(default)]
    issuer: Option<String>,
    #[serde(default)]
    audience: Option<AudienceOverride>,
    #[serde(default, rename = "expirationTime")]
    expiration_time: Option<TimeInputOverride>,
}

impl JwtEndpointSigningOverride {
    fn apply(self, options: &mut JwtSigningOptions) {
        if let Some(issuer) = self.issuer {
            options.issuer = Some(issuer);
        }
        if let Some(audience) = self.audience {
            options.audience = Some(audience.into_vec());
        }
        if let Some(expiration_time) = self.expiration_time {
            options.expiration_time = Some(expiration_time.into_time_input());
        }
    }
}

#[derive(Debug, Deserialize)]
struct JwtEndpointJwksOverride {
    #[serde(default, rename = "remoteUrl")]
    remote_url: Option<String>,
    #[serde(default, rename = "keyPairConfig")]
    key_pair_config: Option<KeyPairConfigOverride>,
    #[serde(default, rename = "disablePrivateKeyEncryption")]
    disable_private_key_encryption: Option<bool>,
    #[serde(default, rename = "rotationInterval")]
    rotation_interval: Option<i64>,
    #[serde(default, rename = "gracePeriod")]
    grace_period: Option<i64>,
    #[serde(default, rename = "jwksPath")]
    jwks_path: Option<String>,
}

impl JwtEndpointJwksOverride {
    fn apply(self, options: &mut JwtJwksOptions) {
        if let Some(remote_url) = self.remote_url {
            options.remote_url = Some(remote_url);
        }
        if let Some(key_pair_config) = self.key_pair_config {
            if let Some(algorithm) = key_pair_config.alg {
                options.key_pair_algorithm = Some(algorithm);
            }
            if let Some(modulus_length) = key_pair_config.modulus_length {
                options.rsa_modulus_length = Some(modulus_length);
            }
        }
        if let Some(disable_private_key_encryption) = self.disable_private_key_encryption {
            options.disable_private_key_encryption = disable_private_key_encryption;
        }
        if let Some(rotation_interval) = self.rotation_interval {
            options.rotation_interval = Some(rotation_interval);
        }
        if let Some(grace_period) = self.grace_period {
            options.grace_period = grace_period;
        }
        if let Some(jwks_path) = self.jwks_path {
            options.jwks_path = jwks_path;
        }
    }
}

#[derive(Debug, Deserialize)]
struct KeyPairConfigOverride {
    #[serde(default)]
    alg: Option<JwkAlgorithm>,
    #[serde(default, rename = "modulusLength")]
    modulus_length: Option<u32>,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum AudienceOverride {
    One(String),
    Many(Vec<String>),
}

impl AudienceOverride {
    fn into_vec(self) -> Vec<String> {
        match self {
            Self::One(value) => vec![value],
            Self::Many(values) => values,
        }
    }
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum TimeInputOverride {
    Number(i64),
    String(String),
}

impl TimeInputOverride {
    fn into_time_input(self) -> TimeInput {
        match self {
            Self::Number(value) => TimeInput::UnixTimestamp(value),
            Self::String(value) => TimeInput::Duration(value),
        }
    }
}