rustauth-sso 0.3.0

Single sign-on support for RustAuth.
Documentation
use rustauth_core::api::{ApiRequest, ApiResponse};
#[cfg(feature = "saml")]
use rustauth_core::auth::session::{GetSessionInput, SessionAuth};
use rustauth_core::context::request_state::current_new_session;
use rustauth_core::context::AuthContext;
#[cfg(feature = "saml")]
use rustauth_core::db::DbAdapter;
use rustauth_core::error::RustAuthError;
use rustauth_core::plugin::PluginAfterHookAction;
#[cfg(feature = "saml")]
use rustauth_core::plugin::PluginBeforeHookAction;
use std::sync::Arc;

use crate::linking_impl::assign_organization_by_domain_with_model;
use crate::options::SsoOptions;
#[cfg(feature = "saml")]
use crate::saml_impl::state::{saml_session_by_id_key, SESSION_PREFIX};
#[cfg(feature = "saml")]
use crate::state::SsoStateStore;

#[cfg(feature = "saml")]
#[derive(Debug, Clone)]
struct SignOutSamlSession {
    session_id: String,
}

#[cfg(feature = "saml")]
pub(crate) async fn capture_sign_out_session(
    context: &AuthContext,
    mut request: ApiRequest,
) -> Result<PluginBeforeHookAction, RustAuthError> {
    let Some(_adapter) = context.adapter.as_deref() else {
        return Ok(PluginBeforeHookAction::Continue(request));
    };
    let cookie_header = request
        .headers()
        .get(http::header::COOKIE)
        .and_then(|value| value.to_str().ok())
        .unwrap_or_default()
        .to_owned();
    let Some(session_result) = SessionAuth::new(context)?
        .get_session(GetSessionInput::new(cookie_header).disable_refresh())
        .await?
    else {
        return Ok(PluginBeforeHookAction::Continue(request));
    };
    if let Some(session) = session_result.session {
        request.extensions_mut().insert(SignOutSamlSession {
            session_id: session.id,
        });
    }
    Ok(PluginBeforeHookAction::Continue(request))
}

#[cfg(feature = "saml")]
pub(crate) async fn cleanup_sign_out_session(
    context: &AuthContext,
    request: &ApiRequest,
    response: ApiResponse,
) -> Result<PluginAfterHookAction, RustAuthError> {
    if response.status().is_success() {
        if let (Some(adapter), Some(session)) = (
            context.adapter.as_deref(),
            request.extensions().get::<SignOutSamlSession>(),
        ) {
            clear_saml_session_lookup_state(context, adapter, &session.session_id).await?;
        }
    }
    Ok(PluginAfterHookAction::Continue(response))
}

pub(crate) async fn assign_domain_organization_after_auth(
    context: &AuthContext,
    _request: &ApiRequest,
    response: ApiResponse,
    options: Arc<SsoOptions>,
) -> Result<PluginAfterHookAction, RustAuthError> {
    if !response.status().is_success() {
        return Ok(PluginAfterHookAction::Continue(response));
    }
    let Some(adapter) = context.adapter.as_deref() else {
        return Ok(PluginAfterHookAction::Continue(response));
    };
    let Some(new_session) = current_new_session()? else {
        return Ok(PluginAfterHookAction::Continue(response));
    };
    assign_organization_by_domain_with_model(
        context,
        adapter,
        &options.model_name,
        &options.organization_provisioning,
        &options.domain_verification,
        &new_session.user,
    )
    .await?;
    Ok(PluginAfterHookAction::Continue(response))
}

#[cfg(feature = "saml")]
pub(crate) async fn clear_saml_session_lookup_state(
    context: &AuthContext,
    adapter: &dyn DbAdapter,
    session_id: &str,
) -> Result<(), RustAuthError> {
    let state_store = SsoStateStore::new(context, adapter);
    let by_id_identifier = saml_session_by_id_key(session_id);
    let Some(by_id) = state_store.find(&by_id_identifier).await? else {
        return Ok(());
    };
    if by_id.value.starts_with(SESSION_PREFIX) {
        state_store.delete(&by_id.value).await?;
    }
    state_store.delete(&by_id_identifier).await
}