openauth-core 0.0.6

Core types and primitives for OpenAuth.
Documentation
mod flow;
mod support;

use http::Method;
use serde_json::Value;
use std::sync::Arc;

use super::shared::{sensitive_session, unauthorized};
use crate::api::{
    create_auth_endpoint, parse_request_body, AsyncAuthEndpoint, AuthEndpointOptions,
    OpenApiOperation,
};
use crate::auth::oauth::{generate_oauth_state, OAuthStateInput, OAuthStateLink};
use crate::db::DbAdapter;
use openauth_oauth::oauth2::SocialAuthorizationUrlRequest;

use flow::{
    callback_get, callback_post_redirect, link_with_id_token, lookup_provider,
    sign_in_with_id_token,
};
use support::{
    link_social_body_schema, redirect_json_response, redirect_uri, social_sign_in_body_schema,
    LinkSocialBody, SocialSignInBody,
};

pub(super) fn sign_in_social_endpoint(adapter: Arc<dyn DbAdapter>) -> AsyncAuthEndpoint {
    sign_in_oauth_endpoint(
        "/sign-in/social",
        "socialSignIn",
        "Sign in with a social provider",
        adapter,
    )
}

pub(super) fn sign_in_oauth2_endpoint(adapter: Arc<dyn DbAdapter>) -> AsyncAuthEndpoint {
    sign_in_oauth_endpoint(
        "/sign-in/oauth2",
        "oauth2SignIn",
        "Sign in with an OAuth2 provider",
        adapter,
    )
}

fn sign_in_oauth_endpoint(
    path: &'static str,
    operation_id: &'static str,
    description: &'static str,
    adapter: Arc<dyn DbAdapter>,
) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        path,
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id(operation_id)
            .allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
            .body_schema(social_sign_in_body_schema())
            .openapi(OpenApiOperation::new(operation_id).description(description)),
        move |context, request| {
            let adapter = Arc::clone(&adapter);
            Box::pin(async move {
                let body: SocialSignInBody = parse_request_body(&request)?;
                let provider = lookup_provider(context, &body.provider)?;
                if let Some(id_token) = body.id_token {
                    return sign_in_with_id_token(context, adapter.as_ref(), provider, id_token)
                        .await;
                }
                let state = generate_oauth_state(
                    context,
                    Some(adapter.as_ref()),
                    OAuthStateInput {
                        callback_url: body.callback_url.unwrap_or_else(|| "/".to_owned()),
                        error_url: body.error_callback_url,
                        new_user_url: body.new_user_callback_url,
                        request_sign_up: body.request_sign_up,
                        additional_data: body.additional_data.unwrap_or(Value::Null),
                        ..OAuthStateInput::default()
                    },
                )
                .await?;
                let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
                    state: state.state,
                    redirect_uri: redirect_uri(context, &request, provider.id()),
                    code_verifier: Some(state.data.code_verifier),
                    scopes: body.scopes,
                    login_hint: body.login_hint,
                })?;
                redirect_json_response(url.to_string(), !body.disable_redirect)
            })
        },
    )
}

pub(super) fn callback_oauth_endpoint(
    method: Method,
    adapter: Arc<dyn DbAdapter>,
) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/callback/:id",
        method,
        AuthEndpointOptions::new()
            .operation_id("handleOAuthCallback")
            .openapi(
                OpenApiOperation::new("handleOAuthCallback").description("Handle OAuth callback"),
            ),
        move |context, request| {
            let adapter = Arc::clone(&adapter);
            Box::pin(async move {
                if request.method() == Method::POST {
                    return callback_post_redirect(context, &request);
                }
                callback_get(context, adapter.as_ref(), request).await
            })
        },
    )
}

pub(super) fn link_social_endpoint(adapter: Arc<dyn DbAdapter>) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/link-social",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("linkSocialAccount")
            .allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
            .body_schema(link_social_body_schema())
            .openapi(
                OpenApiOperation::new("linkSocialAccount").description("Link a social account"),
            ),
        move |context, request| {
            let adapter = Arc::clone(&adapter);
            Box::pin(async move {
                let Some((_session, user, _cookies)) =
                    sensitive_session(adapter.as_ref(), context, &request).await?
                else {
                    return unauthorized();
                };
                let body: LinkSocialBody = parse_request_body(&request)?;
                let provider = lookup_provider(context, &body.provider)?;
                if let Some(id_token) = body.id_token {
                    return link_with_id_token(
                        context,
                        adapter.as_ref(),
                        provider,
                        &user,
                        id_token,
                    )
                    .await;
                }
                let state = generate_oauth_state(
                    context,
                    Some(adapter.as_ref()),
                    OAuthStateInput {
                        callback_url: body.callback_url.unwrap_or_else(|| "/".to_owned()),
                        error_url: body.error_callback_url,
                        link: Some(OAuthStateLink {
                            user_id: user.id,
                            email: user.email,
                        }),
                        request_sign_up: body.request_sign_up,
                        additional_data: body.additional_data.unwrap_or(Value::Null),
                        ..OAuthStateInput::default()
                    },
                )
                .await?;
                let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
                    state: state.state,
                    redirect_uri: redirect_uri(context, &request, provider.id()),
                    code_verifier: Some(state.data.code_verifier),
                    scopes: body.scopes,
                    login_hint: None,
                })?;
                redirect_json_response(url.to_string(), !body.disable_redirect)
            })
        },
    )
}