rustauth-core 0.2.0

Core types and primitives for RustAuth.
Documentation
use std::collections::HashSet;

use http::Method;

use crate::context::AuthContext;
use crate::error::RustAuthError;
use crate::plugin::{PluginAfterHookAction, PluginBeforeHookAction, PluginRequestAction};
use crate::plugin::{PluginPasswordValidationInput, PluginPasswordValidationRejection};

use crate::rate_limit::on_response_rate_limit;

use super::endpoint::{ApiRequest, ApiResponse, AsyncAuthEndpoint, AuthEndpoint};
use super::path::path_matches;
use super::session_request_state::ensure_session_user_in_request_state;

pub(super) fn run_on_request_plugins(
    context: &AuthContext,
    mut request: ApiRequest,
) -> Result<PluginRequestAction, RustAuthError> {
    for plugin in &context.plugins {
        if let Some(hook) = &plugin.on_request {
            match hook(context, request)? {
                PluginRequestAction::Continue(next_request) => request = next_request,
                PluginRequestAction::Respond(response) => {
                    return Ok(PluginRequestAction::Respond(response));
                }
            }
        }
    }
    Ok(PluginRequestAction::Continue(request))
}

pub(super) fn run_matching_middlewares(
    context: &AuthContext,
    request: &ApiRequest,
    path: &str,
) -> Result<Option<ApiResponse>, RustAuthError> {
    for plugin in &context.plugins {
        for middleware in &plugin.middlewares {
            if path_matches(&middleware.path, path) {
                if let Some(response) = (middleware.handler)(context, request)? {
                    return Ok(Some(response));
                }
            }
        }
    }
    Ok(None)
}

pub(super) async fn run_matching_async_middlewares(
    context: &AuthContext,
    request: &ApiRequest,
    path: &str,
) -> Result<Option<ApiResponse>, RustAuthError> {
    for plugin in &context.plugins {
        for middleware in &plugin.async_middlewares {
            if path_matches(&middleware.path, path) {
                if let Some(response) = (middleware.handler)(context, request).await? {
                    return Ok(Some(response));
                }
            }
        }
    }
    Ok(None)
}

pub(super) fn run_on_response_plugins(
    context: &AuthContext,
    request: &ApiRequest,
    mut response: ApiResponse,
) -> Result<ApiResponse, RustAuthError> {
    for plugin in &context.plugins {
        if let Some(hook) = &plugin.on_response {
            response = hook(context, request, response)?;
        }
    }
    Ok(response)
}

pub(super) async fn run_on_response_async_plugins(
    context: &AuthContext,
    request: &ApiRequest,
    response: &ApiResponse,
) -> Result<(), RustAuthError> {
    for plugin in &context.plugins {
        if let Some(hook) = &plugin.on_response_async {
            hook(context, request, response).await?;
        }
    }
    Ok(())
}

/// Apply rate-limit bookkeeping and plugin `on_response` hooks before returning.
pub(super) fn finalize_response(
    context: &AuthContext,
    request: &ApiRequest,
    response: ApiResponse,
) -> Result<ApiResponse, RustAuthError> {
    on_response_rate_limit(context, request)?;
    run_on_response_plugins(context, request, response)
}

/// Async finalize: hydrate session user, run async response hooks, then sync finalize.
pub(super) async fn finalize_response_async(
    context: &AuthContext,
    request: &ApiRequest,
    response: ApiResponse,
) -> Result<ApiResponse, RustAuthError> {
    ensure_session_user_in_request_state(context, request).await?;
    run_on_response_async_plugins(context, request, &response).await?;
    finalize_response(context, request, response)
}

pub(super) async fn run_password_validators(
    context: &AuthContext,
    path: &str,
    password: &str,
) -> Result<(), PluginPasswordValidationRejection> {
    for plugin in &context.plugins {
        for validator in &plugin.password_validators {
            (validator.handler)(context, PluginPasswordValidationInput::new(path, password))
                .await?;
        }
    }
    Ok(())
}

