rustauth-plugins 0.2.0

Official RustAuth plugin modules.
Documentation
use http::Method;
use rustauth_core::api::{
    create_auth_endpoint, redirect_response, session_cookies, ApiRequest, AsyncAuthEndpoint,
    AuthEndpointOptions, OpenApiOperation,
};
use rustauth_core::auth::oauth::{
    handle_oauth_user_info, parse_oauth_state_with_input, HandleOAuthUserInfoInput,
    OAuthStateParseInput,
};
use rustauth_core::error::RustAuthError;
use rustauth_core::options::OAuthStateStoreStrategy;
use time::OffsetDateTime;

use super::options::OAuthProxyOptions;
use super::payload::PassthroughPayload;
use super::utils::{decrypt, is_trusted_callback_url, query_param, redirect_error};

pub(crate) fn oauth_proxy_callback_endpoint(options: OAuthProxyOptions) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/oauth-proxy-callback",
        Method::GET,
        AuthEndpointOptions::new()
            .operation_id("oauthProxyCallback")
            .openapi(
                OpenApiOperation::new("oauthProxyCallback").description("OAuth Proxy Callback"),
            ),
        move |context, request| {
            let options = options.clone();
            async move { handle_callback(&context, request, &options).await }
        },
    )
}

async fn handle_callback(
    context: &rustauth_core::context::AuthContext,
    request: ApiRequest,
    options: &OAuthProxyOptions,
) -> Result<rustauth_core::api::ApiResponse, RustAuthError> {
    let default_error_url = format!("{}/error", context.base_url.trim_end_matches('/'));
    let Some(callback_url) = query_param(&request, "callbackURL") else {
        return redirect_error(&default_error_url, "missing_callback_url");
    };
    if !is_trusted_callback_url(context, &request, &callback_url)? {
        return redirect_error(&default_error_url, "invalid_callback_url");
    }
    let Some(encrypted_profile) = query_param(&request, "profile") else {
        return redirect_error(&default_error_url, "missing_profile");
    };
    let decrypted = match decrypt(context, options, &encrypted_profile) {
        Ok(value) => value,
        Err(_) => return redirect_error(&default_error_url, "invalid_profile"),
    };
    let payload = match serde_json::from_str::<PassthroughPayload>(&decrypted) {
        Ok(value) if value.has_required_fields() => value,
        _ => return redirect_error(&default_error_url, "invalid_payload"),
    };
    let error_url = payload.error_url.as_deref().unwrap_or(&default_error_url);
    let age = OffsetDateTime::now_utc().unix_timestamp() - payload.timestamp;
    if age > options.max_age.whole_seconds() || age < -10 {
        return redirect_error(error_url, "payload_expired");
    }
    let adapter = match context.require_adapter() {
        Ok(adapter) => adapter,
        Err(_) => return redirect_error(error_url, "user_creation_failed"),
    };
    if !context.options.account.skip_state_cookie_check {
        match context.options.account.store_state_strategy {
            OAuthStateStoreStrategy::Cookie => {
                if parse_oauth_state_with_input(
                    context,
                    None,
                    OAuthStateParseInput {
                        state: &payload.state,
                        oauth_state: payload.oauth_state.as_deref(),
                        skip_state_cookie_check: false,
                    },
                )
                .await
                .is_err()
                {
                    return redirect_error(error_url, "invalid_state");
                }
            }
            OAuthStateStoreStrategy::Database => {
                if payload.oauth_state.as_deref().map_or(true, str::is_empty) {
                    return redirect_error(error_url, "invalid_state");
                }
            }
        }
    }
    let trusted_provider = is_trusted_provider(context, &payload.account.provider_id);
    let result = handle_oauth_user_info(
        context,
        adapter.as_ref(),
        HandleOAuthUserInfoInput {
            user_info: payload.user_info,
            account: payload.account,
            callback_url: Some(payload.callback_url.clone()),
            disable_sign_up: payload.disable_sign_up,
            override_user_info: false,
            is_trusted_provider: trusted_provider,
            require_trusted_provider_for_implicit_link: false,
        },
    )
    .await?;
    let Some(data) = result.data else {
        return redirect_error(error_url, "user_creation_failed");
    };
    let cookies = session_cookies(context, &data.session, &data.user, false)?;
    let final_url = if result.is_register {
        payload
            .new_user_url
            .as_deref()
            .unwrap_or(&payload.callback_url)
    } else {
        &payload.callback_url
    };
    redirect_response(final_url, cookies)
}

fn is_trusted_provider(context: &rustauth_core::context::AuthContext, provider_id: &str) -> bool {
    context
        .options
        .account
        .account_linking
        .trusted_providers
        .iter()
        .any(|trusted| trusted == provider_id)
}