openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
use std::collections::HashSet;

use http::Method;

use crate::context::AuthContext;
use crate::error::OpenAuthError;
use crate::plugin::{PluginAfterHookAction, PluginBeforeHookAction, PluginRequestAction};
use crate::plugin::{PluginPasswordValidationInput, PluginPasswordValidationRejection};

use super::endpoint::{ApiRequest, ApiResponse, AsyncAuthEndpoint, AuthEndpoint};
use super::path::path_matches;

pub(super) fn run_on_request_plugins(
    context: &AuthContext,
    mut request: ApiRequest,
) -> Result<PluginRequestAction, OpenAuthError> {
    for plugin in &context.plugins {
        if let Some(hook) = &plugin.on_request {
            match hook(context, request)? {
                PluginRequestAction::Continue(next_request) => request = next_request,
                PluginRequestAction::Respond(response) => {
                    return Ok(PluginRequestAction::Respond(response));
                }
            }
        }
    }
    Ok(PluginRequestAction::Continue(request))
}

pub(super) fn run_matching_middlewares(
    context: &AuthContext,
    request: &ApiRequest,
    path: &str,
) -> Result<Option<ApiResponse>, OpenAuthError> {
    for plugin in &context.plugins {
        for middleware in &plugin.middlewares {
            if path_matches(&middleware.path, path) {
                if let Some(response) = (middleware.handler)(context, request)? {
                    return Ok(Some(response));
                }
            }
        }
    }
    Ok(None)
}

pub(super) async fn run_matching_async_middlewares(
    context: &AuthContext,
    request: &ApiRequest,
    path: &str,
) -> Result<Option<ApiResponse>, OpenAuthError> {
    for plugin in &context.plugins {
        for middleware in &plugin.async_middlewares {
            if path_matches(&middleware.path, path) {
                if let Some(response) = (middleware.handler)(context, request).await? {
                    return Ok(Some(response));
                }
            }
        }
    }
    Ok(None)
}

pub(super) fn run_on_response_plugins(
    context: &AuthContext,
    request: &ApiRequest,
    mut response: ApiResponse,
) -> Result<ApiResponse, OpenAuthError> {
    for plugin in &context.plugins {
        if let Some(hook) = &plugin.on_response {
            response = hook(context, request, response)?;
        }
    }
    Ok(response)
}

pub(super) async fn run_password_validators(
    context: &AuthContext,
    path: &str,
    password: &str,
) -> Result<(), PluginPasswordValidationRejection> {
    for plugin in &context.plugins {
        for validator in &plugin.password_validators {
            (validator.handler)(context, PluginPasswordValidationInput::new(path, password))
                .await?;
        }
    }
    Ok(())
}

pub(super) fn run_before_hooks(
    context: &AuthContext,
    mut request: ApiRequest,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<PluginBeforeHookAction, OpenAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.before {
            if hook.matcher.matches(method, path, operation_id) {
                match (hook.handler)(context, request)? {
                    PluginBeforeHookAction::Continue(next_request) => request = next_request,
                    PluginBeforeHookAction::Respond(response) => {
                        return Ok(PluginBeforeHookAction::Respond(response));
                    }
                }
            }
        }
    }
    Ok(PluginBeforeHookAction::Continue(request))
}

pub(super) async fn run_async_before_hooks(
    context: &AuthContext,
    mut request: ApiRequest,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<PluginBeforeHookAction, OpenAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.async_before {
            if hook.matcher.matches(method, path, operation_id) {
                match (hook.handler)(context, request).await? {
                    PluginBeforeHookAction::Continue(next_request) => request = next_request,
                    PluginBeforeHookAction::Respond(response) => {
                        return Ok(PluginBeforeHookAction::Respond(response));
                    }
                }
            }
        }
    }
    Ok(PluginBeforeHookAction::Continue(request))
}

pub(super) fn run_after_hooks(
    context: &AuthContext,
    request: &ApiRequest,
    mut response: ApiResponse,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<ApiResponse, OpenAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.after {
            if hook.matcher.matches(method, path, operation_id) {
                let PluginAfterHookAction::Continue(next_response) =
                    (hook.handler)(context, request, response)?;
                response = next_response;
            }
        }
    }
    Ok(response)
}

pub(super) async fn run_async_after_hooks(
    context: &AuthContext,
    request: &ApiRequest,
    mut response: ApiResponse,
    method: &Method,
    path: &str,
    operation_id: Option<&str>,
) -> Result<ApiResponse, OpenAuthError> {
    for plugin in &context.plugins {
        for hook in &plugin.hooks.async_after {
            if hook.matcher.matches(method, path, operation_id) {
                let PluginAfterHookAction::Continue(next_response) =
                    (hook.handler)(context, request, response).await?;
                response = next_response;
            }
        }
    }
    Ok(response)
}

pub(super) fn plugin_async_endpoints(
    context: &AuthContext,
    mut async_endpoints: Vec<AsyncAuthEndpoint>,
) -> Vec<AsyncAuthEndpoint> {
    for plugin in &context.plugins {
        async_endpoints.extend(plugin.endpoints.iter().cloned());
    }
    async_endpoints
}

pub(super) fn endpoint_operation_id(endpoint: &AsyncAuthEndpoint) -> Option<&str> {
    endpoint
        .options
        .operation_id
        .as_deref()
        .or_else(|| endpoint.options.openapi.as_ref()?.operation_id.as_deref())
}

pub(super) fn validate_endpoint_conflicts(
    endpoints: &[AuthEndpoint],
    async_endpoints: &[AsyncAuthEndpoint],
) -> Result<(), OpenAuthError> {
    let mut seen = HashSet::new();
    for endpoint in endpoints {
        let key = (
            endpoint.method.clone(),
            endpoint_conflict_key(&endpoint.path),
        );
        if !seen.insert(key) {
            return Err(OpenAuthError::Api(format!(
                "endpoint conflict for {} {}",
                endpoint.method, endpoint.path
            )));
        }
    }
    for endpoint in async_endpoints {
        let key = (
            endpoint.method.clone(),
            endpoint_conflict_key(&endpoint.path),
        );
        if !seen.insert(key) {
            return Err(OpenAuthError::Api(format!(
                "endpoint conflict for {} {}",
                endpoint.method, endpoint.path
            )));
        }
    }
    Ok(())
}

fn endpoint_conflict_key(path: &str) -> String {
    path.split('/')
        .map(|segment| {
            if segment.starts_with(':') && segment.len() > 1 {
                ":".to_owned()
            } else {
                segment.to_owned()
            }
        })
        .collect::<Vec<_>>()
        .join("/")
}