rs_zero/rest/middleware/
mod.rs1pub mod auth;
4#[cfg(feature = "resil")]
5pub mod breaker;
6#[cfg(feature = "resil")]
7pub mod concurrency;
8#[cfg(all(feature = "resil", feature = "cache-redis"))]
9pub mod limiter;
10pub mod request_id;
11#[cfg(feature = "resil")]
12pub mod shedding;
13pub mod timeout;
14pub mod uniform;
15
16use axum::Router;
17use tower_http::{
18 limit::RequestBodyLimitLayer, sensitive_headers::SetSensitiveRequestHeadersLayer,
19 trace::TraceLayer,
20};
21
22use crate::rest::RestConfig;
23
24#[cfg(all(feature = "resil", feature = "observability"))]
25pub(crate) type MiddlewareMetrics = Option<crate::observability::MetricsRegistry>;
26#[cfg(all(feature = "resil", not(feature = "observability")))]
27pub(crate) type MiddlewareMetrics = ();
28
29#[cfg(all(feature = "resil", feature = "observability"))]
30pub(crate) fn middleware_metrics(config: &RestConfig) -> MiddlewareMetrics {
31 config.metrics_registry.clone()
32}
33
34#[cfg(all(feature = "resil", not(feature = "observability")))]
35pub(crate) fn middleware_metrics(_config: &RestConfig) -> MiddlewareMetrics {}
36
37#[cfg(all(feature = "resil", feature = "observability"))]
38pub(crate) fn record_resilience_event(metrics: &MiddlewareMetrics, component: &str, outcome: &str) {
39 crate::observability::record_resilience_decision(metrics.as_ref(), "http", component, outcome);
40}
41
42#[cfg(all(feature = "resil", not(feature = "observability")))]
43pub(crate) fn record_resilience_event(
44 _metrics: &MiddlewareMetrics,
45 _component: &str,
46 _outcome: &str,
47) {
48}
49
50pub fn apply_default_layers(router: Router, config: RestConfig) -> Router {
52 let router = router.layer(RequestBodyLimitLayer::new(config.max_body_bytes));
53 let request_timeout = config
54 .middlewares
55 .resilience
56 .request_timeout
57 .unwrap_or(config.timeout);
58 let router = timeout::apply_timeout(router, request_timeout);
59 let router = apply_resilience_layers(router, &config);
60 let router = apply_metrics_layer(router, &config);
61 let router = router.layer(axum::middleware::from_fn(uniform::uniform_error_middleware));
62
63 let router = if let Some(auth) = config.auth {
64 router.layer(axum::middleware::from_fn(move |request, next| {
65 auth::auth_middleware(auth.clone(), request, next)
66 }))
67 } else {
68 router
69 };
70
71 router
72 .layer(request_id::propagate_request_id_layer())
73 .layer(request_id::set_request_id_layer())
74 .layer(TraceLayer::new_for_http())
75 .layer(SetSensitiveRequestHeadersLayer::new(std::iter::once(
76 axum::http::header::AUTHORIZATION,
77 )))
78}
79
80#[cfg(feature = "resil")]
81fn apply_resilience_layers(router: Router, config: &RestConfig) -> Router {
82 use std::sync::Arc;
83
84 use crate::resil::{BreakerRegistry, ShedderRegistry};
85 use tokio::sync::Semaphore;
86
87 let resilience = config.middlewares.resilience.clone();
88 let metrics = middleware_metrics(config);
89 #[cfg(feature = "cache-redis")]
90 let limiter_config = resilience.rate_limiter.clone();
91 let router = if let Some(max) = resilience.max_concurrency {
92 let semaphore = Arc::new(Semaphore::new(max));
93 let metrics = metrics.clone();
94 router.layer(axum::middleware::from_fn(move |request, next| {
95 concurrency::concurrency_middleware(semaphore.clone(), metrics.clone(), request, next)
96 }))
97 } else {
98 router
99 };
100
101 #[cfg(feature = "cache-redis")]
102 let router = {
103 let limiter = limiter::RestRateLimiter::new(limiter_config);
104 if limiter.is_disabled() {
105 router
106 } else {
107 let metrics = metrics.clone();
108 router.layer(axum::middleware::from_fn(move |request, next| {
109 limiter::rate_limiter_middleware(limiter.clone(), metrics.clone(), request, next)
110 }))
111 }
112 };
113
114 let router = if resilience.shedding_enabled {
115 let registry = ShedderRegistry::new();
116 let service = config.name.clone();
117 let metrics = metrics.clone();
118 router.layer(axum::middleware::from_fn(move |request, next| {
119 shedding::shedding_middleware(
120 registry.clone(),
121 service.clone(),
122 resilience.clone(),
123 metrics.clone(),
124 request,
125 next,
126 )
127 }))
128 } else {
129 router
130 };
131
132 if config.middlewares.resilience.breaker_enabled {
133 let registry = BreakerRegistry::new();
134 let service = config.name.clone();
135 let resilience = config.middlewares.resilience.clone();
136 let metrics = metrics.clone();
137 router.layer(axum::middleware::from_fn(move |request, next| {
138 breaker::breaker_middleware(
139 registry.clone(),
140 service.clone(),
141 resilience.clone(),
142 metrics.clone(),
143 request,
144 next,
145 )
146 }))
147 } else {
148 router
149 }
150}
151
152#[cfg(not(feature = "resil"))]
153fn apply_resilience_layers(router: Router, _config: &RestConfig) -> Router {
154 router
155}
156
157#[cfg(feature = "observability")]
158fn apply_metrics_layer(router: Router, config: &RestConfig) -> Router {
159 if config.middlewares.metrics.enabled {
160 let registry = config.metrics_registry.clone().unwrap_or_default();
161 router.layer(axum::middleware::from_fn_with_state(
162 registry,
163 crate::observability::record_metrics_middleware,
164 ))
165 } else {
166 router
167 }
168}
169
170#[cfg(not(feature = "observability"))]
171fn apply_metrics_layer(router: Router, _config: &RestConfig) -> Router {
172 router
173}