pub mod auth;
#[cfg(feature = "resil")]
pub mod breaker;
#[cfg(feature = "resil")]
pub mod concurrency;
#[cfg(all(feature = "resil", feature = "cache-redis"))]
pub mod limiter;
pub mod request_id;
#[cfg(feature = "resil")]
pub mod shedding;
pub mod timeout;
pub mod uniform;
use axum::Router;
use tower_http::{
limit::RequestBodyLimitLayer, sensitive_headers::SetSensitiveRequestHeadersLayer,
trace::TraceLayer,
};
use crate::rest::RestConfig;
#[cfg(all(feature = "resil", feature = "observability"))]
pub(crate) type MiddlewareMetrics = Option<crate::observability::MetricsRegistry>;
#[cfg(all(feature = "resil", not(feature = "observability")))]
pub(crate) type MiddlewareMetrics = ();
#[cfg(all(feature = "resil", feature = "observability"))]
pub(crate) fn middleware_metrics(config: &RestConfig) -> MiddlewareMetrics {
config.metrics_registry.clone()
}
#[cfg(all(feature = "resil", not(feature = "observability")))]
pub(crate) fn middleware_metrics(_config: &RestConfig) -> MiddlewareMetrics {}
#[cfg(all(feature = "resil", feature = "observability"))]
pub(crate) fn record_resilience_event(metrics: &MiddlewareMetrics, component: &str, outcome: &str) {
crate::observability::record_resilience_decision(metrics.as_ref(), "http", component, outcome);
}
#[cfg(all(feature = "resil", not(feature = "observability")))]
pub(crate) fn record_resilience_event(
_metrics: &MiddlewareMetrics,
_component: &str,
_outcome: &str,
) {
}
pub fn apply_default_layers(router: Router, config: RestConfig) -> Router {
let router = router.layer(RequestBodyLimitLayer::new(config.max_body_bytes));
let request_timeout = config
.middlewares
.resilience
.request_timeout
.unwrap_or(config.timeout);
let router = timeout::apply_timeout(router, request_timeout);
let router = apply_resilience_layers(router, &config);
let router = apply_metrics_layer(router, &config);
let router = router.layer(axum::middleware::from_fn(uniform::uniform_error_middleware));
let router = if let Some(auth) = config.auth {
router.layer(axum::middleware::from_fn(move |request, next| {
auth::auth_middleware(auth.clone(), request, next)
}))
} else {
router
};
router
.layer(request_id::propagate_request_id_layer())
.layer(request_id::set_request_id_layer())
.layer(TraceLayer::new_for_http())
.layer(SetSensitiveRequestHeadersLayer::new(std::iter::once(
axum::http::header::AUTHORIZATION,
)))
}
#[cfg(feature = "resil")]
fn apply_resilience_layers(router: Router, config: &RestConfig) -> Router {
use std::sync::Arc;
use crate::resil::{BreakerRegistry, ShedderRegistry};
use tokio::sync::Semaphore;
let resilience = config.middlewares.resilience.clone();
let metrics = middleware_metrics(config);
#[cfg(feature = "cache-redis")]
let limiter_config = resilience.rate_limiter.clone();
let router = if let Some(max) = resilience.max_concurrency {
let semaphore = Arc::new(Semaphore::new(max));
let metrics = metrics.clone();
router.layer(axum::middleware::from_fn(move |request, next| {
concurrency::concurrency_middleware(semaphore.clone(), metrics.clone(), request, next)
}))
} else {
router
};
#[cfg(feature = "cache-redis")]
let router = {
let limiter = limiter::RestRateLimiter::new(limiter_config);
if limiter.is_disabled() {
router
} else {
let metrics = metrics.clone();
router.layer(axum::middleware::from_fn(move |request, next| {
limiter::rate_limiter_middleware(limiter.clone(), metrics.clone(), request, next)
}))
}
};
let router = if resilience.shedding_enabled {
let registry = ShedderRegistry::new();
let service = config.name.clone();
let metrics = metrics.clone();
router.layer(axum::middleware::from_fn(move |request, next| {
shedding::shedding_middleware(
registry.clone(),
service.clone(),
resilience.clone(),
metrics.clone(),
request,
next,
)
}))
} else {
router
};
if config.middlewares.resilience.breaker_enabled {
let registry = BreakerRegistry::new();
let service = config.name.clone();
let resilience = config.middlewares.resilience.clone();
let metrics = metrics.clone();
router.layer(axum::middleware::from_fn(move |request, next| {
breaker::breaker_middleware(
registry.clone(),
service.clone(),
resilience.clone(),
metrics.clone(),
request,
next,
)
}))
} else {
router
}
}
#[cfg(not(feature = "resil"))]
fn apply_resilience_layers(router: Router, _config: &RestConfig) -> Router {
router
}
#[cfg(feature = "observability")]
fn apply_metrics_layer(router: Router, config: &RestConfig) -> Router {
if config.middlewares.metrics.enabled {
let registry = config.metrics_registry.clone().unwrap_or_default();
router.layer(axum::middleware::from_fn_with_state(
registry,
crate::observability::record_metrics_middleware,
))
} else {
router
}
}
#[cfg(not(feature = "observability"))]
fn apply_metrics_layer(router: Router, _config: &RestConfig) -> Router {
router
}