openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use std::sync::Arc;

use http::{header, HeaderValue, Method, StatusCode};
use openauth_core::api::{
    create_auth_endpoint, parse_request_body, ApiErrorResponse, ApiRequest, ApiResponse,
    AsyncAuthEndpoint, AuthEndpointOptions, BodyField, BodySchema, JsonSchemaType,
    OpenApiOperation,
};
use openauth_core::auth::session::{GetSessionInput, SessionAuth};
use openauth_core::context::AuthContext;
use openauth_core::cookies::{set_session_cookie, Cookie, CookieOptions, SessionCookieOptions};
use openauth_core::crypto::random::generate_random_string;
use openauth_core::db::{DbAdapter, Session, User};
use openauth_core::error::OpenAuthError;
use openauth_core::session::DbSessionStore;
use openauth_core::user::DbUserStore;
use openauth_core::verification::{CreateVerificationInput, DbVerificationStore};
use serde::{Deserialize, Serialize};
use serde_json::json;
use time::{Duration, OffsetDateTime};

use super::hashing::default_key_hasher;
use super::options::{OneTimeTokenOptions, OneTimeTokenSession, StoreToken};

#[derive(Debug, Serialize)]
struct GenerateResponse {
    token: String,
}

#[derive(Debug, Deserialize)]
struct VerifyBody {
    token: String,
}

#[derive(Debug, Serialize)]
struct VerifyResponse {
    session: Session,
    user: User,
}

pub fn generate_endpoint(options: OneTimeTokenOptions) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/one-time-token/generate",
        Method::GET,
        AuthEndpointOptions::new()
            .operation_id("generateOneTimeToken")
            .openapi(
                OpenApiOperation::new("generateOneTimeToken")
                    .description("Generate a one-time token for the current session")
                    .response("200", generate_openapi_response()),
            ),
        move |context, request| {
            let options = options.clone();
            Box::pin(async move {
                if options.disable_client_request {
                    return error_response(
                        StatusCode::BAD_REQUEST,
                        "BAD_REQUEST",
                        "Client requests are disabled",
                    );
                }
                let adapter = required_adapter(context)?;
                let Some((session, user, _cookies)) =
                    current_session(adapter.as_ref(), context, &request).await?
                else {
                    return error_response(
                        StatusCode::UNAUTHORIZED,
                        "UNAUTHORIZED",
                        "Unauthorized",
                    );
                };
                let token = generate_and_store_token(
                    adapter.as_ref(),
                    context,
                    &OneTimeTokenSession { session, user },
                    &options,
                )
                .await?;
                json_response(StatusCode::OK, &GenerateResponse { token }, Vec::new())
            })
        },
    )
}

pub fn verify_endpoint(options: OneTimeTokenOptions) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/one-time-token/verify",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("verifyOneTimeToken")
            .allowed_media_types(["application/json"])
            .body_schema(verify_body_schema())
            .openapi(
                OpenApiOperation::new("verifyOneTimeToken")
                    .description("Verify a one-time token and return its session")
                    .response("200", verify_openapi_response()),
            ),
        move |context, request| {
            let options = options.clone();
            Box::pin(async move {
                let body: VerifyBody = parse_request_body(&request)?;
                let adapter = required_adapter(context)?;
                let stored_token = stored_token(context, &body.token, &options)?;
                let identifier = token_identifier(&stored_token);
                let verification_store = DbVerificationStore::new(adapter.as_ref());
                let Some(verification) = verification_store
                    .find_verification_including_expired(&identifier)
                    .await?
                else {
                    return error_response(StatusCode::BAD_REQUEST, "BAD_REQUEST", "Invalid token");
                };
                verification_store.delete_verification(&identifier).await?;
                if verification.expires_at <= OffsetDateTime::now_utc() {
                    return error_response(StatusCode::BAD_REQUEST, "BAD_REQUEST", "Token expired");
                }

                let session_store = DbSessionStore::new(adapter.as_ref());
                let Some(session) = session_store
                    .find_session_including_expired(&verification.value)
                    .await?
                else {
                    return error_response(
                        StatusCode::BAD_REQUEST,
                        "BAD_REQUEST",
                        "Session not found",
                    );
                };
                if session.expires_at <= OffsetDateTime::now_utc() {
                    return error_response(
                        StatusCode::BAD_REQUEST,
                        "BAD_REQUEST",
                        "Session expired",
                    );
                }
                let Some(user) = DbUserStore::new(adapter.as_ref())
                    .find_user_by_id(&session.user_id)
                    .await?
                else {
                    return error_response(
                        StatusCode::BAD_REQUEST,
                        "BAD_REQUEST",
                        "Session not found",
                    );
                };

                let cookies = if options.disable_set_session_cookie {
                    Vec::new()
                } else {
                    set_session_cookie(
                        &context.auth_cookies,
                        &context.secret,
                        &session.token,
                        SessionCookieOptions {
                            dont_remember: false,
                            overrides: CookieOptions::default(),
                        },
                    )?
                };
                json_response(StatusCode::OK, &VerifyResponse { session, user }, cookies)
            })
        },
    )
}

