axum_conf/fluent/
features.rs

1//! Feature middleware: routing, compression, CORS, security headers, sessions, health checks.
2
3use super::router::FluentRouter;
4use crate::HttpMiddleware;
5
6use {axum::routing::get, http::StatusCode, tower_http::timeout::TimeoutLayer};
7
8#[cfg(feature = "path-normalization")]
9use tower_http::normalize_path::NormalizePathLayer;
10
11#[cfg(feature = "compression")]
12use tower_http::{compression::CompressionLayer, decompression::RequestDecompressionLayer};
13
14#[cfg(feature = "cors")]
15use {http::HeaderName, tower_http::cors::CorsLayer};
16
17#[cfg(feature = "security-headers")]
18use axum_helmet::{Helmet, HelmetLayer};
19
20#[cfg(feature = "session")]
21use tower_sessions::{
22    Expiry, MemoryStore, SessionManagerLayer,
23    cookie::{SameSite, time::Duration as CookieDuration},
24};
25
26#[cfg(feature = "api-versioning")]
27use {
28    crate::utils::ApiVersion,
29    axum::{body::Body, middleware::Next},
30    http::Request,
31};
32
33impl<State> FluentRouter<State>
34where
35    State: Clone + Send + Sync + 'static,
36{
37    /// Sets up cookie-based session handling using an in-memory store.
38    #[cfg(feature = "session")]
39    #[must_use]
40    pub fn setup_session_handling(mut self) -> Self {
41        let session_store = MemoryStore::default();
42        let session_layer = SessionManagerLayer::new(session_store)
43            .with_secure(false)
44            .with_same_site(SameSite::Lax)
45            .with_expiry(Expiry::OnInactivity(CookieDuration::seconds(3600)));
46        self.inner = self.inner.layer(session_layer);
47        tracing::trace!("Session middleware enabled");
48        self
49    }
50
51    /// Sets up path normalization middleware.
52    ///
53    /// When `config.http.trim_trailing_slash` is true, automatically removes
54    /// trailing slashes from request paths:
55    /// - `/api/users/` → `/api/users`
56    /// - `/health/` → `/health`
57    ///
58    /// This ensures consistent routing behavior regardless of whether clients
59    /// include trailing slashes.
60    ///
61    /// # Configuration
62    ///
63    /// ```toml
64    /// [http]
65    /// trim_trailing_slash = true  # Default
66    /// ```
67    #[cfg(feature = "path-normalization")]
68    #[must_use]
69    pub fn setup_path_normalization(mut self) -> Self {
70        if !self.is_middleware_enabled(HttpMiddleware::PathNormalization) {
71            tracing::trace!("PathNormalization middleware skipped (disabled in config)");
72            return self;
73        }
74
75        self.inner = self.inner.layer(NormalizePathLayer::trim_trailing_slash());
76        tracing::trace!("PathNormalization middleware enabled");
77        self
78    }
79
80    /// No-op when `path-normalization` feature is disabled.
81    #[cfg(not(feature = "path-normalization"))]
82    #[must_use]
83    pub fn setup_path_normalization(self) -> Self {
84        self
85    }
86
87    /// Sets up request timeout middleware.
88    ///
89    /// Aborts requests that take longer than the configured duration with a
90    /// `408 Request Timeout` response.
91    ///
92    /// # Configuration
93    ///
94    /// ```toml
95    /// [http]
96    /// request_timeout = "30s"  # Optional, uses humantime format
97    /// ```
98    ///
99    /// # Use Cases
100    ///
101    /// - Prevent slow requests from tying up resources
102    /// - Ensure predictable response times
103    /// - Protect against slowloris attacks
104    #[must_use]
105    pub fn setup_timeout(mut self) -> Self {
106        // Skip if timeout middleware is disabled
107        if !self.is_middleware_enabled(HttpMiddleware::Timeout) {
108            tracing::trace!("Timeout middleware skipped (disabled in config)");
109            return self;
110        }
111
112        if let Some(timeout) = self.config.http.request_timeout {
113            self.inner = self.inner.layer(TimeoutLayer::with_status_code(
114                StatusCode::REQUEST_TIMEOUT,
115                timeout,
116            ));
117            tracing::trace!(timeout = ?timeout, "Timeout middleware enabled");
118        }
119        self
120    }
121
122    /// Sets up API versioning middleware.
123    ///
124    /// Automatically extracts API version from requests and adds it to request extensions.
125    /// Supports multiple version detection methods:
126    /// - **Path-based**: `/v1/users`, `/api/v2/users`
127    /// - **Header-based**: `X-API-Version: 2`, `Accept: application/vnd.api+json;version=2`
128    /// - **Query parameter**: `/users?version=1`
129    ///
130    /// The version is checked in order: path → header → query → default.
131    ///
132    /// **Note**: This middleware is automatically included in `setup_middleware()` using
133    /// the `config.http.default_api_version` setting. You only need to call this method
134    /// directly if you want to override the configured default or set up versioning manually.
135    ///
136    /// # Arguments
137    ///
138    /// * `default_version` - The version to use when none is specified in the request
139    ///
140    /// # Examples
141    ///
142    /// ```rust,no_run
143    /// use axum_conf::{Config, FluentRouter, ApiVersion};
144    /// use axum::{routing::get, extract::Extension};
145    ///
146    /// async fn handler(Extension(version): Extension<ApiVersion>) -> String {
147    ///     format!("API version: {}", version)
148    /// }
149    ///
150    /// # async fn example() -> axum_conf::Result<()> {
151    /// FluentRouter::without_state(Config::default())?
152    ///     .setup_api_versioning(1)  // Default to v1
153    ///     .route("/users", get(handler));
154    /// # Ok(())
155    /// # }
156    /// ```
157    ///
158    /// # Handler Usage
159    ///
160    /// Extract the version in your handlers using `Extension<ApiVersion>`:
161    ///
162    /// ```rust,ignore
163    /// use axum::extract::Extension;
164    /// use axum_conf::ApiVersion;
165    ///
166    /// async fn my_handler(Extension(version): Extension<ApiVersion>) -> String {
167    ///     match version.as_u32() {
168    ///         1 => handle_v1(),
169    ///         2 => handle_v2(),
170    ///         _ => "Unsupported version".to_string(),
171    ///     }
172    /// }
173    /// ```
174    #[cfg(feature = "api-versioning")]
175    #[must_use]
176    pub fn setup_api_versioning(mut self, default_version: u32) -> Self {
177        if !self.is_middleware_enabled(HttpMiddleware::ApiVersioning) {
178            tracing::trace!("ApiVersioning middleware skipped (disabled in config)");
179            return self;
180        }
181
182        tracing::trace!(default_version = default_version, "ApiVersioning middleware enabled");
183        use axum::middleware;
184
185        let default_version = ApiVersion::new(default_version);
186
187        self.inner = self.inner.layer(middleware::from_fn(
188            move |mut req: Request<Body>, next: Next| async move {
189                // Try to extract version from path first
190                let version = ApiVersion::from_path(req.uri().path())
191                    // Then try X-API-Version header
192                    .or_else(|| {
193                        req.headers()
194                            .get("x-api-version")
195                            .and_then(|h| h.to_str().ok())
196                            .and_then(ApiVersion::from_header)
197                    })
198                    // Then try Accept header
199                    .or_else(|| {
200                        req.headers()
201                            .get(http::header::ACCEPT)
202                            .and_then(|h| h.to_str().ok())
203                            .and_then(ApiVersion::from_header)
204                    })
205                    // Then try query parameter
206                    .or_else(|| req.uri().query().and_then(ApiVersion::from_query))
207                    // Fall back to default
208                    .unwrap_or(default_version);
209
210                // Add version to request extensions
211                req.extensions_mut().insert(version);
212
213                // Log the version being used
214                tracing::debug!(
215                    version = %version,
216                    path = %req.uri().path(),
217                    "API version detected"
218                );
219
220                next.run(req).await
221            },
222        ));
223        self
224    }
225
226    /// No-op when `api-versioning` feature is disabled.
227    #[cfg(not(feature = "api-versioning"))]
228    #[must_use]
229    pub fn setup_api_versioning(self, _default_version: u32) -> Self {
230        self
231    }
232
233    /// Sets up security headers using Helmet.
234    ///
235    /// Adds HTTP security headers based on configuration:
236    /// - `X-Content-Type-Options: nosniff` (prevents MIME sniffing)
237    /// - `X-Frame-Options` (clickjacking protection)
238    ///
239    /// # Configuration
240    ///
241    /// ```toml
242    /// [http]
243    /// x_content_type_nosniff = true  # Default
244    /// x_frame_options = "DENY"       # Default: DENY, SAMEORIGIN, or URL
245    /// ```
246    ///
247    /// # Security Benefits
248    ///
249    /// - Prevents browsers from MIME-sniffing responses
250    /// - Protects against clickjacking attacks
251    /// - Improves security score in penetration tests
252    #[cfg(feature = "security-headers")]
253    #[must_use]
254    pub fn setup_helmet(mut self) -> Self {
255        if !self.is_middleware_enabled(HttpMiddleware::SecurityHeaders) {
256            tracing::trace!("SecurityHeaders middleware skipped (disabled in config)");
257            return self;
258        }
259
260        tracing::trace!(
261            x_frame_options = ?self.config.http.x_frame_options,
262            x_content_type_nosniff = self.config.http.x_content_type_nosniff,
263            "SecurityHeaders middleware enabled"
264        );
265        let mut helmet = Helmet::new();
266        if self.config.http.x_content_type_nosniff {
267            helmet = helmet.add(helmet_core::XContentTypeOptions::nosniff());
268        }
269        // Convert our local XFrameOptions to axum_helmet's version
270        let x_frame = match &self.config.http.x_frame_options.0 {
271            crate::XFrameOptions::Deny => axum_helmet::XFrameOptions::Deny,
272            crate::XFrameOptions::SameOrigin => axum_helmet::XFrameOptions::SameOrigin,
273            crate::XFrameOptions::AllowFrom(url) => {
274                axum_helmet::XFrameOptions::AllowFrom(url.clone())
275            }
276        };
277        helmet = helmet.add(x_frame);
278        self.inner = self.inner.layer(HelmetLayer::new(helmet));
279        self
280    }
281
282    /// No-op when `security-headers` feature is disabled.
283    #[cfg(not(feature = "security-headers"))]
284    #[must_use]
285    pub fn setup_helmet(self) -> Self {
286        self
287    }
288
289    /// Sets up request decompression and response compression.
290    ///
291    /// When `config.http.support_compression` is true, enables:
292    /// - Request body decompression (gzip, brotli, deflate, zstd)
293    /// - Response body compression (based on Accept-Encoding header)
294    ///
295    /// # Configuration
296    ///
297    /// ```toml
298    /// [http]
299    /// support_compression = true  # Default: false
300    /// ```
301    ///
302    /// # Performance
303    ///
304    /// - Reduces bandwidth usage
305    /// - May increase CPU usage
306    /// - Most beneficial for text-based responses (JSON, HTML, etc.)
307    #[cfg(feature = "compression")]
308    #[must_use]
309    pub fn setup_compression(mut self) -> Self {
310        if self.config.http.support_compression
311            && self.is_middleware_enabled(HttpMiddleware::Compression)
312        {
313            self.inner = self
314                .inner
315                .layer(RequestDecompressionLayer::new())
316                .layer(CompressionLayer::new());
317            tracing::trace!("Compression middleware enabled");
318        } else {
319            tracing::trace!("Compression middleware skipped (disabled in config)");
320        }
321        self
322    }
323
324    /// No-op when `compression` feature is disabled.
325    #[cfg(not(feature = "compression"))]
326    #[must_use]
327    pub fn setup_compression(self) -> Self {
328        if self.config.http.support_compression {
329            tracing::warn!(
330                "Compression is enabled in config but the 'compression' feature is not enabled. \
331                 Add `compression` to your Cargo.toml features to enable compression support."
332            );
333        }
334        self
335    }
336
337    /// Sets up Cross-Origin Resource Sharing (CORS) middleware.
338    ///
339    /// Configures which web domains can make requests to your API from a browser.
340    /// If no CORS configuration is provided, defaults to very permissive settings
341    /// (allows all origins, methods, and headers).
342    ///
343    /// # Configuration
344    ///
345    /// ```toml
346    /// [http.cors]
347    /// allow_credentials = true
348    /// allowed_origins = ["https://app.example.com", "https://admin.example.com"]
349    /// allowed_methods = ["GET", "POST", "PUT", "DELETE"]
350    /// allowed_headers = ["content-type", "authorization"]
351    /// exposed_headers = ["x-request-id"]
352    /// max_age = "1h"
353    /// ```
354    ///
355    /// # Security Considerations
356    ///
357    /// - When `allow_credentials` is `true`, wildcard origins are not allowed
358    /// - Without explicit configuration, uses `CorsLayer::very_permissive()` which is suitable
359    ///   for development but may be too permissive for production
360    /// - Always configure explicit `allowed_origins` in production environments
361    ///
362    /// # Examples
363    ///
364    /// ```rust,no_run
365    /// use axum_conf::{Config, HttpCorsConfig};
366    /// # fn example() -> axum_conf::Result<()> {
367    /// let mut config = Config::default();
368    /// config.http.cors = Some(HttpCorsConfig {
369    ///     allow_credentials: Some(true),
370    ///     allowed_origins: Some(vec!["https://app.example.com".to_string()]),
371    ///     allowed_methods: None,
372    ///     allowed_headers: None,
373    ///     exposed_headers: None,
374    ///     max_age: None,
375    /// });
376    /// # Ok(())
377    /// # }
378    /// ```
379    #[cfg(feature = "cors")]
380    #[must_use]
381    pub fn setup_cors(mut self) -> Self {
382        if !self.is_middleware_enabled(HttpMiddleware::Cors) {
383            tracing::trace!("CORS middleware skipped (disabled in config)");
384            return self;
385        }
386
387        use http::HeaderValue;
388
389        if let Some(cors_config) = &self.config.http.cors {
390            tracing::trace!("CORS middleware enabled with custom configuration");
391            let mut cors = CorsLayer::new();
392
393            // By default we do NOT allow credentials
394            let has_credentials = cors_config.allow_credentials.unwrap_or(false);
395
396            // Configure allowed origins
397            if let Some(origins) = &cors_config.allowed_origins {
398                for origin in origins {
399                    if let Ok(header_value) = HeaderValue::from_str(origin) {
400                        cors = cors.allow_origin(header_value);
401                    }
402                }
403            } else if !has_credentials {
404                // Only use wildcard if credentials is not enabled
405                cors = cors.allow_origin(tower_http::cors::Any);
406            }
407
408            // Configure allowed methods
409            if let Some(methods) = &cors_config.allowed_methods {
410                let method_list: Vec<http::Method> = methods.iter().map(|m| m.0.clone()).collect();
411                cors = cors.allow_methods(method_list);
412            } else if !has_credentials {
413                // Only use wildcard if credentials is not enabled
414                cors = cors.allow_methods(tower_http::cors::Any);
415            }
416
417            // Configure allowed headers
418            if let Some(headers) = &cors_config.allowed_headers {
419                let header_list: Vec<HeaderName> = headers.iter().map(|h| h.0.clone()).collect();
420                cors = cors.allow_headers(header_list);
421            } else if !has_credentials {
422                // Only use wildcard if credentials is not enabled
423                cors = cors.allow_headers(tower_http::cors::Any);
424            }
425
426            // Configure exposed headers
427            if let Some(headers) = &cors_config.exposed_headers {
428                let header_list: Vec<HeaderName> = headers.iter().map(|h| h.0.clone()).collect();
429                cors = cors.expose_headers(header_list);
430            }
431
432            // Configure max age
433            if let Some(max_age) = cors_config.max_age {
434                cors = cors.max_age(max_age);
435            }
436
437            // Configure credentials (must be set last after origins/headers)
438            if has_credentials {
439                cors = cors.allow_credentials(true);
440            }
441
442            self.inner = self.inner.layer(cors);
443        } else {
444            // No CORS config specified - behavior depends on environment
445            let rust_env = std::env::var("RUST_ENV").unwrap_or_default().to_lowercase();
446            let is_production = rust_env.is_empty()
447                || rust_env == "prod"
448                || rust_env == "production"
449                || rust_env == "release";
450
451            if is_production {
452                // Production: fail-safe to restrictive CORS (same-origin only)
453                tracing::warn!(
454                    "No CORS configuration found in production environment. \
455                     Using restrictive same-origin policy. Configure [http.cors] \
456                     in your config file to allow cross-origin requests."
457                );
458                // Default CorsLayer denies all cross-origin requests
459                self.inner = self.inner.layer(CorsLayer::new());
460            } else {
461                // Development/Test: use permissive defaults with warning
462                tracing::warn!(
463                    "No CORS configuration found (RUST_ENV={}). Using permissive defaults. \
464                     This is NOT safe for production - configure explicit CORS rules.",
465                    rust_env
466                );
467                self.inner = self.inner.layer(CorsLayer::very_permissive());
468            }
469        }
470        self
471    }
472
473    /// No-op when `cors` feature is disabled.
474    #[cfg(not(feature = "cors"))]
475    #[must_use]
476    pub fn setup_cors(self) -> Self {
477        if self.config.http.cors.is_some() {
478            tracing::warn!(
479                "CORS is configured but the 'cors' feature is not enabled. \
480                 Add `cors` to your Cargo.toml features to enable CORS support."
481            );
482        }
483        self
484    }
485
486    /// Sets up the Kubernetes liveness probe endpoint.
487    ///
488    /// Adds a simple endpoint that always returns 200 OK to indicate the process is running.
489    /// This endpoint is placed very early in the middleware stack (after panic catching) so
490    /// it remains accessible even when other middleware fails.
491    ///
492    /// # Configuration
493    ///
494    /// ```toml
495    /// [http]
496    /// liveness_route = "/live"   # Default
497    /// ```
498    ///
499    /// # Kubernetes Integration
500    ///
501    /// ```yaml
502    /// livenessProbe:
503    ///   httpGet:
504    ///     path: /live
505    ///     port: 3000
506    /// ```
507    #[must_use]
508    pub fn setup_liveness(mut self) -> Self {
509        if !self.is_middleware_enabled(HttpMiddleware::Liveness) {
510            tracing::trace!("Liveness middleware skipped (disabled in config)");
511            return self;
512        }
513
514        let liveness_route = self.config.http.liveness_route.clone();
515        tracing::trace!(route = %liveness_route, "Liveness endpoint enabled");
516        self.inner = self.inner.route(&liveness_route, get(|| async { "OK\n" }));
517        self
518    }
519
520    /// Sets up the Kubernetes readiness probe endpoint.
521    ///
522    /// Adds an endpoint that returns 200 OK if the service can handle traffic.
523    /// When the `postgres` feature is enabled and a database is configured,
524    /// this endpoint verifies database connectivity by executing a simple query.
525    /// If the database is unreachable, returns 503 Service Unavailable.
526    ///
527    /// When the `circuit-breaker` feature is also enabled, the endpoint first checks
528    /// if the database circuit breaker is open. If the circuit is open, it returns
529    /// 503 immediately without attempting a database query, preventing additional
530    /// load on a failing database.
531    ///
532    /// This endpoint is placed after rate limiting and timeout middleware so that:
533    /// - Excessive health check requests don't overwhelm the service
534    /// - Database queries have a timeout to prevent hanging probes
535    ///
536    /// # Configuration
537    ///
538    /// ```toml
539    /// [http]
540    /// readiness_route = "/ready" # Default
541    /// ```
542    ///
543    /// # Kubernetes Integration
544    ///
545    /// ```yaml
546    /// readinessProbe:
547    ///   httpGet:
548    ///     path: /ready
549    ///     port: 3000
550    /// ```
551    #[must_use]
552    pub fn setup_readiness(mut self) -> Self {
553        if !self.is_middleware_enabled(HttpMiddleware::Readiness) {
554            tracing::trace!("Readiness middleware skipped (disabled in config)");
555            return self;
556        }
557
558        let readiness_route = self.config.http.readiness_route.clone();
559        tracing::trace!(route = %readiness_route, "Readiness endpoint enabled");
560
561        #[cfg(feature = "postgres")]
562        let db_pool = self.db_pool.clone();
563
564        #[cfg(all(feature = "circuit-breaker", feature = "postgres"))]
565        let circuit_breaker_registry = self.circuit_breaker_registry.clone();
566
567        self.inner = self.inner.route(
568            &readiness_route,
569            get(|| async move {
570                // When circuit-breaker and postgres are both enabled,
571                // check circuit state before querying
572                #[cfg(all(feature = "circuit-breaker", feature = "postgres"))]
573                {
574                    let breaker = circuit_breaker_registry.get_or_default("database");
575                    if !breaker.should_allow() {
576                        tracing::warn!(
577                            "Database circuit breaker is open, skipping health check query"
578                        );
579                        return (StatusCode::SERVICE_UNAVAILABLE, "Database circuit open\n");
580                    }
581                }
582
583                #[cfg(feature = "postgres")]
584                match sqlx::query("SELECT 1").execute(&db_pool).await {
585                    Ok(_) => (StatusCode::OK, "OK\n"),
586                    Err(e) => {
587                        tracing::error!("Database health check failed: {}", e);
588                        (StatusCode::SERVICE_UNAVAILABLE, "Database unavailable\n")
589                    }
590                }
591
592                #[cfg(not(feature = "postgres"))]
593                (StatusCode::OK, "OK\n")
594            }),
595        );
596        self
597    }
598
599    /// Sets up both Kubernetes health check endpoints.
600    ///
601    /// This is a convenience method that calls both [`setup_liveness`](Self::setup_liveness)
602    /// and [`setup_readiness`](Self::setup_readiness). However, when using `setup_middleware()`,
603    /// these endpoints are placed at different positions in the middleware stack for optimal
604    /// behavior.
605    ///
606    /// # Deprecated
607    ///
608    /// Prefer using `setup_middleware()` which places liveness and readiness endpoints at
609    /// their optimal positions in the middleware stack. If you need manual control, use
610    /// `setup_liveness()` and `setup_readiness()` separately.
611    ///
612    /// # Configuration
613    ///
614    /// ```toml
615    /// [http]
616    /// liveness_route = "/live"   # Default
617    /// readiness_route = "/ready" # Default
618    /// ```
619    #[must_use]
620    #[deprecated(
621        since = "0.4.0",
622        note = "Use setup_middleware() or call setup_liveness() and setup_readiness() separately for optimal middleware ordering"
623    )]
624    pub fn setup_liveness_readiness(self) -> Self {
625        self.setup_liveness().setup_readiness()
626    }
627}