axum_conf/fluent/features.rs
1//! Feature middleware: routing, compression, CORS, security headers, sessions, health checks.
2
3use super::router::FluentRouter;
4use crate::HttpMiddleware;
5
6use {axum::routing::get, http::StatusCode, tower_http::timeout::TimeoutLayer};
7
8#[cfg(feature = "path-normalization")]
9use tower_http::normalize_path::NormalizePathLayer;
10
11#[cfg(feature = "compression")]
12use tower_http::{compression::CompressionLayer, decompression::RequestDecompressionLayer};
13
14#[cfg(feature = "cors")]
15use {http::HeaderName, tower_http::cors::CorsLayer};
16
17#[cfg(feature = "security-headers")]
18use axum_helmet::{Helmet, HelmetLayer};
19
20#[cfg(feature = "session")]
21use tower_sessions::{
22 Expiry, MemoryStore, SessionManagerLayer,
23 cookie::{SameSite, time::Duration as CookieDuration},
24};
25
26#[cfg(feature = "api-versioning")]
27use {
28 crate::utils::ApiVersion,
29 axum::{body::Body, middleware::Next},
30 http::Request,
31};
32
33impl<State> FluentRouter<State>
34where
35 State: Clone + Send + Sync + 'static,
36{
37 /// Sets up cookie-based session handling using an in-memory store.
38 #[cfg(feature = "session")]
39 #[must_use]
40 pub fn setup_session_handling(mut self) -> Self {
41 let session_store = MemoryStore::default();
42 let session_layer = SessionManagerLayer::new(session_store)
43 .with_secure(false)
44 .with_same_site(SameSite::Lax)
45 .with_expiry(Expiry::OnInactivity(CookieDuration::seconds(3600)));
46 self.inner = self.inner.layer(session_layer);
47 tracing::trace!("Session middleware enabled");
48 self
49 }
50
51 /// Sets up path normalization middleware.
52 ///
53 /// When `config.http.trim_trailing_slash` is true, automatically removes
54 /// trailing slashes from request paths:
55 /// - `/api/users/` → `/api/users`
56 /// - `/health/` → `/health`
57 ///
58 /// This ensures consistent routing behavior regardless of whether clients
59 /// include trailing slashes.
60 ///
61 /// # Configuration
62 ///
63 /// ```toml
64 /// [http]
65 /// trim_trailing_slash = true # Default
66 /// ```
67 #[cfg(feature = "path-normalization")]
68 #[must_use]
69 pub fn setup_path_normalization(mut self) -> Self {
70 if !self.is_middleware_enabled(HttpMiddleware::PathNormalization) {
71 tracing::trace!("PathNormalization middleware skipped (disabled in config)");
72 return self;
73 }
74
75 self.inner = self.inner.layer(NormalizePathLayer::trim_trailing_slash());
76 tracing::trace!("PathNormalization middleware enabled");
77 self
78 }
79
80 /// No-op when `path-normalization` feature is disabled.
81 #[cfg(not(feature = "path-normalization"))]
82 #[must_use]
83 pub fn setup_path_normalization(self) -> Self {
84 self
85 }
86
87 /// Sets up request timeout middleware.
88 ///
89 /// Aborts requests that take longer than the configured duration with a
90 /// `408 Request Timeout` response.
91 ///
92 /// # Configuration
93 ///
94 /// ```toml
95 /// [http]
96 /// request_timeout = "30s" # Optional, uses humantime format
97 /// ```
98 ///
99 /// # Use Cases
100 ///
101 /// - Prevent slow requests from tying up resources
102 /// - Ensure predictable response times
103 /// - Protect against slowloris attacks
104 #[must_use]
105 pub fn setup_timeout(mut self) -> Self {
106 // Skip if timeout middleware is disabled
107 if !self.is_middleware_enabled(HttpMiddleware::Timeout) {
108 tracing::trace!("Timeout middleware skipped (disabled in config)");
109 return self;
110 }
111
112 if let Some(timeout) = self.config.http.request_timeout {
113 self.inner = self.inner.layer(TimeoutLayer::with_status_code(
114 StatusCode::REQUEST_TIMEOUT,
115 timeout,
116 ));
117 tracing::trace!(timeout = ?timeout, "Timeout middleware enabled");
118 }
119 self
120 }
121
122 /// Sets up API versioning middleware.
123 ///
124 /// Automatically extracts API version from requests and adds it to request extensions.
125 /// Supports multiple version detection methods:
126 /// - **Path-based**: `/v1/users`, `/api/v2/users`
127 /// - **Header-based**: `X-API-Version: 2`, `Accept: application/vnd.api+json;version=2`
128 /// - **Query parameter**: `/users?version=1`
129 ///
130 /// The version is checked in order: path → header → query → default.
131 ///
132 /// **Note**: This middleware is automatically included in `setup_middleware()` using
133 /// the `config.http.default_api_version` setting. You only need to call this method
134 /// directly if you want to override the configured default or set up versioning manually.
135 ///
136 /// # Arguments
137 ///
138 /// * `default_version` - The version to use when none is specified in the request
139 ///
140 /// # Examples
141 ///
142 /// ```rust,no_run
143 /// use axum_conf::{Config, FluentRouter, ApiVersion};
144 /// use axum::{routing::get, extract::Extension};
145 ///
146 /// async fn handler(Extension(version): Extension<ApiVersion>) -> String {
147 /// format!("API version: {}", version)
148 /// }
149 ///
150 /// # async fn example() -> axum_conf::Result<()> {
151 /// FluentRouter::without_state(Config::default())?
152 /// .setup_api_versioning(1) // Default to v1
153 /// .route("/users", get(handler));
154 /// # Ok(())
155 /// # }
156 /// ```
157 ///
158 /// # Handler Usage
159 ///
160 /// Extract the version in your handlers using `Extension<ApiVersion>`:
161 ///
162 /// ```rust,ignore
163 /// use axum::extract::Extension;
164 /// use axum_conf::ApiVersion;
165 ///
166 /// async fn my_handler(Extension(version): Extension<ApiVersion>) -> String {
167 /// match version.as_u32() {
168 /// 1 => handle_v1(),
169 /// 2 => handle_v2(),
170 /// _ => "Unsupported version".to_string(),
171 /// }
172 /// }
173 /// ```
174 #[cfg(feature = "api-versioning")]
175 #[must_use]
176 pub fn setup_api_versioning(mut self, default_version: u32) -> Self {
177 if !self.is_middleware_enabled(HttpMiddleware::ApiVersioning) {
178 tracing::trace!("ApiVersioning middleware skipped (disabled in config)");
179 return self;
180 }
181
182 tracing::trace!(default_version = default_version, "ApiVersioning middleware enabled");
183 use axum::middleware;
184
185 let default_version = ApiVersion::new(default_version);
186
187 self.inner = self.inner.layer(middleware::from_fn(
188 move |mut req: Request<Body>, next: Next| async move {
189 // Try to extract version from path first
190 let version = ApiVersion::from_path(req.uri().path())
191 // Then try X-API-Version header
192 .or_else(|| {
193 req.headers()
194 .get("x-api-version")
195 .and_then(|h| h.to_str().ok())
196 .and_then(ApiVersion::from_header)
197 })
198 // Then try Accept header
199 .or_else(|| {
200 req.headers()
201 .get(http::header::ACCEPT)
202 .and_then(|h| h.to_str().ok())
203 .and_then(ApiVersion::from_header)
204 })
205 // Then try query parameter
206 .or_else(|| req.uri().query().and_then(ApiVersion::from_query))
207 // Fall back to default
208 .unwrap_or(default_version);
209
210 // Add version to request extensions
211 req.extensions_mut().insert(version);
212
213 // Log the version being used
214 tracing::debug!(
215 version = %version,
216 path = %req.uri().path(),
217 "API version detected"
218 );
219
220 next.run(req).await
221 },
222 ));
223 self
224 }
225
226 /// No-op when `api-versioning` feature is disabled.
227 #[cfg(not(feature = "api-versioning"))]
228 #[must_use]
229 pub fn setup_api_versioning(self, _default_version: u32) -> Self {
230 self
231 }
232
233 /// Sets up security headers using Helmet.
234 ///
235 /// Adds HTTP security headers based on configuration:
236 /// - `X-Content-Type-Options: nosniff` (prevents MIME sniffing)
237 /// - `X-Frame-Options` (clickjacking protection)
238 ///
239 /// # Configuration
240 ///
241 /// ```toml
242 /// [http]
243 /// x_content_type_nosniff = true # Default
244 /// x_frame_options = "DENY" # Default: DENY, SAMEORIGIN, or URL
245 /// ```
246 ///
247 /// # Security Benefits
248 ///
249 /// - Prevents browsers from MIME-sniffing responses
250 /// - Protects against clickjacking attacks
251 /// - Improves security score in penetration tests
252 #[cfg(feature = "security-headers")]
253 #[must_use]
254 pub fn setup_helmet(mut self) -> Self {
255 if !self.is_middleware_enabled(HttpMiddleware::SecurityHeaders) {
256 tracing::trace!("SecurityHeaders middleware skipped (disabled in config)");
257 return self;
258 }
259
260 tracing::trace!(
261 x_frame_options = ?self.config.http.x_frame_options,
262 x_content_type_nosniff = self.config.http.x_content_type_nosniff,
263 "SecurityHeaders middleware enabled"
264 );
265 let mut helmet = Helmet::new();
266 if self.config.http.x_content_type_nosniff {
267 helmet = helmet.add(helmet_core::XContentTypeOptions::nosniff());
268 }
269 // Convert our local XFrameOptions to axum_helmet's version
270 let x_frame = match &self.config.http.x_frame_options.0 {
271 crate::XFrameOptions::Deny => axum_helmet::XFrameOptions::Deny,
272 crate::XFrameOptions::SameOrigin => axum_helmet::XFrameOptions::SameOrigin,
273 crate::XFrameOptions::AllowFrom(url) => {
274 axum_helmet::XFrameOptions::AllowFrom(url.clone())
275 }
276 };
277 helmet = helmet.add(x_frame);
278 self.inner = self.inner.layer(HelmetLayer::new(helmet));
279 self
280 }
281
282 /// No-op when `security-headers` feature is disabled.
283 #[cfg(not(feature = "security-headers"))]
284 #[must_use]
285 pub fn setup_helmet(self) -> Self {
286 self
287 }
288
289 /// Sets up request decompression and response compression.
290 ///
291 /// When `config.http.support_compression` is true, enables:
292 /// - Request body decompression (gzip, brotli, deflate, zstd)
293 /// - Response body compression (based on Accept-Encoding header)
294 ///
295 /// # Configuration
296 ///
297 /// ```toml
298 /// [http]
299 /// support_compression = true # Default: false
300 /// ```
301 ///
302 /// # Performance
303 ///
304 /// - Reduces bandwidth usage
305 /// - May increase CPU usage
306 /// - Most beneficial for text-based responses (JSON, HTML, etc.)
307 #[cfg(feature = "compression")]
308 #[must_use]
309 pub fn setup_compression(mut self) -> Self {
310 if self.config.http.support_compression
311 && self.is_middleware_enabled(HttpMiddleware::Compression)
312 {
313 self.inner = self
314 .inner
315 .layer(RequestDecompressionLayer::new())
316 .layer(CompressionLayer::new());
317 tracing::trace!("Compression middleware enabled");
318 } else {
319 tracing::trace!("Compression middleware skipped (disabled in config)");
320 }
321 self
322 }
323
324 /// No-op when `compression` feature is disabled.
325 #[cfg(not(feature = "compression"))]
326 #[must_use]
327 pub fn setup_compression(self) -> Self {
328 if self.config.http.support_compression {
329 tracing::warn!(
330 "Compression is enabled in config but the 'compression' feature is not enabled. \
331 Add `compression` to your Cargo.toml features to enable compression support."
332 );
333 }
334 self
335 }
336
337 /// Sets up Cross-Origin Resource Sharing (CORS) middleware.
338 ///
339 /// Configures which web domains can make requests to your API from a browser.
340 /// If no CORS configuration is provided, defaults to very permissive settings
341 /// (allows all origins, methods, and headers).
342 ///
343 /// # Configuration
344 ///
345 /// ```toml
346 /// [http.cors]
347 /// allow_credentials = true
348 /// allowed_origins = ["https://app.example.com", "https://admin.example.com"]
349 /// allowed_methods = ["GET", "POST", "PUT", "DELETE"]
350 /// allowed_headers = ["content-type", "authorization"]
351 /// exposed_headers = ["x-request-id"]
352 /// max_age = "1h"
353 /// ```
354 ///
355 /// # Security Considerations
356 ///
357 /// - When `allow_credentials` is `true`, wildcard origins are not allowed
358 /// - Without explicit configuration, uses `CorsLayer::very_permissive()` which is suitable
359 /// for development but may be too permissive for production
360 /// - Always configure explicit `allowed_origins` in production environments
361 ///
362 /// # Examples
363 ///
364 /// ```rust,no_run
365 /// use axum_conf::{Config, HttpCorsConfig};
366 /// # fn example() -> axum_conf::Result<()> {
367 /// let mut config = Config::default();
368 /// config.http.cors = Some(HttpCorsConfig {
369 /// allow_credentials: Some(true),
370 /// allowed_origins: Some(vec!["https://app.example.com".to_string()]),
371 /// allowed_methods: None,
372 /// allowed_headers: None,
373 /// exposed_headers: None,
374 /// max_age: None,
375 /// });
376 /// # Ok(())
377 /// # }
378 /// ```
379 #[cfg(feature = "cors")]
380 #[must_use]
381 pub fn setup_cors(mut self) -> Self {
382 if !self.is_middleware_enabled(HttpMiddleware::Cors) {
383 tracing::trace!("CORS middleware skipped (disabled in config)");
384 return self;
385 }
386
387 use http::HeaderValue;
388
389 if let Some(cors_config) = &self.config.http.cors {
390 tracing::trace!("CORS middleware enabled with custom configuration");
391 let mut cors = CorsLayer::new();
392
393 // By default we do NOT allow credentials
394 let has_credentials = cors_config.allow_credentials.unwrap_or(false);
395
396 // Configure allowed origins
397 if let Some(origins) = &cors_config.allowed_origins {
398 for origin in origins {
399 if let Ok(header_value) = HeaderValue::from_str(origin) {
400 cors = cors.allow_origin(header_value);
401 }
402 }
403 } else if !has_credentials {
404 // Only use wildcard if credentials is not enabled
405 cors = cors.allow_origin(tower_http::cors::Any);
406 }
407
408 // Configure allowed methods
409 if let Some(methods) = &cors_config.allowed_methods {
410 let method_list: Vec<http::Method> = methods.iter().map(|m| m.0.clone()).collect();
411 cors = cors.allow_methods(method_list);
412 } else if !has_credentials {
413 // Only use wildcard if credentials is not enabled
414 cors = cors.allow_methods(tower_http::cors::Any);
415 }
416
417 // Configure allowed headers
418 if let Some(headers) = &cors_config.allowed_headers {
419 let header_list: Vec<HeaderName> = headers.iter().map(|h| h.0.clone()).collect();
420 cors = cors.allow_headers(header_list);
421 } else if !has_credentials {
422 // Only use wildcard if credentials is not enabled
423 cors = cors.allow_headers(tower_http::cors::Any);
424 }
425
426 // Configure exposed headers
427 if let Some(headers) = &cors_config.exposed_headers {
428 let header_list: Vec<HeaderName> = headers.iter().map(|h| h.0.clone()).collect();
429 cors = cors.expose_headers(header_list);
430 }
431
432 // Configure max age
433 if let Some(max_age) = cors_config.max_age {
434 cors = cors.max_age(max_age);
435 }
436
437 // Configure credentials (must be set last after origins/headers)
438 if has_credentials {
439 cors = cors.allow_credentials(true);
440 }
441
442 self.inner = self.inner.layer(cors);
443 } else {
444 // No CORS config specified - behavior depends on environment
445 let rust_env = std::env::var("RUST_ENV").unwrap_or_default().to_lowercase();
446 let is_production = rust_env.is_empty()
447 || rust_env == "prod"
448 || rust_env == "production"
449 || rust_env == "release";
450
451 if is_production {
452 // Production: fail-safe to restrictive CORS (same-origin only)
453 tracing::warn!(
454 "No CORS configuration found in production environment. \
455 Using restrictive same-origin policy. Configure [http.cors] \
456 in your config file to allow cross-origin requests."
457 );
458 // Default CorsLayer denies all cross-origin requests
459 self.inner = self.inner.layer(CorsLayer::new());
460 } else {
461 // Development/Test: use permissive defaults with warning
462 tracing::warn!(
463 "No CORS configuration found (RUST_ENV={}). Using permissive defaults. \
464 This is NOT safe for production - configure explicit CORS rules.",
465 rust_env
466 );
467 self.inner = self.inner.layer(CorsLayer::very_permissive());
468 }
469 }
470 self
471 }
472
473 /// No-op when `cors` feature is disabled.
474 #[cfg(not(feature = "cors"))]
475 #[must_use]
476 pub fn setup_cors(self) -> Self {
477 if self.config.http.cors.is_some() {
478 tracing::warn!(
479 "CORS is configured but the 'cors' feature is not enabled. \
480 Add `cors` to your Cargo.toml features to enable CORS support."
481 );
482 }
483 self
484 }
485
486 /// Sets up the Kubernetes liveness probe endpoint.
487 ///
488 /// Adds a simple endpoint that always returns 200 OK to indicate the process is running.
489 /// This endpoint is placed very early in the middleware stack (after panic catching) so
490 /// it remains accessible even when other middleware fails.
491 ///
492 /// # Configuration
493 ///
494 /// ```toml
495 /// [http]
496 /// liveness_route = "/live" # Default
497 /// ```
498 ///
499 /// # Kubernetes Integration
500 ///
501 /// ```yaml
502 /// livenessProbe:
503 /// httpGet:
504 /// path: /live
505 /// port: 3000
506 /// ```
507 #[must_use]
508 pub fn setup_liveness(mut self) -> Self {
509 if !self.is_middleware_enabled(HttpMiddleware::Liveness) {
510 tracing::trace!("Liveness middleware skipped (disabled in config)");
511 return self;
512 }
513
514 let liveness_route = self.config.http.liveness_route.clone();
515 tracing::trace!(route = %liveness_route, "Liveness endpoint enabled");
516 self.inner = self.inner.route(&liveness_route, get(|| async { "OK\n" }));
517 self
518 }
519
520 /// Sets up the Kubernetes readiness probe endpoint.
521 ///
522 /// Adds an endpoint that returns 200 OK if the service can handle traffic.
523 /// When the `postgres` feature is enabled and a database is configured,
524 /// this endpoint verifies database connectivity by executing a simple query.
525 /// If the database is unreachable, returns 503 Service Unavailable.
526 ///
527 /// When the `circuit-breaker` feature is also enabled, the endpoint first checks
528 /// if the database circuit breaker is open. If the circuit is open, it returns
529 /// 503 immediately without attempting a database query, preventing additional
530 /// load on a failing database.
531 ///
532 /// This endpoint is placed after rate limiting and timeout middleware so that:
533 /// - Excessive health check requests don't overwhelm the service
534 /// - Database queries have a timeout to prevent hanging probes
535 ///
536 /// # Configuration
537 ///
538 /// ```toml
539 /// [http]
540 /// readiness_route = "/ready" # Default
541 /// ```
542 ///
543 /// # Kubernetes Integration
544 ///
545 /// ```yaml
546 /// readinessProbe:
547 /// httpGet:
548 /// path: /ready
549 /// port: 3000
550 /// ```
551 #[must_use]
552 pub fn setup_readiness(mut self) -> Self {
553 if !self.is_middleware_enabled(HttpMiddleware::Readiness) {
554 tracing::trace!("Readiness middleware skipped (disabled in config)");
555 return self;
556 }
557
558 let readiness_route = self.config.http.readiness_route.clone();
559 tracing::trace!(route = %readiness_route, "Readiness endpoint enabled");
560
561 #[cfg(feature = "postgres")]
562 let db_pool = self.db_pool.clone();
563
564 #[cfg(all(feature = "circuit-breaker", feature = "postgres"))]
565 let circuit_breaker_registry = self.circuit_breaker_registry.clone();
566
567 self.inner = self.inner.route(
568 &readiness_route,
569 get(|| async move {
570 // When circuit-breaker and postgres are both enabled,
571 // check circuit state before querying
572 #[cfg(all(feature = "circuit-breaker", feature = "postgres"))]
573 {
574 let breaker = circuit_breaker_registry.get_or_default("database");
575 if !breaker.should_allow() {
576 tracing::warn!(
577 "Database circuit breaker is open, skipping health check query"
578 );
579 return (StatusCode::SERVICE_UNAVAILABLE, "Database circuit open\n");
580 }
581 }
582
583 #[cfg(feature = "postgres")]
584 match sqlx::query("SELECT 1").execute(&db_pool).await {
585 Ok(_) => (StatusCode::OK, "OK\n"),
586 Err(e) => {
587 tracing::error!("Database health check failed: {}", e);
588 (StatusCode::SERVICE_UNAVAILABLE, "Database unavailable\n")
589 }
590 }
591
592 #[cfg(not(feature = "postgres"))]
593 (StatusCode::OK, "OK\n")
594 }),
595 );
596 self
597 }
598
599 /// Sets up both Kubernetes health check endpoints.
600 ///
601 /// This is a convenience method that calls both [`setup_liveness`](Self::setup_liveness)
602 /// and [`setup_readiness`](Self::setup_readiness). However, when using `setup_middleware()`,
603 /// these endpoints are placed at different positions in the middleware stack for optimal
604 /// behavior.
605 ///
606 /// # Deprecated
607 ///
608 /// Prefer using `setup_middleware()` which places liveness and readiness endpoints at
609 /// their optimal positions in the middleware stack. If you need manual control, use
610 /// `setup_liveness()` and `setup_readiness()` separately.
611 ///
612 /// # Configuration
613 ///
614 /// ```toml
615 /// [http]
616 /// liveness_route = "/live" # Default
617 /// readiness_route = "/ready" # Default
618 /// ```
619 #[must_use]
620 #[deprecated(
621 since = "0.4.0",
622 note = "Use setup_middleware() or call setup_liveness() and setup_readiness() separately for optimal middleware ordering"
623 )]
624 pub fn setup_liveness_readiness(self) -> Self {
625 self.setup_liveness().setup_readiness()
626 }
627}