pub async fn generate_and_store_token(
    adapter: &dyn DbAdapter,
    context: &AuthContext,
    session: &OneTimeTokenSession,
    options: &OneTimeTokenOptions,
) -> Result<String, OpenAuthError> {
    let token = match &options.generate_token {
        Some(generate) => generate(session, context)?,
        None => generate_random_string(32),
    };
    let stored_token = stored_token(context, &token, options)?;
    let expires_at = OffsetDateTime::now_utc() + Duration::minutes(options.expires_in as i64);
    DbVerificationStore::new(adapter)
        .create_verification(CreateVerificationInput::new(
            token_identifier(&stored_token),
            session.session.token.clone(),
            expires_at,
        ))
        .await?;
    Ok(token)
}

async fn current_session(
    adapter: &dyn DbAdapter,
    context: &AuthContext,
    request: &ApiRequest,
) -> Result<Option<(Session, User, Vec<Cookie>)>, OpenAuthError> {
    let cookie_header = request
        .headers()
        .get(header::COOKIE)
        .and_then(|value| value.to_str().ok())
        .unwrap_or_default();
    let Some(result) = SessionAuth::new(adapter, context)
        .get_session(GetSessionInput::new(cookie_header))
        .await?
    else {
        return Ok(None);
    };
    let Some(session) = result.session else {
        return Ok(None);
    };
    let Some(user) = result.user else {
        return Ok(None);
    };
    Ok(Some((session, user, result.cookies)))
}

fn required_adapter(context: &AuthContext) -> Result<Arc<dyn DbAdapter>, OpenAuthError> {
    context.adapter().ok_or_else(|| {
        OpenAuthError::InvalidConfig("one-time-token plugin requires a database adapter".to_owned())
    })
}

fn stored_token(
    _context: &AuthContext,
    token: &str,
    options: &OneTimeTokenOptions,
) -> Result<String, OpenAuthError> {
    match &options.store_token {
        StoreToken::Plain => Ok(token.to_owned()),
        StoreToken::Hashed => Ok(default_key_hasher(token)),
        StoreToken::Custom(hash) => hash(token),
    }
}

fn token_identifier(stored_token: &str) -> String {
    format!("one-time-token:{stored_token}")
}

fn verify_body_schema() -> BodySchema {
    BodySchema::object([
        BodyField::new("token", JsonSchemaType::String).description("The token to verify")
    ])
}

fn generate_openapi_response() -> serde_json::Value {
    json!({
        "description": "One-time token generated",
        "content": {
            "application/json": {
                "schema": {
                    "type": "object",
                    "properties": {
                        "token": { "type": "string" }
                    },
                    "required": ["token"]
                }
            }
        }
    })
}

fn verify_openapi_response() -> serde_json::Value {
    json!({
        "description": "One-time token verified",
        "content": {
            "application/json": {
                "schema": {
                    "type": "object",
                    "properties": {
                        "session": { "$ref": "#/components/schemas/Session" },
                        "user": { "$ref": "#/components/schemas/User" }
                    },
                    "required": ["session", "user"]
                }
            }
        }
    })
}

fn json_response<T>(
    status: StatusCode,
    body: &T,
    cookies: Vec<Cookie>,
) -> Result<ApiResponse, OpenAuthError>
where
    T: Serialize,
{
    let body = serde_json::to_vec(body).map_err(|error| OpenAuthError::Api(error.to_string()))?;
    let mut response = http::Response::builder()
        .status(status)
        .header(header::CONTENT_TYPE, "application/json")
        .body(body)
        .map_err(|error| OpenAuthError::Api(error.to_string()))?;
    for cookie in cookies {
        response.headers_mut().append(
            header::SET_COOKIE,
            HeaderValue::from_str(&serialize_cookie(&cookie))
                .map_err(|error| OpenAuthError::Cookie(error.to_string()))?,
        );
    }
    Ok(response)
}

fn error_response(
    status: StatusCode,
    code: &str,
    message: &str,
) -> Result<ApiResponse, OpenAuthError> {
    json_response(
        status,
        &ApiErrorResponse {
            code: code.to_owned(),
            message: message.to_owned(),
            original_message: None,
        },
        Vec::new(),
    )
}

fn serialize_cookie(cookie: &Cookie) -> String {
    let mut parts = vec![format!("{}={}", cookie.name, cookie.value)];
    if let Some(max_age) = cookie.attributes.max_age {
        parts.push(format!("Max-Age={max_age}"));
    }
    if let Some(expires) = &cookie.attributes.expires {
        parts.push(format!("Expires={expires}"));
    }
    if let Some(domain) = &cookie.attributes.domain {
        parts.push(format!("Domain={domain}"));
    }
    if let Some(path) = &cookie.attributes.path {
        parts.push(format!("Path={path}"));
    }
    if cookie.attributes.secure == Some(true) {
        parts.push("Secure".to_owned());
    }
    if cookie.attributes.http_only == Some(true) {
        parts.push("HttpOnly".to_owned());
    }
    if let Some(same_site) = &cookie.attributes.same_site {
        parts.push(format!("SameSite={same_site}"));
    }
    if cookie.attributes.partitioned == Some(true) {
        parts.push("Partitioned".to_owned());
    }
    parts.join("; ")
}