openauth-plugins 0.0.3

Official OpenAuth plugin modules.
Documentation
use http::{Method, StatusCode};
use openauth_core::api::{
    create_auth_endpoint, parse_request_body, ApiRequest, ApiResponse, AsyncAuthEndpoint,
    AuthEndpointOptions,
};
use openauth_core::auth::oauth::{
    generate_oauth_state, handle_oauth_user_info, parse_oauth_state, HandleOAuthUserInfoInput,
    OAuthStateInput, OAuthStateLink,
};
use openauth_core::context::AuthContext;
use openauth_core::cookies::{set_session_cookie, SessionCookieOptions};
use openauth_core::error::OpenAuthError;
use openauth_oauth::oauth2::{
    SocialAuthorizationCodeRequest, SocialAuthorizationUrlRequest, SocialOAuthProvider,
};
use serde::Deserialize;
use serde_json::Value;

use super::account::{link_account, link_error_code, normalize_user_info, oauth_account};
use super::config::{GenericOAuthFlow, GenericOAuthOptions};
use super::discovery::DiscoveryCache;
use super::errors;
use super::provider::GenericOAuthProvider;
use super::route_http::{
    api_error, link_schema, redirect, redirect_with_error, redirect_with_error_description,
    sign_in_schema,
};
use super::route_support::*;

#[derive(Debug, Deserialize)]
struct SignInOAuth2Body {
    #[serde(alias = "providerId")]
    provider_id: String,
    #[serde(default, alias = "callbackURL")]
    callback_url: Option<String>,
    #[serde(default, alias = "errorCallbackURL")]
    error_callback_url: Option<String>,
    #[serde(default, alias = "newUserCallbackURL")]
    new_user_callback_url: Option<String>,
    #[serde(default, alias = "disableRedirect")]
    disable_redirect: bool,
    #[serde(default)]
    scopes: Vec<String>,
    #[serde(default, alias = "requestSignUp")]
    request_sign_up: bool,
    #[serde(default, alias = "additionalData")]
    additional_data: Option<Value>,
}

#[derive(Debug, Deserialize)]
struct LinkOAuth2Body {
    #[serde(alias = "providerId")]
    provider_id: String,
    #[serde(alias = "callbackURL")]
    callback_url: String,
    #[serde(default, alias = "errorCallbackURL")]
    error_callback_url: Option<String>,
    #[serde(default)]
    scopes: Vec<String>,
}

pub fn sign_in_oauth2_endpoint(
    options: GenericOAuthOptions,
    discovery_cache: DiscoveryCache,
) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sign-in/oauth2",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("signInWithOAuth2")
            .allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
            .body_schema(sign_in_schema()),
        move |context, request| {
            let options = options.clone();
            let discovery_cache = discovery_cache.clone();
            Box::pin(async move {
                let adapter = adapter(context)?;
                let body: SignInOAuth2Body = parse_request_body(&request)?;
                if options.find(&body.provider_id).is_none() {
                    return api_error(
                        StatusCode::BAD_REQUEST,
                        errors::PROVIDER_CONFIG_NOT_FOUND,
                        "No config found for provider",
                    );
                }
                let mut config =
                    match resolved_config(&options, &discovery_cache, &body.provider_id).await {
                        Ok(config) => config,
                        Err(error) => return config_error_response(error),
                    };
                let redirect_uri = callback_redirect_uri(context, &config);
                resolve_authorization_url_params(
                    &mut config,
                    GenericOAuthFlow::SignIn,
                    redirect_uri.clone(),
                )
                .await?;
                let state = generate_oauth_state(
                    context,
                    Some(adapter.as_ref()),
                    OAuthStateInput {
                        callback_url: body.callback_url.unwrap_or_else(|| "/".to_owned()),
                        error_url: body.error_callback_url,
                        new_user_url: body.new_user_callback_url,
                        request_sign_up: body.request_sign_up,
                        additional_data: body.additional_data.unwrap_or(Value::Null),
                        ..OAuthStateInput::default()
                    },
                )
                .await?;
                let provider =
                    GenericOAuthProvider::with_discovery_cache(config.clone(), discovery_cache);
                let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
                    state: state.state,
                    redirect_uri,
                    code_verifier: Some(state.data.code_verifier),
                    scopes: body.scopes,
                    login_hint: None,
                })?;
                redirect_json_response(url.to_string(), !body.disable_redirect)
            })
        },
    )
}

pub fn oauth2_callback_endpoint(
    options: GenericOAuthOptions,
    discovery_cache: DiscoveryCache,
) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/oauth2/callback/:providerId",
        Method::GET,
        AuthEndpointOptions::new().operation_id("oAuth2Callback"),
        move |context, request| {
            let options = options.clone();
            let discovery_cache = discovery_cache.clone();
            Box::pin(
                async move { callback_get(context, &options, &discovery_cache, request).await },
            )
        },
    )
}

