rustauth-sso 0.3.0

Single sign-on support for RustAuth.
Documentation
use std::sync::Arc;

use http::Method;
use rustauth_core::api::{
    create_auth_endpoint, parse_request_body, AsyncAuthEndpoint, AuthEndpointOptions,
    OpenApiOperation,
};
use serde_json::json;

use crate::audit;
use crate::openapi::{
    error_code_response, provider_id_body_schema, provider_id_query_parameter,
    sso_provider_list_response, sso_provider_response, success_response,
};
use crate::options::{SsoAuditEvent, SsoAuditEventKind, SsoAuditSeverity, SsoOptions};
use crate::org::{accessible_providers, can_manage_provider};
use crate::store::SsoProviderStore;
use crate::utils;

use super::support::{
    authenticated_user, invalid_provider_id, query_param, unauthorized, valid_provider_id,
    ProviderIdBody,
};

pub(super) fn list_endpoint(options: Arc<SsoOptions>) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sso/providers",
        Method::GET,
        AuthEndpointOptions::new()
            .operation_id("listSSOProviders")
            .openapi(
                OpenApiOperation::new("listSSOProviders")
                    .tag("SSO")
                    .response("200", sso_provider_list_response()),
            ),
        {
            let options = Arc::clone(&options);
            move |context, request| {
                let options = Arc::clone(&options);
                async move {
                    let Some((adapter, user_id)) = authenticated_user(&context, &request).await?
                    else {
                        return unauthorized();
                    };
                    let providers = accessible_providers(
                        &context,
                        adapter.as_ref(),
                        &user_id,
                        SsoProviderStore::new_with_options(adapter.as_ref(), &options)
                            .list()
                            .await?,
                    )
                    .await?
                    .into_iter()
                    .map(|provider| {
                        provider.sanitized_with_options(&context.base_url, Some(&options))
                    })
                    .collect::<Vec<_>>();
                    utils::json(http::StatusCode::OK, &json!({ "providers": providers }))
                }
            }
        },
    )
}

pub(super) fn get_endpoint(options: Arc<SsoOptions>) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sso/get-provider",
        Method::GET,
        AuthEndpointOptions::new()
            .operation_id("getSSOProvider")
            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
            .openapi(
                OpenApiOperation::new("getSSOProvider")
                    .tag("SSO")
                    .parameter(provider_id_query_parameter())
                    .response("200", sso_provider_response("SSO provider"))
                    .response("400", error_code_response("Invalid provider id"))
                    .response("404", error_code_response("Provider not found")),
            ),
        {
            let options = Arc::clone(&options);
            move |context, request| {
                let options = Arc::clone(&options);
                async move {
                    let Some((adapter, user_id)) = authenticated_user(&context, &request).await?
                    else {
                        return unauthorized();
                    };
                    let provider_id = match query_param(&request, "providerId")
                        .filter(|value| !value.is_empty())
                    {
                        Some(provider_id) => provider_id,
                        None => parse_request_body::<ProviderIdBody>(&request)?.provider_id,
                    };
                    if !valid_provider_id(&provider_id) {
                        return invalid_provider_id();
                    }
                    let Some(provider) =
                        SsoProviderStore::new_with_options(adapter.as_ref(), &options)
                            .find_by_provider_id(&provider_id)
                            .await?
                    else {
                        return utils::json(
                            http::StatusCode::NOT_FOUND,
                            &json!({"code": "PROVIDER_NOT_FOUND"}),
                        );
                    };
                    if !can_manage_provider(&context, adapter.as_ref(), &user_id, &provider).await?
                    {
                        return utils::json(
                            http::StatusCode::FORBIDDEN,
                            &json!({"code": "FORBIDDEN"}),
                        );
                    }
                    utils::json(
                        http::StatusCode::OK,
                        &provider.sanitized_with_options(&context.base_url, Some(&options)),
                    )
                }
            }
        },
    )
}

pub(super) fn delete_endpoint(options: Arc<SsoOptions>) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sso/delete-provider",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("deleteSSOProvider")
            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
            .body_schema(provider_id_body_schema())
            .openapi(
                OpenApiOperation::new("deleteSSOProvider")
                    .tag("SSO")
                    .response("200", success_response("Provider deleted")),
            ),
        {
            let options = Arc::clone(&options);
            move |context, request| {
                let options = Arc::clone(&options);
                async move {
                    let Some((adapter, user_id)) = authenticated_user(&context, &request).await?
                    else {
                        return unauthorized();
                    };
                    let body = parse_request_body::<ProviderIdBody>(&request)?;
                    if !valid_provider_id(&body.provider_id) {
                        return invalid_provider_id();
                    }
                    let store = SsoProviderStore::new_with_options(adapter.as_ref(), &options);
                    let Some(provider) = store.find_by_provider_id(&body.provider_id).await? else {
                        return utils::json(
                            http::StatusCode::NOT_FOUND,
                            &json!({"code": "PROVIDER_NOT_FOUND"}),
                        );
                    };
                    if !can_manage_provider(&context, adapter.as_ref(), &user_id, &provider).await?
                    {
                        return utils::json(
                            http::StatusCode::FORBIDDEN,
                            &json!({"code": "FORBIDDEN"}),
                        );
                    }
                    store.delete(&body.provider_id).await?;
                    let mut event = SsoAuditEvent::new(
                        SsoAuditEventKind::ProviderDeleted,
                        SsoAuditSeverity::Warn,
                    )
                    .provider_id(provider.provider_id.clone())
                    .user_id(user_id);
                    if let Some(organization_id) = provider.organization_id {
                        event = event.organization_id(organization_id);
                    }
                    audit::emit(&context, &options, event).await;
                    utils::json(http::StatusCode::OK, &json!({"success": true}))
                }
            }
        },
    )
}