rustauth-sso 0.3.0

Single sign-on support for RustAuth.
Documentation
use std::sync::{Arc, OnceLock};

use http::Method;
use rustauth_core::api::{
    create_auth_endpoint, parse_request_body, AsyncAuthEndpoint, AuthEndpointOptions,
    OpenApiOperation,
};
use rustauth_core::crypto::random::generate_random_string;
use serde_json::json;
use time::{Duration, OffsetDateTime};

use crate::audit;
use crate::openapi::{
    domain_verification_token_response, error_code_response, provider_id_body_schema,
    success_response,
};
use crate::options::{SsoAuditEvent, SsoAuditEventKind, SsoAuditSeverity, SsoOptions};
use crate::org::can_verify_provider_domain;
use crate::state::SsoStateStore;
use crate::store::SsoProviderStore;
use crate::utils;

use super::support::{authenticated_user, unauthorized, ProviderIdBody};

const DNS_LABEL_MAX_LENGTH: usize = 63;

static DNS_RESOLVER: OnceLock<Result<hickory_resolver::TokioResolver, String>> = OnceLock::new();

pub(super) fn request_endpoint(options: Arc<SsoOptions>) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sso/request-domain-verification",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("requestDomainVerification")
            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
            .body_schema(provider_id_body_schema())
            .openapi(
                OpenApiOperation::new("requestDomainVerification")
                    .tag("SSO")
                    .response(
                        "201",
                        domain_verification_token_response("Domain verification token"),
                    )
                    .response("404", error_code_response("Provider not found")),
            ),
        {
            let options = Arc::clone(&options);
            move |context, request| {
                let options = Arc::clone(&options);
                async move {
                    let Some((adapter, user_id)) = authenticated_user(&context, &request).await?
                    else {
                        return unauthorized();
                    };
                    let body = parse_request_body::<ProviderIdBody>(&request)?;
                    let Some(provider) =
                        SsoProviderStore::new_with_options(adapter.as_ref(), &options)
                            .find_by_provider_id(&body.provider_id)
                            .await?
                    else {
                        return utils::json(
                            http::StatusCode::NOT_FOUND,
                            &json!({"code": "PROVIDER_NOT_FOUND"}),
                        );
                    };
                    if !can_verify_provider_domain(&context, adapter.as_ref(), &user_id, &provider)
                        .await?
                    {
                        return utils::json(
                            http::StatusCode::FORBIDDEN,
                            &json!({"code": "FORBIDDEN"}),
                        );
                    }
                    if provider.domain_verified.unwrap_or(false) {
                        return utils::json(
                            http::StatusCode::CONFLICT,
                            &json!({"code": "DOMAIN_VERIFIED"}),
                        );
                    }
                    let identifier = verification_identifier(&options, &provider.provider_id);
                    let state_store = SsoStateStore::new(&context, adapter.as_ref());
                    if let Some(active) = state_store.find(&identifier).await? {
                        audit::emit(
                            &context,
                            &options,
                            SsoAuditEvent::new(
                                SsoAuditEventKind::DomainVerificationRequested,
                                SsoAuditSeverity::Info,
                            )
                            .provider_id(provider.provider_id.clone())
                            .user_id(user_id),
                        )
                        .await;
                        return utils::json(
                            http::StatusCode::CREATED,
                            &json!({"domainVerificationToken": active.value}),
                        );
                    }
                    let token = generate_random_string(24);
                    state_store
                        .create(
                            identifier,
                            token.clone(),
                            OffsetDateTime::now_utc()
                                + Duration::seconds(
                                    options.domain_verification.token_ttl_seconds as i64,
                                ),
                        )
                        .await?;
                    audit::emit(
                        &context,
                        &options,
                        SsoAuditEvent::new(
                            SsoAuditEventKind::DomainVerificationRequested,
                            SsoAuditSeverity::Info,
                        )
                        .provider_id(provider.provider_id.clone())
                        .user_id(user_id),
                    )
                    .await;
                    utils::json(
                        http::StatusCode::CREATED,
                        &json!({"domainVerificationToken": token}),
                    )
                }
            }
        },
    )
}

