use std::time::Duration;
use axum::Router;
use crate::{
error::ErrorEnvelopeLayer,
health::HealthRegistry,
middleware::{
PrometheusMetrics, RateLimitConfig, RequestIdConfig, body_limit_layer, catch_panic_layer,
request_context_layer, security_headers_layer, streaming_body_limit_layer,
timeout_response_layer, validated_request_id_layer,
},
};
#[derive(Clone)]
pub struct ApiDefaults {
service_name: String,
version: Option<String>,
environment: Option<String>,
request_ids: Option<RequestIdConfig>,
request_context: bool,
error_envelope: bool,
metrics: Option<PrometheusMetrics>,
health: Option<HealthRegistry>,
rate_limit: Option<RateLimitConfig>,
security_headers: bool,
body_limit: Option<u64>,
streaming_body_limit: Option<usize>,
timeout: Option<Duration>,
catch_panic: bool,
}
impl ApiDefaults {
pub fn production(service_name: impl Into<String>) -> Self {
Self {
service_name: service_name.into(),
version: None,
environment: None,
request_ids: Some(RequestIdConfig::production()),
request_context: true,
error_envelope: true,
metrics: None,
health: Some(HealthRegistry::new()),
rate_limit: None,
security_headers: true,
body_limit: Some(1024 * 1024),
streaming_body_limit: None,
timeout: Some(Duration::from_secs(30)),
catch_panic: true,
}
}
pub fn service_name(&self) -> &str {
&self.service_name
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
pub fn environment(mut self, environment: impl Into<String>) -> Self {
self.environment = Some(environment.into());
self
}
pub fn request_ids(mut self, config: RequestIdConfig) -> Self {
self.request_ids = Some(config);
self
}
pub fn without_request_ids(mut self) -> Self {
self.request_ids = None;
self
}
pub fn without_request_context(mut self) -> Self {
self.request_context = false;
self
}
pub fn without_error_envelope(mut self) -> Self {
self.error_envelope = false;
self
}
pub fn metrics(mut self, metrics: PrometheusMetrics) -> Self {
self.metrics = Some(metrics);
self
}
pub fn without_metrics(mut self) -> Self {
self.metrics = None;
self
}
pub fn health(mut self, health: HealthRegistry) -> Self {
self.health = Some(health);
self
}
pub fn without_health(mut self) -> Self {
self.health = None;
self
}
pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
self.rate_limit = Some(config);
self
}
pub fn without_rate_limit(mut self) -> Self {
self.rate_limit = None;
self
}
pub fn body_limit(mut self, max_bytes: u64) -> Self {
self.body_limit = Some(max_bytes);
self
}
pub fn streaming_body_limit(mut self, max_bytes: usize) -> Self {
self.streaming_body_limit = Some(max_bytes);
self
}
pub fn without_body_limit(mut self) -> Self {
self.body_limit = None;
self
}
pub fn security_headers(mut self) -> Self {
self.security_headers = true;
self
}
pub fn without_security_headers(mut self) -> Self {
self.security_headers = false;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn without_timeout(mut self) -> Self {
self.timeout = None;
self
}
pub fn without_catch_panic(mut self) -> Self {
self.catch_panic = false;
self
}
pub fn apply(self, mut router: Router) -> Router {
if let Some(health) = self.health {
router = router.merge(health.routes());
}
if self.catch_panic {
router = router.layer(catch_panic_layer());
}
if let Some(rate_limit) = self.rate_limit {
router = router.layer(rate_limit.layer());
}
if let Some(max_bytes) = self.body_limit {
router = router.layer(body_limit_layer(max_bytes));
}
if let Some(max_bytes) = self.streaming_body_limit {
router = router.layer(streaming_body_limit_layer(max_bytes));
}
if let Some(timeout) = self.timeout {
router = router.layer(timeout_response_layer(timeout));
}
if self.error_envelope {
router = router.layer(ErrorEnvelopeLayer::new());
}
if let Some(metrics) = self.metrics {
router = router.layer(metrics.layer());
}
if self.request_context {
router = router.layer(request_context_layer());
}
if let Some(request_ids) = self.request_ids {
router = router.layer(validated_request_id_layer(request_ids));
}
if self.security_headers {
router = router.layer(security_headers_layer());
}
router
}
}