rs-zero 0.2.6

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
//! REST middleware used by the default runtime stack.

pub mod auth;
#[cfg(feature = "resil")]
pub mod breaker;
#[cfg(feature = "resil")]
pub mod concurrency;
#[cfg(all(feature = "resil", feature = "cache-redis"))]
pub mod limiter;
pub mod request_id;
#[cfg(feature = "resil")]
pub mod shedding;
pub mod timeout;
pub mod uniform;

use axum::Router;
use tower_http::{
    limit::RequestBodyLimitLayer, sensitive_headers::SetSensitiveRequestHeadersLayer,
    trace::TraceLayer,
};

use crate::rest::RestConfig;

#[cfg(all(feature = "resil", feature = "observability"))]
pub(crate) type MiddlewareMetrics = Option<crate::observability::MetricsRegistry>;
#[cfg(all(feature = "resil", not(feature = "observability")))]
pub(crate) type MiddlewareMetrics = ();

#[cfg(all(feature = "resil", feature = "observability"))]
pub(crate) fn middleware_metrics(config: &RestConfig) -> MiddlewareMetrics {
    config.metrics_registry.clone()
}

#[cfg(all(feature = "resil", not(feature = "observability")))]
pub(crate) fn middleware_metrics(_config: &RestConfig) -> MiddlewareMetrics {}

#[cfg(all(feature = "resil", feature = "observability"))]
pub(crate) fn record_resilience_event(metrics: &MiddlewareMetrics, component: &str, outcome: &str) {
    crate::observability::record_resilience_decision(metrics.as_ref(), "http", component, outcome);
}

#[cfg(all(feature = "resil", not(feature = "observability")))]
pub(crate) fn record_resilience_event(
    _metrics: &MiddlewareMetrics,
    _component: &str,
    _outcome: &str,
) {
}

/// Applies the default rs-zero REST middleware stack.
pub fn apply_default_layers(router: Router, config: RestConfig) -> Router {
    let router = router.layer(RequestBodyLimitLayer::new(config.max_body_bytes));
    let request_timeout = config
        .middlewares
        .resilience
        .request_timeout
        .unwrap_or(config.timeout);
    let router = timeout::apply_timeout(router, request_timeout);
    let router = apply_resilience_layers(router, &config);
    let router = apply_metrics_layer(router, &config);
    let router = router.layer(axum::middleware::from_fn(uniform::uniform_error_middleware));

    let router = if let Some(auth) = config.auth {
        router.layer(axum::middleware::from_fn(move |request, next| {
            auth::auth_middleware(auth.clone(), request, next)
        }))
    } else {
        router
    };

    router
        .layer(request_id::propagate_request_id_layer())
        .layer(request_id::set_request_id_layer())
        .layer(TraceLayer::new_for_http())
        .layer(SetSensitiveRequestHeadersLayer::new(std::iter::once(
            axum::http::header::AUTHORIZATION,
        )))
}

#[cfg(feature = "resil")]
fn apply_resilience_layers(router: Router, config: &RestConfig) -> Router {
    use std::sync::Arc;

    use crate::resil::{BreakerRegistry, ShedderRegistry};
    use tokio::sync::Semaphore;

    let resilience = config.middlewares.resilience.clone();
    let metrics = middleware_metrics(config);
    #[cfg(feature = "cache-redis")]
    let limiter_config = resilience.rate_limiter.clone();
    let router = if let Some(max) = resilience.max_concurrency {
        let semaphore = Arc::new(Semaphore::new(max));
        let metrics = metrics.clone();
        router.layer(axum::middleware::from_fn(move |request, next| {
            concurrency::concurrency_middleware(semaphore.clone(), metrics.clone(), request, next)
        }))
    } else {
        router
    };

    #[cfg(feature = "cache-redis")]
    let router = {
        let limiter = limiter::RestRateLimiter::new(limiter_config);
        if limiter.is_disabled() {
            router
        } else {
            let metrics = metrics.clone();
            router.layer(axum::middleware::from_fn(move |request, next| {
                limiter::rate_limiter_middleware(limiter.clone(), metrics.clone(), request, next)
            }))
        }
    };

    let router = if resilience.shedding_enabled {
        let registry = ShedderRegistry::new();
        let service = config.name.clone();
        let metrics = metrics.clone();
        router.layer(axum::middleware::from_fn(move |request, next| {
            shedding::shedding_middleware(
                registry.clone(),
                service.clone(),
                resilience.clone(),
                metrics.clone(),
                request,
                next,
            )
        }))
    } else {
        router
    };

    if config.middlewares.resilience.breaker_enabled {
        let registry = BreakerRegistry::new();
        let service = config.name.clone();
        let resilience = config.middlewares.resilience.clone();
        let metrics = metrics.clone();
        router.layer(axum::middleware::from_fn(move |request, next| {
            breaker::breaker_middleware(
                registry.clone(),
                service.clone(),
                resilience.clone(),
                metrics.clone(),
                request,
                next,
            )
        }))
    } else {
        router
    }
}

#[cfg(not(feature = "resil"))]
fn apply_resilience_layers(router: Router, _config: &RestConfig) -> Router {
    router
}

#[cfg(feature = "observability")]
fn apply_metrics_layer(router: Router, config: &RestConfig) -> Router {
    if config.middlewares.metrics.enabled {
        let registry = config.metrics_registry.clone().unwrap_or_default();
        router.layer(axum::middleware::from_fn_with_state(
            registry,
            crate::observability::record_metrics_middleware,
        ))
    } else {
        router
    }
}

#[cfg(not(feature = "observability"))]
fn apply_metrics_layer(router: Router, _config: &RestConfig) -> Router {
    router
}