pub(super) fn run_before_hooks(
    context: &AuthContext,
    mut request: ApiRequest,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<PluginBeforeHookAction, RustAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.before {
            if hook.matcher.matches(method, path, operation_id) {
                match (hook.handler)(context, request)? {
                    PluginBeforeHookAction::Continue(next_request) => request = next_request,
                    PluginBeforeHookAction::Respond(response) => {
                        return Ok(PluginBeforeHookAction::Respond(response));
                    }
                }
            }
        }
    }
    Ok(PluginBeforeHookAction::Continue(request))
}

pub(super) async fn run_async_before_hooks(
    context: &AuthContext,
    mut request: ApiRequest,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<PluginBeforeHookAction, RustAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.async_before {
            if hook.matcher.matches(method, path, operation_id) {
                match (hook.handler)(context, request).await? {
                    PluginBeforeHookAction::Continue(next_request) => request = next_request,
                    PluginBeforeHookAction::Respond(response) => {
                        return Ok(PluginBeforeHookAction::Respond(response));
                    }
                }
            }
        }
    }
    Ok(PluginBeforeHookAction::Continue(request))
}

pub(super) fn run_after_hooks(
    context: &AuthContext,
    request: &ApiRequest,
    mut response: ApiResponse,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<ApiResponse, RustAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.after {
            if hook.matcher.matches(method, path, operation_id) {
                let PluginAfterHookAction::Continue(next_response) =
                    (hook.handler)(context, request, response)?;
                response = next_response;
            }
        }
    }
    Ok(response)
}

pub(super) async fn run_async_after_hooks(
    context: &AuthContext,
    request: &ApiRequest,
    mut response: ApiResponse,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<ApiResponse, RustAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.async_after {
            if hook.matcher.matches(method, path, operation_id) {
                let PluginAfterHookAction::Continue(next_response) =
                    (hook.handler)(context, request, response).await?;
                response = next_response;
            }
        }
    }
    Ok(response)
}

pub(super) fn plugin_async_endpoints(
    context: &AuthContext,
    mut async_endpoints: Vec<AsyncAuthEndpoint>,
) -> Vec<AsyncAuthEndpoint> {
    for plugin in &context.plugins {
        async_endpoints.extend(plugin.endpoints.iter().cloned());
    }
    async_endpoints
}

pub(super) fn endpoint_operation_id(endpoint: &AsyncAuthEndpoint) -> Option<&str> {
    endpoint
        .options
        .operation_id
        .as_deref()
        .or_else(|| endpoint.options.openapi.as_ref()?.operation_id.as_deref())
}

pub(super) fn validate_endpoint_conflicts(
    endpoints: &[AuthEndpoint],
    async_endpoints: &[AsyncAuthEndpoint],
) -> Result<(), RustAuthError> {
    let mut seen = HashSet::new();
    for endpoint in endpoints {
        let key = (
            endpoint.method.clone(),
            endpoint_conflict_key(&endpoint.path),
        );
        if !seen.insert(key) {
            return Err(RustAuthError::Api(format!(
                "endpoint conflict for {} {}",
                endpoint.method, endpoint.path
            )));
        }
    }
    for endpoint in async_endpoints {
        let key = (
            endpoint.method.clone(),
            endpoint_conflict_key(&endpoint.path),
        );
        if !seen.insert(key) {
            return Err(RustAuthError::Api(format!(
                "endpoint conflict for {} {}",
                endpoint.method, endpoint.path
            )));
        }
    }
    Ok(())
}

fn endpoint_conflict_key(path: &str) -> String {
    path.split('/')
        .map(|segment| {
            if segment.starts_with(':') && segment.len() > 1 {
                ":".to_owned()
            } else {
                segment.to_owned()
            }
        })
        .collect::<Vec<_>>()
        .join("/")
}