Skip to main content

rs_zero/rest/middleware/
mod.rs

1//! REST middleware used by the default runtime stack.
2
3pub mod auth;
4#[cfg(feature = "resil")]
5pub mod breaker;
6#[cfg(feature = "resil")]
7pub mod concurrency;
8#[cfg(all(feature = "resil", feature = "cache-redis"))]
9pub mod limiter;
10pub mod request_id;
11#[cfg(feature = "resil")]
12pub mod shedding;
13pub mod timeout;
14pub mod uniform;
15
16use axum::Router;
17use tower_http::{
18    limit::RequestBodyLimitLayer, sensitive_headers::SetSensitiveRequestHeadersLayer,
19    trace::TraceLayer,
20};
21
22use crate::rest::RestConfig;
23
24#[cfg(all(feature = "resil", feature = "observability"))]
25pub(crate) type MiddlewareMetrics = Option<crate::observability::MetricsRegistry>;
26#[cfg(all(feature = "resil", not(feature = "observability")))]
27pub(crate) type MiddlewareMetrics = ();
28
29#[cfg(all(feature = "resil", feature = "observability"))]
30pub(crate) fn middleware_metrics(config: &RestConfig) -> MiddlewareMetrics {
31    config.metrics_registry.clone()
32}
33
34#[cfg(all(feature = "resil", not(feature = "observability")))]
35pub(crate) fn middleware_metrics(_config: &RestConfig) -> MiddlewareMetrics {}
36
37#[cfg(all(feature = "resil", feature = "observability"))]
38pub(crate) fn record_resilience_event(metrics: &MiddlewareMetrics, component: &str, outcome: &str) {
39    crate::observability::record_resilience_decision(metrics.as_ref(), "http", component, outcome);
40}
41
42#[cfg(all(feature = "resil", not(feature = "observability")))]
43pub(crate) fn record_resilience_event(
44    _metrics: &MiddlewareMetrics,
45    _component: &str,
46    _outcome: &str,
47) {
48}
49
50/// Applies the default rs-zero REST middleware stack.
51pub fn apply_default_layers(router: Router, config: RestConfig) -> Router {
52    let router = router.layer(RequestBodyLimitLayer::new(config.max_body_bytes));
53    let request_timeout = config
54        .middlewares
55        .resilience
56        .request_timeout
57        .unwrap_or(config.timeout);
58    let router = timeout::apply_timeout(router, request_timeout);
59    let router = apply_resilience_layers(router, &config);
60    let router = apply_metrics_layer(router, &config);
61    let router = router.layer(axum::middleware::from_fn(uniform::uniform_error_middleware));
62
63    let router = if let Some(auth) = config.auth {
64        router.layer(axum::middleware::from_fn(move |request, next| {
65            auth::auth_middleware(auth.clone(), request, next)
66        }))
67    } else {
68        router
69    };
70
71    router
72        .layer(request_id::propagate_request_id_layer())
73        .layer(request_id::set_request_id_layer())
74        .layer(TraceLayer::new_for_http())
75        .layer(SetSensitiveRequestHeadersLayer::new(std::iter::once(
76            axum::http::header::AUTHORIZATION,
77        )))
78}
79
80#[cfg(feature = "resil")]
81fn apply_resilience_layers(router: Router, config: &RestConfig) -> Router {
82    use std::sync::Arc;
83
84    use crate::resil::{BreakerRegistry, ShedderRegistry};
85    use tokio::sync::Semaphore;
86
87    let resilience = config.middlewares.resilience.clone();
88    let metrics = middleware_metrics(config);
89    #[cfg(feature = "cache-redis")]
90    let limiter_config = resilience.rate_limiter.clone();
91    let router = if let Some(max) = resilience.max_concurrency {
92        let semaphore = Arc::new(Semaphore::new(max));
93        let metrics = metrics.clone();
94        router.layer(axum::middleware::from_fn(move |request, next| {
95            concurrency::concurrency_middleware(semaphore.clone(), metrics.clone(), request, next)
96        }))
97    } else {
98        router
99    };
100
101    #[cfg(feature = "cache-redis")]
102    let router = {
103        let limiter = limiter::RestRateLimiter::new(limiter_config);
104        if limiter.is_disabled() {
105            router
106        } else {
107            let metrics = metrics.clone();
108            router.layer(axum::middleware::from_fn(move |request, next| {
109                limiter::rate_limiter_middleware(limiter.clone(), metrics.clone(), request, next)
110            }))
111        }
112    };
113
114    let router = if resilience.shedding_enabled {
115        let registry = ShedderRegistry::new();
116        let service = config.name.clone();
117        let metrics = metrics.clone();
118        router.layer(axum::middleware::from_fn(move |request, next| {
119            shedding::shedding_middleware(
120                registry.clone(),
121                service.clone(),
122                resilience.clone(),
123                metrics.clone(),
124                request,
125                next,
126            )
127        }))
128    } else {
129        router
130    };
131
132    if config.middlewares.resilience.breaker_enabled {
133        let registry = BreakerRegistry::new();
134        let service = config.name.clone();
135        let resilience = config.middlewares.resilience.clone();
136        let metrics = metrics.clone();
137        router.layer(axum::middleware::from_fn(move |request, next| {
138            breaker::breaker_middleware(
139                registry.clone(),
140                service.clone(),
141                resilience.clone(),
142                metrics.clone(),
143                request,
144                next,
145            )
146        }))
147    } else {
148        router
149    }
150}
151
152#[cfg(not(feature = "resil"))]
153fn apply_resilience_layers(router: Router, _config: &RestConfig) -> Router {
154    router
155}
156
157#[cfg(feature = "observability")]
158fn apply_metrics_layer(router: Router, config: &RestConfig) -> Router {
159    if config.middlewares.metrics.enabled {
160        let registry = config.metrics_registry.clone().unwrap_or_default();
161        router.layer(axum::middleware::from_fn_with_state(
162            registry,
163            crate::observability::record_metrics_middleware,
164        ))
165    } else {
166        router
167    }
168}
169
170#[cfg(not(feature = "observability"))]
171fn apply_metrics_layer(router: Router, _config: &RestConfig) -> Router {
172    router
173}