pub(super) fn verify_endpoint(options: Arc<SsoOptions>) -> AsyncAuthEndpoint {
    create_auth_endpoint(
        "/sso/verify-domain",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("verifyDomain")
            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
            .body_schema(provider_id_body_schema())
            .openapi(
                OpenApiOperation::new("verifyDomain")
                    .tag("SSO")
                    .response("200", success_response("Domain verified"))
                    .response("404", error_code_response("Provider or token not found"))
                    .response("409", error_code_response("Domain already verified"))
                    .response("502", error_code_response("DNS verification failed")),
            ),
        {
            let options = Arc::clone(&options);
            move |context, request| {
                let options = Arc::clone(&options);
                async move {
                    let Some((adapter, user_id)) = authenticated_user(&context, &request).await?
                    else {
                        return unauthorized();
                    };
                    let body = parse_request_body::<ProviderIdBody>(&request)?;
                    let store = SsoProviderStore::new_with_options(adapter.as_ref(), &options);
                    let Some(provider) = store.find_by_provider_id(&body.provider_id).await? else {
                        return utils::json(
                            http::StatusCode::NOT_FOUND,
                            &json!({"code": "PROVIDER_NOT_FOUND"}),
                        );
                    };
                    if !can_verify_provider_domain(&context, adapter.as_ref(), &user_id, &provider)
                        .await?
                    {
                        return utils::json(
                            http::StatusCode::FORBIDDEN,
                            &json!({"code": "FORBIDDEN"}),
                        );
                    }
                    if provider.domain_verified.unwrap_or(false) {
                        return utils::json(
                            http::StatusCode::CONFLICT,
                            &json!({"code": "DOMAIN_VERIFIED"}),
                        );
                    }

                    let identifier = verification_identifier(&options, &provider.provider_id);
                    if identifier.len() > DNS_LABEL_MAX_LENGTH {
                        return utils::json(
                            http::StatusCode::BAD_REQUEST,
                            &json!({"code": "IDENTIFIER_TOO_LONG", "message": "Verification identifier exceeds the DNS label limit"}),
                        );
                    }

                    let state_store = SsoStateStore::new(&context, adapter.as_ref());
                    let Some(active) = state_store.find(&identifier).await? else {
                        return utils::json(
                            http::StatusCode::NOT_FOUND,
                            &json!({"code": "NO_PENDING_VERIFICATION"}),
                        );
                    };
                    let Some(hostname) = verification_hostname(&provider.domain) else {
                        return utils::json(
                            http::StatusCode::BAD_REQUEST,
                            &json!({"code": "INVALID_DOMAIN", "message": "Invalid domain"}),
                        );
                    };

                    let expected = format!("{}={}", active.identifier, active.value);
                    let records =
                        match resolve_txt_records(&options, &format!("{identifier}.{hostname}"))
                            .await
                        {
                            Ok(records) => records,
                            Err(_) => {
                                audit::emit(
                                    &context,
                                    &options,
                                    SsoAuditEvent::new(
                                        SsoAuditEventKind::DomainVerificationFailed,
                                        SsoAuditSeverity::Warn,
                                    )
                                    .provider_id(provider.provider_id.clone())
                                    .user_id(user_id.clone())
                                    .reason("resolver_error"),
                                )
                                .await;
                                return domain_verification_failed("resolver_error");
                            }
                        };
                    if records.is_empty() {
                        audit::emit(
                            &context,
                            &options,
                            SsoAuditEvent::new(
                                SsoAuditEventKind::DomainVerificationFailed,
                                SsoAuditSeverity::Warn,
                            )
                            .provider_id(provider.provider_id.clone())
                            .user_id(user_id.clone())
                            .reason("no_txt_records"),
                        )
                        .await;
                        return domain_verification_failed("no_txt_records");
                    }
                    if !records
                        .iter()
                        .any(|record| utils::constant_time_eq(record.trim(), &expected))
                    {
                        audit::emit(
                            &context,
                            &options,
                            SsoAuditEvent::new(
                                SsoAuditEventKind::DomainVerificationFailed,
                                SsoAuditSeverity::Warn,
                            )
                            .provider_id(provider.provider_id.clone())
                            .user_id(user_id.clone())
                            .reason("txt_value_mismatch"),
                        )
                        .await;
                        return domain_verification_failed("txt_value_mismatch");
                    }

                    store
                        .update_domain_verified(&provider.provider_id, true)
                        .await?;
                    audit::emit(
                        &context,
                        &options,
                        SsoAuditEvent::new(
                            SsoAuditEventKind::DomainVerificationSucceeded,
                            SsoAuditSeverity::Info,
                        )
                        .provider_id(provider.provider_id.clone())
                        .user_id(user_id),
                    )
                    .await;
                    http::Response::builder()
                        .status(http::StatusCode::NO_CONTENT)
                        .body(Vec::new())
                        .map_err(|error| {
                            rustauth_core::error::RustAuthError::Api(error.to_string())
                        })
                }
            }
        },
    )
}

pub(super) fn verification_identifier(options: &SsoOptions, provider_id: &str) -> String {
    format!(
        "_{}-{}",
        options.domain_verification.token_prefix, provider_id
    )
}

fn domain_verification_failed(
    reason: &'static str,
) -> Result<http::Response<Vec<u8>>, rustauth_core::error::RustAuthError> {
    utils::json(
        http::StatusCode::BAD_GATEWAY,
        &json!({
            "code": "DOMAIN_VERIFICATION_FAILED",
            "message": "Unable to verify domain ownership. Try again later",
            "reason": reason,
        }),
    )
}

fn verification_hostname(domain: &str) -> Option<String> {
    let first = domain.split(',').next()?.trim();
    let without_scheme = first
        .trim_start_matches("http://")
        .trim_start_matches("https://");
    let host = without_scheme
        .split('/')
        .next()?
        .trim()
        .trim_end_matches('.');
    (!host.is_empty()).then(|| host.to_ascii_lowercase())
}

async fn resolve_txt_records(
    options: &SsoOptions,
    name: &str,
) -> Result<Vec<String>, rustauth_core::error::RustAuthError> {
    if let Some(resolver) = &options.domain_verification.txt_resolver {
        return resolver.resolve(name).await;
    }
    let resolver = DNS_RESOLVER
        .get_or_init(|| {
            build_dns_resolver().map_err(|error| format!("failed to build DNS resolver: {error}"))
        })
        .as_ref()
        .map_err(|error| rustauth_core::error::RustAuthError::Api(error.clone()))?;
    let lookup = resolver.txt_lookup(name).await.map_err(|error| {
        rustauth_core::error::RustAuthError::Api(format!("DNS TXT lookup failed: {error}"))
    })?;
    Ok(lookup
        .answers()
        .iter()
        .filter_map(|record| match &record.data {
            hickory_resolver::proto::rr::RData::TXT(txt) => Some(
                txt.txt_data
                    .iter()
                    .map(|chunk| String::from_utf8_lossy(chunk).into_owned())
                    .collect::<String>(),
            ),
            _ => None,
        })
        .collect())
}

fn build_dns_resolver() -> Result<hickory_resolver::TokioResolver, Box<dyn std::error::Error>> {
    hickory_resolver::Resolver::builder_tokio()
        .map_err(|error| Box::new(error) as Box<dyn std::error::Error>)?
        .build()
        .map_err(|error| Box::new(error) as Box<dyn std::error::Error>)
}