pub fn oauth2_link_endpoint(
    options: GenericOAuthOptions,
    discovery_cache: DiscoveryCache,
) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/oauth2/link",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("oAuth2LinkAccount")
            .allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
            .body_schema(link_schema()),
        move |context, request| {
            let options = options.clone();
            let discovery_cache = discovery_cache.clone();
            Box::pin(async move {
                let adapter = adapter(context)?;
                let Some((_session, user)) =
                    current_session(context, adapter.as_ref(), &request).await?
                else {
                    return api_error(
                        StatusCode::UNAUTHORIZED,
                        errors::SESSION_REQUIRED,
                        "Session is required",
                    );
                };
                let body: LinkOAuth2Body = parse_request_body(&request)?;
                if options.find(&body.provider_id).is_none() {
                    return api_error(
                        StatusCode::NOT_FOUND,
                        errors::PROVIDER_CONFIG_NOT_FOUND,
                        "No config found for provider",
                    );
                }
                let mut config =
                    match resolved_config(&options, &discovery_cache, &body.provider_id).await {
                        Ok(config) => config,
                        Err(error) => return config_error_response(error),
                    };
                let redirect_uri = callback_redirect_uri(context, &config);
                resolve_authorization_url_params(
                    &mut config,
                    GenericOAuthFlow::Link,
                    redirect_uri.clone(),
                )
                .await?;
                let state = generate_oauth_state(
                    context,
                    Some(adapter.as_ref()),
                    OAuthStateInput {
                        callback_url: body.callback_url,
                        error_url: body.error_callback_url,
                        link: Some(OAuthStateLink {
                            user_id: user.id,
                            email: user.email,
                        }),
                        ..OAuthStateInput::default()
                    },
                )
                .await?;
                let provider =
                    GenericOAuthProvider::with_discovery_cache(config.clone(), discovery_cache);
                let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
                    state: state.state,
                    redirect_uri,
                    code_verifier: Some(state.data.code_verifier),
                    scopes: body.scopes,
                    login_hint: None,
                })?;
                redirect_json_response(url.to_string(), true)
            })
        },
    )
}

async fn callback_get(
    context: &AuthContext,
    options: &GenericOAuthOptions,
    discovery_cache: &DiscoveryCache,
    request: ApiRequest,
) -> Result<ApiResponse, OpenAuthError> {
    let adapter = adapter(context)?;
    let provider_id = path_param(&request, "providerId")?;
    let mut config = resolved_config(options, discovery_cache, provider_id).await?;
    if let Some(error) = query_param(&request, "error") {
        return redirect_with_error_description(
            &default_error_url(context),
            &error,
            query_param(&request, "error_description").as_deref(),
        );
    }
    let Some(code) = query_param(&request, "code") else {
        return redirect_with_error(&default_error_url(context), "oAuth_code_missing");
    };
    let Some(state) = query_param(&request, "state") else {
        return redirect_with_error(&default_error_url(context), "invalid_state");
    };
    let state_data = match parse_oauth_state(context, Some(adapter.as_ref()), &state).await {
        Ok(data) => data,
        Err(_) => return redirect_with_error(&default_error_url(context), "invalid_state"),
    };
    let error_url = state_data
        .error_url
        .clone()
        .unwrap_or_else(|| default_error_url(context));
    if let Some(error) = issuer_error(&config, query_param(&request, "iss").as_deref()) {
        return redirect_with_error(&error_url, error);
    }
    let redirect_uri = callback_redirect_uri(context, &config);
    if resolve_token_url_params(
        &mut config,
        GenericOAuthFlow::Callback,
        redirect_uri.clone(),
    )
    .await
    .is_err()
    {
        return redirect_with_error(&error_url, "oauth_code_verification_failed");
    }
    let provider =
        GenericOAuthProvider::with_discovery_cache(config.clone(), discovery_cache.clone());
    let tokens = match provider
        .validate_authorization_code(SocialAuthorizationCodeRequest {
            code,
            code_verifier: Some(state_data.code_verifier),
            redirect_uri,
            device_id: query_param(&request, "device_id"),
        })
        .await
    {
        Ok(tokens) => tokens,
        Err(_) => return redirect_with_error(&error_url, "oauth_code_verification_failed"),
    };
    let Some(user_info) = provider.get_user_info(tokens.clone(), None).await? else {
        return redirect_with_error(&error_url, "user_info_is_missing");
    };
    if let Some(link) = state_data.link {
        if let Err(error) = link_account(
            context,
            adapter.as_ref(),
            &config,
            &link,
            &user_info,
            &tokens,
        )
        .await
        {
            return redirect_with_error(&error_url, link_error_code(&error));
        }
        return redirect(&state_data.callback_url, Vec::new());
    }
    let user_info = normalize_user_info(&user_info)?;
    let result = handle_oauth_user_info(
        context,
        adapter.as_ref(),
        HandleOAuthUserInfoInput {
            account: oauth_account(context, &config.provider_id, &user_info.id, &tokens)?,
            user_info,
            callback_url: Some(state_data.callback_url.clone()),
            disable_sign_up: (config.disable_implicit_sign_up && !state_data.request_sign_up)
                || config.disable_sign_up,
            override_user_info: config.override_user_info,
            is_trusted_provider: true,
        },
    )
    .await?;
    let Some(data) = result.data else {
        return redirect_with_error(
            &error_url,
            result
                .error
                .map(oauth_user_info_error)
                .unwrap_or("oauth_sign_in_failed"),
        );
    };
    let mut cookies = set_session_cookie(
        &context.auth_cookies,
        &context.secret,
        &data.session.token,
        SessionCookieOptions::default(),
    )?;
    cookies.extend(result.cookies);
    let target = if result.is_register {
        state_data
            .new_user_url
            .as_deref()
            .unwrap_or(&state_data.callback_url)
    } else {
        &state_data.callback_url
    };
    redirect(target, cookies)
}