Skip to main content

fastapi_core/
middleware.rs

1//! Middleware abstraction for request/response processing.
2//!
3//! This module provides a flexible middleware system that allows:
4//! - Pre-processing requests before handlers run
5//! - Post-processing responses after handlers complete
6//! - Short-circuiting to return early without calling handlers
7//! - Composable middleware stacks with defined ordering
8//!
9//! # Design Philosophy
10//!
11//! The middleware system follows these principles:
12//! - **Zero-cost when empty**: No overhead if no middleware is configured
13//! - **Async-native**: All hooks are async for I/O operations
14//! - **Cancel-aware**: Integrates with asupersync's cancellation
15//! - **Composable**: Middleware can be stacked and layered
16//!
17//! # Ordering Semantics
18//!
19//! Middleware executes in a specific order:
20//! 1. `before` hooks run in **registration order** (first registered, first run)
21//! 2. Handler executes
22//! 3. `after` hooks run in **reverse order** (last registered, first run)
23//!
24//! This creates an "onion" model where the first middleware wraps everything:
25//!
26//! ```text
27//! Request → MW1.before → MW2.before → MW3.before → Handler
28//!                                                     ↓
29//! Response ← MW1.after ← MW2.after ← MW3.after ← Response
30//! ```
31//!
32//! # Example
33//!
34//! ```ignore
35//! use fastapi_core::middleware::{Middleware, ControlFlow};
36//! use fastapi_core::{Request, Response, RequestContext};
37//!
38//! struct LoggingMiddleware;
39//!
40//! impl Middleware for LoggingMiddleware {
41//!     async fn before(&self, ctx: &RequestContext, req: &Request) -> ControlFlow {
42//!         println!("Request: {} {}", req.method(), req.path());
43//!         ControlFlow::Continue
44//!     }
45//!
46//!     async fn after(&self, _ctx: &RequestContext, _req: &Request, resp: Response) -> Response {
47//!         println!("Response: {}", resp.status().as_u16());
48//!         resp
49//!     }
50//! }
51//! ```
52
53use std::collections::HashSet;
54use std::future::Future;
55use std::ops::ControlFlow as StdControlFlow;
56use std::pin::Pin;
57use std::sync::Arc;
58use std::time::Instant;
59
60use crate::context::RequestContext;
61use crate::dependency::DependencyOverrides;
62use crate::logging::{LogConfig, RequestLogger};
63use crate::request::{Body, Request};
64use crate::response::Response;
65
66/// A boxed future for async middleware operations.
67pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
68
69/// Control flow for middleware `before` hooks.
70///
71/// Determines whether request processing should continue to the handler
72/// or short-circuit with an early response.
73#[derive(Debug)]
74pub enum ControlFlow {
75    /// Continue processing - call the next middleware or handler.
76    Continue,
77    /// Short-circuit - return this response immediately without calling the handler.
78    ///
79    /// Subsequent `before` hooks and the handler will NOT run.
80    /// However, `after` hooks for middleware that already ran their `before` WILL run.
81    Break(Response),
82}
83
84impl ControlFlow {
85    /// Returns `true` if this is `Continue`.
86    #[must_use]
87    pub fn is_continue(&self) -> bool {
88        matches!(self, Self::Continue)
89    }
90
91    /// Returns `true` if this is `Break`.
92    #[must_use]
93    pub fn is_break(&self) -> bool {
94        matches!(self, Self::Break(_))
95    }
96}
97
98impl From<ControlFlow> for StdControlFlow<Response, ()> {
99    fn from(cf: ControlFlow) -> Self {
100        match cf {
101            ControlFlow::Continue => StdControlFlow::Continue(()),
102            ControlFlow::Break(r) => StdControlFlow::Break(r),
103        }
104    }
105}
106
107/// The core middleware trait.
108///
109/// Middleware wraps request handling with pre-processing and post-processing hooks.
110/// Implementations must be thread-safe (`Send + Sync`) as middleware may be shared
111/// across concurrent requests.
112///
113/// # Implementation Guide
114///
115/// - **`before`**: Inspect/modify the request, optionally short-circuit
116/// - **`after`**: Inspect/modify the response
117///
118/// Both methods have default implementations that do nothing, so you can
119/// implement only what you need.
120///
121/// # Cancel-Safety
122///
123/// Middleware should check `ctx.checkpoint()` for long operations to support
124/// graceful cancellation when clients disconnect or timeouts occur.
125///
126/// # Example: Request Timing
127///
128/// ```ignore
129/// use std::time::Instant;
130/// use fastapi_core::middleware::{Middleware, ControlFlow};
131///
132/// struct TimingMiddleware;
133///
134/// impl Middleware for TimingMiddleware {
135///     async fn before(&self, ctx: &RequestContext, req: &mut Request) -> ControlFlow {
136///         // Store start time in request extensions (future feature)
137///         ControlFlow::Continue
138///     }
139///
140///     async fn after(&self, _ctx: &RequestContext, _req: &Request, mut resp: Response) -> Response {
141///         // Add timing header
142///         resp = resp.header("X-Response-Time", b"42ms".to_vec());
143///         resp
144///     }
145/// }
146/// ```
147pub trait Middleware: Send + Sync {
148    /// Called before the handler executes.
149    ///
150    /// # Parameters
151    ///
152    /// - `ctx`: Request context with cancellation support
153    /// - `req`: Mutable request that can be inspected or modified
154    ///
155    /// # Returns
156    ///
157    /// - `ControlFlow::Continue` to proceed to the next middleware/handler
158    /// - `ControlFlow::Break(response)` to short-circuit and return immediately
159    ///
160    /// # Default Implementation
161    ///
162    /// Returns `ControlFlow::Continue` (no-op).
163    fn before<'a>(
164        &'a self,
165        _ctx: &'a RequestContext,
166        _req: &'a mut Request,
167    ) -> BoxFuture<'a, ControlFlow> {
168        Box::pin(async { ControlFlow::Continue })
169    }
170
171    /// Called after the handler executes.
172    ///
173    /// # Parameters
174    ///
175    /// - `ctx`: Request context with cancellation support
176    /// - `req`: The request (read-only at this point)
177    /// - `response`: The response from the handler or previous `after` hooks
178    ///
179    /// # Returns
180    ///
181    /// The response to pass to the next `after` hook or to return to the client.
182    ///
183    /// # Default Implementation
184    ///
185    /// Returns the response unchanged (no-op).
186    fn after<'a>(
187        &'a self,
188        _ctx: &'a RequestContext,
189        _req: &'a Request,
190        response: Response,
191    ) -> BoxFuture<'a, Response> {
192        Box::pin(async move { response })
193    }
194
195    /// Returns the middleware name for debugging and logging.
196    ///
197    /// Override this to provide a meaningful name for your middleware.
198    fn name(&self) -> &'static str {
199        std::any::type_name::<Self>()
200    }
201}
202
203/// A handler that processes requests into responses.
204///
205/// This trait abstracts over handler functions, allowing middleware to wrap
206/// any type that can handle requests.
207pub trait Handler: Send + Sync {
208    /// Process a request and return a response.
209    fn call<'a>(&'a self, ctx: &'a RequestContext, req: &'a mut Request)
210    -> BoxFuture<'a, Response>;
211
212    /// Optional dependency overrides to apply when building request contexts.
213    ///
214    /// Default implementation returns `None`, which means no overrides.
215    fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
216        None
217    }
218}
219
220/// Implement Handler for async functions.
221///
222/// This allows any async function with the signature
223/// `async fn(&RequestContext, &mut Request) -> Response` to be used as a handler.
224impl<F, Fut> Handler for F
225where
226    F: Fn(&RequestContext, &mut Request) -> Fut + Send + Sync,
227    Fut: Future<Output = Response> + Send + 'static,
228{
229    fn call<'a>(
230        &'a self,
231        ctx: &'a RequestContext,
232        req: &'a mut Request,
233    ) -> BoxFuture<'a, Response> {
234        let fut = self(ctx, req);
235        Box::pin(fut)
236    }
237}
238
239/// Delegate `Handler` to an `Arc`-wrapped handler.
240///
241/// This is a convenience for building apps behind `Arc` (common in tests and when
242/// cloning shared handlers).
243impl<H: Handler + ?Sized> Handler for Arc<H> {
244    fn call<'a>(
245        &'a self,
246        ctx: &'a RequestContext,
247        req: &'a mut Request,
248    ) -> BoxFuture<'a, Response> {
249        (**self).call(ctx, req)
250    }
251
252    fn dependency_overrides(&self) -> Option<Arc<DependencyOverrides>> {
253        (**self).dependency_overrides()
254    }
255}
256
257/// A stack of middleware that wraps a handler.
258///
259/// The stack executes middleware in order:
260/// 1. `before` hooks run first-to-last (registration order)
261/// 2. Handler executes (if no middleware short-circuited)
262/// 3. `after` hooks run last-to-first (reverse order)
263///
264/// # Example
265///
266/// ```ignore
267/// let mut stack = MiddlewareStack::new();
268/// stack.push(LoggingMiddleware);
269/// stack.push(AuthMiddleware);
270/// stack.push(CorsMiddleware);
271///
272/// let response = stack.execute(&handler, &ctx, &mut request).await;
273/// ```
274#[derive(Default)]
275pub struct MiddlewareStack {
276    middleware: Vec<Arc<dyn Middleware>>,
277}
278
279impl MiddlewareStack {
280    /// Creates an empty middleware stack.
281    #[must_use]
282    pub fn new() -> Self {
283        Self {
284            middleware: Vec::new(),
285        }
286    }
287
288    /// Creates a middleware stack with pre-allocated capacity.
289    #[must_use]
290    pub fn with_capacity(capacity: usize) -> Self {
291        Self {
292            middleware: Vec::with_capacity(capacity),
293        }
294    }
295
296    /// Adds middleware to the end of the stack.
297    ///
298    /// Middleware added first will have its `before` run first and `after` run last.
299    pub fn push<M: Middleware + 'static>(&mut self, middleware: M) {
300        self.middleware.push(Arc::new(middleware));
301    }
302
303    /// Adds middleware wrapped in an Arc.
304    ///
305    /// Useful for sharing middleware across multiple stacks.
306    pub fn push_arc(&mut self, middleware: Arc<dyn Middleware>) {
307        self.middleware.push(middleware);
308    }
309
310    /// Returns the number of middleware in the stack.
311    #[must_use]
312    pub fn len(&self) -> usize {
313        self.middleware.len()
314    }
315
316    /// Returns `true` if the stack is empty.
317    #[must_use]
318    pub fn is_empty(&self) -> bool {
319        self.middleware.is_empty()
320    }
321
322    /// Executes the middleware stack with the given handler.
323    ///
324    /// # Execution Order
325    ///
326    /// 1. Each middleware's `before` hook runs in order
327    /// 2. If any `before` returns `Break`, skip remaining middleware and handler
328    /// 3. Handler executes
329    /// 4. Each middleware's `after` hook runs in reverse order
330    ///
331    /// # Short-Circuit Behavior
332    ///
333    /// If middleware N calls `Break(response)`:
334    /// - Middleware N+1..end `before` hooks do NOT run
335    /// - Handler does NOT run
336    /// - Middleware 0..N `after` hooks STILL run (in reverse: N, N-1, ..., 0)
337    ///
338    /// This ensures cleanup middleware (like timing or logging) always runs.
339    pub async fn execute<H: Handler>(
340        &self,
341        handler: &H,
342        ctx: &RequestContext,
343        req: &mut Request,
344    ) -> Response {
345        // Track which middleware ran their `before` hook
346        let mut ran_before_count = 0;
347
348        // Run before hooks in order
349        for mw in &self.middleware {
350            let _ = ctx.checkpoint();
351            match mw.before(ctx, req).await {
352                ControlFlow::Continue => {
353                    ran_before_count += 1;
354                }
355                ControlFlow::Break(response) => {
356                    // Short-circuit: run after hooks for middleware that already ran
357                    return self
358                        .run_after_hooks(ctx, req, response, ran_before_count)
359                        .await;
360                }
361            }
362        }
363
364        // All before hooks passed, call the handler
365        let _ = ctx.checkpoint();
366        let response = handler.call(ctx, req).await;
367
368        // Run after hooks in reverse order
369        self.run_after_hooks(ctx, req, response, ran_before_count)
370            .await
371    }
372
373    /// Runs after hooks for middleware that ran their before hook.
374    async fn run_after_hooks(
375        &self,
376        ctx: &RequestContext,
377        req: &Request,
378        mut response: Response,
379        count: usize,
380    ) -> Response {
381        // Run in reverse order (last middleware's after runs first)
382        for mw in self.middleware[..count].iter().rev() {
383            let _ = ctx.checkpoint();
384            response = mw.after(ctx, req, response).await;
385        }
386        response
387    }
388}
389
390/// A layer that can wrap handlers with middleware.
391///
392/// This provides a more functional composition style similar to Tower's Layer trait.
393///
394/// # Example
395///
396/// ```ignore
397/// let layer = Layer::new(LoggingMiddleware);
398/// let wrapped = layer.wrap(my_handler);
399/// ```
400pub struct Layer<M> {
401    middleware: M,
402}
403
404impl<M: Middleware + Clone> Layer<M> {
405    /// Creates a new layer with the given middleware.
406    pub fn new(middleware: M) -> Self {
407        Self { middleware }
408    }
409
410    /// Wraps a handler with this layer's middleware.
411    pub fn wrap<H: Handler>(&self, handler: H) -> Layered<M, H> {
412        Layered {
413            middleware: self.middleware.clone(),
414            inner: handler,
415        }
416    }
417}
418
419/// A handler wrapped with middleware via a Layer.
420pub struct Layered<M, H> {
421    middleware: M,
422    inner: H,
423}
424
425impl<M: Middleware, H: Handler> Handler for Layered<M, H> {
426    fn call<'a>(
427        &'a self,
428        ctx: &'a RequestContext,
429        req: &'a mut Request,
430    ) -> BoxFuture<'a, Response> {
431        Box::pin(async move {
432            // Run before hook
433            let _ = ctx.checkpoint();
434            match self.middleware.before(ctx, req).await {
435                ControlFlow::Continue => {
436                    // Call inner handler
437                    let _ = ctx.checkpoint();
438                    let response = self.inner.call(ctx, req).await;
439                    // Run after hook
440                    let _ = ctx.checkpoint();
441                    self.middleware.after(ctx, req, response).await
442                }
443                ControlFlow::Break(response) => {
444                    // Short-circuit: still run after for this middleware
445                    let _ = ctx.checkpoint();
446                    self.middleware.after(ctx, req, response).await
447                }
448            }
449        })
450    }
451}
452
453// ============================================================================
454// Common Middleware Implementations
455// ============================================================================
456
457/// No-op middleware that does nothing.
458///
459/// Useful as a placeholder or for testing.
460#[derive(Debug, Clone, Copy, Default)]
461pub struct NoopMiddleware;
462
463impl Middleware for NoopMiddleware {
464    fn name(&self) -> &'static str {
465        "Noop"
466    }
467}
468
469/// Middleware that adds a custom header to all responses.
470///
471/// # Example
472///
473/// ```ignore
474/// // Add X-Powered-By header to all responses
475/// let mw = AddResponseHeader::new("X-Powered-By", "fastapi_rust");
476/// stack.push(mw);
477/// ```
478#[derive(Debug, Clone)]
479pub struct AddResponseHeader {
480    name: String,
481    value: Vec<u8>,
482}
483
484impl AddResponseHeader {
485    /// Creates a new middleware that adds the specified header to responses.
486    pub fn new(name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
487        Self {
488            name: name.into(),
489            value: value.into(),
490        }
491    }
492}
493
494impl Middleware for AddResponseHeader {
495    fn after<'a>(
496        &'a self,
497        _ctx: &'a RequestContext,
498        _req: &'a Request,
499        response: Response,
500    ) -> BoxFuture<'a, Response> {
501        let name = self.name.clone();
502        let value = self.value.clone();
503        Box::pin(async move { response.header(name, value) })
504    }
505
506    fn name(&self) -> &'static str {
507        "AddResponseHeader"
508    }
509}
510
511/// Middleware that requires a specific header to be present.
512///
513/// Returns 400 Bad Request if the header is missing.
514///
515/// # Example
516///
517/// ```ignore
518/// // Require X-Api-Key header
519/// let mw = RequireHeader::new("X-Api-Key");
520/// stack.push(mw);
521/// ```
522#[derive(Debug, Clone)]
523pub struct RequireHeader {
524    name: String,
525}
526
527impl RequireHeader {
528    /// Creates a new middleware that requires the specified header.
529    pub fn new(name: impl Into<String>) -> Self {
530        Self { name: name.into() }
531    }
532}
533
534impl Middleware for RequireHeader {
535    fn before<'a>(
536        &'a self,
537        _ctx: &'a RequestContext,
538        req: &'a mut Request,
539    ) -> BoxFuture<'a, ControlFlow> {
540        let has_header = req.headers().get(&self.name).is_some();
541        let name = self.name.clone();
542        Box::pin(async move {
543            if has_header {
544                ControlFlow::Continue
545            } else {
546                let body = format!("Missing required header: {name}");
547                ControlFlow::Break(
548                    Response::with_status(crate::response::StatusCode::BAD_REQUEST)
549                        .header("content-type", b"text/plain".to_vec())
550                        .body(crate::response::ResponseBody::Bytes(body.into_bytes())),
551                )
552            }
553        })
554    }
555
556    fn name(&self) -> &'static str {
557        "RequireHeader"
558    }
559}
560
561/// Middleware that limits request processing based on path prefix.
562///
563/// Only allows requests to paths starting with the specified prefix.
564/// Other requests receive a 404 Not Found response.
565///
566/// # Example
567///
568/// ```ignore
569/// // Only allow requests to /api/*
570/// let mw = PathPrefixFilter::new("/api");
571/// stack.push(mw);
572/// ```
573#[derive(Debug, Clone)]
574pub struct PathPrefixFilter {
575    prefix: String,
576}
577
578impl PathPrefixFilter {
579    /// Creates a new middleware that only allows requests with the specified path prefix.
580    pub fn new(prefix: impl Into<String>) -> Self {
581        Self {
582            prefix: prefix.into(),
583        }
584    }
585}
586
587impl Middleware for PathPrefixFilter {
588    fn before<'a>(
589        &'a self,
590        _ctx: &'a RequestContext,
591        req: &'a mut Request,
592    ) -> BoxFuture<'a, ControlFlow> {
593        let path_matches = req.path().starts_with(&self.prefix);
594        Box::pin(async move {
595            if path_matches {
596                ControlFlow::Continue
597            } else {
598                ControlFlow::Break(Response::with_status(
599                    crate::response::StatusCode::NOT_FOUND,
600                ))
601            }
602        })
603    }
604
605    fn name(&self) -> &'static str {
606        "PathPrefixFilter"
607    }
608}
609
610/// Middleware that sets response status code based on a condition.
611///
612/// This is useful for implementing health checks or conditional responses.
613#[derive(Debug, Clone)]
614pub struct ConditionalStatus<F>
615where
616    F: Fn(&Request) -> bool + Send + Sync,
617{
618    condition: F,
619    status_if_true: crate::response::StatusCode,
620    status_if_false: crate::response::StatusCode,
621}
622
623impl<F> ConditionalStatus<F>
624where
625    F: Fn(&Request) -> bool + Send + Sync,
626{
627    /// Creates a new conditional status middleware.
628    ///
629    /// If the condition returns true, the response gets `status_if_true`.
630    /// Otherwise, it gets `status_if_false`.
631    pub fn new(
632        condition: F,
633        status_if_true: crate::response::StatusCode,
634        status_if_false: crate::response::StatusCode,
635    ) -> Self {
636        Self {
637            condition,
638            status_if_true,
639            status_if_false,
640        }
641    }
642}
643
644impl<F> Middleware for ConditionalStatus<F>
645where
646    F: Fn(&Request) -> bool + Send + Sync,
647{
648    fn after<'a>(
649        &'a self,
650        _ctx: &'a RequestContext,
651        req: &'a Request,
652        response: Response,
653    ) -> BoxFuture<'a, Response> {
654        let matches = (self.condition)(req);
655        let status = if matches {
656            self.status_if_true
657        } else {
658            self.status_if_false
659        };
660        Box::pin(async move { Response::with_status(status).body(response.body_ref().into()) })
661    }
662
663    fn name(&self) -> &'static str {
664        "ConditionalStatus"
665    }
666}
667
668// ============================================================================
669// CORS Middleware
670// ============================================================================
671
672/// Origin matching pattern for CORS.
673#[derive(Debug, Clone)]
674pub enum OriginPattern {
675    /// Allow any origin.
676    Any,
677    /// Exact match.
678    Exact(String),
679    /// Wildcard match (supports `*`).
680    Wildcard(String),
681    /// Simple regex match (supports `^`, `$`, `.`, `*`).
682    Regex(String),
683}
684
685impl OriginPattern {
686    fn matches(&self, origin: &str) -> bool {
687        match self {
688            Self::Any => true,
689            Self::Exact(value) => value == origin,
690            Self::Wildcard(pattern) => wildcard_match(pattern, origin),
691            Self::Regex(pattern) => regex_match(pattern, origin),
692        }
693    }
694}
695
696/// Cross-Origin Resource Sharing (CORS) configuration.
697///
698/// Controls which origins, methods, and headers are allowed for
699/// cross-origin requests. By default, no origins are allowed.
700///
701/// # Defaults
702///
703/// | Setting | Default |
704/// |---------|---------|
705/// | `allow_any_origin` | `false` |
706/// | `allow_credentials` | `false` |
707/// | `allowed_methods` | GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD |
708/// | `allowed_headers` | none |
709/// | `expose_headers` | none |
710/// | `max_age` | none |
711///
712/// # Security: Credentials and Wildcards
713///
714/// According to the CORS specification (Fetch Standard), when credentials
715/// mode is enabled (`allow_credentials: true`), the following headers
716/// **cannot** use the `*` wildcard value:
717///
718/// - `Access-Control-Allow-Origin` (must echo the specific origin)
719/// - `Access-Control-Allow-Headers` (must list specific headers)
720/// - `Access-Control-Allow-Methods` (must list specific methods)
721/// - `Access-Control-Expose-Headers` (must list specific headers)
722///
723/// This implementation enforces this: when `allow_credentials(true)` is
724/// combined with `allow_any_origin()`, the response echoes back the
725/// specific request origin instead of returning `*`.
726///
727/// # Example
728///
729/// ```ignore
730/// use fastapi_core::Cors;
731///
732/// // Secure: specific origin with credentials
733/// let cors = Cors::new()
734///     .allow_origin("https://myapp.example.com")
735///     .allow_credentials(true)
736///     .expose_headers(["X-Request-Id"]);
737///
738/// // Also secure: any origin echoes back specific origin when credentials enabled
739/// // (not recommended - prefer explicit origins for security)
740/// let cors = Cors::new()
741///     .allow_any_origin()
742///     .allow_credentials(true);
743/// ```
744#[derive(Debug, Clone)]
745pub struct CorsConfig {
746    allow_any_origin: bool,
747    allow_credentials: bool,
748    allowed_methods: Vec<crate::request::Method>,
749    allowed_headers: Vec<String>,
750    expose_headers: Vec<String>,
751    max_age: Option<u32>,
752    origins: Vec<OriginPattern>,
753}
754
755impl Default for CorsConfig {
756    fn default() -> Self {
757        Self {
758            allow_any_origin: false,
759            allow_credentials: false,
760            allowed_methods: vec![
761                crate::request::Method::Get,
762                crate::request::Method::Post,
763                crate::request::Method::Put,
764                crate::request::Method::Patch,
765                crate::request::Method::Delete,
766                crate::request::Method::Options,
767                crate::request::Method::Head,
768            ],
769            allowed_headers: Vec::new(),
770            expose_headers: Vec::new(),
771            max_age: None,
772            origins: Vec::new(),
773        }
774    }
775}
776
777/// CORS middleware.
778#[derive(Debug, Clone)]
779pub struct Cors {
780    config: CorsConfig,
781}
782
783impl Cors {
784    /// Create a new CORS middleware with default configuration.
785    #[must_use]
786    pub fn new() -> Self {
787        Self {
788            config: CorsConfig::default(),
789        }
790    }
791
792    /// Replace the configuration entirely.
793    #[must_use]
794    pub fn config(mut self, config: CorsConfig) -> Self {
795        self.config = config;
796        self
797    }
798
799    /// Allow any origin.
800    #[must_use]
801    pub fn allow_any_origin(mut self) -> Self {
802        self.config.allow_any_origin = true;
803        self
804    }
805
806    /// Allow a single exact origin.
807    #[must_use]
808    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
809        self.config
810            .origins
811            .push(OriginPattern::Exact(origin.into()));
812        self
813    }
814
815    /// Allow a wildcard origin pattern (supports `*`).
816    #[must_use]
817    pub fn allow_origin_wildcard(mut self, pattern: impl Into<String>) -> Self {
818        self.config
819            .origins
820            .push(OriginPattern::Wildcard(pattern.into()));
821        self
822    }
823
824    /// Allow a simple regex origin pattern (supports `^`, `$`, `.`, `*`).
825    #[must_use]
826    pub fn allow_origin_regex(mut self, pattern: impl Into<String>) -> Self {
827        self.config
828            .origins
829            .push(OriginPattern::Regex(pattern.into()));
830        self
831    }
832
833    /// Allow credentials for CORS responses.
834    #[must_use]
835    pub fn allow_credentials(mut self, allow: bool) -> Self {
836        self.config.allow_credentials = allow;
837        self
838    }
839
840    /// Override allowed HTTP methods for preflight.
841    #[must_use]
842    pub fn allow_methods<I>(mut self, methods: I) -> Self
843    where
844        I: IntoIterator<Item = crate::request::Method>,
845    {
846        self.config.allowed_methods = methods.into_iter().collect();
847        self
848    }
849
850    /// Override allowed headers for preflight.
851    #[must_use]
852    pub fn allow_headers<I, S>(mut self, headers: I) -> Self
853    where
854        I: IntoIterator<Item = S>,
855        S: Into<String>,
856    {
857        self.config.allowed_headers = headers.into_iter().map(Into::into).collect();
858        self
859    }
860
861    /// Add exposed headers for responses.
862    #[must_use]
863    pub fn expose_headers<I, S>(mut self, headers: I) -> Self
864    where
865        I: IntoIterator<Item = S>,
866        S: Into<String>,
867    {
868        self.config.expose_headers = headers.into_iter().map(Into::into).collect();
869        self
870    }
871
872    /// Set the preflight max-age in seconds.
873    #[must_use]
874    pub fn max_age(mut self, seconds: u32) -> Self {
875        self.config.max_age = Some(seconds);
876        self
877    }
878
879    fn is_origin_allowed(&self, origin: &str) -> bool {
880        if self.config.allow_any_origin {
881            return true;
882        }
883        self.config
884            .origins
885            .iter()
886            .any(|pattern| pattern.matches(origin))
887    }
888
889    fn allow_origin_value(&self, origin: &str) -> Option<String> {
890        if !self.is_origin_allowed(origin) {
891            return None;
892        }
893        if self.config.allow_any_origin && !self.config.allow_credentials {
894            Some("*".to_string())
895        } else {
896            Some(origin.to_string())
897        }
898    }
899
900    fn allow_methods_value(&self) -> String {
901        self.config
902            .allowed_methods
903            .iter()
904            .map(|method| method.as_str())
905            .collect::<Vec<_>>()
906            .join(", ")
907    }
908
909    fn allow_headers_value(&self, request: &Request) -> Option<String> {
910        if self.config.allowed_headers.is_empty() {
911            // No allowed headers configured — do NOT reflect the request's
912            // Access-Control-Request-Headers back, as that effectively allows
913            // arbitrary headers. Return None so the header is omitted entirely,
914            // meaning only CORS-safelisted request headers are permitted.
915            return None;
916        }
917
918        // Check for wildcard "*" — if any entry is wildcard, reflect request
919        // headers (standard CORS wildcard behavior when credentials are not
920        // in use). When credentials are enabled, wildcard is NOT valid per
921        // the Fetch spec, so we reflect the requested headers instead.
922        if self.config.allowed_headers.iter().any(|h| h == "*") {
923            if self.config.allow_credentials {
924                // With credentials, we cannot use literal "*" so reflect
925                // the request's headers as an explicit allow list.
926                return request
927                    .headers()
928                    .get("access-control-request-headers")
929                    .and_then(|value| std::str::from_utf8(value).ok())
930                    .map(ToString::to_string);
931            }
932            return Some("*".to_string());
933        }
934
935        Some(self.config.allowed_headers.join(", "))
936    }
937
938    fn apply_common_headers(&self, mut response: Response, origin: &str) -> Response {
939        if let Some(allow_origin) = self.allow_origin_value(origin) {
940            let is_wildcard = allow_origin == "*";
941            response = response.header("access-control-allow-origin", allow_origin.into_bytes());
942            if !is_wildcard {
943                response = response.header("vary", b"Origin".to_vec());
944            }
945            if self.config.allow_credentials {
946                response = response.header("access-control-allow-credentials", b"true".to_vec());
947            }
948            if !self.config.expose_headers.is_empty() {
949                response = response.header(
950                    "access-control-expose-headers",
951                    self.config.expose_headers.join(", ").into_bytes(),
952                );
953            }
954        }
955        response
956    }
957}
958
959impl Default for Cors {
960    fn default() -> Self {
961        Self::new()
962    }
963}
964
965#[derive(Debug, Clone)]
966struct CorsOrigin(String);
967
968impl Middleware for Cors {
969    fn before<'a>(
970        &'a self,
971        _ctx: &'a RequestContext,
972        req: &'a mut Request,
973    ) -> BoxFuture<'a, ControlFlow> {
974        let origin = req
975            .headers()
976            .get("origin")
977            .and_then(|value| std::str::from_utf8(value).ok())
978            .map(ToString::to_string);
979
980        let Some(origin) = origin else {
981            return Box::pin(async { ControlFlow::Continue });
982        };
983
984        if !self.is_origin_allowed(&origin) {
985            let is_preflight = req.method() == crate::request::Method::Options
986                && req.headers().get("access-control-request-method").is_some();
987            if is_preflight {
988                return Box::pin(async {
989                    ControlFlow::Break(Response::with_status(
990                        crate::response::StatusCode::FORBIDDEN,
991                    ))
992                });
993            }
994            return Box::pin(async { ControlFlow::Continue });
995        }
996
997        let is_preflight = req.method() == crate::request::Method::Options
998            && req.headers().get("access-control-request-method").is_some();
999
1000        if is_preflight {
1001            let mut response = Response::no_content();
1002            response = self.apply_common_headers(response, &origin);
1003            response = response.header(
1004                "access-control-allow-methods",
1005                self.allow_methods_value().into_bytes(),
1006            );
1007
1008            if let Some(value) = self.allow_headers_value(req) {
1009                response = response.header("access-control-allow-headers", value.into_bytes());
1010            }
1011
1012            if let Some(max_age) = self.config.max_age {
1013                response =
1014                    response.header("access-control-max-age", max_age.to_string().into_bytes());
1015            }
1016
1017            return Box::pin(async move { ControlFlow::Break(response) });
1018        }
1019
1020        req.insert_extension(CorsOrigin(origin));
1021        Box::pin(async { ControlFlow::Continue })
1022    }
1023
1024    fn after<'a>(
1025        &'a self,
1026        _ctx: &'a RequestContext,
1027        req: &'a Request,
1028        response: Response,
1029    ) -> BoxFuture<'a, Response> {
1030        let origin = req.get_extension::<CorsOrigin>().map(|v| v.0.clone());
1031        Box::pin(async move {
1032            if let Some(origin) = origin {
1033                return self.apply_common_headers(response, &origin);
1034            }
1035            response
1036        })
1037    }
1038
1039    fn name(&self) -> &'static str {
1040        "Cors"
1041    }
1042}
1043
1044fn wildcard_match(pattern: &str, value: &str) -> bool {
1045    // Simple glob matcher for '*'
1046    let mut pat_chars = pattern.chars().peekable();
1047    let mut val_chars = value.chars().peekable();
1048    let mut star = None;
1049    let mut match_after_star = None;
1050
1051    while let Some(p) = pat_chars.next() {
1052        match p {
1053            '*' => {
1054                star = Some(pat_chars.clone());
1055                match_after_star = Some(val_chars.clone());
1056            }
1057            _ => {
1058                if let Some(v) = val_chars.next() {
1059                    if p != v {
1060                        if let (Some(pat_backup), Some(val_backup)) =
1061                            (star.clone(), match_after_star.clone())
1062                        {
1063                            pat_chars = pat_backup;
1064                            val_chars = val_backup;
1065                            val_chars.next();
1066                            match_after_star = Some(val_chars.clone());
1067                            continue;
1068                        }
1069                        return false;
1070                    }
1071                } else {
1072                    return false;
1073                }
1074            }
1075        }
1076    }
1077
1078    // Consume trailing '*' in pattern
1079    if pat_chars.peek().is_none() && val_chars.peek().is_none() {
1080        return true;
1081    }
1082
1083    if let Some(pat_backup) = star {
1084        if val_chars.peek().is_none() {
1085            let trailing = pat_backup;
1086            for ch in trailing {
1087                if ch != '*' {
1088                    return false;
1089                }
1090            }
1091            return true;
1092        }
1093    }
1094
1095    val_chars.peek().is_none()
1096}
1097
1098fn regex_match(pattern: &str, value: &str) -> bool {
1099    // Minimal regex engine: supports ^, $, ., *
1100    let pat = pattern.as_bytes();
1101    let text = value.as_bytes();
1102
1103    if pat.first() == Some(&b'^') {
1104        return regex_match_here(&pat[1..], text);
1105    }
1106
1107    let mut i = 0;
1108    loop {
1109        if regex_match_here(pat, &text[i..]) {
1110            return true;
1111        }
1112        if i == text.len() {
1113            break;
1114        }
1115        i += 1;
1116    }
1117    false
1118}
1119
1120fn regex_match_here(pattern: &[u8], text: &[u8]) -> bool {
1121    if pattern.is_empty() {
1122        return true;
1123    }
1124    if pattern == b"$" {
1125        return text.is_empty();
1126    }
1127    if pattern.len() >= 2 && pattern[1] == b'*' {
1128        return regex_match_star(pattern[0], &pattern[2..], text);
1129    }
1130    if !text.is_empty() && (pattern[0] == b'.' || pattern[0] == text[0]) {
1131        return regex_match_here(&pattern[1..], &text[1..]);
1132    }
1133    false
1134}
1135
1136fn regex_match_star(ch: u8, pattern: &[u8], text: &[u8]) -> bool {
1137    let mut i = 0;
1138    loop {
1139        if regex_match_here(pattern, &text[i..]) {
1140            return true;
1141        }
1142        if i == text.len() {
1143            return false;
1144        }
1145        if ch != b'.' && text[i] != ch {
1146            return false;
1147        }
1148        i += 1;
1149    }
1150}
1151
1152// ============================================================================
1153// Request/Response Logging Middleware
1154// ============================================================================
1155
1156/// Middleware that logs requests and responses with configurable redaction.
1157#[derive(Debug, Clone)]
1158pub struct RequestResponseLogger {
1159    log_config: LogConfig,
1160    redact_headers: HashSet<String>,
1161    log_request_headers: bool,
1162    log_response_headers: bool,
1163    log_body: bool,
1164    max_body_bytes: usize,
1165}
1166
1167impl Default for RequestResponseLogger {
1168    fn default() -> Self {
1169        Self {
1170            log_config: LogConfig::production(),
1171            redact_headers: default_redacted_headers(),
1172            log_request_headers: true,
1173            log_response_headers: true,
1174            log_body: false,
1175            max_body_bytes: 1024,
1176        }
1177    }
1178}
1179
1180impl RequestResponseLogger {
1181    /// Create a new logger middleware with defaults.
1182    #[must_use]
1183    pub fn new() -> Self {
1184        Self::default()
1185    }
1186
1187    /// Override the logging configuration.
1188    #[must_use]
1189    pub fn log_config(mut self, config: LogConfig) -> Self {
1190        self.log_config = config;
1191        self
1192    }
1193
1194    /// Enable or disable request header logging.
1195    #[must_use]
1196    pub fn log_request_headers(mut self, enabled: bool) -> Self {
1197        self.log_request_headers = enabled;
1198        self
1199    }
1200
1201    /// Enable or disable response header logging.
1202    #[must_use]
1203    pub fn log_response_headers(mut self, enabled: bool) -> Self {
1204        self.log_response_headers = enabled;
1205        self
1206    }
1207
1208    /// Enable or disable request/response body logging.
1209    #[must_use]
1210    pub fn log_body(mut self, enabled: bool) -> Self {
1211        self.log_body = enabled;
1212        self
1213    }
1214
1215    /// Set the maximum number of body bytes to include in logs.
1216    #[must_use]
1217    pub fn max_body_bytes(mut self, max: usize) -> Self {
1218        self.max_body_bytes = max;
1219        self
1220    }
1221
1222    /// Add a header name to redact (case-insensitive).
1223    #[must_use]
1224    pub fn redact_header(mut self, name: impl Into<String>) -> Self {
1225        self.redact_headers.insert(name.into().to_ascii_lowercase());
1226        self
1227    }
1228}
1229
1230#[derive(Debug, Clone)]
1231struct RequestStart(Instant);
1232
1233impl Middleware for RequestResponseLogger {
1234    fn before<'a>(
1235        &'a self,
1236        ctx: &'a RequestContext,
1237        req: &'a mut Request,
1238    ) -> BoxFuture<'a, ControlFlow> {
1239        let logger = RequestLogger::new(ctx, self.log_config.clone());
1240        req.insert_extension(RequestStart(Instant::now()));
1241
1242        let method = req.method();
1243        let path = req.path();
1244        let query = req.query();
1245        let body_bytes = body_len(req.body());
1246
1247        logger.info_with_fields("request", |entry| {
1248            let mut entry = entry
1249                .field("method", method)
1250                .field("path", path)
1251                .field("body_bytes", body_bytes);
1252
1253            if let Some(q) = query {
1254                entry = entry.field("query", q);
1255            }
1256
1257            if self.log_request_headers {
1258                let headers = format_headers(req.headers().iter(), &self.redact_headers);
1259                entry = entry.field("headers", headers);
1260            }
1261
1262            if self.log_body {
1263                if let Some(body) = preview_body(req.body(), self.max_body_bytes) {
1264                    entry = entry.field("body", body);
1265                }
1266            }
1267
1268            entry
1269        });
1270
1271        Box::pin(async { ControlFlow::Continue })
1272    }
1273
1274    fn after<'a>(
1275        &'a self,
1276        ctx: &'a RequestContext,
1277        req: &'a Request,
1278        response: Response,
1279    ) -> BoxFuture<'a, Response> {
1280        let logger = RequestLogger::new(ctx, self.log_config.clone());
1281        let duration = req
1282            .get_extension::<RequestStart>()
1283            .map(|start| start.0.elapsed())
1284            .unwrap_or_default();
1285
1286        let status = response.status();
1287        let body_bytes = response.body_ref().len();
1288
1289        logger.info_with_fields("response", |entry| {
1290            let mut entry = entry
1291                .field("status", status.as_u16())
1292                .field("duration_us", duration.as_micros())
1293                .field("body_bytes", body_bytes);
1294
1295            if self.log_response_headers {
1296                let headers = format_response_headers(response.headers(), &self.redact_headers);
1297                entry = entry.field("headers", headers);
1298            }
1299
1300            if self.log_body {
1301                if let Some(body) = preview_response_body(response.body_ref(), self.max_body_bytes)
1302                {
1303                    entry = entry.field("body", body);
1304                }
1305            }
1306
1307            entry
1308        });
1309
1310        Box::pin(async move { response })
1311    }
1312
1313    fn name(&self) -> &'static str {
1314        "RequestResponseLogger"
1315    }
1316}
1317
1318fn default_redacted_headers() -> HashSet<String> {
1319    [
1320        "authorization",
1321        "proxy-authorization",
1322        "cookie",
1323        "set-cookie",
1324    ]
1325    .iter()
1326    .map(ToString::to_string)
1327    .collect()
1328}
1329
1330fn body_len(body: &Body) -> usize {
1331    match body {
1332        Body::Empty => 0,
1333        Body::Bytes(bytes) => bytes.len(),
1334        Body::Stream { content_length, .. } => content_length.unwrap_or(0),
1335    }
1336}
1337
1338fn preview_body(body: &Body, max_bytes: usize) -> Option<String> {
1339    if max_bytes == 0 {
1340        return None;
1341    }
1342    match body {
1343        Body::Empty => None,
1344        Body::Bytes(bytes) => {
1345            if bytes.is_empty() {
1346                None
1347            } else {
1348                Some(format_bytes(bytes, max_bytes))
1349            }
1350        }
1351        Body::Stream { .. } => None,
1352    }
1353}
1354
1355fn preview_response_body(body: &crate::response::ResponseBody, max_bytes: usize) -> Option<String> {
1356    if max_bytes == 0 {
1357        return None;
1358    }
1359    match body {
1360        crate::response::ResponseBody::Empty => None,
1361        crate::response::ResponseBody::Bytes(bytes) => {
1362            if bytes.is_empty() {
1363                None
1364            } else {
1365                Some(format_bytes(bytes, max_bytes))
1366            }
1367        }
1368        crate::response::ResponseBody::Stream(_) => None,
1369    }
1370}
1371
1372fn format_headers<'a>(
1373    headers: impl Iterator<Item = (&'a str, &'a [u8])>,
1374    redacted: &HashSet<String>,
1375) -> String {
1376    let mut out = String::new();
1377    for (idx, (name, value)) in headers.enumerate() {
1378        if idx > 0 {
1379            out.push_str(", ");
1380        }
1381        out.push_str(name);
1382        out.push('=');
1383
1384        let lowered = name.to_ascii_lowercase();
1385        if redacted.contains(&lowered) {
1386            out.push_str("<redacted>");
1387            continue;
1388        }
1389
1390        match std::str::from_utf8(value) {
1391            Ok(text) => out.push_str(text),
1392            Err(_) => out.push_str("<binary>"),
1393        }
1394    }
1395    out
1396}
1397
1398fn format_response_headers(headers: &[(String, Vec<u8>)], redacted: &HashSet<String>) -> String {
1399    format_headers(
1400        headers
1401            .iter()
1402            .map(|(name, value)| (name.as_str(), value.as_slice())),
1403        redacted,
1404    )
1405}
1406
1407fn format_bytes(bytes: &[u8], max_bytes: usize) -> String {
1408    let limit = max_bytes.min(bytes.len());
1409    match std::str::from_utf8(&bytes[..limit]) {
1410        Ok(text) => {
1411            let mut output = text.to_string();
1412            if bytes.len() > max_bytes {
1413                output.push_str("...");
1414            }
1415            output
1416        }
1417        Err(_) => format!("<{} bytes binary>", bytes.len()),
1418    }
1419}
1420
1421// Helper for ResponseBody conversion
1422impl From<&crate::response::ResponseBody> for crate::response::ResponseBody {
1423    fn from(body: &crate::response::ResponseBody) -> Self {
1424        match body {
1425            crate::response::ResponseBody::Empty => crate::response::ResponseBody::Empty,
1426            crate::response::ResponseBody::Bytes(b) => {
1427                crate::response::ResponseBody::Bytes(b.clone())
1428            }
1429            crate::response::ResponseBody::Stream(_) => crate::response::ResponseBody::Empty,
1430        }
1431    }
1432}
1433
1434// ============================================================================
1435// Request ID Middleware
1436// ============================================================================
1437
1438/// A request ID that was extracted or generated for the current request.
1439///
1440/// This is stored in request extensions and can be retrieved by handlers
1441/// or other middleware for logging and tracing.
1442#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1443pub struct RequestId(pub String);
1444
1445impl RequestId {
1446    /// Creates a new request ID with the given value.
1447    #[must_use]
1448    pub fn new(id: impl Into<String>) -> Self {
1449        Self(id.into())
1450    }
1451
1452    /// Returns the request ID as a string slice.
1453    #[must_use]
1454    pub fn as_str(&self) -> &str {
1455        &self.0
1456    }
1457
1458    /// Generates a new unique request ID.
1459    ///
1460    /// Uses a simple format: timestamp-counter for uniqueness without
1461    /// requiring external UUID dependencies.
1462    #[must_use]
1463    pub fn generate() -> Self {
1464        use std::sync::atomic::{AtomicU64, Ordering};
1465        use std::time::{SystemTime, UNIX_EPOCH};
1466
1467        static COUNTER: AtomicU64 = AtomicU64::new(0);
1468
1469        let timestamp = SystemTime::now()
1470            .duration_since(UNIX_EPOCH)
1471            .map(|d| d.as_micros() as u64)
1472            .unwrap_or(0);
1473        let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
1474
1475        // Format: hex timestamp + full counter for unique IDs without collisions
1476        Self(format!("{:x}-{:x}", timestamp, counter))
1477    }
1478}
1479
1480impl std::fmt::Display for RequestId {
1481    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1482        write!(f, "{}", self.0)
1483    }
1484}
1485
1486impl From<String> for RequestId {
1487    fn from(s: String) -> Self {
1488        Self(s)
1489    }
1490}
1491
1492impl From<&str> for RequestId {
1493    fn from(s: &str) -> Self {
1494        Self(s.to_string())
1495    }
1496}
1497
1498/// Configuration for request ID middleware.
1499#[derive(Debug, Clone)]
1500pub struct RequestIdConfig {
1501    /// Header name to read/write request ID (default: "x-request-id").
1502    pub header_name: String,
1503    /// Whether to accept request ID from client (default: true).
1504    pub accept_from_client: bool,
1505    /// Whether to add request ID to response headers (default: true).
1506    pub add_to_response: bool,
1507    /// Maximum length of client-provided request ID (default: 128).
1508    pub max_client_id_length: usize,
1509}
1510
1511impl Default for RequestIdConfig {
1512    fn default() -> Self {
1513        Self {
1514            header_name: "x-request-id".to_string(),
1515            accept_from_client: true,
1516            add_to_response: true,
1517            max_client_id_length: 128,
1518        }
1519    }
1520}
1521
1522impl RequestIdConfig {
1523    /// Creates a new configuration with defaults.
1524    #[must_use]
1525    pub fn new() -> Self {
1526        Self::default()
1527    }
1528
1529    /// Sets the header name for request ID.
1530    #[must_use]
1531    pub fn header_name(mut self, name: impl Into<String>) -> Self {
1532        self.header_name = name.into();
1533        self
1534    }
1535
1536    /// Sets whether to accept request ID from client.
1537    #[must_use]
1538    pub fn accept_from_client(mut self, accept: bool) -> Self {
1539        self.accept_from_client = accept;
1540        self
1541    }
1542
1543    /// Sets whether to add request ID to response.
1544    #[must_use]
1545    pub fn add_to_response(mut self, add: bool) -> Self {
1546        self.add_to_response = add;
1547        self
1548    }
1549
1550    /// Sets the maximum length for client-provided request IDs.
1551    #[must_use]
1552    pub fn max_client_id_length(mut self, max: usize) -> Self {
1553        self.max_client_id_length = max;
1554        self
1555    }
1556}
1557
1558/// Middleware that adds unique request IDs to requests and responses.
1559///
1560/// This middleware:
1561/// 1. Checks for an existing X-Request-ID header from the client
1562/// 2. If present and valid, uses it; otherwise generates a new ID
1563/// 3. Stores the ID in request extensions for handlers to access
1564/// 4. Adds the ID to response headers
1565///
1566/// # Example
1567///
1568/// ```ignore
1569/// use fastapi_core::middleware::RequestIdMiddleware;
1570///
1571/// let mut stack = MiddlewareStack::new();
1572/// stack.push(RequestIdMiddleware::new());
1573///
1574/// // In your handler:
1575/// async fn handler(ctx: &RequestContext, req: &Request) -> Response {
1576///     if let Some(request_id) = req.get_extension::<RequestId>() {
1577///         println!("Request ID: {}", request_id);
1578///     }
1579///     Response::ok()
1580/// }
1581/// ```
1582#[derive(Debug, Clone)]
1583pub struct RequestIdMiddleware {
1584    config: RequestIdConfig,
1585}
1586
1587impl Default for RequestIdMiddleware {
1588    fn default() -> Self {
1589        Self::new()
1590    }
1591}
1592
1593impl RequestIdMiddleware {
1594    /// Creates a new request ID middleware with default configuration.
1595    #[must_use]
1596    pub fn new() -> Self {
1597        Self {
1598            config: RequestIdConfig::default(),
1599        }
1600    }
1601
1602    /// Creates a new request ID middleware with the given configuration.
1603    #[must_use]
1604    pub fn with_config(config: RequestIdConfig) -> Self {
1605        Self { config }
1606    }
1607
1608    /// Extracts or generates a request ID for the given request.
1609    fn get_or_generate_id(&self, req: &Request) -> RequestId {
1610        if self.config.accept_from_client {
1611            if let Some(header_value) = req.headers().get(&self.config.header_name) {
1612                if let Ok(client_id) = std::str::from_utf8(header_value) {
1613                    // Validate length and basic content
1614                    if !client_id.is_empty()
1615                        && client_id.len() <= self.config.max_client_id_length
1616                        && is_valid_request_id(client_id)
1617                    {
1618                        return RequestId::new(client_id);
1619                    }
1620                }
1621            }
1622        }
1623        RequestId::generate()
1624    }
1625}
1626
1627/// Validates that a request ID contains only safe characters.
1628fn is_valid_request_id(id: &str) -> bool {
1629    !id.is_empty()
1630        && id
1631            .chars()
1632            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
1633}
1634
1635impl Middleware for RequestIdMiddleware {
1636    fn before<'a>(
1637        &'a self,
1638        _ctx: &'a RequestContext,
1639        req: &'a mut Request,
1640    ) -> BoxFuture<'a, ControlFlow> {
1641        let request_id = self.get_or_generate_id(req);
1642        req.insert_extension(request_id);
1643        Box::pin(async { ControlFlow::Continue })
1644    }
1645
1646    fn after<'a>(
1647        &'a self,
1648        _ctx: &'a RequestContext,
1649        req: &'a Request,
1650        response: Response,
1651    ) -> BoxFuture<'a, Response> {
1652        if !self.config.add_to_response {
1653            return Box::pin(async move { response });
1654        }
1655
1656        let request_id = req.get_extension::<RequestId>().cloned();
1657        let header_name = self.config.header_name.clone();
1658
1659        Box::pin(async move {
1660            if let Some(id) = request_id {
1661                response.header(header_name, id.0.into_bytes())
1662            } else {
1663                response
1664            }
1665        })
1666    }
1667
1668    fn name(&self) -> &'static str {
1669        "RequestId"
1670    }
1671}
1672
1673// ============================================================================
1674// Security Headers Middleware
1675// ============================================================================
1676
1677/// X-Frame-Options header value.
1678///
1679/// Controls whether the page can be displayed in a frame.
1680#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1681pub enum XFrameOptions {
1682    /// Prevents any domain from framing the content.
1683    Deny,
1684    /// Allows the current site to frame the content.
1685    SameOrigin,
1686}
1687
1688impl XFrameOptions {
1689    fn as_bytes(self) -> &'static [u8] {
1690        match self {
1691            Self::Deny => b"DENY",
1692            Self::SameOrigin => b"SAMEORIGIN",
1693        }
1694    }
1695}
1696
1697/// Referrer-Policy header value.
1698///
1699/// Controls how much referrer information should be included with requests.
1700#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1701pub enum ReferrerPolicy {
1702    /// No referrer information is sent.
1703    NoReferrer,
1704    /// Only send origin when protocol security level stays the same.
1705    NoReferrerWhenDowngrade,
1706    /// Only send the origin (not the path).
1707    Origin,
1708    /// Only send origin for cross-origin requests.
1709    OriginWhenCrossOrigin,
1710    /// Send the origin, path, and query string for same-origin requests only.
1711    SameOrigin,
1712    /// Only send origin if protocol security level stays the same.
1713    StrictOrigin,
1714    /// Send full referrer for same-origin, origin only for cross-origin if secure.
1715    StrictOriginWhenCrossOrigin,
1716    /// Send the full referrer (not recommended).
1717    UnsafeUrl,
1718}
1719
1720impl ReferrerPolicy {
1721    fn as_bytes(self) -> &'static [u8] {
1722        match self {
1723            Self::NoReferrer => b"no-referrer",
1724            Self::NoReferrerWhenDowngrade => b"no-referrer-when-downgrade",
1725            Self::Origin => b"origin",
1726            Self::OriginWhenCrossOrigin => b"origin-when-cross-origin",
1727            Self::SameOrigin => b"same-origin",
1728            Self::StrictOrigin => b"strict-origin",
1729            Self::StrictOriginWhenCrossOrigin => b"strict-origin-when-cross-origin",
1730            Self::UnsafeUrl => b"unsafe-url",
1731        }
1732    }
1733}
1734
1735/// Configuration for the Security Headers middleware.
1736///
1737/// All headers are optional. Set a value to `Some(...)` to include the header,
1738/// or `None` to skip it.
1739///
1740/// # Defaults
1741///
1742/// The default configuration provides secure defaults:
1743/// - `X-Content-Type-Options: nosniff`
1744/// - `X-Frame-Options: DENY`
1745/// - `X-XSS-Protection: 0` (disabled as modern browsers have built-in protection)
1746/// - `Referrer-Policy: strict-origin-when-cross-origin`
1747///
1748/// # Example
1749///
1750/// ```ignore
1751/// use fastapi_core::middleware::{SecurityHeadersConfig, XFrameOptions, ReferrerPolicy};
1752///
1753/// let config = SecurityHeadersConfig::default()
1754///     .x_frame_options(XFrameOptions::SameOrigin)
1755///     .content_security_policy("default-src 'self'")
1756///     .hsts(31536000, true);  // 1 year, includeSubDomains
1757/// ```
1758#[derive(Debug, Clone)]
1759pub struct SecurityHeadersConfig {
1760    /// X-Content-Type-Options header.
1761    /// Default: `Some("nosniff")`
1762    pub x_content_type_options: Option<&'static str>,
1763    /// X-Frame-Options header.
1764    /// Default: `Some(XFrameOptions::Deny)`
1765    pub x_frame_options: Option<XFrameOptions>,
1766    /// X-XSS-Protection header.
1767    /// Default: `Some("0")` (disabled - modern browsers have built-in protection)
1768    ///
1769    /// Note: This header is largely obsolete. Setting it to "0" is recommended
1770    /// to prevent potential security issues in older browsers.
1771    pub x_xss_protection: Option<&'static str>,
1772    /// Content-Security-Policy header.
1773    /// Default: `None` (should be configured based on your application)
1774    pub content_security_policy: Option<String>,
1775    /// Strict-Transport-Security (HSTS) header.
1776    /// Tuple of (max_age_seconds, include_sub_domains, preload)
1777    /// Default: `None` (only set this for HTTPS-only sites)
1778    pub hsts: Option<(u64, bool, bool)>,
1779    /// Referrer-Policy header.
1780    /// Default: `Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)`
1781    pub referrer_policy: Option<ReferrerPolicy>,
1782    /// Permissions-Policy header (formerly Feature-Policy).
1783    /// Default: `None` (should be configured based on your application)
1784    pub permissions_policy: Option<String>,
1785}
1786
1787impl Default for SecurityHeadersConfig {
1788    fn default() -> Self {
1789        Self {
1790            x_content_type_options: Some("nosniff"),
1791            x_frame_options: Some(XFrameOptions::Deny),
1792            x_xss_protection: Some("0"),
1793            content_security_policy: None,
1794            hsts: None,
1795            referrer_policy: Some(ReferrerPolicy::StrictOriginWhenCrossOrigin),
1796            permissions_policy: None,
1797        }
1798    }
1799}
1800
1801impl SecurityHeadersConfig {
1802    /// Creates a new configuration with secure defaults.
1803    #[must_use]
1804    pub fn new() -> Self {
1805        Self::default()
1806    }
1807
1808    /// Creates an empty configuration (no headers).
1809    #[must_use]
1810    pub fn none() -> Self {
1811        Self {
1812            x_content_type_options: None,
1813            x_frame_options: None,
1814            x_xss_protection: None,
1815            content_security_policy: None,
1816            hsts: None,
1817            referrer_policy: None,
1818            permissions_policy: None,
1819        }
1820    }
1821
1822    /// Creates a strict configuration for high-security applications.
1823    ///
1824    /// Includes:
1825    /// - All default headers
1826    /// - HSTS with 1 year max-age and includeSubDomains
1827    /// - A basic CSP that only allows same-origin resources
1828    #[must_use]
1829    pub fn strict() -> Self {
1830        Self {
1831            x_content_type_options: Some("nosniff"),
1832            x_frame_options: Some(XFrameOptions::Deny),
1833            x_xss_protection: Some("0"),
1834            content_security_policy: Some("default-src 'self'".to_string()),
1835            hsts: Some((31536000, true, false)), // 1 year, includeSubDomains
1836            referrer_policy: Some(ReferrerPolicy::NoReferrer),
1837            permissions_policy: Some("geolocation=(), camera=(), microphone=()".to_string()),
1838        }
1839    }
1840
1841    /// Sets the X-Content-Type-Options header.
1842    #[must_use]
1843    pub fn x_content_type_options(mut self, value: Option<&'static str>) -> Self {
1844        self.x_content_type_options = value;
1845        self
1846    }
1847
1848    /// Sets the X-Frame-Options header.
1849    #[must_use]
1850    pub fn x_frame_options(mut self, value: Option<XFrameOptions>) -> Self {
1851        self.x_frame_options = value;
1852        self
1853    }
1854
1855    /// Sets the X-XSS-Protection header.
1856    #[must_use]
1857    pub fn x_xss_protection(mut self, value: Option<&'static str>) -> Self {
1858        self.x_xss_protection = value;
1859        self
1860    }
1861
1862    /// Sets the Content-Security-Policy header.
1863    #[must_use]
1864    pub fn content_security_policy(mut self, value: impl Into<String>) -> Self {
1865        self.content_security_policy = Some(value.into());
1866        self
1867    }
1868
1869    /// Clears the Content-Security-Policy header.
1870    #[must_use]
1871    pub fn no_content_security_policy(mut self) -> Self {
1872        self.content_security_policy = None;
1873        self
1874    }
1875
1876    /// Sets the Strict-Transport-Security (HSTS) header.
1877    ///
1878    /// # Arguments
1879    ///
1880    /// - `max_age`: Maximum time (in seconds) the browser should remember HTTPS
1881    /// - `include_sub_domains`: Whether to apply to all subdomains
1882    /// - `preload`: Whether to include in browser preload lists (use with caution)
1883    ///
1884    /// # Warning
1885    ///
1886    /// Only enable HSTS for sites that are HTTPS-only. Enabling HSTS incorrectly
1887    /// can make your site inaccessible.
1888    #[must_use]
1889    pub fn hsts(mut self, max_age: u64, include_sub_domains: bool, preload: bool) -> Self {
1890        self.hsts = Some((max_age, include_sub_domains, preload));
1891        self
1892    }
1893
1894    /// Clears the HSTS header.
1895    #[must_use]
1896    pub fn no_hsts(mut self) -> Self {
1897        self.hsts = None;
1898        self
1899    }
1900
1901    /// Sets the Referrer-Policy header.
1902    #[must_use]
1903    pub fn referrer_policy(mut self, value: Option<ReferrerPolicy>) -> Self {
1904        self.referrer_policy = value;
1905        self
1906    }
1907
1908    /// Sets the Permissions-Policy header.
1909    #[must_use]
1910    pub fn permissions_policy(mut self, value: impl Into<String>) -> Self {
1911        self.permissions_policy = Some(value.into());
1912        self
1913    }
1914
1915    /// Clears the Permissions-Policy header.
1916    #[must_use]
1917    pub fn no_permissions_policy(mut self) -> Self {
1918        self.permissions_policy = None;
1919        self
1920    }
1921
1922    /// Builds the HSTS header value.
1923    fn build_hsts_value(&self) -> Option<String> {
1924        self.hsts.map(|(max_age, include_sub, preload)| {
1925            let mut value = format!("max-age={}", max_age);
1926            if include_sub {
1927                value.push_str("; includeSubDomains");
1928            }
1929            if preload {
1930                value.push_str("; preload");
1931            }
1932            value
1933        })
1934    }
1935}
1936
1937/// Middleware that adds security-related HTTP headers to responses.
1938///
1939/// This middleware helps protect against common web vulnerabilities by setting
1940/// appropriate security headers. It's recommended for all web applications.
1941///
1942/// # Headers
1943///
1944/// - **X-Content-Type-Options**: Prevents MIME type sniffing
1945/// - **X-Frame-Options**: Controls iframe embedding (clickjacking protection)
1946/// - **X-XSS-Protection**: Legacy XSS filter control (disabled by default)
1947/// - **Content-Security-Policy**: Controls resource loading
1948/// - **Strict-Transport-Security**: Enforces HTTPS
1949/// - **Referrer-Policy**: Controls referrer information
1950/// - **Permissions-Policy**: Controls browser features
1951///
1952/// # Example
1953///
1954/// ```ignore
1955/// use fastapi_core::middleware::{SecurityHeaders, SecurityHeadersConfig};
1956///
1957/// // Use defaults
1958/// let mw = SecurityHeaders::new();
1959///
1960/// // Custom configuration
1961/// let config = SecurityHeadersConfig::default()
1962///     .content_security_policy("default-src 'self'; img-src *")
1963///     .hsts(86400, false, false);  // 1 day
1964///
1965/// let mw = SecurityHeaders::with_config(config);
1966/// ```
1967#[derive(Debug, Clone)]
1968pub struct SecurityHeaders {
1969    config: SecurityHeadersConfig,
1970}
1971
1972impl Default for SecurityHeaders {
1973    fn default() -> Self {
1974        Self::new()
1975    }
1976}
1977
1978impl SecurityHeaders {
1979    /// Creates a new middleware with default configuration.
1980    #[must_use]
1981    pub fn new() -> Self {
1982        Self {
1983            config: SecurityHeadersConfig::default(),
1984        }
1985    }
1986
1987    /// Creates a new middleware with custom configuration.
1988    #[must_use]
1989    pub fn with_config(config: SecurityHeadersConfig) -> Self {
1990        Self { config }
1991    }
1992
1993    /// Creates a middleware with strict security settings.
1994    #[must_use]
1995    pub fn strict() -> Self {
1996        Self {
1997            config: SecurityHeadersConfig::strict(),
1998        }
1999    }
2000}
2001
2002impl Middleware for SecurityHeaders {
2003    fn after<'a>(
2004        &'a self,
2005        _ctx: &'a RequestContext,
2006        _req: &'a Request,
2007        response: Response,
2008    ) -> BoxFuture<'a, Response> {
2009        let config = self.config.clone();
2010        Box::pin(async move {
2011            let mut resp = response;
2012
2013            // X-Content-Type-Options
2014            if let Some(value) = config.x_content_type_options {
2015                resp = resp.header("X-Content-Type-Options", value.as_bytes().to_vec());
2016            }
2017
2018            // X-Frame-Options
2019            if let Some(value) = config.x_frame_options {
2020                resp = resp.header("X-Frame-Options", value.as_bytes().to_vec());
2021            }
2022
2023            // X-XSS-Protection
2024            if let Some(value) = config.x_xss_protection {
2025                resp = resp.header("X-XSS-Protection", value.as_bytes().to_vec());
2026            }
2027
2028            // Content-Security-Policy
2029            if let Some(ref value) = config.content_security_policy {
2030                resp = resp.header("Content-Security-Policy", value.as_bytes().to_vec());
2031            }
2032
2033            // Strict-Transport-Security
2034            if let Some(ref hsts_value) = config.build_hsts_value() {
2035                resp = resp.header("Strict-Transport-Security", hsts_value.as_bytes().to_vec());
2036            }
2037
2038            // Referrer-Policy
2039            if let Some(value) = config.referrer_policy {
2040                resp = resp.header("Referrer-Policy", value.as_bytes().to_vec());
2041            }
2042
2043            // Permissions-Policy
2044            if let Some(ref value) = config.permissions_policy {
2045                resp = resp.header("Permissions-Policy", value.as_bytes().to_vec());
2046            }
2047
2048            resp
2049        })
2050    }
2051
2052    fn name(&self) -> &'static str {
2053        "SecurityHeaders"
2054    }
2055}
2056
2057// ============================================================================
2058// CSRF Protection Middleware
2059// ============================================================================
2060
2061/// CSRF token stored in request extensions.
2062///
2063/// Middleware stores this after generating or validating a token,
2064/// allowing handlers to access the current CSRF token.
2065#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2066pub struct CsrfToken(pub String);
2067
2068impl CsrfToken {
2069    /// Creates a new CSRF token with the given value.
2070    #[must_use]
2071    pub fn new(token: impl Into<String>) -> Self {
2072        Self(token.into())
2073    }
2074
2075    /// Returns the token as a string slice.
2076    #[must_use]
2077    pub fn as_str(&self) -> &str {
2078        &self.0
2079    }
2080
2081    /// Generates a new unique CSRF token using cryptographic randomness.
2082    ///
2083    /// Uses `/dev/urandom` for secure random bytes.
2084    ///
2085    /// # Panics
2086    ///
2087    /// Panics if `/dev/urandom` is unavailable. CSRF tokens MUST be
2088    /// cryptographically unpredictable - there is no safe fallback.
2089    #[must_use]
2090    pub fn generate() -> Self {
2091        // CSRF tokens must be cryptographically secure - no weak fallback
2092        let bytes = Self::read_urandom(32).unwrap_or_else(|_| {
2093            panic!(
2094                "FATAL: Cryptographically secure random source (/dev/urandom) is unavailable. \
2095                 CSRF token generation requires a CSPRNG. Cannot safely generate CSRF tokens \
2096                 without cryptographic entropy."
2097            );
2098        });
2099        Self(Self::bytes_to_hex(&bytes))
2100    }
2101
2102    fn read_urandom(len: usize) -> std::io::Result<Vec<u8>> {
2103        use std::io::Read;
2104        let mut f = std::fs::File::open("/dev/urandom")?;
2105        let mut buf = vec![0u8; len];
2106        f.read_exact(&mut buf)?;
2107        Ok(buf)
2108    }
2109
2110    fn bytes_to_hex(bytes: &[u8]) -> String {
2111        use std::fmt::Write;
2112        let mut s = String::with_capacity(bytes.len() * 2);
2113        for b in bytes {
2114            let _ = write!(s, "{b:02x}");
2115        }
2116        s
2117    }
2118}
2119
2120impl std::fmt::Display for CsrfToken {
2121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2122        f.write_str(&self.0)
2123    }
2124}
2125
2126impl From<&str> for CsrfToken {
2127    fn from(s: &str) -> Self {
2128        Self(s.to_string())
2129    }
2130}
2131
2132/// CSRF protection mode.
2133#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
2134pub enum CsrfMode {
2135    /// Double-submit cookie pattern: token in cookie must match token in header.
2136    /// This is the default and most common pattern.
2137    #[default]
2138    DoubleSubmit,
2139    /// Require token in header only (for APIs where cookies are not used).
2140    HeaderOnly,
2141}
2142
2143/// Configuration for CSRF protection middleware.
2144#[derive(Debug, Clone)]
2145pub struct CsrfConfig {
2146    /// Cookie name for CSRF token (default: "csrf_token").
2147    pub cookie_name: String,
2148    /// Header name for CSRF token (default: "x-csrf-token").
2149    pub header_name: String,
2150    /// CSRF protection mode (default: DoubleSubmit).
2151    pub mode: CsrfMode,
2152    /// Whether to rotate token on each request (default: false).
2153    pub rotate_token: bool,
2154    /// Whether in production mode (affects Secure cookie flag).
2155    pub production: bool,
2156    /// Custom error message for CSRF failures.
2157    pub error_message: Option<String>,
2158}
2159
2160impl Default for CsrfConfig {
2161    fn default() -> Self {
2162        Self {
2163            cookie_name: "csrf_token".to_string(),
2164            header_name: "x-csrf-token".to_string(),
2165            mode: CsrfMode::DoubleSubmit,
2166            rotate_token: false,
2167            production: true,
2168            error_message: None,
2169        }
2170    }
2171}
2172
2173impl CsrfConfig {
2174    /// Creates a new configuration with defaults.
2175    #[must_use]
2176    pub fn new() -> Self {
2177        Self::default()
2178    }
2179
2180    /// Sets the cookie name for CSRF token.
2181    #[must_use]
2182    pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
2183        self.cookie_name = name.into();
2184        self
2185    }
2186
2187    /// Sets the header name for CSRF token.
2188    #[must_use]
2189    pub fn header_name(mut self, name: impl Into<String>) -> Self {
2190        self.header_name = name.into();
2191        self
2192    }
2193
2194    /// Sets the CSRF protection mode.
2195    #[must_use]
2196    pub fn mode(mut self, mode: CsrfMode) -> Self {
2197        self.mode = mode;
2198        self
2199    }
2200
2201    /// Enables token rotation on each request.
2202    #[must_use]
2203    pub fn rotate_token(mut self, rotate: bool) -> Self {
2204        self.rotate_token = rotate;
2205        self
2206    }
2207
2208    /// Sets production mode (affects Secure cookie flag).
2209    #[must_use]
2210    pub fn production(mut self, production: bool) -> Self {
2211        self.production = production;
2212        self
2213    }
2214
2215    /// Sets a custom error message for CSRF failures.
2216    #[must_use]
2217    pub fn error_message(mut self, message: impl Into<String>) -> Self {
2218        self.error_message = Some(message.into());
2219        self
2220    }
2221}
2222
2223/// CSRF protection middleware.
2224///
2225/// Implements protection against Cross-Site Request Forgery attacks using
2226/// the double-submit cookie pattern by default.
2227///
2228/// # How It Works
2229///
2230/// 1. For safe methods (GET, HEAD, OPTIONS, TRACE): generates a CSRF token
2231///    and sets it in a cookie if not present.
2232/// 2. For state-changing methods (POST, PUT, DELETE, PATCH): validates that
2233///    the token in the header matches the token in the cookie.
2234///
2235/// # Example
2236///
2237/// ```ignore
2238/// use fastapi_core::middleware::{CsrfMiddleware, CsrfConfig};
2239///
2240/// let mut stack = MiddlewareStack::new();
2241/// stack.push(CsrfMiddleware::new());
2242///
2243/// // Or with custom configuration:
2244/// let csrf = CsrfMiddleware::with_config(
2245///     CsrfConfig::new()
2246///         .header_name("X-XSRF-Token")
2247///         .cookie_name("XSRF-TOKEN")
2248///         .production(false)
2249/// );
2250/// stack.push(csrf);
2251/// ```
2252#[derive(Debug, Clone)]
2253pub struct CsrfMiddleware {
2254    config: CsrfConfig,
2255}
2256
2257impl Default for CsrfMiddleware {
2258    fn default() -> Self {
2259        Self::new()
2260    }
2261}
2262
2263impl CsrfMiddleware {
2264    /// Creates a new CSRF middleware with default configuration.
2265    #[must_use]
2266    pub fn new() -> Self {
2267        Self {
2268            config: CsrfConfig::default(),
2269        }
2270    }
2271
2272    /// Creates a new CSRF middleware with the given configuration.
2273    #[must_use]
2274    pub fn with_config(config: CsrfConfig) -> Self {
2275        Self { config }
2276    }
2277
2278    /// Checks if the HTTP method is safe (does not modify state).
2279    fn is_safe_method(method: crate::request::Method) -> bool {
2280        matches!(
2281            method,
2282            crate::request::Method::Get
2283                | crate::request::Method::Head
2284                | crate::request::Method::Options
2285                | crate::request::Method::Trace
2286        )
2287    }
2288
2289    /// Extracts the CSRF token from the cookie header.
2290    fn get_cookie_token(&self, req: &Request) -> Option<String> {
2291        let cookie_header = req.headers().get("cookie")?;
2292        let cookie_str = std::str::from_utf8(cookie_header).ok()?;
2293
2294        // Parse cookie header: "name1=value1; name2=value2"
2295        for part in cookie_str.split(';') {
2296            let part = part.trim();
2297            if let Some((name, value)) = part.split_once('=') {
2298                if name.trim() == self.config.cookie_name {
2299                    return Some(value.trim().to_string());
2300                }
2301            }
2302        }
2303        None
2304    }
2305
2306    /// Extracts the CSRF token from the request header.
2307    fn get_header_token(&self, req: &Request) -> Option<String> {
2308        let header_value = req.headers().get(&self.config.header_name)?;
2309        std::str::from_utf8(header_value)
2310            .ok()
2311            .map(|s| s.trim().to_string())
2312    }
2313
2314    /// Validates the CSRF token for state-changing requests.
2315    fn validate_token(&self, req: &Request) -> Result<Option<CsrfToken>, Response> {
2316        let header_token = self.get_header_token(req);
2317
2318        match self.config.mode {
2319            CsrfMode::DoubleSubmit => {
2320                let cookie_token = self.get_cookie_token(req);
2321
2322                match (header_token, cookie_token) {
2323                    (Some(header), Some(cookie))
2324                        if !header.is_empty()
2325                            && crate::password::constant_time_eq(
2326                                header.as_bytes(),
2327                                cookie.as_bytes(),
2328                            ) =>
2329                    {
2330                        Ok(Some(CsrfToken::new(header)))
2331                    }
2332                    (None, _) | (_, None) => Err(self.csrf_error_response("CSRF token missing")),
2333                    _ => Err(self.csrf_error_response("CSRF token mismatch")),
2334                }
2335            }
2336            CsrfMode::HeaderOnly => match header_token {
2337                Some(token) if !token.is_empty() => Ok(Some(CsrfToken::new(token))),
2338                _ => Err(self.csrf_error_response("CSRF token missing in header")),
2339            },
2340        }
2341    }
2342
2343    /// Creates a 403 Forbidden response for CSRF failures.
2344    fn csrf_error_response(&self, default_message: &str) -> Response {
2345        let message = self
2346            .config
2347            .error_message
2348            .as_deref()
2349            .unwrap_or(default_message);
2350
2351        // Create a FastAPI-compatible error response using serde_json
2352        // to properly escape header_name and message values.
2353        let detail = serde_json::json!({
2354            "detail": [{
2355                "type": "csrf_error",
2356                "loc": ["header", self.config.header_name],
2357                "msg": message,
2358            }]
2359        });
2360        let body = detail.to_string();
2361
2362        Response::with_status(crate::response::StatusCode::FORBIDDEN)
2363            .header("content-type", b"application/json".to_vec())
2364            .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
2365    }
2366
2367    /// Creates the Set-Cookie header value for a CSRF token.
2368    fn make_set_cookie_header_value(cookie_name: &str, token: &str, production: bool) -> Vec<u8> {
2369        let mut cookie = format!("{}={}; Path=/; SameSite=Strict", cookie_name, token);
2370
2371        if production {
2372            cookie.push_str("; Secure");
2373        }
2374
2375        // Note: HttpOnly is NOT set - CSRF cookies must be readable by JavaScript
2376
2377        cookie.into_bytes()
2378    }
2379}
2380
2381impl Middleware for CsrfMiddleware {
2382    fn before<'a>(
2383        &'a self,
2384        _ctx: &'a RequestContext,
2385        req: &'a mut Request,
2386    ) -> BoxFuture<'a, ControlFlow> {
2387        Box::pin(async move {
2388            if Self::is_safe_method(req.method()) {
2389                // Safe methods: generate token if not present
2390                let existing_token = self.get_cookie_token(req);
2391                let token = existing_token
2392                    .map(CsrfToken::new)
2393                    .unwrap_or_else(CsrfToken::generate);
2394                req.insert_extension(token);
2395                ControlFlow::Continue
2396            } else {
2397                // State-changing methods: validate token
2398                match self.validate_token(req) {
2399                    Ok(Some(token)) => {
2400                        req.insert_extension(token);
2401                        ControlFlow::Continue
2402                    }
2403                    Ok(None) => ControlFlow::Continue,
2404                    Err(response) => ControlFlow::Break(response),
2405                }
2406            }
2407        })
2408    }
2409
2410    fn after<'a>(
2411        &'a self,
2412        _ctx: &'a RequestContext,
2413        req: &'a Request,
2414        response: Response,
2415    ) -> BoxFuture<'a, Response> {
2416        let config = self.config.clone();
2417        let is_safe = Self::is_safe_method(req.method());
2418        let existing_cookie_token = self.get_cookie_token(req);
2419        let token = req.get_extension::<CsrfToken>().cloned();
2420
2421        Box::pin(async move {
2422            // Set cookie for safe methods if:
2423            // 1. No cookie exists yet, or
2424            // 2. Token rotation is enabled
2425            if is_safe {
2426                let should_set_cookie = existing_cookie_token.is_none() || config.rotate_token;
2427
2428                if should_set_cookie {
2429                    if let Some(token) = token {
2430                        let cookie_value = Self::make_set_cookie_header_value(
2431                            &config.cookie_name,
2432                            token.as_str(),
2433                            config.production,
2434                        );
2435                        return response.header("set-cookie", cookie_value);
2436                    }
2437                }
2438            }
2439            response
2440        })
2441    }
2442
2443    fn name(&self) -> &'static str {
2444        "CSRF"
2445    }
2446}
2447
2448// ============================================================================
2449// Compression Middleware (requires "compression" feature)
2450// ============================================================================
2451
2452/// Configuration for response compression.
2453///
2454/// Controls when and how responses are compressed using gzip.
2455///
2456/// # Example
2457///
2458/// ```ignore
2459/// use fastapi_core::middleware::{CompressionMiddleware, CompressionConfig};
2460///
2461/// // Use defaults (min size 1024, level 6)
2462/// let mw = CompressionMiddleware::new();
2463///
2464/// // Custom configuration
2465/// let config = CompressionConfig::new()
2466///     .min_size(512)
2467///     .level(9);  // Maximum compression
2468/// let mw = CompressionMiddleware::with_config(config);
2469/// ```
2470#[cfg(feature = "compression")]
2471#[derive(Debug, Clone)]
2472pub struct CompressionConfig {
2473    /// Minimum response size in bytes to compress.
2474    /// Responses smaller than this are not compressed.
2475    /// Default: 1024 bytes (1 KB)
2476    pub min_size: usize,
2477    /// Compression level (1-9).
2478    /// 1 = fastest, 9 = best compression, 6 = balanced (default)
2479    pub level: u32,
2480    /// Content types that are already compressed and should be skipped.
2481    /// Default includes common compressed formats.
2482    pub skip_content_types: Vec<&'static str>,
2483}
2484
2485#[cfg(feature = "compression")]
2486impl Default for CompressionConfig {
2487    fn default() -> Self {
2488        Self {
2489            min_size: 1024,
2490            level: 6,
2491            skip_content_types: vec![
2492                // Images (already compressed)
2493                "image/jpeg",
2494                "image/png",
2495                "image/gif",
2496                "image/webp",
2497                "image/avif",
2498                // Video/Audio (already compressed)
2499                "video/",
2500                "audio/",
2501                // Archives (already compressed)
2502                "application/zip",
2503                "application/gzip",
2504                "application/x-gzip",
2505                "application/x-bzip2",
2506                "application/x-xz",
2507                "application/x-7z-compressed",
2508                "application/x-rar-compressed",
2509                // Other compressed formats
2510                "application/pdf",
2511                "application/woff",
2512                "application/woff2",
2513                "font/woff",
2514                "font/woff2",
2515            ],
2516        }
2517    }
2518}
2519
2520#[cfg(feature = "compression")]
2521impl CompressionConfig {
2522    /// Creates a new configuration with default values.
2523    #[must_use]
2524    pub fn new() -> Self {
2525        Self::default()
2526    }
2527
2528    /// Sets the minimum response size to compress.
2529    ///
2530    /// Responses smaller than this threshold will not be compressed,
2531    /// as compression overhead may exceed the savings.
2532    #[must_use]
2533    pub fn min_size(mut self, size: usize) -> Self {
2534        self.min_size = size;
2535        self
2536    }
2537
2538    /// Sets the compression level (1-9).
2539    ///
2540    /// - 1: Fastest compression, lowest ratio
2541    /// - 6: Balanced (default)
2542    /// - 9: Best compression ratio, slowest
2543    ///
2544    /// Values outside 1-9 are clamped.
2545    #[must_use]
2546    pub fn level(mut self, level: u32) -> Self {
2547        self.level = level.clamp(1, 9);
2548        self
2549    }
2550
2551    /// Adds a content type to skip during compression.
2552    ///
2553    /// Content types can be exact matches or prefixes (e.g., "video/" matches all video types).
2554    #[must_use]
2555    pub fn skip_content_type(mut self, content_type: &'static str) -> Self {
2556        self.skip_content_types.push(content_type);
2557        self
2558    }
2559
2560    /// Checks if the given content type should be skipped.
2561    fn should_skip_content_type(&self, content_type: &str) -> bool {
2562        let ct_lower = content_type.to_ascii_lowercase();
2563        for skip in &self.skip_content_types {
2564            if skip.ends_with('/') {
2565                // Prefix match (e.g., "video/" matches "video/mp4")
2566                if ct_lower.starts_with(*skip) {
2567                    return true;
2568                }
2569            } else {
2570                // Exact match (with optional charset)
2571                if ct_lower == *skip || ct_lower.starts_with(&format!("{skip};")) {
2572                    return true;
2573                }
2574            }
2575        }
2576        false
2577    }
2578}
2579
2580/// Middleware that compresses responses using gzip.
2581///
2582/// This middleware inspects the `Accept-Encoding` header and compresses
2583/// eligible responses with gzip. Compression is skipped for:
2584/// - Responses smaller than `min_size`
2585/// - Responses with already-compressed content types
2586/// - Responses that already have a `Content-Encoding` header
2587/// - Clients that don't accept gzip
2588///
2589/// # Example
2590///
2591/// ```ignore
2592/// use fastapi_core::middleware::{CompressionMiddleware, CompressionConfig, MiddlewareStack};
2593///
2594/// let mut stack = MiddlewareStack::new();
2595///
2596/// // Default configuration
2597/// stack.push(CompressionMiddleware::new());
2598///
2599/// // Or with custom settings
2600/// let config = CompressionConfig::new()
2601///     .min_size(256)   // Compress smaller responses
2602///     .level(9);       // Maximum compression
2603/// stack.push(CompressionMiddleware::with_config(config));
2604/// ```
2605///
2606/// # Headers
2607///
2608/// When compression is applied:
2609/// - `Content-Encoding: gzip` is added
2610/// - `Vary: Accept-Encoding` is added (for caching)
2611/// - `Content-Length` is updated to reflect compressed size
2612#[cfg(feature = "compression")]
2613#[derive(Debug, Clone)]
2614pub struct CompressionMiddleware {
2615    config: CompressionConfig,
2616}
2617
2618#[cfg(feature = "compression")]
2619impl Default for CompressionMiddleware {
2620    fn default() -> Self {
2621        Self::new()
2622    }
2623}
2624
2625#[cfg(feature = "compression")]
2626impl CompressionMiddleware {
2627    /// Creates compression middleware with default configuration.
2628    #[must_use]
2629    pub fn new() -> Self {
2630        Self {
2631            config: CompressionConfig::default(),
2632        }
2633    }
2634
2635    /// Creates compression middleware with custom configuration.
2636    #[must_use]
2637    pub fn with_config(config: CompressionConfig) -> Self {
2638        Self { config }
2639    }
2640
2641    /// Checks if the client accepts gzip encoding.
2642    fn accepts_gzip(req: &Request) -> bool {
2643        if let Some(accept_encoding) = req.headers().get("accept-encoding") {
2644            if let Ok(value) = std::str::from_utf8(accept_encoding) {
2645                // Parse Accept-Encoding header
2646                // Examples: "gzip", "gzip, deflate", "gzip;q=1.0, identity;q=0.5"
2647                for part in value.split(',') {
2648                    let encoding = part.trim().split(';').next().unwrap_or("").trim();
2649                    if encoding.eq_ignore_ascii_case("gzip") {
2650                        return true;
2651                    }
2652                    // Also accept "*" which means any encoding
2653                    if encoding == "*" {
2654                        return true;
2655                    }
2656                }
2657            }
2658        }
2659        false
2660    }
2661
2662    /// Gets the Content-Type from response headers.
2663    fn get_content_type(headers: &[(String, Vec<u8>)]) -> Option<String> {
2664        for (name, value) in headers {
2665            if name.eq_ignore_ascii_case("content-type") {
2666                return std::str::from_utf8(value).ok().map(String::from);
2667            }
2668        }
2669        None
2670    }
2671
2672    /// Checks if response already has Content-Encoding header.
2673    fn has_content_encoding(headers: &[(String, Vec<u8>)]) -> bool {
2674        headers
2675            .iter()
2676            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
2677    }
2678
2679    /// Compresses data using gzip.
2680    fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>, std::io::Error> {
2681        use flate2::Compression;
2682        use flate2::write::GzEncoder;
2683        use std::io::Write;
2684
2685        let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
2686        encoder.write_all(data)?;
2687        encoder.finish()
2688    }
2689}
2690
2691#[cfg(feature = "compression")]
2692impl Middleware for CompressionMiddleware {
2693    fn after<'a>(
2694        &'a self,
2695        _ctx: &'a RequestContext,
2696        req: &'a Request,
2697        response: Response,
2698    ) -> BoxFuture<'a, Response> {
2699        let config = self.config.clone();
2700
2701        Box::pin(async move {
2702            // Check if client accepts gzip
2703            if !Self::accepts_gzip(req) {
2704                return response;
2705            }
2706
2707            // Decompose response to inspect body
2708            let (status, headers, body) = response.into_parts();
2709
2710            // Check if already compressed
2711            if Self::has_content_encoding(&headers) {
2712                return Response::with_status(status)
2713                    .body(body)
2714                    .rebuild_with_headers(headers);
2715            }
2716
2717            // Get body bytes (only compress Bytes variant, not streaming)
2718            let body_bytes = match body {
2719                crate::response::ResponseBody::Bytes(bytes) => bytes,
2720                other => {
2721                    // Can't compress Empty or Stream bodies
2722                    return Response::with_status(status)
2723                        .body(other)
2724                        .rebuild_with_headers(headers);
2725                }
2726            };
2727
2728            // Check minimum size
2729            if body_bytes.len() < config.min_size {
2730                return Response::with_status(status)
2731                    .body(crate::response::ResponseBody::Bytes(body_bytes))
2732                    .rebuild_with_headers(headers);
2733            }
2734
2735            // Check content type
2736            if let Some(content_type) = Self::get_content_type(&headers) {
2737                if config.should_skip_content_type(&content_type) {
2738                    return Response::with_status(status)
2739                        .body(crate::response::ResponseBody::Bytes(body_bytes))
2740                        .rebuild_with_headers(headers);
2741                }
2742            }
2743
2744            // Compress the body
2745            match Self::compress_gzip(&body_bytes, config.level) {
2746                Ok(compressed) => {
2747                    // Only use compressed if it's actually smaller
2748                    if compressed.len() >= body_bytes.len() {
2749                        return Response::with_status(status)
2750                            .body(crate::response::ResponseBody::Bytes(body_bytes))
2751                            .rebuild_with_headers(headers);
2752                    }
2753
2754                    // Build response with compression headers
2755                    let mut resp = Response::with_status(status)
2756                        .body(crate::response::ResponseBody::Bytes(compressed));
2757
2758                    // Copy original headers (except content-length)
2759                    for (name, value) in headers {
2760                        if !name.eq_ignore_ascii_case("content-length") {
2761                            resp = resp.header(name, value);
2762                        }
2763                    }
2764
2765                    // Add compression headers
2766                    resp = resp.header("Content-Encoding", b"gzip".to_vec());
2767                    resp = resp.header("Vary", b"Accept-Encoding".to_vec());
2768
2769                    resp
2770                }
2771                Err(_) => {
2772                    // Compression failed, return original
2773                    Response::with_status(status)
2774                        .body(crate::response::ResponseBody::Bytes(body_bytes))
2775                        .rebuild_with_headers(headers)
2776                }
2777            }
2778        })
2779    }
2780
2781    fn name(&self) -> &'static str {
2782        "Compression"
2783    }
2784}
2785
2786// ---------------------------------------------------------------------------
2787// Rate Limiting Middleware
2788// ---------------------------------------------------------------------------
2789
2790use parking_lot::Mutex;
2791use std::collections::HashMap as StdHashMap;
2792use std::time::Duration;
2793
2794/// Rate limiting algorithm.
2795#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2796pub enum RateLimitAlgorithm {
2797    /// Token bucket: steady refill rate, allows short bursts.
2798    TokenBucket,
2799    /// Fixed window: resets at the start of each interval.
2800    FixedWindow,
2801    /// Sliding window: weighted combination of current and previous window.
2802    SlidingWindow,
2803}
2804
2805/// Result of a rate limit check.
2806#[derive(Debug, Clone)]
2807pub struct RateLimitResult {
2808    /// Whether the request is allowed.
2809    pub allowed: bool,
2810    /// Maximum requests per window.
2811    pub limit: u64,
2812    /// Remaining requests in the current window.
2813    pub remaining: u64,
2814    /// Seconds until the window resets.
2815    pub reset_after_secs: u64,
2816}
2817
2818/// Extracts a rate limit key from a request.
2819///
2820/// Different extractors allow rate limiting by different criteria:
2821/// IP address, API key header, path, or custom logic.
2822pub trait KeyExtractor: Send + Sync {
2823    /// Extract the key string from the request.
2824    ///
2825    /// Returns `None` if no key can be extracted (request is not rate-limited).
2826    fn extract_key(&self, req: &Request) -> Option<String>;
2827}
2828
2829/// The remote address (peer IP) of the TCP connection.
2830///
2831/// This should be set by the HTTP server layer as a request extension to enable
2832/// secure IP-based rate limiting. Unlike `X-Forwarded-For` headers, this value
2833/// cannot be spoofed by clients.
2834///
2835/// # Example
2836///
2837/// ```ignore
2838/// // In your HTTP server code:
2839/// use fastapi_core::middleware::RemoteAddr;
2840/// use std::net::IpAddr;
2841///
2842/// // When accepting a connection:
2843/// let peer_addr: IpAddr = socket.peer_addr()?.ip();
2844/// request.insert_extension(RemoteAddr(peer_addr));
2845/// ```
2846#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
2847pub struct RemoteAddr(pub std::net::IpAddr);
2848
2849impl std::fmt::Display for RemoteAddr {
2850    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2851        write!(f, "{}", self.0)
2852    }
2853}
2854
2855/// Rate limit by the actual TCP connection IP address.
2856///
2857/// This is the **secure** way to do IP-based rate limiting. It uses the
2858/// `RemoteAddr` extension set by the HTTP server, which represents the actual
2859/// TCP peer address and cannot be spoofed by clients.
2860///
2861/// # Prerequisites
2862///
2863/// Your HTTP server must set the `RemoteAddr` extension on each request:
2864///
2865/// ```ignore
2866/// request.insert_extension(RemoteAddr(peer_addr.ip()));
2867/// ```
2868///
2869/// If `RemoteAddr` is not set, this extractor returns `None` (request is not rate-limited).
2870///
2871/// # Security
2872///
2873/// This extractor is safe to use without a reverse proxy, as it relies on the
2874/// TCP connection's peer address rather than client-supplied headers.
2875#[derive(Debug, Clone)]
2876pub struct ConnectedIpKeyExtractor;
2877
2878impl KeyExtractor for ConnectedIpKeyExtractor {
2879    fn extract_key(&self, req: &Request) -> Option<String> {
2880        req.get_extension::<RemoteAddr>().map(ToString::to_string)
2881    }
2882}
2883
2884/// Rate limit by client IP address from `X-Forwarded-For` or `X-Real-IP` headers.
2885///
2886/// # Security Warning
2887///
2888/// **This extractor trusts client-supplied headers, which can be spoofed!**
2889///
2890/// Only use this extractor when:
2891/// 1. Your application runs behind a trusted reverse proxy (nginx, Cloudflare, etc.)
2892/// 2. The proxy is configured to set/override these headers
2893/// 3. Clients cannot connect directly to your application
2894///
2895/// For direct client connections, use [`ConnectedIpKeyExtractor`] instead.
2896///
2897/// # How Proxies Work
2898///
2899/// When a request passes through proxies:
2900/// - `X-Forwarded-For: client_ip, proxy1_ip, proxy2_ip`
2901/// - The first IP is typically the original client
2902/// - Each proxy appends its own IP
2903///
2904/// This extractor takes the **first** IP from `X-Forwarded-For`, which is correct
2905/// only if your trusted proxy always sets/overwrites this header.
2906///
2907/// # Fallback Behavior
2908///
2909/// Falls back to `"unknown"` when no IP header is present, which means all such
2910/// requests share the same rate limit bucket. This may not be desirable in
2911/// production - consider using [`TrustedProxyIpKeyExtractor`] for better control.
2912#[derive(Debug, Clone)]
2913pub struct IpKeyExtractor;
2914
2915impl KeyExtractor for IpKeyExtractor {
2916    fn extract_key(&self, req: &Request) -> Option<String> {
2917        // Try X-Forwarded-For first, then X-Real-IP, then fall back
2918        if let Some(forwarded) = req.headers().get("x-forwarded-for") {
2919            if let Ok(s) = std::str::from_utf8(forwarded) {
2920                // Take the first IP (client IP) from the chain
2921                if let Some(ip) = s.split(',').next() {
2922                    return Some(ip.trim().to_string());
2923                }
2924            }
2925        }
2926        if let Some(real_ip) = req.headers().get("x-real-ip") {
2927            if let Ok(s) = std::str::from_utf8(real_ip) {
2928                return Some(s.trim().to_string());
2929            }
2930        }
2931        Some("unknown".to_string())
2932    }
2933}
2934
2935/// Rate limit by client IP with trusted proxy validation.
2936///
2937/// This is a **secure** IP extractor that only trusts `X-Forwarded-For` headers
2938/// when the immediate upstream (TCP peer) is a known trusted proxy.
2939///
2940/// # How It Works
2941///
2942/// 1. If `RemoteAddr` extension is set and matches a trusted proxy CIDR:
2943///    - Extract client IP from `X-Forwarded-For` (first IP in chain)
2944/// 2. If `RemoteAddr` is set but NOT a trusted proxy:
2945///    - Use the `RemoteAddr` directly (the client connected directly)
2946/// 3. If `RemoteAddr` is not set:
2947///    - Returns `None` (request is not rate-limited) - safer than guessing
2948///
2949/// # Example
2950///
2951/// ```ignore
2952/// use fastapi_core::middleware::{TrustedProxyIpKeyExtractor, RateLimitMiddleware};
2953///
2954/// let extractor = TrustedProxyIpKeyExtractor::new()
2955///     .trust_cidr("10.0.0.0/8")      // Internal network
2956///     .trust_cidr("172.16.0.0/12")   // Docker default
2957///     .trust_loopback();              // localhost
2958///
2959/// let rate_limiter = RateLimitMiddleware::builder()
2960///     .requests(100)
2961///     .per(Duration::from_secs(60))
2962///     .key_extractor(extractor)
2963///     .build();
2964/// ```
2965#[derive(Debug, Clone)]
2966pub struct TrustedProxyIpKeyExtractor {
2967    /// List of trusted proxy CIDRs (stored as (ip, prefix_len))
2968    trusted_cidrs: Vec<(std::net::IpAddr, u8)>,
2969}
2970
2971impl TrustedProxyIpKeyExtractor {
2972    /// Create a new trusted proxy IP extractor with no trusted proxies.
2973    #[must_use]
2974    pub fn new() -> Self {
2975        Self {
2976            trusted_cidrs: Vec::new(),
2977        }
2978    }
2979
2980    /// Add a trusted CIDR range (e.g., "10.0.0.0/8", "192.168.1.0/24").
2981    ///
2982    /// # Panics
2983    ///
2984    /// Panics if the CIDR string is invalid.
2985    #[must_use]
2986    pub fn trust_cidr(mut self, cidr: &str) -> Self {
2987        let (ip, prefix) = parse_cidr(cidr).expect("invalid CIDR notation");
2988        self.trusted_cidrs.push((ip, prefix));
2989        self
2990    }
2991
2992    /// Trust loopback addresses (127.0.0.0/8 for IPv4, ::1/128 for IPv6).
2993    #[must_use]
2994    pub fn trust_loopback(mut self) -> Self {
2995        self.trusted_cidrs.push((
2996            std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 0)),
2997            8,
2998        ));
2999        self.trusted_cidrs
3000            .push((std::net::IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), 128));
3001        self
3002    }
3003
3004    /// Check if an IP is within any trusted CIDR range.
3005    fn is_trusted(&self, ip: std::net::IpAddr) -> bool {
3006        self.trusted_cidrs
3007            .iter()
3008            .any(|(cidr_ip, prefix)| ip_in_cidr(ip, *cidr_ip, *prefix))
3009    }
3010
3011    /// Extract client IP from X-Forwarded-For header.
3012    fn extract_from_header(&self, req: &Request) -> Option<String> {
3013        if let Some(forwarded) = req.headers().get("x-forwarded-for") {
3014            if let Ok(s) = std::str::from_utf8(forwarded) {
3015                if let Some(ip) = s.split(',').next() {
3016                    return Some(ip.trim().to_string());
3017                }
3018            }
3019        }
3020        if let Some(real_ip) = req.headers().get("x-real-ip") {
3021            if let Ok(s) = std::str::from_utf8(real_ip) {
3022                return Some(s.trim().to_string());
3023            }
3024        }
3025        None
3026    }
3027}
3028
3029impl Default for TrustedProxyIpKeyExtractor {
3030    fn default() -> Self {
3031        Self::new()
3032    }
3033}
3034
3035impl KeyExtractor for TrustedProxyIpKeyExtractor {
3036    fn extract_key(&self, req: &Request) -> Option<String> {
3037        let remote = req.get_extension::<RemoteAddr>()?;
3038
3039        if self.is_trusted(remote.0) {
3040            // Request came from trusted proxy - use header value
3041            self.extract_from_header(req)
3042                .or_else(|| Some(remote.to_string()))
3043        } else {
3044            // Request came directly from client - use connection IP
3045            Some(remote.to_string())
3046        }
3047    }
3048}
3049
3050/// Parse a CIDR string like "192.168.1.0/24" into (ip, prefix_length).
3051fn parse_cidr(cidr: &str) -> Option<(std::net::IpAddr, u8)> {
3052    let (ip_str, prefix_str) = cidr.split_once('/')?;
3053    let ip: std::net::IpAddr = ip_str.parse().ok()?;
3054    let prefix: u8 = prefix_str.parse().ok()?;
3055
3056    // Validate prefix length
3057    let max_prefix = match ip {
3058        std::net::IpAddr::V4(_) => 32,
3059        std::net::IpAddr::V6(_) => 128,
3060    };
3061    if prefix > max_prefix {
3062        return None;
3063    }
3064
3065    Some((ip, prefix))
3066}
3067
3068/// Check if an IP address is within a CIDR range.
3069fn ip_in_cidr(ip: std::net::IpAddr, cidr_ip: std::net::IpAddr, prefix: u8) -> bool {
3070    match (ip, cidr_ip) {
3071        (std::net::IpAddr::V4(ip), std::net::IpAddr::V4(cidr)) => {
3072            if prefix == 0 {
3073                return true;
3074            }
3075            let ip_bits = u32::from(ip);
3076            let cidr_bits = u32::from(cidr);
3077            let mask = !0u32 << (32 - prefix);
3078            (ip_bits & mask) == (cidr_bits & mask)
3079        }
3080        (std::net::IpAddr::V6(ip), std::net::IpAddr::V6(cidr)) => {
3081            if prefix == 0 {
3082                return true;
3083            }
3084            let ip_bits = u128::from(ip);
3085            let cidr_bits = u128::from(cidr);
3086            let mask = !0u128 << (128 - prefix);
3087            (ip_bits & mask) == (cidr_bits & mask)
3088        }
3089        _ => false, // IPv4 vs IPv6 mismatch
3090    }
3091}
3092
3093/// Rate limit by a specific header value (e.g., `X-API-Key`).
3094#[derive(Debug, Clone)]
3095pub struct HeaderKeyExtractor {
3096    header_name: String,
3097}
3098
3099impl HeaderKeyExtractor {
3100    /// Create a new header key extractor.
3101    #[must_use]
3102    pub fn new(header_name: impl Into<String>) -> Self {
3103        Self {
3104            header_name: header_name.into(),
3105        }
3106    }
3107}
3108
3109impl KeyExtractor for HeaderKeyExtractor {
3110    fn extract_key(&self, req: &Request) -> Option<String> {
3111        req.headers()
3112            .get(&self.header_name)
3113            .and_then(|v| std::str::from_utf8(v).ok())
3114            .map(str::to_string)
3115    }
3116}
3117
3118/// Rate limit by request path.
3119#[derive(Debug, Clone)]
3120pub struct PathKeyExtractor;
3121
3122impl KeyExtractor for PathKeyExtractor {
3123    fn extract_key(&self, req: &Request) -> Option<String> {
3124        Some(req.path().to_string())
3125    }
3126}
3127
3128/// A composite key extractor that combines multiple extractors.
3129///
3130/// Keys from all extractors are joined with `:` to form a composite key.
3131/// If any extractor returns `None`, that part is omitted.
3132pub struct CompositeKeyExtractor {
3133    extractors: Vec<Box<dyn KeyExtractor>>,
3134}
3135
3136impl CompositeKeyExtractor {
3137    /// Create a composite key extractor from multiple extractors.
3138    #[must_use]
3139    pub fn new(extractors: Vec<Box<dyn KeyExtractor>>) -> Self {
3140        Self { extractors }
3141    }
3142}
3143
3144impl KeyExtractor for CompositeKeyExtractor {
3145    fn extract_key(&self, req: &Request) -> Option<String> {
3146        let parts: Vec<String> = self
3147            .extractors
3148            .iter()
3149            .filter_map(|e| e.extract_key(req))
3150            .collect();
3151        if parts.is_empty() {
3152            None
3153        } else {
3154            Some(parts.join(":"))
3155        }
3156    }
3157}
3158
3159/// Token bucket state for a single key.
3160#[derive(Debug, Clone)]
3161struct TokenBucketState {
3162    tokens: f64,
3163    last_refill: Instant,
3164}
3165
3166/// Fixed window state for a single key.
3167#[derive(Debug, Clone)]
3168struct FixedWindowState {
3169    count: u64,
3170    window_start: Instant,
3171}
3172
3173/// Sliding window state for a single key.
3174#[derive(Debug, Clone)]
3175struct SlidingWindowState {
3176    current_count: u64,
3177    previous_count: u64,
3178    current_window_start: Instant,
3179}
3180
3181/// In-memory rate limit store.
3182///
3183/// Uses a `HashMap` protected by a `Mutex` for thread-safe access.
3184/// Suitable for single-process deployments. For distributed systems,
3185/// implement a custom store using Redis or similar.
3186pub struct InMemoryRateLimitStore {
3187    token_buckets: Mutex<StdHashMap<String, TokenBucketState>>,
3188    fixed_windows: Mutex<StdHashMap<String, FixedWindowState>>,
3189    sliding_windows: Mutex<StdHashMap<String, SlidingWindowState>>,
3190}
3191
3192impl InMemoryRateLimitStore {
3193    /// Create a new in-memory store.
3194    #[must_use]
3195    pub fn new() -> Self {
3196        Self {
3197            token_buckets: Mutex::new(StdHashMap::new()),
3198            fixed_windows: Mutex::new(StdHashMap::new()),
3199            sliding_windows: Mutex::new(StdHashMap::new()),
3200        }
3201    }
3202
3203    #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3204    fn check_token_bucket(
3205        &self,
3206        key: &str,
3207        max_tokens: u64,
3208        refill_rate: f64,
3209        window: Duration,
3210    ) -> RateLimitResult {
3211        let mut buckets = self.token_buckets.lock();
3212        let now = Instant::now();
3213
3214        let state = buckets
3215            .entry(key.to_string())
3216            .or_insert_with(|| TokenBucketState {
3217                tokens: max_tokens as f64,
3218                last_refill: now,
3219            });
3220
3221        // Refill tokens based on elapsed time
3222        let elapsed = now.duration_since(state.last_refill);
3223        let refill = elapsed.as_secs_f64() * refill_rate;
3224        state.tokens = (state.tokens + refill).min(max_tokens as f64);
3225        state.last_refill = now;
3226
3227        if state.tokens >= 1.0 {
3228            state.tokens -= 1.0;
3229            RateLimitResult {
3230                allowed: true,
3231                limit: max_tokens,
3232                remaining: state.tokens as u64,
3233                reset_after_secs: if state.tokens < max_tokens as f64 {
3234                    ((max_tokens as f64 - state.tokens) / refill_rate).ceil() as u64
3235                } else {
3236                    window.as_secs()
3237                },
3238            }
3239        } else {
3240            let wait_secs = ((1.0 - state.tokens) / refill_rate).ceil() as u64;
3241            RateLimitResult {
3242                allowed: false,
3243                limit: max_tokens,
3244                remaining: 0,
3245                reset_after_secs: wait_secs,
3246            }
3247        }
3248    }
3249
3250    fn check_fixed_window(
3251        &self,
3252        key: &str,
3253        max_requests: u64,
3254        window: Duration,
3255    ) -> RateLimitResult {
3256        let mut windows = self.fixed_windows.lock();
3257        let now = Instant::now();
3258
3259        let state = windows
3260            .entry(key.to_string())
3261            .or_insert_with(|| FixedWindowState {
3262                count: 0,
3263                window_start: now,
3264            });
3265
3266        // Check if window has expired
3267        let elapsed = now.duration_since(state.window_start);
3268        if elapsed >= window {
3269            state.count = 0;
3270            state.window_start = now;
3271        }
3272
3273        let remaining_time = window
3274            .checked_sub(now.duration_since(state.window_start))
3275            .unwrap_or(Duration::ZERO);
3276
3277        if state.count < max_requests {
3278            state.count += 1;
3279            RateLimitResult {
3280                allowed: true,
3281                limit: max_requests,
3282                remaining: max_requests - state.count,
3283                reset_after_secs: remaining_time.as_secs(),
3284            }
3285        } else {
3286            RateLimitResult {
3287                allowed: false,
3288                limit: max_requests,
3289                remaining: 0,
3290                reset_after_secs: remaining_time.as_secs(),
3291            }
3292        }
3293    }
3294
3295    #[allow(clippy::cast_precision_loss, clippy::cast_sign_loss)]
3296    fn check_sliding_window(
3297        &self,
3298        key: &str,
3299        max_requests: u64,
3300        window: Duration,
3301    ) -> RateLimitResult {
3302        let mut windows = self.sliding_windows.lock();
3303        let now = Instant::now();
3304
3305        let state = windows
3306            .entry(key.to_string())
3307            .or_insert_with(|| SlidingWindowState {
3308                current_count: 0,
3309                previous_count: 0,
3310                current_window_start: now,
3311            });
3312
3313        // Check if we need to rotate windows
3314        let elapsed = now.duration_since(state.current_window_start);
3315        if elapsed >= window {
3316            // Rotate: current becomes previous
3317            state.previous_count = state.current_count;
3318            state.current_count = 0;
3319            state.current_window_start = now;
3320        }
3321
3322        // Calculate weighted count using the proportion of the previous window
3323        // that overlaps with the current sliding window
3324        let window_elapsed = now.duration_since(state.current_window_start);
3325        let window_fraction = window_elapsed.as_secs_f64() / window.as_secs_f64();
3326        let previous_weight = 1.0 - window_fraction;
3327        let weighted_count =
3328            (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3329
3330        let remaining_time = window.checked_sub(window_elapsed).unwrap_or(Duration::ZERO);
3331
3332        if weighted_count < max_requests as f64 {
3333            state.current_count += 1;
3334            let new_weighted =
3335                (state.previous_count as f64 * previous_weight) + state.current_count as f64;
3336            let remaining = (max_requests as f64 - new_weighted).max(0.0) as u64;
3337            RateLimitResult {
3338                allowed: true,
3339                limit: max_requests,
3340                remaining,
3341                reset_after_secs: remaining_time.as_secs(),
3342            }
3343        } else {
3344            RateLimitResult {
3345                allowed: false,
3346                limit: max_requests,
3347                remaining: 0,
3348                reset_after_secs: remaining_time.as_secs(),
3349            }
3350        }
3351    }
3352
3353    /// Check and consume a request against the rate limit.
3354    #[allow(clippy::cast_precision_loss)]
3355    pub fn check(
3356        &self,
3357        key: &str,
3358        algorithm: RateLimitAlgorithm,
3359        max_requests: u64,
3360        window: Duration,
3361    ) -> RateLimitResult {
3362        match algorithm {
3363            RateLimitAlgorithm::TokenBucket => {
3364                let refill_rate = max_requests as f64 / window.as_secs_f64();
3365                self.check_token_bucket(key, max_requests, refill_rate, window)
3366            }
3367            RateLimitAlgorithm::FixedWindow => self.check_fixed_window(key, max_requests, window),
3368            RateLimitAlgorithm::SlidingWindow => {
3369                self.check_sliding_window(key, max_requests, window)
3370            }
3371        }
3372    }
3373}
3374
3375impl Default for InMemoryRateLimitStore {
3376    fn default() -> Self {
3377        Self::new()
3378    }
3379}
3380
3381/// Configuration for the rate limiting middleware.
3382///
3383/// Controls request rate limits using token bucket or sliding window algorithms.
3384/// When the limit is exceeded, a 429 Too Many Requests response is returned.
3385///
3386/// # Defaults
3387///
3388/// | Setting | Default |
3389/// |---------|---------|
3390/// | `max_requests` | 100 |
3391/// | `window` | 60s |
3392/// | `algorithm` | `TokenBucket` |
3393/// | `include_headers` | `true` |
3394/// | `retry_message` | "Rate limit exceeded. Please retry later." |
3395///
3396/// # Response Headers (when `include_headers` is `true`)
3397///
3398/// - `X-RateLimit-Limit`: Maximum requests per window
3399/// - `X-RateLimit-Remaining`: Remaining requests in current window
3400/// - `X-RateLimit-Reset`: Seconds until window resets
3401/// - `Retry-After`: Seconds to wait (only on 429 responses)
3402///
3403/// # Example
3404///
3405/// ```ignore
3406/// use fastapi_core::middleware::{RateLimitBuilder, RateLimitAlgorithm};
3407///
3408/// let rate_limit = RateLimitBuilder::new()
3409///     .max_requests(1000)
3410///     .window_secs(3600) // 1000 req/hour
3411///     .algorithm(RateLimitAlgorithm::SlidingWindow)
3412///     .build();
3413/// ```
3414#[derive(Clone)]
3415pub struct RateLimitConfig {
3416    /// Maximum number of requests allowed per window.
3417    pub max_requests: u64,
3418    /// Time window for the rate limit.
3419    pub window: Duration,
3420    /// The algorithm to use.
3421    pub algorithm: RateLimitAlgorithm,
3422    /// Whether to include rate limit headers in responses.
3423    pub include_headers: bool,
3424    /// Custom message for 429 responses.
3425    pub retry_message: String,
3426}
3427
3428impl Default for RateLimitConfig {
3429    fn default() -> Self {
3430        Self {
3431            max_requests: 100,
3432            window: Duration::from_secs(60),
3433            algorithm: RateLimitAlgorithm::TokenBucket,
3434            include_headers: true,
3435            retry_message: "Rate limit exceeded. Please retry later.".to_string(),
3436        }
3437    }
3438}
3439
3440/// Builder for `RateLimitConfig`.
3441pub struct RateLimitBuilder {
3442    config: RateLimitConfig,
3443    key_extractor: Option<Box<dyn KeyExtractor>>,
3444}
3445
3446impl RateLimitBuilder {
3447    /// Create a new rate limit builder with default configuration.
3448    #[must_use]
3449    pub fn new() -> Self {
3450        Self {
3451            config: RateLimitConfig::default(),
3452            key_extractor: None,
3453        }
3454    }
3455
3456    /// Set the maximum number of requests per window.
3457    #[must_use]
3458    pub fn requests(mut self, max: u64) -> Self {
3459        self.config.max_requests = max;
3460        self
3461    }
3462
3463    /// Set the time window.
3464    #[must_use]
3465    pub fn per(mut self, window: Duration) -> Self {
3466        self.config.window = window;
3467        self
3468    }
3469
3470    /// Shorthand: set the window to the given number of seconds.
3471    #[must_use]
3472    pub fn per_second(self, secs: u64) -> Self {
3473        self.per(Duration::from_secs(secs))
3474    }
3475
3476    /// Shorthand: set the window to the given number of minutes.
3477    #[must_use]
3478    pub fn per_minute(self, minutes: u64) -> Self {
3479        self.per(Duration::from_secs(minutes * 60))
3480    }
3481
3482    /// Shorthand: set the window to the given number of hours.
3483    #[must_use]
3484    pub fn per_hour(self, hours: u64) -> Self {
3485        self.per(Duration::from_secs(hours * 3600))
3486    }
3487
3488    /// Set the rate limiting algorithm.
3489    #[must_use]
3490    pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
3491        self.config.algorithm = algo;
3492        self
3493    }
3494
3495    /// Set the key extractor.
3496    #[must_use]
3497    pub fn key_extractor(mut self, extractor: impl KeyExtractor + 'static) -> Self {
3498        self.key_extractor = Some(Box::new(extractor));
3499        self
3500    }
3501
3502    /// Whether to include rate limit headers in responses.
3503    #[must_use]
3504    pub fn include_headers(mut self, include: bool) -> Self {
3505        self.config.include_headers = include;
3506        self
3507    }
3508
3509    /// Set the custom message for 429 responses.
3510    #[must_use]
3511    pub fn retry_message(mut self, msg: impl Into<String>) -> Self {
3512        self.config.retry_message = msg.into();
3513        self
3514    }
3515
3516    /// Build the rate limiting middleware.
3517    #[must_use]
3518    pub fn build(self) -> RateLimitMiddleware {
3519        let key_extractor = self
3520            .key_extractor
3521            .unwrap_or_else(|| Box::new(IpKeyExtractor));
3522        RateLimitMiddleware {
3523            config: self.config,
3524            store: Arc::new(InMemoryRateLimitStore::new()),
3525            key_extractor: Arc::from(key_extractor),
3526        }
3527    }
3528}
3529
3530impl Default for RateLimitBuilder {
3531    fn default() -> Self {
3532        Self::new()
3533    }
3534}
3535
3536/// Extension type stored on requests to carry rate limit info to `after` hook.
3537#[derive(Debug, Clone)]
3538struct RateLimitInfo {
3539    result: RateLimitResult,
3540}
3541
3542/// Rate limiting middleware.
3543///
3544/// Tracks request rates per key and returns 429 Too Many Requests
3545/// when a client exceeds the configured limit.
3546///
3547/// # Example
3548///
3549/// ```ignore
3550/// use fastapi_core::middleware::{RateLimitMiddleware, RateLimitAlgorithm, IpKeyExtractor};
3551/// use std::time::Duration;
3552///
3553/// let rate_limiter = RateLimitMiddleware::builder()
3554///     .requests(100)
3555///     .per(Duration::from_secs(60))
3556///     .algorithm(RateLimitAlgorithm::TokenBucket)
3557///     .key_extractor(IpKeyExtractor)
3558///     .build();
3559///
3560/// let app = App::builder()
3561///     .middleware(rate_limiter)
3562///     .build();
3563/// ```
3564pub struct RateLimitMiddleware {
3565    config: RateLimitConfig,
3566    store: Arc<InMemoryRateLimitStore>,
3567    key_extractor: Arc<dyn KeyExtractor>,
3568}
3569
3570impl RateLimitMiddleware {
3571    /// Create a new rate limiter with default settings (100 requests/minute, token bucket, IP-based).
3572    #[must_use]
3573    pub fn new() -> Self {
3574        Self::builder().build()
3575    }
3576
3577    /// Create a builder for configuring the rate limiter.
3578    #[must_use]
3579    pub fn builder() -> RateLimitBuilder {
3580        RateLimitBuilder::new()
3581    }
3582
3583    /// Format a 429 response body as JSON.
3584    fn too_many_requests_body(&self, result: &RateLimitResult) -> Vec<u8> {
3585        format!(
3586            r#"{{"detail":"{}","retry_after_secs":{}}}"#,
3587            self.config.retry_message, result.reset_after_secs
3588        )
3589        .into_bytes()
3590    }
3591
3592    /// Add rate limit headers to a response.
3593    fn add_headers(&self, response: Response, result: &RateLimitResult) -> Response {
3594        response
3595            .header("X-RateLimit-Limit", result.limit.to_string().into_bytes())
3596            .header(
3597                "X-RateLimit-Remaining",
3598                result.remaining.to_string().into_bytes(),
3599            )
3600            .header(
3601                "X-RateLimit-Reset",
3602                result.reset_after_secs.to_string().into_bytes(),
3603            )
3604    }
3605}
3606
3607impl Default for RateLimitMiddleware {
3608    fn default() -> Self {
3609        Self::new()
3610    }
3611}
3612
3613impl Middleware for RateLimitMiddleware {
3614    fn before<'a>(
3615        &'a self,
3616        _ctx: &'a RequestContext,
3617        req: &'a mut Request,
3618    ) -> BoxFuture<'a, ControlFlow> {
3619        Box::pin(async move {
3620            // Extract the key for this request
3621            let Some(key) = self.key_extractor.extract_key(req) else {
3622                // No key extracted — skip rate limiting for this request
3623                return ControlFlow::Continue;
3624            };
3625
3626            // Check the rate limit
3627            let result = self.store.check(
3628                &key,
3629                self.config.algorithm,
3630                self.config.max_requests,
3631                self.config.window,
3632            );
3633
3634            if result.allowed {
3635                // Store the result for the `after` hook to add headers
3636                req.insert_extension(RateLimitInfo { result });
3637                ControlFlow::Continue
3638            } else {
3639                // Return 429 Too Many Requests
3640                let body = self.too_many_requests_body(&result);
3641                let mut response =
3642                    Response::with_status(crate::response::StatusCode::TOO_MANY_REQUESTS)
3643                        .header("Content-Type", b"application/json".to_vec())
3644                        .header(
3645                            "Retry-After",
3646                            result.reset_after_secs.to_string().into_bytes(),
3647                        )
3648                        .body(crate::response::ResponseBody::Bytes(body));
3649
3650                if self.config.include_headers {
3651                    response = self.add_headers(response, &result);
3652                }
3653
3654                ControlFlow::Break(response)
3655            }
3656        })
3657    }
3658
3659    fn after<'a>(
3660        &'a self,
3661        _ctx: &'a RequestContext,
3662        req: &'a Request,
3663        response: Response,
3664    ) -> BoxFuture<'a, Response> {
3665        Box::pin(async move {
3666            if !self.config.include_headers {
3667                return response;
3668            }
3669
3670            // Retrieve the rate limit info stored in `before`
3671            if let Some(info) = req.get_extension::<RateLimitInfo>() {
3672                self.add_headers(response, &info.result)
3673            } else {
3674                response
3675            }
3676        })
3677    }
3678
3679    fn name(&self) -> &'static str {
3680        "RateLimit"
3681    }
3682}
3683
3684// ---------------------------------------------------------------------------
3685// End Rate Limiting Middleware
3686// ---------------------------------------------------------------------------
3687
3688// ============================================================================
3689// Request Inspection Middleware (Development)
3690// ============================================================================
3691
3692/// Verbosity level for the request inspection middleware.
3693///
3694/// Controls how much detail is shown in the request/response output.
3695#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3696pub enum InspectionVerbosity {
3697    /// Minimal: one-line summary per request/response.
3698    ///
3699    /// Shows: `-->  GET /path` and `<--  200 OK (12ms)`
3700    Minimal,
3701
3702    /// Normal: summary plus headers.
3703    ///
3704    /// Shows method/path, all headers (filtered), and status/timing.
3705    Normal,
3706
3707    /// Verbose: summary, headers, and body preview.
3708    ///
3709    /// Shows everything in Normal plus request/response body previews
3710    /// with JSON pretty-printing when applicable.
3711    Verbose,
3712}
3713
3714/// Development middleware that logs detailed, human-readable request/response
3715/// information using arrow-style formatting.
3716///
3717/// This middleware is designed for development and debugging. It outputs
3718/// concise inspection lines showing request flow:
3719///
3720/// ```text
3721/// -->  POST /api/users
3722///      Content-Type: application/json
3723///      Content-Length: 42
3724///      {"name": "Alice"}
3725/// <--  201 Created (12ms)
3726///      Content-Type: application/json
3727///      {"id": 1, "name": "Alice"}
3728/// ```
3729///
3730/// # Features
3731///
3732/// - **Configurable verbosity**: Minimal (one-liner), Normal (+ headers),
3733///   Verbose (+ body preview with JSON pretty-printing)
3734/// - **Slow request highlighting**: Marks requests exceeding a threshold
3735/// - **Sensitive header filtering**: Redacts authorization, cookie, etc.
3736/// - **JSON pretty-printing**: Detects JSON bodies and formats them
3737/// - **Body size limits**: Truncates large bodies to a configurable max
3738///
3739/// # Example
3740///
3741/// ```ignore
3742/// use fastapi_core::middleware::RequestInspectionMiddleware;
3743///
3744/// let inspector = RequestInspectionMiddleware::new()
3745///     .verbosity(InspectionVerbosity::Verbose)
3746///     .slow_threshold_ms(500)
3747///     .max_body_preview(4096);
3748///
3749/// let mut stack = MiddlewareStack::new();
3750/// stack.push(inspector);
3751/// ```
3752pub struct RequestInspectionMiddleware {
3753    log_config: LogConfig,
3754    verbosity: InspectionVerbosity,
3755    redact_headers: HashSet<String>,
3756    slow_threshold_ms: u64,
3757    max_body_preview: usize,
3758}
3759
3760impl Default for RequestInspectionMiddleware {
3761    fn default() -> Self {
3762        Self {
3763            log_config: LogConfig::development(),
3764            verbosity: InspectionVerbosity::Normal,
3765            redact_headers: default_redacted_headers(),
3766            slow_threshold_ms: 1000,
3767            max_body_preview: 2048,
3768        }
3769    }
3770}
3771
3772impl RequestInspectionMiddleware {
3773    /// Create a new inspection middleware with development defaults.
3774    #[must_use]
3775    pub fn new() -> Self {
3776        Self::default()
3777    }
3778
3779    /// Set the logging configuration.
3780    #[must_use]
3781    pub fn log_config(mut self, config: LogConfig) -> Self {
3782        self.log_config = config;
3783        self
3784    }
3785
3786    /// Set the verbosity level.
3787    #[must_use]
3788    pub fn verbosity(mut self, level: InspectionVerbosity) -> Self {
3789        self.verbosity = level;
3790        self
3791    }
3792
3793    /// Set the threshold (in milliseconds) above which requests are flagged as slow.
3794    #[must_use]
3795    pub fn slow_threshold_ms(mut self, ms: u64) -> Self {
3796        self.slow_threshold_ms = ms;
3797        self
3798    }
3799
3800    /// Set the maximum number of bytes to show in body previews.
3801    #[must_use]
3802    pub fn max_body_preview(mut self, max: usize) -> Self {
3803        self.max_body_preview = max;
3804        self
3805    }
3806
3807    /// Add a header name to the redaction set (case-insensitive).
3808    #[must_use]
3809    pub fn redact_header(mut self, name: impl Into<String>) -> Self {
3810        self.redact_headers.insert(name.into().to_ascii_lowercase());
3811        self
3812    }
3813
3814    /// Format a request body for display, with optional JSON pretty-printing.
3815    fn format_body_preview(&self, bytes: &[u8], content_type: Option<&[u8]>) -> Option<String> {
3816        if bytes.is_empty() || self.max_body_preview == 0 {
3817            return None;
3818        }
3819
3820        let is_json = content_type
3821            .and_then(|ct| std::str::from_utf8(ct).ok())
3822            .is_some_and(|ct| ct.contains("application/json"));
3823
3824        let limit = self.max_body_preview.min(bytes.len());
3825        let truncated = bytes.len() > self.max_body_preview;
3826
3827        match std::str::from_utf8(&bytes[..limit]) {
3828            Ok(text) => {
3829                if is_json {
3830                    // Attempt JSON pretty-printing on the full available text
3831                    if let Some(pretty) = try_pretty_json(text) {
3832                        let mut output = pretty;
3833                        if truncated {
3834                            output.push_str("\n     ... (truncated)");
3835                        }
3836                        return Some(output);
3837                    }
3838                }
3839                let mut output = text.to_string();
3840                if truncated {
3841                    output.push_str("...");
3842                }
3843                Some(output)
3844            }
3845            Err(_) => Some(format!("<{} bytes binary>", bytes.len())),
3846        }
3847    }
3848
3849    /// Format a response body for display.
3850    fn format_response_preview(
3851        &self,
3852        body: &crate::response::ResponseBody,
3853        content_type: Option<&[u8]>,
3854    ) -> Option<String> {
3855        match body {
3856            crate::response::ResponseBody::Empty => None,
3857            crate::response::ResponseBody::Bytes(bytes) => {
3858                self.format_body_preview(bytes, content_type)
3859            }
3860            crate::response::ResponseBody::Stream(_) => Some("<streaming body>".to_string()),
3861        }
3862    }
3863
3864    /// Build the formatted header block for display.
3865    fn format_inspection_headers<'a>(
3866        &self,
3867        headers: impl Iterator<Item = (&'a str, &'a [u8])>,
3868    ) -> String {
3869        let mut out = String::new();
3870        for (name, value) in headers {
3871            out.push_str("\n     ");
3872            out.push_str(name);
3873            out.push_str(": ");
3874
3875            let lowered = name.to_ascii_lowercase();
3876            if self.redact_headers.contains(&lowered) {
3877                out.push_str("[REDACTED]");
3878            } else {
3879                match std::str::from_utf8(value) {
3880                    Ok(text) => out.push_str(text),
3881                    Err(_) => out.push_str("<binary>"),
3882                }
3883            }
3884        }
3885        out
3886    }
3887
3888    /// Build the response header block from (String, Vec<u8>) pairs.
3889    fn format_response_inspection_headers(&self, headers: &[(String, Vec<u8>)]) -> String {
3890        self.format_inspection_headers(
3891            headers
3892                .iter()
3893                .map(|(name, value)| (name.as_str(), value.as_slice())),
3894        )
3895    }
3896}
3897
3898/// Extension type to store request start time for the inspection middleware.
3899#[derive(Debug, Clone)]
3900struct InspectionStart(Instant);
3901
3902impl Middleware for RequestInspectionMiddleware {
3903    fn before<'a>(
3904        &'a self,
3905        ctx: &'a RequestContext,
3906        req: &'a mut Request,
3907    ) -> BoxFuture<'a, ControlFlow> {
3908        let logger = RequestLogger::new(ctx, self.log_config.clone());
3909        req.insert_extension(InspectionStart(Instant::now()));
3910
3911        let method = req.method();
3912        let path = req.path();
3913        let query = req.query();
3914
3915        // Build the request line: "-->  GET /path?query"
3916        let mut request_line = format!("-->  {method} {path}");
3917        if let Some(q) = query {
3918            request_line.push('?');
3919            request_line.push_str(q);
3920        }
3921
3922        let body_size = body_len(req.body());
3923        if body_size > 0 {
3924            request_line.push_str(&format!(" ({body_size} bytes)"));
3925        }
3926
3927        match self.verbosity {
3928            InspectionVerbosity::Minimal => {
3929                logger.info(request_line);
3930            }
3931            InspectionVerbosity::Normal => {
3932                let headers = self.format_inspection_headers(req.headers().iter());
3933                logger.info(format!("{request_line}{headers}"));
3934            }
3935            InspectionVerbosity::Verbose => {
3936                let headers = self.format_inspection_headers(req.headers().iter());
3937                let content_type = req.headers().get("content-type");
3938                let body_preview = match req.body() {
3939                    Body::Empty => None,
3940                    Body::Bytes(bytes) => self.format_body_preview(bytes, content_type),
3941                    Body::Stream { .. } => None,
3942                };
3943
3944                let mut output = format!("{request_line}{headers}");
3945                if let Some(body) = body_preview {
3946                    output.push_str("\n     ");
3947                    // Indent multi-line body previews
3948                    output.push_str(&body.replace('\n', "\n     "));
3949                }
3950                logger.info(output);
3951            }
3952        }
3953
3954        Box::pin(async { ControlFlow::Continue })
3955    }
3956
3957    fn after<'a>(
3958        &'a self,
3959        ctx: &'a RequestContext,
3960        req: &'a Request,
3961        response: Response,
3962    ) -> BoxFuture<'a, Response> {
3963        let logger = RequestLogger::new(ctx, self.log_config.clone());
3964        let duration = req
3965            .get_extension::<InspectionStart>()
3966            .map(|start| start.0.elapsed())
3967            .unwrap_or_default();
3968
3969        let status = response.status();
3970        let duration_ms = duration.as_millis();
3971
3972        // Build the response line: "<--  200 OK (12ms)"
3973        let mut response_line = format!(
3974            "<--  {} {} ({duration_ms}ms)",
3975            status.as_u16(),
3976            status.canonical_reason(),
3977        );
3978
3979        // Flag slow requests
3980        if duration_ms >= u128::from(self.slow_threshold_ms) {
3981            response_line.push_str(" [SLOW]");
3982        }
3983
3984        match self.verbosity {
3985            InspectionVerbosity::Minimal => {
3986                if duration_ms >= u128::from(self.slow_threshold_ms) {
3987                    logger.warn(response_line);
3988                } else {
3989                    logger.info(response_line);
3990                }
3991            }
3992            InspectionVerbosity::Normal => {
3993                let headers = self.format_response_inspection_headers(response.headers());
3994                let output = format!("{response_line}{headers}");
3995                if duration_ms >= u128::from(self.slow_threshold_ms) {
3996                    logger.warn(output);
3997                } else {
3998                    logger.info(output);
3999                }
4000            }
4001            InspectionVerbosity::Verbose => {
4002                let headers = self.format_response_inspection_headers(response.headers());
4003
4004                // Find content-type from response headers for JSON detection
4005                let resp_content_type: Option<&[u8]> = response
4006                    .headers()
4007                    .iter()
4008                    .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
4009                    .map(|(_, value)| value.as_slice());
4010
4011                let body_preview =
4012                    self.format_response_preview(response.body_ref(), resp_content_type);
4013
4014                let mut output = format!("{response_line}{headers}");
4015                if let Some(body) = body_preview {
4016                    output.push_str("\n     ");
4017                    output.push_str(&body.replace('\n', "\n     "));
4018                }
4019
4020                if duration_ms >= u128::from(self.slow_threshold_ms) {
4021                    logger.warn(output);
4022                } else {
4023                    logger.info(output);
4024                }
4025            }
4026        }
4027
4028        Box::pin(async move { response })
4029    }
4030
4031    fn name(&self) -> &'static str {
4032        "RequestInspection"
4033    }
4034}
4035
4036/// Attempt to parse and pretty-print a JSON string.
4037///
4038/// Returns `None` if the input is not valid JSON. Uses a minimal
4039/// recursive formatter to avoid external dependencies.
4040fn try_pretty_json(input: &str) -> Option<String> {
4041    let trimmed = input.trim();
4042    if !trimmed.starts_with('{') && !trimmed.starts_with('[') {
4043        return None;
4044    }
4045
4046    // Validate it's actual JSON by attempting a parse, then pretty-format.
4047    let mut output = String::with_capacity(trimmed.len() * 2);
4048    if json_pretty_format(trimmed, &mut output).is_ok() {
4049        Some(output)
4050    } else {
4051        None
4052    }
4053}
4054
4055/// Minimal JSON pretty-formatter without external dependencies.
4056///
4057/// Handles objects, arrays, strings, numbers, booleans, and null.
4058/// Produces 2-space indented output.
4059fn json_pretty_format(input: &str, output: &mut String) -> Result<(), ()> {
4060    let bytes = input.as_bytes();
4061    let mut pos = 0;
4062    let mut indent: usize = 0;
4063    let mut in_string = false;
4064    let mut escape_next = false;
4065
4066    while pos < bytes.len() {
4067        let ch = bytes[pos] as char;
4068
4069        if escape_next {
4070            output.push(ch);
4071            escape_next = false;
4072            pos += 1;
4073            continue;
4074        }
4075
4076        if in_string {
4077            output.push(ch);
4078            if ch == '\\' {
4079                escape_next = true;
4080            } else if ch == '"' {
4081                in_string = false;
4082            }
4083            pos += 1;
4084            continue;
4085        }
4086
4087        match ch {
4088            '"' => {
4089                in_string = true;
4090                output.push('"');
4091            }
4092            '{' | '[' => {
4093                output.push(ch);
4094                // Peek ahead: if the next non-whitespace is the closing bracket, keep compact
4095                let peek = skip_whitespace(bytes, pos + 1);
4096                let closing = if ch == '{' { '}' } else { ']' };
4097                if peek < bytes.len() && bytes[peek] as char == closing {
4098                    output.push(closing);
4099                    pos = peek + 1;
4100                    continue;
4101                }
4102                indent += 1;
4103                output.push('\n');
4104                push_indent(output, indent);
4105            }
4106            '}' | ']' => {
4107                indent = indent.saturating_sub(1);
4108                output.push('\n');
4109                push_indent(output, indent);
4110                output.push(ch);
4111            }
4112            ':' => {
4113                output.push_str(": ");
4114            }
4115            ',' => {
4116                output.push(',');
4117                output.push('\n');
4118                push_indent(output, indent);
4119            }
4120            c if c.is_ascii_whitespace() => {
4121                // Skip whitespace outside strings
4122            }
4123            _ => {
4124                output.push(ch);
4125            }
4126        }
4127
4128        pos += 1;
4129    }
4130
4131    if in_string || indent != 0 {
4132        return Err(());
4133    }
4134
4135    Ok(())
4136}
4137
4138fn skip_whitespace(bytes: &[u8], start: usize) -> usize {
4139    let mut i = start;
4140    while i < bytes.len() && (bytes[i] as char).is_ascii_whitespace() {
4141        i += 1;
4142    }
4143    i
4144}
4145
4146fn push_indent(output: &mut String, level: usize) {
4147    for _ in 0..level {
4148        output.push_str("  ");
4149    }
4150}
4151
4152// ---------------------------------------------------------------------------
4153// End Request Inspection Middleware
4154// ---------------------------------------------------------------------------
4155
4156// ===========================================================================
4157// ETag Middleware
4158// ===========================================================================
4159
4160/// Configuration for ETag generation strategy.
4161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4162pub enum ETagMode {
4163    /// Automatically generate ETag from response body hash.
4164    /// Uses FNV-1a hash for fast, consistent ETag generation.
4165    Auto,
4166    /// Expect handler to set ETag manually. Middleware only handles
4167    /// conditional request logic (If-None-Match checking).
4168    Manual,
4169    /// Disable ETag handling entirely.
4170    Disabled,
4171}
4172
4173impl Default for ETagMode {
4174    fn default() -> Self {
4175        Self::Auto
4176    }
4177}
4178
4179/// Configuration for ETag middleware.
4180#[derive(Debug, Clone)]
4181pub struct ETagConfig {
4182    /// ETag generation mode.
4183    pub mode: ETagMode,
4184    /// Generate weak ETags (W/"...") instead of strong ETags.
4185    /// Weak ETags indicate semantic equivalence, allowing minor changes.
4186    pub weak: bool,
4187    /// Minimum response body size to generate ETag.
4188    /// Responses smaller than this won't get an ETag.
4189    pub min_size: usize,
4190}
4191
4192impl Default for ETagConfig {
4193    fn default() -> Self {
4194        Self {
4195            mode: ETagMode::Auto,
4196            weak: false,
4197            min_size: 0,
4198        }
4199    }
4200}
4201
4202impl ETagConfig {
4203    /// Create a new ETag configuration with default settings.
4204    #[must_use]
4205    pub fn new() -> Self {
4206        Self::default()
4207    }
4208
4209    /// Set the ETag generation mode.
4210    #[must_use]
4211    pub fn mode(mut self, mode: ETagMode) -> Self {
4212        self.mode = mode;
4213        self
4214    }
4215
4216    /// Enable weak ETags.
4217    #[must_use]
4218    pub fn weak(mut self, weak: bool) -> Self {
4219        self.weak = weak;
4220        self
4221    }
4222
4223    /// Set minimum body size for ETag generation.
4224    #[must_use]
4225    pub fn min_size(mut self, size: usize) -> Self {
4226        self.min_size = size;
4227        self
4228    }
4229}
4230
4231/// Middleware for ETag generation and conditional request handling.
4232///
4233/// Implements HTTP caching through ETags as defined in RFC 7232.
4234///
4235/// # Features
4236///
4237/// - **Automatic ETag generation**: Computes ETag from response body hash
4238/// - **If-None-Match handling**: Returns 304 Not Modified for GET/HEAD when ETag matches
4239/// - **Weak and strong ETags**: Configurable ETag strength
4240///
4241/// # Example
4242///
4243/// ```ignore
4244/// use fastapi_core::middleware::{ETagMiddleware, ETagConfig, ETagMode};
4245///
4246/// // Default: auto-generate strong ETags
4247/// let middleware = ETagMiddleware::new();
4248///
4249/// // With custom configuration
4250/// let middleware = ETagMiddleware::with_config(
4251///     ETagConfig::new()
4252///         .mode(ETagMode::Auto)
4253///         .weak(true)
4254///         .min_size(1024)
4255/// );
4256/// ```
4257///
4258/// # Conditional Request Flow
4259///
4260/// For GET/HEAD requests with `If-None-Match` header:
4261/// 1. Generate ETag for response body
4262/// 2. Compare with client's cached ETag
4263/// 3. If match: return 304 Not Modified (empty body)
4264/// 4. If no match: return full response with ETag header
4265///
4266/// # Note on If-Match
4267///
4268/// `If-Match` handling for PUT/PATCH/DELETE is typically done at the
4269/// application level since it requires knowledge of the current resource
4270/// state before the modification occurs.
4271pub struct ETagMiddleware {
4272    config: ETagConfig,
4273}
4274
4275impl Default for ETagMiddleware {
4276    fn default() -> Self {
4277        Self::new()
4278    }
4279}
4280
4281impl ETagMiddleware {
4282    /// Create ETag middleware with default configuration.
4283    #[must_use]
4284    pub fn new() -> Self {
4285        Self {
4286            config: ETagConfig::default(),
4287        }
4288    }
4289
4290    /// Create ETag middleware with custom configuration.
4291    #[must_use]
4292    pub fn with_config(config: ETagConfig) -> Self {
4293        Self { config }
4294    }
4295
4296    /// Generate an ETag from response body bytes using FNV-1a hash.
4297    ///
4298    /// FNV-1a is chosen for:
4299    /// - Speed: Very fast for small to medium data
4300    /// - Consistency: Deterministic output
4301    /// - Simplicity: No external dependencies
4302    fn generate_etag(data: &[u8], weak: bool) -> String {
4303        // FNV-1a 64-bit hash
4304        const FNV_OFFSET_BASIS: u64 = 0xcbf29ce484222325;
4305        const FNV_PRIME: u64 = 0x100000001b3;
4306
4307        let mut hash = FNV_OFFSET_BASIS;
4308        for &byte in data {
4309            hash ^= u64::from(byte);
4310            hash = hash.wrapping_mul(FNV_PRIME);
4311        }
4312
4313        // Format as quoted hex string
4314        if weak {
4315            format!("W/\"{:016x}\"", hash)
4316        } else {
4317            format!("\"{:016x}\"", hash)
4318        }
4319    }
4320
4321    /// Parse ETags from If-None-Match header value.
4322    ///
4323    /// Handles:
4324    /// - Single ETag: "abc123"
4325    /// - Multiple ETags: "abc123", "def456"
4326    /// - Wildcard: *
4327    /// - Weak ETags: W/"abc123"
4328    fn parse_if_none_match(value: &str) -> Vec<String> {
4329        let trimmed = value.trim();
4330
4331        // Handle wildcard
4332        if trimmed == "*" {
4333            return vec!["*".to_string()];
4334        }
4335
4336        let mut etags = Vec::new();
4337        let mut current = String::new();
4338        let mut in_quote = false;
4339        let mut prev_char = '\0';
4340
4341        for ch in trimmed.chars() {
4342            match ch {
4343                '"' if prev_char != '\\' => {
4344                    current.push(ch);
4345                    if in_quote {
4346                        // End of ETag value
4347                        let etag = current.trim().to_string();
4348                        if !etag.is_empty() {
4349                            etags.push(etag);
4350                        }
4351                        current.clear();
4352                    }
4353                    in_quote = !in_quote;
4354                }
4355                ',' if !in_quote => {
4356                    // ETag separator, already handled by quote closing
4357                    current.clear();
4358                }
4359                _ => {
4360                    current.push(ch);
4361                }
4362            }
4363            prev_char = ch;
4364        }
4365
4366        etags
4367    }
4368
4369    /// Check if two ETags match according to weak comparison rules.
4370    ///
4371    /// Weak comparison (for If-None-Match with GET/HEAD):
4372    /// - W/"a" matches W/"a"
4373    /// - W/"a" matches "a"
4374    /// - "a" matches W/"a"
4375    /// - "a" matches "a"
4376    fn etags_match_weak(etag1: &str, etag2: &str) -> bool {
4377        // Strip W/ prefix for weak comparison
4378        let e1 = Self::strip_weak_prefix(etag1);
4379        let e2 = Self::strip_weak_prefix(etag2);
4380        e1 == e2
4381    }
4382
4383    /// Strip the weak ETag prefix (W/) if present.
4384    fn strip_weak_prefix(s: &str) -> &str {
4385        if s.starts_with("W/") || s.starts_with("w/") {
4386            &s[2..]
4387        } else {
4388            s
4389        }
4390    }
4391
4392    /// Check if request method is cacheable (GET or HEAD).
4393    fn is_cacheable_method(method: crate::request::Method) -> bool {
4394        matches!(
4395            method,
4396            crate::request::Method::Get | crate::request::Method::Head
4397        )
4398    }
4399
4400    /// Get existing ETag from response headers.
4401    fn get_existing_etag(headers: &[(String, Vec<u8>)]) -> Option<String> {
4402        for (name, value) in headers {
4403            if name.eq_ignore_ascii_case("etag") {
4404                return std::str::from_utf8(value).ok().map(String::from);
4405            }
4406        }
4407        None
4408    }
4409}
4410
4411impl Middleware for ETagMiddleware {
4412    fn after<'a>(
4413        &'a self,
4414        _ctx: &'a RequestContext,
4415        req: &'a Request,
4416        response: Response,
4417    ) -> BoxFuture<'a, Response> {
4418        let config = self.config.clone();
4419
4420        Box::pin(async move {
4421            // Skip if disabled
4422            if config.mode == ETagMode::Disabled {
4423                return response;
4424            }
4425
4426            // Only handle cacheable methods
4427            if !Self::is_cacheable_method(req.method()) {
4428                return response;
4429            }
4430
4431            // Decompose response to work with parts
4432            let (status, headers, body) = response.into_parts();
4433
4434            // Check for existing ETag (for Manual mode or pre-set ETags)
4435            let existing_etag = Self::get_existing_etag(&headers);
4436
4437            // Get body bytes if available
4438            let body_bytes = match &body {
4439                crate::response::ResponseBody::Bytes(bytes) => Some(bytes.clone()),
4440                crate::response::ResponseBody::Empty => Some(Vec::new()),
4441                crate::response::ResponseBody::Stream(_) => None,
4442            };
4443
4444            // Determine the ETag to use
4445            let etag = if let Some(existing) = existing_etag {
4446                Some(existing)
4447            } else if config.mode == ETagMode::Auto {
4448                if let Some(ref bytes) = body_bytes {
4449                    if bytes.len() >= config.min_size {
4450                        Some(Self::generate_etag(bytes, config.weak))
4451                    } else {
4452                        None
4453                    }
4454                } else {
4455                    None
4456                }
4457            } else {
4458                None
4459            };
4460
4461            // Check If-None-Match header
4462            if let Some(ref etag_value) = etag {
4463                if let Some(if_none_match) = req.headers().get("if-none-match") {
4464                    if let Ok(value) = std::str::from_utf8(if_none_match) {
4465                        let client_etags = Self::parse_if_none_match(value);
4466
4467                        // Check for wildcard or matching ETag
4468                        let matches = client_etags.iter().any(|client_etag| {
4469                            client_etag == "*" || Self::etags_match_weak(client_etag, etag_value)
4470                        });
4471
4472                        if matches {
4473                            // Return 304 Not Modified with ETag header
4474                            return Response::with_status(
4475                                crate::response::StatusCode::NOT_MODIFIED,
4476                            )
4477                            .header("etag", etag_value.as_bytes().to_vec());
4478                        }
4479                    }
4480                }
4481            }
4482
4483            // Rebuild response with ETag header if we have one
4484            let mut new_response = Response::with_status(status)
4485                .body(body)
4486                .rebuild_with_headers(headers);
4487
4488            if let Some(etag_value) = etag {
4489                new_response = new_response.header("etag", etag_value.into_bytes());
4490            }
4491
4492            new_response
4493        })
4494    }
4495
4496    fn name(&self) -> &'static str {
4497        "ETagMiddleware"
4498    }
4499}
4500
4501// ===========================================================================
4502// HTTP Cache Control Middleware
4503// ===========================================================================
4504
4505/// Individual Cache-Control directives.
4506///
4507/// These directives control how responses are cached by browsers, proxies,
4508/// and CDNs. See RFC 7234 for full specification.
4509#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4510pub enum CacheDirective {
4511    /// Response may be stored by any cache.
4512    Public,
4513    /// Response may only be stored by browser cache (not shared caches like CDNs).
4514    Private,
4515    /// Response must not be stored by any cache.
4516    NoStore,
4517    /// Cache must validate with server before using cached response.
4518    NoCache,
4519    /// Cache must not transform the response (e.g., compress images).
4520    NoTransform,
4521    /// Cached response must be revalidated once it becomes stale.
4522    MustRevalidate,
4523    /// Like must-revalidate but only for shared caches.
4524    ProxyRevalidate,
4525    /// Response may be served stale if origin is unreachable.
4526    StaleIfError,
4527    /// Response may be served stale while revalidating in background.
4528    StaleWhileRevalidate,
4529    /// Only cache if explicitly told to (for shared caches).
4530    SMaxAge,
4531    /// Do not store response in persistent storage.
4532    OnlyIfCached,
4533    /// Indicates an immutable response that won't change during its freshness lifetime.
4534    Immutable,
4535}
4536
4537impl CacheDirective {
4538    /// Returns the directive as a Cache-Control header string fragment.
4539    fn as_str(self) -> &'static str {
4540        match self {
4541            Self::Public => "public",
4542            Self::Private => "private",
4543            Self::NoStore => "no-store",
4544            Self::NoCache => "no-cache",
4545            Self::NoTransform => "no-transform",
4546            Self::MustRevalidate => "must-revalidate",
4547            Self::ProxyRevalidate => "proxy-revalidate",
4548            Self::StaleIfError => "stale-if-error",
4549            Self::StaleWhileRevalidate => "stale-while-revalidate",
4550            Self::SMaxAge => "s-maxage",
4551            Self::OnlyIfCached => "only-if-cached",
4552            Self::Immutable => "immutable",
4553        }
4554    }
4555}
4556
4557/// Builder for constructing Cache-Control header values.
4558///
4559/// Provides a fluent API for building complex cache control policies.
4560///
4561/// # Example
4562///
4563/// ```ignore
4564/// use fastapi_core::middleware::CacheControlBuilder;
4565///
4566/// // Public, cacheable for 1 hour, must revalidate after
4567/// let cache = CacheControlBuilder::new()
4568///     .public()
4569///     .max_age_secs(3600)
4570///     .must_revalidate()
4571///     .build();
4572///
4573/// // Private, no caching
4574/// let no_cache = CacheControlBuilder::new()
4575///     .private()
4576///     .no_store()
4577///     .build();
4578///
4579/// // CDN-friendly: public with different browser/CDN TTLs
4580/// let cdn = CacheControlBuilder::new()
4581///     .public()
4582///     .max_age_secs(60)        // Browser caches for 1 minute
4583///     .s_maxage_secs(3600)     // CDN caches for 1 hour
4584///     .build();
4585/// ```
4586#[derive(Debug, Clone, Default)]
4587pub struct CacheControlBuilder {
4588    directives: Vec<CacheDirective>,
4589    max_age: Option<u32>,
4590    s_maxage: Option<u32>,
4591    stale_while_revalidate: Option<u32>,
4592    stale_if_error: Option<u32>,
4593}
4594
4595impl CacheControlBuilder {
4596    /// Create a new empty Cache-Control builder.
4597    #[must_use]
4598    pub fn new() -> Self {
4599        Self::default()
4600    }
4601
4602    /// Add the `public` directive - response may be cached by any cache.
4603    #[must_use]
4604    pub fn public(mut self) -> Self {
4605        self.directives.push(CacheDirective::Public);
4606        self
4607    }
4608
4609    /// Add the `private` directive - response may only be cached by browser.
4610    #[must_use]
4611    pub fn private(mut self) -> Self {
4612        self.directives.push(CacheDirective::Private);
4613        self
4614    }
4615
4616    /// Add the `no-store` directive - response must not be cached.
4617    #[must_use]
4618    pub fn no_store(mut self) -> Self {
4619        self.directives.push(CacheDirective::NoStore);
4620        self
4621    }
4622
4623    /// Add the `no-cache` directive - must revalidate before using cache.
4624    #[must_use]
4625    pub fn no_cache(mut self) -> Self {
4626        self.directives.push(CacheDirective::NoCache);
4627        self
4628    }
4629
4630    /// Add the `no-transform` directive - caches must not modify response.
4631    #[must_use]
4632    pub fn no_transform(mut self) -> Self {
4633        self.directives.push(CacheDirective::NoTransform);
4634        self
4635    }
4636
4637    /// Add the `must-revalidate` directive - cache must check origin when stale.
4638    #[must_use]
4639    pub fn must_revalidate(mut self) -> Self {
4640        self.directives.push(CacheDirective::MustRevalidate);
4641        self
4642    }
4643
4644    /// Add the `proxy-revalidate` directive - shared caches must check origin when stale.
4645    #[must_use]
4646    pub fn proxy_revalidate(mut self) -> Self {
4647        self.directives.push(CacheDirective::ProxyRevalidate);
4648        self
4649    }
4650
4651    /// Add the `immutable` directive - response won't change during freshness lifetime.
4652    #[must_use]
4653    pub fn immutable(mut self) -> Self {
4654        self.directives.push(CacheDirective::Immutable);
4655        self
4656    }
4657
4658    /// Set `max-age` directive - maximum time response is fresh (in seconds).
4659    #[must_use]
4660    pub fn max_age_secs(mut self, seconds: u32) -> Self {
4661        self.max_age = Some(seconds);
4662        self
4663    }
4664
4665    /// Set `max-age` directive from a Duration.
4666    #[must_use]
4667    pub fn max_age(self, duration: std::time::Duration) -> Self {
4668        self.max_age_secs(duration.as_secs() as u32)
4669    }
4670
4671    /// Set `s-maxage` directive - maximum time for shared caches (in seconds).
4672    #[must_use]
4673    pub fn s_maxage_secs(mut self, seconds: u32) -> Self {
4674        self.s_maxage = Some(seconds);
4675        self
4676    }
4677
4678    /// Set `s-maxage` directive from a Duration.
4679    #[must_use]
4680    pub fn s_maxage(self, duration: std::time::Duration) -> Self {
4681        self.s_maxage_secs(duration.as_secs() as u32)
4682    }
4683
4684    /// Set `stale-while-revalidate` directive - serve stale while revalidating (in seconds).
4685    #[must_use]
4686    pub fn stale_while_revalidate_secs(mut self, seconds: u32) -> Self {
4687        self.stale_while_revalidate = Some(seconds);
4688        self
4689    }
4690
4691    /// Set `stale-if-error` directive - serve stale if origin errors (in seconds).
4692    #[must_use]
4693    pub fn stale_if_error_secs(mut self, seconds: u32) -> Self {
4694        self.stale_if_error = Some(seconds);
4695        self
4696    }
4697
4698    /// Build the Cache-Control header value string.
4699    #[must_use]
4700    pub fn build(&self) -> String {
4701        let mut parts = Vec::new();
4702
4703        // Add directives
4704        for directive in &self.directives {
4705            parts.push(directive.as_str().to_string());
4706        }
4707
4708        // Add max-age
4709        if let Some(age) = self.max_age {
4710            parts.push(format!("max-age={age}"));
4711        }
4712
4713        // Add s-maxage
4714        if let Some(age) = self.s_maxage {
4715            parts.push(format!("s-maxage={age}"));
4716        }
4717
4718        // Add stale-while-revalidate
4719        if let Some(seconds) = self.stale_while_revalidate {
4720            parts.push(format!("stale-while-revalidate={seconds}"));
4721        }
4722
4723        // Add stale-if-error
4724        if let Some(seconds) = self.stale_if_error {
4725            parts.push(format!("stale-if-error={seconds}"));
4726        }
4727
4728        parts.join(", ")
4729    }
4730
4731    /// Check if this represents a no-cache policy.
4732    #[must_use]
4733    pub fn is_no_cache(&self) -> bool {
4734        self.directives.contains(&CacheDirective::NoStore)
4735            || self.directives.contains(&CacheDirective::NoCache)
4736    }
4737}
4738
4739/// Common cache control presets for typical use cases.
4740#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4741pub enum CachePreset {
4742    /// No caching: `no-store, no-cache, must-revalidate`
4743    NoCache,
4744    /// Private caching only: `private, max-age=0, must-revalidate`
4745    PrivateNoCache,
4746    /// Standard public caching: `public, max-age=3600`
4747    PublicOneHour,
4748    /// Long-term immutable: `public, max-age=31536000, immutable`
4749    Immutable,
4750    /// CDN-friendly with short browser TTL: `public, max-age=60, s-maxage=3600`
4751    CdnFriendly,
4752    /// Static assets: `public, max-age=86400`
4753    StaticAssets,
4754}
4755
4756impl CachePreset {
4757    /// Convert preset to Cache-Control header value.
4758    #[must_use]
4759    pub fn to_header_value(&self) -> String {
4760        match self {
4761            Self::NoCache => "no-store, no-cache, must-revalidate".to_string(),
4762            Self::PrivateNoCache => "private, max-age=0, must-revalidate".to_string(),
4763            Self::PublicOneHour => "public, max-age=3600".to_string(),
4764            Self::Immutable => "public, max-age=31536000, immutable".to_string(),
4765            Self::CdnFriendly => "public, max-age=60, s-maxage=3600".to_string(),
4766            Self::StaticAssets => "public, max-age=86400".to_string(),
4767        }
4768    }
4769
4770    /// Convert preset to a CacheControlBuilder for further customization.
4771    #[must_use]
4772    pub fn to_builder(&self) -> CacheControlBuilder {
4773        match self {
4774            Self::NoCache => CacheControlBuilder::new()
4775                .no_store()
4776                .no_cache()
4777                .must_revalidate(),
4778            Self::PrivateNoCache => CacheControlBuilder::new()
4779                .private()
4780                .max_age_secs(0)
4781                .must_revalidate(),
4782            Self::PublicOneHour => CacheControlBuilder::new().public().max_age_secs(3600),
4783            Self::Immutable => CacheControlBuilder::new()
4784                .public()
4785                .max_age_secs(31536000)
4786                .immutable(),
4787            Self::CdnFriendly => CacheControlBuilder::new()
4788                .public()
4789                .max_age_secs(60)
4790                .s_maxage_secs(3600),
4791            Self::StaticAssets => CacheControlBuilder::new().public().max_age_secs(86400),
4792        }
4793    }
4794}
4795
4796/// Configuration for the Cache Control middleware.
4797#[derive(Debug, Clone)]
4798pub struct CacheControlConfig {
4799    /// The Cache-Control header value to set.
4800    pub cache_control: String,
4801    /// Optional Vary header values for content negotiation.
4802    pub vary: Vec<String>,
4803    /// Whether to set Expires header (deprecated but still used).
4804    pub set_expires: bool,
4805    /// Whether to preserve existing Cache-Control headers.
4806    pub preserve_existing: bool,
4807    /// HTTP methods to apply caching to (default: GET, HEAD).
4808    pub methods: Vec<crate::request::Method>,
4809    /// Path patterns to match (empty = match all).
4810    pub path_patterns: Vec<String>,
4811    /// Status codes to cache (default: 200-299).
4812    pub cacheable_statuses: Vec<u16>,
4813}
4814
4815impl Default for CacheControlConfig {
4816    fn default() -> Self {
4817        Self {
4818            cache_control: CachePreset::NoCache.to_header_value(),
4819            vary: Vec::new(),
4820            set_expires: false,
4821            preserve_existing: true,
4822            methods: vec![crate::request::Method::Get, crate::request::Method::Head],
4823            path_patterns: Vec::new(),
4824            cacheable_statuses: (200..300).collect(),
4825        }
4826    }
4827}
4828
4829impl CacheControlConfig {
4830    /// Create a new configuration with the default no-cache policy.
4831    #[must_use]
4832    pub fn new() -> Self {
4833        Self::default()
4834    }
4835
4836    /// Create configuration from a preset.
4837    #[must_use]
4838    pub fn from_preset(preset: CachePreset) -> Self {
4839        Self {
4840            cache_control: preset.to_header_value(),
4841            ..Self::default()
4842        }
4843    }
4844
4845    /// Create configuration from a custom builder.
4846    #[must_use]
4847    pub fn from_builder(builder: CacheControlBuilder) -> Self {
4848        Self {
4849            cache_control: builder.build(),
4850            ..Self::default()
4851        }
4852    }
4853
4854    /// Set the Cache-Control header value.
4855    #[must_use]
4856    pub fn cache_control(mut self, value: impl Into<String>) -> Self {
4857        self.cache_control = value.into();
4858        self
4859    }
4860
4861    /// Add a Vary header value (for content negotiation).
4862    #[must_use]
4863    pub fn vary(mut self, header: impl Into<String>) -> Self {
4864        self.vary.push(header.into());
4865        self
4866    }
4867
4868    /// Add multiple Vary header values.
4869    #[must_use]
4870    pub fn vary_headers(mut self, headers: Vec<String>) -> Self {
4871        self.vary.extend(headers);
4872        self
4873    }
4874
4875    /// Enable setting the Expires header.
4876    #[must_use]
4877    pub fn with_expires(mut self, enable: bool) -> Self {
4878        self.set_expires = enable;
4879        self
4880    }
4881
4882    /// Whether to preserve existing Cache-Control headers.
4883    #[must_use]
4884    pub fn preserve_existing(mut self, preserve: bool) -> Self {
4885        self.preserve_existing = preserve;
4886        self
4887    }
4888
4889    /// Set the HTTP methods to apply caching to.
4890    #[must_use]
4891    pub fn methods(mut self, methods: Vec<crate::request::Method>) -> Self {
4892        self.methods = methods;
4893        self
4894    }
4895
4896    /// Set path patterns to match (glob-style).
4897    #[must_use]
4898    pub fn path_patterns(mut self, patterns: Vec<String>) -> Self {
4899        self.path_patterns = patterns;
4900        self
4901    }
4902
4903    /// Set cacheable status codes.
4904    #[must_use]
4905    pub fn cacheable_statuses(mut self, statuses: Vec<u16>) -> Self {
4906        self.cacheable_statuses = statuses;
4907        self
4908    }
4909}
4910
4911/// Middleware for setting HTTP cache control headers.
4912///
4913/// This middleware adds Cache-Control, Vary, and optionally Expires headers
4914/// to responses. It supports various caching strategies from no-cache to
4915/// aggressive caching for static assets.
4916///
4917/// # Features
4918///
4919/// - **Cache-Control directives**: Full support for RFC 7234 directives
4920/// - **Vary header**: Content negotiation support for Accept-Encoding, Accept-Language, etc.
4921/// - **Expires header**: Optional legacy header support
4922/// - **Per-route configuration**: Apply different policies via middleware stacks
4923/// - **Method filtering**: Only cache GET/HEAD by default
4924/// - **Status filtering**: Only cache successful responses
4925///
4926/// # Example
4927///
4928/// ```ignore
4929/// use fastapi_core::middleware::{CacheControlMiddleware, CacheControlConfig, CachePreset};
4930///
4931/// // No caching for API responses (default)
4932/// let api_cache = CacheControlMiddleware::new();
4933///
4934/// // Public caching for static assets
4935/// let static_cache = CacheControlMiddleware::with_preset(CachePreset::StaticAssets);
4936///
4937/// // Custom caching with Vary header
4938/// let custom_cache = CacheControlMiddleware::with_config(
4939///     CacheControlConfig::from_preset(CachePreset::PublicOneHour)
4940///         .vary("Accept-Encoding")
4941///         .vary("Accept-Language")
4942///         .with_expires(true)
4943/// );
4944///
4945/// // CDN-friendly caching
4946/// let cdn_cache = CacheControlMiddleware::with_preset(CachePreset::CdnFriendly);
4947/// ```
4948///
4949/// # Response Headers Set
4950///
4951/// | Header | Description |
4952/// |--------|-------------|
4953/// | `Cache-Control` | Main caching directive |
4954/// | `Vary` | Headers that affect caching |
4955/// | `Expires` | Legacy expiration (if enabled) |
4956///
4957pub struct CacheControlMiddleware {
4958    config: CacheControlConfig,
4959}
4960
4961impl Default for CacheControlMiddleware {
4962    fn default() -> Self {
4963        Self::new()
4964    }
4965}
4966
4967impl CacheControlMiddleware {
4968    /// Create middleware with default no-cache policy.
4969    ///
4970    /// This is the safest default - no caching unless explicitly configured.
4971    #[must_use]
4972    pub fn new() -> Self {
4973        Self {
4974            config: CacheControlConfig::default(),
4975        }
4976    }
4977
4978    /// Create middleware with a preset caching policy.
4979    #[must_use]
4980    pub fn with_preset(preset: CachePreset) -> Self {
4981        Self {
4982            config: CacheControlConfig::from_preset(preset),
4983        }
4984    }
4985
4986    /// Create middleware with custom configuration.
4987    #[must_use]
4988    pub fn with_config(config: CacheControlConfig) -> Self {
4989        Self { config }
4990    }
4991
4992    /// Check if the request method is cacheable.
4993    fn is_cacheable_method(&self, method: crate::request::Method) -> bool {
4994        self.config.methods.contains(&method)
4995    }
4996
4997    /// Check if the response status is cacheable.
4998    fn is_cacheable_status(&self, status: u16) -> bool {
4999        self.config.cacheable_statuses.contains(&status)
5000    }
5001
5002    /// Check if the path matches any configured patterns.
5003    fn matches_path(&self, path: &str) -> bool {
5004        if self.config.path_patterns.is_empty() {
5005            return true; // Match all if no patterns configured
5006        }
5007
5008        for pattern in &self.config.path_patterns {
5009            if path_matches_pattern(path, pattern) {
5010                return true;
5011            }
5012        }
5013        false
5014    }
5015
5016    /// Check if response already has a Cache-Control header.
5017    fn has_cache_control(headers: &[(String, Vec<u8>)]) -> bool {
5018        headers
5019            .iter()
5020            .any(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
5021    }
5022
5023    /// Calculate Expires date from max-age value.
5024    fn calculate_expires(cache_control: &str) -> Option<String> {
5025        // Extract max-age value if present
5026        for directive in cache_control.split(',') {
5027            let directive = directive.trim();
5028            if directive.starts_with("max-age=") {
5029                if let Ok(seconds) = directive[8..].parse::<u64>() {
5030                    // Calculate expiration time
5031                    let now = std::time::SystemTime::now();
5032                    if let Some(expires) = now.checked_add(std::time::Duration::from_secs(seconds))
5033                    {
5034                        return Some(format_http_date(expires));
5035                    }
5036                }
5037            }
5038        }
5039        None
5040    }
5041}
5042
5043/// Simple path pattern matching (supports * wildcard).
5044fn path_matches_pattern(path: &str, pattern: &str) -> bool {
5045    if pattern == "*" {
5046        return true;
5047    }
5048
5049    if pattern.contains('*') {
5050        // Simple wildcard matching
5051        let parts: Vec<&str> = pattern.split('*').collect();
5052        if parts.len() == 2 {
5053            let (prefix, suffix) = (parts[0], parts[1]);
5054            return path.starts_with(prefix) && path.ends_with(suffix);
5055        }
5056        // For more complex patterns, do a simple contains check
5057        let fixed_parts: Vec<&str> = pattern.split('*').filter(|s| !s.is_empty()).collect();
5058        let mut remaining = path;
5059        for part in fixed_parts {
5060            if let Some(pos) = remaining.find(part) {
5061                remaining = &remaining[pos + part.len()..];
5062            } else {
5063                return false;
5064            }
5065        }
5066        true
5067    } else {
5068        path == pattern
5069    }
5070}
5071
5072/// Format a SystemTime as an HTTP date (RFC 7231).
5073fn format_http_date(time: std::time::SystemTime) -> String {
5074    // Use UNIX_EPOCH to calculate duration
5075    match time.duration_since(std::time::UNIX_EPOCH) {
5076        Ok(duration) => {
5077            // Calculate date components
5078            let secs = duration.as_secs();
5079            // Days since epoch
5080            let days = secs / 86400;
5081            let remaining_secs = secs % 86400;
5082            let hours = remaining_secs / 3600;
5083            let minutes = (remaining_secs % 3600) / 60;
5084            let seconds = remaining_secs % 60;
5085
5086            // Calculate day of week (Jan 1, 1970 was Thursday = 4)
5087            let day_of_week = ((days + 4) % 7) as usize;
5088            let day_names = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
5089
5090            // Calculate date (simplified - doesn't account for leap years perfectly but good enough)
5091            let (year, month, day) = days_to_date(days);
5092            let month_names = [
5093                "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
5094            ];
5095
5096            format!(
5097                "{}, {:02} {} {} {:02}:{:02}:{:02} GMT",
5098                day_names[day_of_week],
5099                day,
5100                month_names[(month - 1) as usize],
5101                year,
5102                hours,
5103                minutes,
5104                seconds
5105            )
5106        }
5107        Err(_) => "Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
5108    }
5109}
5110
5111/// Convert days since UNIX epoch to (year, month, day).
5112fn days_to_date(days: u64) -> (u64, u64, u64) {
5113    // Simplified algorithm - works for dates 1970-2099
5114    let mut remaining_days = days;
5115    let mut year = 1970u64;
5116
5117    loop {
5118        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
5119        if remaining_days < days_in_year {
5120            break;
5121        }
5122        remaining_days -= days_in_year;
5123        year += 1;
5124    }
5125
5126    let leap = is_leap_year(year);
5127    let month_days: [u64; 12] = if leap {
5128        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5129    } else {
5130        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
5131    };
5132
5133    let mut month = 1u64;
5134    for &days_in_month in &month_days {
5135        if remaining_days < days_in_month {
5136            break;
5137        }
5138        remaining_days -= days_in_month;
5139        month += 1;
5140    }
5141
5142    (year, month, remaining_days + 1)
5143}
5144
5145/// Check if a year is a leap year.
5146fn is_leap_year(year: u64) -> bool {
5147    (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
5148}
5149
5150impl Middleware for CacheControlMiddleware {
5151    fn after<'a>(
5152        &'a self,
5153        _ctx: &'a RequestContext,
5154        req: &'a Request,
5155        response: Response,
5156    ) -> BoxFuture<'a, Response> {
5157        let config = self.config.clone();
5158
5159        Box::pin(async move {
5160            // Check if this request/response is cacheable
5161            if !self.is_cacheable_method(req.method()) {
5162                return response;
5163            }
5164
5165            if !self.is_cacheable_status(response.status().as_u16()) {
5166                return response;
5167            }
5168
5169            if !self.matches_path(req.path()) {
5170                return response;
5171            }
5172
5173            // Decompose response to modify headers
5174            let (status, mut headers, body) = response.into_parts();
5175
5176            // Check for existing Cache-Control header
5177            if config.preserve_existing && Self::has_cache_control(&headers) {
5178                // Reconstruct and return unchanged
5179                let mut resp = Response::with_status(status);
5180                for (name, value) in headers {
5181                    resp = resp.header(name, value);
5182                }
5183                return resp.body(body);
5184            }
5185
5186            // Add Cache-Control header
5187            headers.push((
5188                "Cache-Control".to_string(),
5189                config.cache_control.as_bytes().to_vec(),
5190            ));
5191
5192            // Add Vary header if configured
5193            if !config.vary.is_empty() {
5194                let vary_value = config.vary.join(", ");
5195                headers.push(("Vary".to_string(), vary_value.into_bytes()));
5196            }
5197
5198            // Add Expires header if configured
5199            if config.set_expires {
5200                if let Some(expires) = Self::calculate_expires(&config.cache_control) {
5201                    headers.push(("Expires".to_string(), expires.into_bytes()));
5202                }
5203            }
5204
5205            // Reconstruct response
5206            let mut resp = Response::with_status(status);
5207            for (name, value) in headers {
5208                resp = resp.header(name, value);
5209            }
5210            resp.body(body)
5211        })
5212    }
5213
5214    fn name(&self) -> &'static str {
5215        "CacheControlMiddleware"
5216    }
5217}
5218
5219// ===========================================================================
5220// End Cache Control Middleware
5221// ===========================================================================
5222
5223// ===========================================================================
5224// TRACE Method Rejection Middleware (Security)
5225// ===========================================================================
5226
5227/// Middleware that rejects HTTP TRACE requests to prevent Cross-Site Tracing (XST) attacks.
5228///
5229/// The HTTP TRACE method echoes the request back to the client, which can be exploited
5230/// in XSS attacks to steal sensitive headers like Authorization or cookies.
5231///
5232/// # Security Rationale
5233///
5234/// - TRACE can expose Authorization headers via XSS attacks
5235/// - No legitimate use case in modern APIs
5236/// - OWASP recommends disabling TRACE
5237///
5238/// # Example
5239///
5240/// ```ignore
5241/// use fastapi_core::middleware::TraceRejectionMiddleware;
5242///
5243/// let app = App::builder()
5244///     .middleware(TraceRejectionMiddleware::new())
5245///     .build();
5246/// ```
5247///
5248/// # Behavior
5249///
5250/// - Returns 405 Method Not Allowed for all TRACE requests
5251/// - Logs TRACE attempts as security events (when log_attempts is true)
5252/// - Cannot be disabled per-route (intentionally)
5253#[derive(Debug, Clone)]
5254pub struct TraceRejectionMiddleware {
5255    /// Whether to log TRACE attempts as security events.
5256    log_attempts: bool,
5257}
5258
5259impl Default for TraceRejectionMiddleware {
5260    fn default() -> Self {
5261        Self::new()
5262    }
5263}
5264
5265impl TraceRejectionMiddleware {
5266    /// Create a new TRACE rejection middleware with default settings.
5267    ///
5268    /// By default, logging of TRACE attempts is enabled.
5269    #[must_use]
5270    pub fn new() -> Self {
5271        Self { log_attempts: true }
5272    }
5273
5274    /// Configure whether to log TRACE attempts.
5275    ///
5276    /// When enabled, each TRACE request is logged as a security event
5277    /// including the remote IP (if available) and request path.
5278    #[must_use]
5279    pub fn log_attempts(mut self, log: bool) -> Self {
5280        self.log_attempts = log;
5281        self
5282    }
5283
5284    /// Create a response for rejected TRACE requests.
5285    fn rejection_response(path: &str) -> Response {
5286        let body = format!(
5287            r#"{{"detail":"HTTP TRACE method is not allowed","path":"{}"}}"#,
5288            path.replace('"', "\\\"")
5289        );
5290        Response::with_status(crate::response::StatusCode::METHOD_NOT_ALLOWED)
5291            .header("Content-Type", b"application/json".to_vec())
5292            .header(
5293                "Allow",
5294                b"GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD".to_vec(),
5295            )
5296            .body(crate::response::ResponseBody::Bytes(body.into_bytes()))
5297    }
5298}
5299
5300impl Middleware for TraceRejectionMiddleware {
5301    fn before<'a>(
5302        &'a self,
5303        _ctx: &'a RequestContext,
5304        req: &'a mut Request,
5305    ) -> BoxFuture<'a, ControlFlow> {
5306        Box::pin(async move {
5307            if req.method() == crate::request::Method::Trace {
5308                if self.log_attempts {
5309                    // Log as security event
5310                    let path = req.path();
5311                    let remote_ip = req
5312                        .headers()
5313                        .get("X-Forwarded-For")
5314                        .or_else(|| req.headers().get("X-Real-IP"))
5315                        .map(|v| String::from_utf8_lossy(v).to_string())
5316                        .unwrap_or_else(|| "unknown".to_string());
5317
5318                    eprintln!(
5319                        "[SECURITY] TRACE request blocked: path={}, remote_ip={}",
5320                        path, remote_ip
5321                    );
5322                }
5323
5324                return ControlFlow::Break(Self::rejection_response(req.path()));
5325            }
5326
5327            ControlFlow::Continue
5328        })
5329    }
5330
5331    fn name(&self) -> &'static str {
5332        "TraceRejection"
5333    }
5334}
5335
5336// ===========================================================================
5337// End TRACE Rejection Middleware
5338// ===========================================================================
5339
5340// ===========================================================================
5341// HTTPS Redirect and HSTS Middleware (Security)
5342// ===========================================================================
5343
5344/// Configuration for HTTPS redirect behavior.
5345#[derive(Debug, Clone)]
5346#[allow(clippy::struct_excessive_bools)]
5347pub struct HttpsRedirectConfig {
5348    /// Enable HTTP to HTTPS redirects.
5349    pub redirect_enabled: bool,
5350    /// Use permanent (301) or temporary (307) redirects.
5351    pub permanent_redirect: bool,
5352    /// HSTS max-age in seconds (0 = disabled).
5353    pub hsts_max_age_secs: u64,
5354    /// Include subdomains in HSTS.
5355    pub hsts_include_subdomains: bool,
5356    /// Enable HSTS preload.
5357    pub hsts_preload: bool,
5358    /// Paths to exclude from redirect (e.g., health checks).
5359    pub exclude_paths: Vec<String>,
5360    /// Port for HTTPS (default 443).
5361    pub https_port: u16,
5362}
5363
5364impl Default for HttpsRedirectConfig {
5365    fn default() -> Self {
5366        Self {
5367            redirect_enabled: true,
5368            permanent_redirect: true,      // 301
5369            hsts_max_age_secs: 31_536_000, // 1 year
5370            hsts_include_subdomains: false,
5371            hsts_preload: false,
5372            exclude_paths: Vec::new(),
5373            https_port: 443,
5374        }
5375    }
5376}
5377
5378/// Middleware that redirects HTTP requests to HTTPS and sets HSTS headers.
5379///
5380/// This middleware provides two critical security features:
5381///
5382/// 1. **HTTP to HTTPS Redirect**: Automatically redirects insecure HTTP requests
5383///    to their HTTPS equivalents, ensuring all traffic is encrypted.
5384///
5385/// 2. **HSTS (Strict Transport Security)**: Adds the `Strict-Transport-Security`
5386///    header to HTTPS responses, instructing browsers to always use HTTPS.
5387///
5388/// # Proxy Awareness
5389///
5390/// The middleware respects the `X-Forwarded-Proto` header, so it works correctly
5391/// behind reverse proxies like nginx or HAProxy. If the proxy sets this header
5392/// to "https", the request is treated as secure.
5393///
5394/// # Example
5395///
5396/// ```ignore
5397/// use fastapi_core::middleware::HttpsRedirectMiddleware;
5398///
5399/// let app = App::builder()
5400///     .middleware(HttpsRedirectMiddleware::new()
5401///         .hsts_max_age_secs(31536000)  // 1 year
5402///         .include_subdomains(true)
5403///         .preload(true)
5404///         .exclude_path("/health")
5405///         .exclude_path("/readiness"))
5406///     .build();
5407/// ```
5408///
5409/// # Configuration Options
5410///
5411/// - `redirect_enabled`: Enable/disable redirects (default: true)
5412/// - `permanent_redirect`: Use 301 (true) or 307 (false) redirects
5413/// - `hsts_max_age_secs`: HSTS max-age value in seconds
5414/// - `include_subdomains`: Apply HSTS to all subdomains
5415/// - `preload`: Mark site for HSTS preload list
5416/// - `exclude_path`: Paths that should remain accessible over HTTP
5417#[derive(Debug, Clone)]
5418pub struct HttpsRedirectMiddleware {
5419    config: HttpsRedirectConfig,
5420}
5421
5422impl Default for HttpsRedirectMiddleware {
5423    fn default() -> Self {
5424        Self::new()
5425    }
5426}
5427
5428impl HttpsRedirectMiddleware {
5429    /// Create a new HTTPS redirect middleware with default settings.
5430    #[must_use]
5431    pub fn new() -> Self {
5432        Self {
5433            config: HttpsRedirectConfig::default(),
5434        }
5435    }
5436
5437    /// Enable or disable HTTP to HTTPS redirects.
5438    #[must_use]
5439    pub fn redirect_enabled(mut self, enabled: bool) -> Self {
5440        self.config.redirect_enabled = enabled;
5441        self
5442    }
5443
5444    /// Use permanent (301) redirects instead of temporary (307).
5445    ///
5446    /// Default is true (permanent redirects).
5447    #[must_use]
5448    pub fn permanent_redirect(mut self, permanent: bool) -> Self {
5449        self.config.permanent_redirect = permanent;
5450        self
5451    }
5452
5453    /// Set the HSTS max-age in seconds.
5454    ///
5455    /// Set to 0 to disable HSTS header.
5456    /// Default is 31536000 (1 year).
5457    #[must_use]
5458    pub fn hsts_max_age_secs(mut self, secs: u64) -> Self {
5459        self.config.hsts_max_age_secs = secs;
5460        self
5461    }
5462
5463    /// Include subdomains in HSTS policy.
5464    #[must_use]
5465    pub fn include_subdomains(mut self, include: bool) -> Self {
5466        self.config.hsts_include_subdomains = include;
5467        self
5468    }
5469
5470    /// Enable HSTS preload.
5471    ///
5472    /// Only enable this if you're ready to submit your site to the
5473    /// HSTS preload list at hstspreload.org.
5474    #[must_use]
5475    pub fn preload(mut self, preload: bool) -> Self {
5476        self.config.hsts_preload = preload;
5477        self
5478    }
5479
5480    /// Add a path to exclude from redirects.
5481    ///
5482    /// Use this for health check endpoints that need to remain
5483    /// accessible over HTTP for load balancer probes.
5484    #[must_use]
5485    pub fn exclude_path(mut self, path: impl Into<String>) -> Self {
5486        self.config.exclude_paths.push(path.into());
5487        self
5488    }
5489
5490    /// Set multiple excluded paths at once.
5491    #[must_use]
5492    pub fn exclude_paths(mut self, paths: Vec<String>) -> Self {
5493        self.config.exclude_paths = paths;
5494        self
5495    }
5496
5497    /// Set the HTTPS port (default 443).
5498    #[must_use]
5499    pub fn https_port(mut self, port: u16) -> Self {
5500        self.config.https_port = port;
5501        self
5502    }
5503
5504    /// Check if the request is using HTTPS.
5505    ///
5506    /// This checks both the scheme and the X-Forwarded-Proto header
5507    /// for proxy-aware detection.
5508    fn is_secure(&self, req: &Request) -> bool {
5509        fn trim_ascii(mut bytes: &[u8]) -> &[u8] {
5510            while matches!(bytes.first(), Some(b' ' | b'\t')) {
5511                bytes = &bytes[1..];
5512            }
5513            while matches!(bytes.last(), Some(b' ' | b'\t')) {
5514                bytes = &bytes[..bytes.len() - 1];
5515            }
5516            bytes
5517        }
5518
5519        if let Some(info) = req.get_extension::<crate::request::ConnectionInfo>() {
5520            if info.is_tls {
5521                return true;
5522            }
5523        }
5524
5525        // RFC 7239 Forwarded: for=...;proto=https;host=...
5526        if let Some(forwarded) = req.headers().get("Forwarded") {
5527            if let Ok(s) = std::str::from_utf8(forwarded) {
5528                for entry in s.split(',') {
5529                    for param in entry.split(';') {
5530                        let param = param.trim();
5531                        if let Some((k, v)) = param.split_once('=') {
5532                            if k.trim().eq_ignore_ascii_case("proto") {
5533                                let proto = v.trim().trim_matches('"');
5534                                if proto.eq_ignore_ascii_case("https") {
5535                                    return true;
5536                                }
5537                            }
5538                        }
5539                    }
5540                }
5541            }
5542        }
5543
5544        // Check X-Forwarded-Proto header first (for reverse proxy)
5545        if let Some(proto) = req.headers().get("X-Forwarded-Proto") {
5546            let first = proto.split(|&b| b == b',').next().unwrap_or(proto);
5547            return trim_ascii(first).eq_ignore_ascii_case(b"https");
5548        }
5549
5550        // Check X-Forwarded-Ssl header (alternative)
5551        if let Some(ssl) = req.headers().get("X-Forwarded-Ssl") {
5552            return ssl.eq_ignore_ascii_case(b"on");
5553        }
5554
5555        // Check Front-End-Https header (Microsoft IIS)
5556        if let Some(https) = req.headers().get("Front-End-Https") {
5557            return https.eq_ignore_ascii_case(b"on");
5558        }
5559
5560        false
5561    }
5562
5563    /// Check if a path should be excluded from redirects.
5564    fn is_excluded(&self, path: &str) -> bool {
5565        self.config
5566            .exclude_paths
5567            .iter()
5568            .any(|p| path.starts_with(p))
5569    }
5570
5571    /// Build the HSTS header value.
5572    fn build_hsts_header(&self) -> Option<Vec<u8>> {
5573        if self.config.hsts_max_age_secs == 0 {
5574            return None;
5575        }
5576
5577        let mut value = format!("max-age={}", self.config.hsts_max_age_secs);
5578
5579        if self.config.hsts_include_subdomains {
5580            value.push_str("; includeSubDomains");
5581        }
5582
5583        if self.config.hsts_preload {
5584            value.push_str("; preload");
5585        }
5586
5587        Some(value.into_bytes())
5588    }
5589
5590    /// Build the redirect URL.
5591    fn build_redirect_url(&self, req: &Request) -> String {
5592        let host = req
5593            .headers()
5594            .get("Host")
5595            .map(|h| String::from_utf8_lossy(h).to_string())
5596            .unwrap_or_else(|| "localhost".to_string());
5597
5598        // Remove port from host if present
5599        let host_without_port = host.split(':').next().unwrap_or(&host);
5600
5601        let path = req.path();
5602        let query = req.query();
5603
5604        if self.config.https_port == 443 {
5605            match query {
5606                Some(q) => format!("https://{}{}?{}", host_without_port, path, q),
5607                None => format!("https://{}{}", host_without_port, path),
5608            }
5609        } else {
5610            match query {
5611                Some(q) => format!(
5612                    "https://{}:{}{}?{}",
5613                    host_without_port, self.config.https_port, path, q
5614                ),
5615                None => format!(
5616                    "https://{}:{}{}",
5617                    host_without_port, self.config.https_port, path
5618                ),
5619            }
5620        }
5621    }
5622}
5623
5624impl Middleware for HttpsRedirectMiddleware {
5625    fn before<'a>(
5626        &'a self,
5627        _ctx: &'a RequestContext,
5628        req: &'a mut Request,
5629    ) -> BoxFuture<'a, ControlFlow> {
5630        Box::pin(async move {
5631            // Skip if redirects are disabled
5632            if !self.config.redirect_enabled {
5633                return ControlFlow::Continue;
5634            }
5635
5636            // Skip if already HTTPS
5637            if self.is_secure(req) {
5638                return ControlFlow::Continue;
5639            }
5640
5641            // Skip excluded paths (e.g., health checks)
5642            if self.is_excluded(req.path()) {
5643                return ControlFlow::Continue;
5644            }
5645
5646            // Build redirect URL
5647            let redirect_url = self.build_redirect_url(req);
5648
5649            // Choose status code
5650            let status = if self.config.permanent_redirect {
5651                crate::response::StatusCode::MOVED_PERMANENTLY
5652            } else {
5653                crate::response::StatusCode::TEMPORARY_REDIRECT
5654            };
5655
5656            // Create redirect response
5657            let response = Response::with_status(status)
5658                .header("Location", redirect_url.into_bytes())
5659                .header("Content-Type", b"text/plain".to_vec())
5660                .body(crate::response::ResponseBody::Bytes(
5661                    b"Redirecting to HTTPS...".to_vec(),
5662                ));
5663
5664            ControlFlow::Break(response)
5665        })
5666    }
5667
5668    fn after<'a>(
5669        &'a self,
5670        _ctx: &'a RequestContext,
5671        req: &'a Request,
5672        response: Response,
5673    ) -> BoxFuture<'a, Response> {
5674        Box::pin(async move {
5675            // Only add HSTS to secure responses
5676            if !self.is_secure(req) {
5677                return response;
5678            }
5679
5680            // Add HSTS header if configured
5681            if let Some(hsts_value) = self.build_hsts_header() {
5682                response.header("Strict-Transport-Security", hsts_value)
5683            } else {
5684                response
5685            }
5686        })
5687    }
5688
5689    fn name(&self) -> &'static str {
5690        "HttpsRedirect"
5691    }
5692}
5693
5694// ===========================================================================
5695// End HTTPS Redirect Middleware
5696// ===========================================================================
5697
5698// ===========================================================================
5699// Response Interceptors and Transformers
5700// ===========================================================================
5701//
5702// This section provides a simplified abstraction for response-only processing.
5703// Unlike full Middleware, ResponseInterceptor only handles post-handler processing,
5704// making it lighter weight and easier to compose for response transformations.
5705
5706/// A response interceptor that processes responses after handler execution.
5707///
5708/// Unlike the full [`Middleware`] trait, `ResponseInterceptor` only handles
5709/// the post-handler phase, making it simpler to implement for response-only
5710/// processing like:
5711/// - Adding timing headers
5712/// - Transforming response bodies
5713/// - Adding debug information
5714/// - Logging response details
5715///
5716/// # Example
5717///
5718/// ```ignore
5719/// use fastapi_core::middleware::{ResponseInterceptor, ResponseInterceptorContext};
5720///
5721/// struct TimingInterceptor {
5722///     start_time: Instant,
5723/// }
5724///
5725/// impl ResponseInterceptor for TimingInterceptor {
5726///     fn intercept(&self, ctx: &ResponseInterceptorContext, response: Response) -> Response {
5727///         let elapsed = self.start_time.elapsed();
5728///         response.header("X-Response-Time", format!("{}ms", elapsed.as_millis()).into_bytes())
5729///     }
5730/// }
5731/// ```
5732pub trait ResponseInterceptor: Send + Sync {
5733    /// Process a response after the handler has executed.
5734    ///
5735    /// # Parameters
5736    ///
5737    /// - `ctx`: Context containing request information and timing data
5738    /// - `response`: The response from the handler or previous interceptors
5739    ///
5740    /// # Returns
5741    ///
5742    /// The modified response to pass to the next interceptor or return to client.
5743    fn intercept<'a>(
5744        &'a self,
5745        ctx: &'a ResponseInterceptorContext<'a>,
5746        response: Response,
5747    ) -> BoxFuture<'a, Response>;
5748
5749    /// Returns the interceptor name for debugging and logging.
5750    fn name(&self) -> &'static str {
5751        std::any::type_name::<Self>()
5752    }
5753}
5754
5755/// Context provided to response interceptors.
5756///
5757/// Contains information about the original request and timing data
5758/// that interceptors might need to process responses.
5759#[derive(Debug)]
5760pub struct ResponseInterceptorContext<'a> {
5761    /// The original request (read-only).
5762    pub request: &'a Request,
5763    /// When the request processing started.
5764    pub start_time: Instant,
5765    /// The request context for cancellation support.
5766    pub request_ctx: &'a RequestContext,
5767}
5768
5769impl<'a> ResponseInterceptorContext<'a> {
5770    /// Create a new interceptor context.
5771    pub fn new(request: &'a Request, request_ctx: &'a RequestContext, start_time: Instant) -> Self {
5772        Self {
5773            request,
5774            start_time,
5775            request_ctx,
5776        }
5777    }
5778
5779    /// Get the elapsed time since request processing started.
5780    pub fn elapsed(&self) -> std::time::Duration {
5781        self.start_time.elapsed()
5782    }
5783
5784    /// Get the elapsed time in milliseconds.
5785    pub fn elapsed_ms(&self) -> u128 {
5786        self.start_time.elapsed().as_millis()
5787    }
5788}
5789
5790/// A stack of response interceptors that run in order.
5791///
5792/// Interceptors are executed in registration order (first registered, first run).
5793/// Each interceptor receives the response from the previous one and can modify it.
5794///
5795/// # Example
5796///
5797/// ```ignore
5798/// let mut stack = ResponseInterceptorStack::new();
5799/// stack.push(TimingInterceptor);
5800/// stack.push(DebugHeadersInterceptor::new());
5801///
5802/// let response = stack.process(&ctx, response).await;
5803/// ```
5804#[derive(Default)]
5805pub struct ResponseInterceptorStack {
5806    interceptors: Vec<Arc<dyn ResponseInterceptor>>,
5807}
5808
5809impl ResponseInterceptorStack {
5810    /// Create an empty interceptor stack.
5811    #[must_use]
5812    pub fn new() -> Self {
5813        Self {
5814            interceptors: Vec::new(),
5815        }
5816    }
5817
5818    /// Create a stack with pre-allocated capacity.
5819    #[must_use]
5820    pub fn with_capacity(capacity: usize) -> Self {
5821        Self {
5822            interceptors: Vec::with_capacity(capacity),
5823        }
5824    }
5825
5826    /// Add an interceptor to the end of the stack.
5827    pub fn push<I: ResponseInterceptor + 'static>(&mut self, interceptor: I) {
5828        self.interceptors.push(Arc::new(interceptor));
5829    }
5830
5831    /// Add an Arc-wrapped interceptor.
5832    pub fn push_arc(&mut self, interceptor: Arc<dyn ResponseInterceptor>) {
5833        self.interceptors.push(interceptor);
5834    }
5835
5836    /// Return the number of interceptors in the stack.
5837    #[must_use]
5838    pub fn len(&self) -> usize {
5839        self.interceptors.len()
5840    }
5841
5842    /// Return true if the stack is empty.
5843    #[must_use]
5844    pub fn is_empty(&self) -> bool {
5845        self.interceptors.is_empty()
5846    }
5847
5848    /// Process a response through all interceptors.
5849    pub async fn process(
5850        &self,
5851        ctx: &ResponseInterceptorContext<'_>,
5852        mut response: Response,
5853    ) -> Response {
5854        for interceptor in &self.interceptors {
5855            let _ = ctx.request_ctx.checkpoint();
5856            response = interceptor.intercept(ctx, response).await;
5857        }
5858        response
5859    }
5860}
5861
5862// ---------------------------------------------------------------------------
5863// Timing Interceptor
5864// ---------------------------------------------------------------------------
5865
5866/// Interceptor that adds response timing headers.
5867///
5868/// Adds the `X-Response-Time` header with the time taken to process the request.
5869/// Optionally adds Server-Timing header for browser DevTools integration.
5870///
5871/// # Example
5872///
5873/// ```ignore
5874/// let interceptor = TimingInterceptor::new();
5875/// // Or with Server-Timing header
5876/// let interceptor = TimingInterceptor::with_server_timing("app");
5877/// ```
5878#[derive(Debug, Clone)]
5879pub struct TimingInterceptor {
5880    /// Header name for the response time (default: X-Response-Time).
5881    header_name: String,
5882    /// Whether to include Server-Timing header.
5883    include_server_timing: bool,
5884    /// The timing metric name for Server-Timing (default: "total").
5885    server_timing_name: String,
5886}
5887
5888impl Default for TimingInterceptor {
5889    fn default() -> Self {
5890        Self::new()
5891    }
5892}
5893
5894impl TimingInterceptor {
5895    /// Create a new timing interceptor with default settings.
5896    #[must_use]
5897    pub fn new() -> Self {
5898        Self {
5899            header_name: "X-Response-Time".to_string(),
5900            include_server_timing: false,
5901            server_timing_name: "total".to_string(),
5902        }
5903    }
5904
5905    /// Enable Server-Timing header with the given metric name.
5906    #[must_use]
5907    pub fn with_server_timing(mut self, metric_name: impl Into<String>) -> Self {
5908        self.include_server_timing = true;
5909        self.server_timing_name = metric_name.into();
5910        self
5911    }
5912
5913    /// Set a custom header name instead of X-Response-Time.
5914    #[must_use]
5915    pub fn header_name(mut self, name: impl Into<String>) -> Self {
5916        self.header_name = name.into();
5917        self
5918    }
5919}
5920
5921impl ResponseInterceptor for TimingInterceptor {
5922    fn intercept<'a>(
5923        &'a self,
5924        ctx: &'a ResponseInterceptorContext<'a>,
5925        response: Response,
5926    ) -> BoxFuture<'a, Response> {
5927        Box::pin(async move {
5928            let elapsed_ms = ctx.elapsed_ms();
5929            let timing_value = format!("{}ms", elapsed_ms);
5930
5931            let response = response.header(&self.header_name, timing_value.clone().into_bytes());
5932
5933            if self.include_server_timing {
5934                // Server-Timing format: name;dur=value;desc="description"
5935                let server_timing = format!("{};dur={}", self.server_timing_name, elapsed_ms);
5936                response.header("Server-Timing", server_timing.into_bytes())
5937            } else {
5938                response
5939            }
5940        })
5941    }
5942
5943    fn name(&self) -> &'static str {
5944        "TimingInterceptor"
5945    }
5946}
5947
5948// ---------------------------------------------------------------------------
5949// Debug Headers Interceptor
5950// ---------------------------------------------------------------------------
5951
5952/// Interceptor that adds debug information headers.
5953///
5954/// Useful for development/staging environments to expose internal
5955/// processing information in response headers.
5956///
5957/// # Headers Added
5958///
5959/// - `X-Debug-Request-Id`: The request ID (if available)
5960/// - `X-Debug-Handler-Time`: Handler execution time
5961/// - `X-Debug-Path`: The request path
5962/// - `X-Debug-Method`: The HTTP method
5963///
5964/// # Example
5965///
5966/// ```ignore
5967/// let interceptor = DebugInfoInterceptor::new()
5968///     .include_path(true)
5969///     .include_method(true);
5970/// ```
5971#[derive(Debug, Clone)]
5972#[allow(clippy::struct_excessive_bools)]
5973pub struct DebugInfoInterceptor {
5974    /// Include path in debug headers.
5975    include_path: bool,
5976    /// Include HTTP method in debug headers.
5977    include_method: bool,
5978    /// Include request ID in debug headers.
5979    include_request_id: bool,
5980    /// Include timing information.
5981    include_timing: bool,
5982    /// Header prefix (default: "X-Debug-").
5983    header_prefix: String,
5984}
5985
5986impl Default for DebugInfoInterceptor {
5987    fn default() -> Self {
5988        Self::new()
5989    }
5990}
5991
5992impl DebugInfoInterceptor {
5993    /// Create a new debug info interceptor with all options enabled.
5994    #[must_use]
5995    pub fn new() -> Self {
5996        Self {
5997            include_path: true,
5998            include_method: true,
5999            include_request_id: true,
6000            include_timing: true,
6001            header_prefix: "X-Debug-".to_string(),
6002        }
6003    }
6004
6005    /// Set whether to include the path.
6006    #[must_use]
6007    pub fn include_path(mut self, include: bool) -> Self {
6008        self.include_path = include;
6009        self
6010    }
6011
6012    /// Set whether to include the HTTP method.
6013    #[must_use]
6014    pub fn include_method(mut self, include: bool) -> Self {
6015        self.include_method = include;
6016        self
6017    }
6018
6019    /// Set whether to include the request ID.
6020    #[must_use]
6021    pub fn include_request_id(mut self, include: bool) -> Self {
6022        self.include_request_id = include;
6023        self
6024    }
6025
6026    /// Set whether to include timing information.
6027    #[must_use]
6028    pub fn include_timing(mut self, include: bool) -> Self {
6029        self.include_timing = include;
6030        self
6031    }
6032
6033    /// Set a custom header prefix.
6034    #[must_use]
6035    pub fn header_prefix(mut self, prefix: impl Into<String>) -> Self {
6036        self.header_prefix = prefix.into();
6037        self
6038    }
6039}
6040
6041impl ResponseInterceptor for DebugInfoInterceptor {
6042    fn intercept<'a>(
6043        &'a self,
6044        ctx: &'a ResponseInterceptorContext<'a>,
6045        response: Response,
6046    ) -> BoxFuture<'a, Response> {
6047        Box::pin(async move {
6048            let mut resp = response;
6049
6050            if self.include_path {
6051                let header_name = format!("{}Path", self.header_prefix);
6052                resp = resp.header(header_name, ctx.request.path().as_bytes().to_vec());
6053            }
6054
6055            if self.include_method {
6056                let header_name = format!("{}Method", self.header_prefix);
6057                resp = resp.header(
6058                    header_name,
6059                    ctx.request.method().as_str().as_bytes().to_vec(),
6060                );
6061            }
6062
6063            if self.include_request_id {
6064                if let Some(request_id) = ctx.request.get_extension::<RequestId>() {
6065                    let header_name = format!("{}Request-Id", self.header_prefix);
6066                    resp = resp.header(header_name, request_id.0.as_bytes().to_vec());
6067                }
6068            }
6069
6070            if self.include_timing {
6071                let header_name = format!("{}Handler-Time", self.header_prefix);
6072                let timing = format!("{}ms", ctx.elapsed_ms());
6073                resp = resp.header(header_name, timing.into_bytes());
6074            }
6075
6076            resp
6077        })
6078    }
6079
6080    fn name(&self) -> &'static str {
6081        "DebugInfoInterceptor"
6082    }
6083}
6084
6085// ---------------------------------------------------------------------------
6086// Response Body Transform
6087// ---------------------------------------------------------------------------
6088
6089/// A response transformer that applies a function to the response body.
6090///
6091/// This is useful for content transformations like:
6092/// - Minification
6093/// - Pretty-printing
6094/// - Wrapping responses
6095/// - Filtering content
6096///
6097/// # Example
6098///
6099/// ```ignore
6100/// // Wrap JSON responses in an envelope
6101/// let transformer = ResponseBodyTransform::new(|body| {
6102///     format!(r#"{{"data": {}}}"#, String::from_utf8_lossy(&body)).into_bytes()
6103/// });
6104/// ```
6105pub struct ResponseBodyTransform<F>
6106where
6107    F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6108{
6109    transform_fn: F,
6110    /// Optional content type filter - only transform if content type matches.
6111    content_type_filter: Option<String>,
6112}
6113
6114impl<F> ResponseBodyTransform<F>
6115where
6116    F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6117{
6118    /// Create a new body transformer with the given function.
6119    pub fn new(transform_fn: F) -> Self {
6120        Self {
6121            transform_fn,
6122            content_type_filter: None,
6123        }
6124    }
6125
6126    /// Only apply transformation if the response content type starts with this value.
6127    #[must_use]
6128    pub fn for_content_type(mut self, content_type: impl Into<String>) -> Self {
6129        self.content_type_filter = Some(content_type.into());
6130        self
6131    }
6132
6133    fn should_transform(&self, response: &Response) -> bool {
6134        match &self.content_type_filter {
6135            Some(filter) => response
6136                .headers()
6137                .iter()
6138                .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
6139                .and_then(|(_, ct)| std::str::from_utf8(ct).ok())
6140                .map(|ct| ct.starts_with(filter))
6141                .unwrap_or(false),
6142            None => true,
6143        }
6144    }
6145}
6146
6147impl<F> ResponseInterceptor for ResponseBodyTransform<F>
6148where
6149    F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync,
6150{
6151    fn intercept<'a>(
6152        &'a self,
6153        _ctx: &'a ResponseInterceptorContext<'a>,
6154        response: Response,
6155    ) -> BoxFuture<'a, Response> {
6156        Box::pin(async move {
6157            if !self.should_transform(&response) {
6158                return response;
6159            }
6160
6161            // Extract the body bytes
6162            let body_bytes = match response.body_ref() {
6163                crate::response::ResponseBody::Empty => Vec::new(),
6164                crate::response::ResponseBody::Bytes(b) => b.clone(),
6165                crate::response::ResponseBody::Stream(_) => {
6166                    // Cannot transform streaming responses
6167                    return response;
6168                }
6169            };
6170
6171            // Apply transformation
6172            let transformed = (self.transform_fn)(body_bytes);
6173
6174            // Rebuild response with new body
6175            response.body(crate::response::ResponseBody::Bytes(transformed))
6176        })
6177    }
6178
6179    fn name(&self) -> &'static str {
6180        "ResponseBodyTransform"
6181    }
6182}
6183
6184// ---------------------------------------------------------------------------
6185// Header Transform Interceptor
6186// ---------------------------------------------------------------------------
6187
6188/// An interceptor that transforms response headers.
6189///
6190/// Allows adding, removing, or modifying headers based on the response.
6191///
6192/// # Example
6193///
6194/// ```ignore
6195/// let interceptor = HeaderTransformInterceptor::new()
6196///     .add("X-Powered-By", "fastapi_rust")
6197///     .remove("Server")
6198///     .rename("X-Request-Id", "X-Trace-Id");
6199/// ```
6200#[derive(Debug, Clone, Default)]
6201pub struct HeaderTransformInterceptor {
6202    /// Headers to add.
6203    add_headers: Vec<(String, Vec<u8>)>,
6204    /// Headers to remove.
6205    remove_headers: Vec<String>,
6206    /// Headers to rename (old_name -> new_name).
6207    rename_headers: Vec<(String, String)>,
6208}
6209
6210impl HeaderTransformInterceptor {
6211    /// Create a new header transform interceptor.
6212    #[must_use]
6213    pub fn new() -> Self {
6214        Self::default()
6215    }
6216
6217    /// Add a header to the response.
6218    #[must_use]
6219    pub fn add(mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
6220        self.add_headers.push((name.into(), value.into()));
6221        self
6222    }
6223
6224    /// Remove a header from the response.
6225    #[must_use]
6226    pub fn remove(mut self, name: impl Into<String>) -> Self {
6227        self.remove_headers.push(name.into());
6228        self
6229    }
6230
6231    /// Rename a header (if it exists).
6232    #[must_use]
6233    pub fn rename(mut self, old_name: impl Into<String>, new_name: impl Into<String>) -> Self {
6234        self.rename_headers.push((old_name.into(), new_name.into()));
6235        self
6236    }
6237}
6238
6239impl ResponseInterceptor for HeaderTransformInterceptor {
6240    fn intercept<'a>(
6241        &'a self,
6242        _ctx: &'a ResponseInterceptorContext<'a>,
6243        response: Response,
6244    ) -> BoxFuture<'a, Response> {
6245        let add_headers = self.add_headers.clone();
6246        let remove_headers = self.remove_headers.clone();
6247        let rename_headers = self.rename_headers.clone();
6248
6249        Box::pin(async move {
6250            let mut resp = response;
6251
6252            // Handle renames first - get values of headers to rename
6253            for (old_name, new_name) in &rename_headers {
6254                let values: Vec<Vec<u8>> = resp
6255                    .headers()
6256                    .iter()
6257                    .filter(|(name, _)| name.eq_ignore_ascii_case(old_name))
6258                    .map(|(_, v)| v.clone())
6259                    .collect();
6260
6261                if !values.is_empty() {
6262                    resp = resp.remove_header(old_name);
6263                    for v in values {
6264                        resp = resp.header(new_name, v);
6265                    }
6266                }
6267            }
6268
6269            // Add new headers
6270            for (name, value) in add_headers {
6271                resp = resp.header(name, value);
6272            }
6273
6274            // Remove headers (case-insensitive) after renames/additions.
6275            for name in &remove_headers {
6276                resp = resp.remove_header(name);
6277            }
6278
6279            resp
6280        })
6281    }
6282
6283    fn name(&self) -> &'static str {
6284        "HeaderTransformInterceptor"
6285    }
6286}
6287
6288// ---------------------------------------------------------------------------
6289// Conditional Interceptor Wrapper
6290// ---------------------------------------------------------------------------
6291
6292/// Wrapper that applies an interceptor only when a condition is met.
6293///
6294/// # Example
6295///
6296/// ```ignore
6297/// // Only add debug headers for non-production requests
6298/// let interceptor = ConditionalInterceptor::new(
6299///     DebugInfoInterceptor::new(),
6300///     |ctx, resp| ctx.request.headers().get("X-Debug").is_some()
6301/// );
6302/// ```
6303pub struct ConditionalInterceptor<I, F>
6304where
6305    I: ResponseInterceptor,
6306    F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6307{
6308    inner: I,
6309    condition: F,
6310}
6311
6312impl<I, F> ConditionalInterceptor<I, F>
6313where
6314    I: ResponseInterceptor,
6315    F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6316{
6317    /// Create a new conditional interceptor.
6318    pub fn new(inner: I, condition: F) -> Self {
6319        Self { inner, condition }
6320    }
6321}
6322
6323impl<I, F> ResponseInterceptor for ConditionalInterceptor<I, F>
6324where
6325    I: ResponseInterceptor,
6326    F: Fn(&ResponseInterceptorContext, &Response) -> bool + Send + Sync,
6327{
6328    fn intercept<'a>(
6329        &'a self,
6330        ctx: &'a ResponseInterceptorContext<'a>,
6331        response: Response,
6332    ) -> BoxFuture<'a, Response> {
6333        Box::pin(async move {
6334            if (self.condition)(ctx, &response) {
6335                self.inner.intercept(ctx, response).await
6336            } else {
6337                response
6338            }
6339        })
6340    }
6341
6342    fn name(&self) -> &'static str {
6343        "ConditionalInterceptor"
6344    }
6345}
6346
6347// ---------------------------------------------------------------------------
6348// Error Response Transformer
6349// ---------------------------------------------------------------------------
6350
6351/// Interceptor that transforms error responses.
6352///
6353/// Useful for:
6354/// - Hiding internal error details in production
6355/// - Adding consistent error formatting
6356/// - Logging error responses
6357///
6358/// # Example
6359///
6360/// ```ignore
6361/// let interceptor = ErrorResponseTransformer::new()
6362///     .hide_details_for_status(StatusCode::INTERNAL_SERVER_ERROR)
6363///     .with_replacement_body(b"An internal error occurred".to_vec());
6364/// ```
6365#[derive(Debug, Clone)]
6366pub struct ErrorResponseTransformer {
6367    /// Status codes to transform.
6368    status_codes: HashSet<u16>,
6369    /// Replacement body for error responses.
6370    replacement_body: Option<Vec<u8>>,
6371    /// Whether to add an error ID header.
6372    add_error_id: bool,
6373}
6374
6375impl Default for ErrorResponseTransformer {
6376    fn default() -> Self {
6377        Self::new()
6378    }
6379}
6380
6381impl ErrorResponseTransformer {
6382    /// Create a new error response transformer.
6383    #[must_use]
6384    pub fn new() -> Self {
6385        Self {
6386            status_codes: HashSet::new(),
6387            replacement_body: None,
6388            add_error_id: false,
6389        }
6390    }
6391
6392    /// Hide details for the given status code.
6393    #[must_use]
6394    pub fn hide_details_for_status(mut self, status: crate::response::StatusCode) -> Self {
6395        self.status_codes.insert(status.as_u16());
6396        self
6397    }
6398
6399    /// Set the replacement body for error responses.
6400    #[must_use]
6401    pub fn with_replacement_body(mut self, body: impl Into<Vec<u8>>) -> Self {
6402        self.replacement_body = Some(body.into());
6403        self
6404    }
6405
6406    /// Enable adding an error ID header for tracking.
6407    #[must_use]
6408    pub fn add_error_id(mut self, enable: bool) -> Self {
6409        self.add_error_id = enable;
6410        self
6411    }
6412}
6413
6414impl ResponseInterceptor for ErrorResponseTransformer {
6415    fn intercept<'a>(
6416        &'a self,
6417        ctx: &'a ResponseInterceptorContext<'a>,
6418        response: Response,
6419    ) -> BoxFuture<'a, Response> {
6420        Box::pin(async move {
6421            let status_code = response.status().as_u16();
6422
6423            if !self.status_codes.contains(&status_code) {
6424                return response;
6425            }
6426
6427            let mut resp = response;
6428
6429            // Replace body if configured
6430            if let Some(ref replacement) = self.replacement_body {
6431                resp = resp.body(crate::response::ResponseBody::Bytes(replacement.clone()));
6432            }
6433
6434            // Add error ID header if enabled
6435            if self.add_error_id {
6436                // Use request ID if available, otherwise generate a simple one
6437                let error_id = ctx
6438                    .request
6439                    .get_extension::<RequestId>()
6440                    .map(|r| r.0.clone())
6441                    .unwrap_or_else(|| format!("err-{}", ctx.elapsed_ms()));
6442                resp = resp.header("X-Error-Id", error_id.into_bytes());
6443            }
6444
6445            resp
6446        })
6447    }
6448
6449    fn name(&self) -> &'static str {
6450        "ErrorResponseTransformer"
6451    }
6452}
6453
6454// ---------------------------------------------------------------------------
6455// Middleware adapter for ResponseInterceptor
6456// ---------------------------------------------------------------------------
6457
6458/// Adapter that wraps a `ResponseInterceptor` as a `Middleware`.
6459///
6460/// This allows using response interceptors in the existing middleware stack.
6461///
6462/// # Example
6463///
6464/// ```ignore
6465/// let timing = TimingInterceptor::new();
6466/// let middleware = ResponseInterceptorMiddleware::new(timing);
6467/// stack.push(middleware);
6468/// ```
6469pub struct ResponseInterceptorMiddleware<I>
6470where
6471    I: ResponseInterceptor,
6472{
6473    interceptor: I,
6474}
6475
6476impl<I> ResponseInterceptorMiddleware<I>
6477where
6478    I: ResponseInterceptor,
6479{
6480    /// Wrap a response interceptor as middleware.
6481    pub fn new(interceptor: I) -> Self {
6482        Self { interceptor }
6483    }
6484}
6485
6486impl<I> Middleware for ResponseInterceptorMiddleware<I>
6487where
6488    I: ResponseInterceptor,
6489{
6490    fn before<'a>(
6491        &'a self,
6492        _ctx: &'a RequestContext,
6493        req: &'a mut Request,
6494    ) -> BoxFuture<'a, ControlFlow> {
6495        // Store the start time in request extensions
6496        req.insert_extension(InterceptorStartTime(Instant::now()));
6497        Box::pin(async { ControlFlow::Continue })
6498    }
6499
6500    fn after<'a>(
6501        &'a self,
6502        ctx: &'a RequestContext,
6503        req: &'a Request,
6504        response: Response,
6505    ) -> BoxFuture<'a, Response> {
6506        Box::pin(async move {
6507            // Retrieve start time from extensions
6508            let start_time = req
6509                .get_extension::<InterceptorStartTime>()
6510                .map(|t| t.0)
6511                .unwrap_or_else(Instant::now);
6512
6513            let interceptor_ctx = ResponseInterceptorContext::new(req, ctx, start_time);
6514            self.interceptor.intercept(&interceptor_ctx, response).await
6515        })
6516    }
6517
6518    fn name(&self) -> &'static str {
6519        self.interceptor.name()
6520    }
6521}
6522
6523/// Internal type for storing interceptor start time in request extensions.
6524#[derive(Debug, Clone, Copy)]
6525struct InterceptorStartTime(Instant);
6526
6527// ===========================================================================
6528// End Response Interceptors and Transformers
6529// ===========================================================================
6530
6531// ===========================================================================
6532// Response Timing Metrics Collection
6533// ===========================================================================
6534//
6535// This section provides comprehensive timing metrics for monitoring:
6536// - Request duration
6537// - Time-to-first-byte (TTFB)
6538// - Server-Timing header with multiple metrics
6539// - Histogram collection for aggregation
6540// - Integration with logging
6541
6542/// A single entry in the Server-Timing header.
6543///
6544/// Each entry has a name, duration in milliseconds, and optional description.
6545///
6546/// # Server-Timing Format
6547///
6548/// ```text
6549/// Server-Timing: name;dur=value;desc="description"
6550/// ```
6551///
6552/// # Example
6553///
6554/// ```ignore
6555/// let entry = ServerTimingEntry::new("db", 42.5)
6556///     .with_description("Database query");
6557/// ```
6558#[derive(Debug, Clone)]
6559pub struct ServerTimingEntry {
6560    /// The metric name (e.g., "db", "cache", "render").
6561    name: String,
6562    /// Duration in milliseconds (supports sub-millisecond precision).
6563    duration_ms: f64,
6564    /// Optional description for the metric.
6565    description: Option<String>,
6566}
6567
6568impl ServerTimingEntry {
6569    /// Create a new Server-Timing entry.
6570    #[must_use]
6571    pub fn new(name: impl Into<String>, duration_ms: f64) -> Self {
6572        Self {
6573            name: name.into(),
6574            duration_ms,
6575            description: None,
6576        }
6577    }
6578
6579    /// Add a description to the entry.
6580    #[must_use]
6581    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
6582        self.description = Some(desc.into());
6583        self
6584    }
6585
6586    /// Format this entry for the Server-Timing header.
6587    #[must_use]
6588    pub fn to_header_value(&self) -> String {
6589        match &self.description {
6590            Some(desc) => format!(
6591                "{};dur={:.3};desc=\"{}\"",
6592                self.name, self.duration_ms, desc
6593            ),
6594            None => format!("{};dur={:.3}", self.name, self.duration_ms),
6595        }
6596    }
6597}
6598
6599/// Builder for constructing Server-Timing headers with multiple metrics.
6600///
6601/// Collects multiple timing entries and formats them as a single header value.
6602///
6603/// # Example
6604///
6605/// ```ignore
6606/// let timing = ServerTimingBuilder::new()
6607///     .add("total", 150.5)
6608///     .add_with_desc("db", 42.0, "Database queries")
6609///     .add_with_desc("cache", 5.0, "Cache lookup")
6610///     .build();
6611///
6612/// // Result: "total;dur=150.500, db;dur=42.000;desc=\"Database queries\", cache;dur=5.000;desc=\"Cache lookup\""
6613/// ```
6614#[derive(Debug, Clone, Default)]
6615pub struct ServerTimingBuilder {
6616    entries: Vec<ServerTimingEntry>,
6617}
6618
6619impl ServerTimingBuilder {
6620    /// Create a new empty builder.
6621    #[must_use]
6622    pub fn new() -> Self {
6623        Self::default()
6624    }
6625
6626    /// Add a timing entry with just a name and duration.
6627    #[must_use]
6628    pub fn add(mut self, name: impl Into<String>, duration_ms: f64) -> Self {
6629        self.entries.push(ServerTimingEntry::new(name, duration_ms));
6630        self
6631    }
6632
6633    /// Add a timing entry with a description.
6634    #[must_use]
6635    pub fn add_with_desc(
6636        mut self,
6637        name: impl Into<String>,
6638        duration_ms: f64,
6639        description: impl Into<String>,
6640    ) -> Self {
6641        self.entries
6642            .push(ServerTimingEntry::new(name, duration_ms).with_description(description));
6643        self
6644    }
6645
6646    /// Add a pre-built entry.
6647    #[must_use]
6648    pub fn add_entry(mut self, entry: ServerTimingEntry) -> Self {
6649        self.entries.push(entry);
6650        self
6651    }
6652
6653    /// Build the Server-Timing header value.
6654    #[must_use]
6655    pub fn build(&self) -> String {
6656        self.entries
6657            .iter()
6658            .map(ServerTimingEntry::to_header_value)
6659            .collect::<Vec<_>>()
6660            .join(", ")
6661    }
6662
6663    /// Return true if no entries have been added.
6664    #[must_use]
6665    pub fn is_empty(&self) -> bool {
6666        self.entries.is_empty()
6667    }
6668
6669    /// Return the number of entries.
6670    #[must_use]
6671    pub fn len(&self) -> usize {
6672        self.entries.len()
6673    }
6674}
6675
6676/// Collected timing metrics for a single request.
6677///
6678/// This struct is stored in request extensions and can be read by
6679/// interceptors or logging middleware to expose timing data.
6680///
6681/// # Usage
6682///
6683/// Handlers can access and modify timing metrics via request extensions:
6684///
6685/// ```ignore
6686/// // Add a custom timing metric
6687/// if let Some(metrics) = req.get_extension_mut::<TimingMetrics>() {
6688///     metrics.add_metric("db", db_time.as_secs_f64() * 1000.0);
6689/// }
6690/// ```
6691#[derive(Debug, Clone)]
6692pub struct TimingMetrics {
6693    /// When the request processing started.
6694    pub start_time: Instant,
6695    /// When the first byte of the response was sent (if known).
6696    pub first_byte_time: Option<Instant>,
6697    /// Custom metrics added by handlers (name -> duration_ms).
6698    pub custom_metrics: Vec<(String, f64, Option<String>)>,
6699}
6700
6701impl TimingMetrics {
6702    /// Create new timing metrics starting now.
6703    #[must_use]
6704    pub fn new() -> Self {
6705        Self {
6706            start_time: Instant::now(),
6707            first_byte_time: None,
6708            custom_metrics: Vec::new(),
6709        }
6710    }
6711
6712    /// Create timing metrics with a specific start time.
6713    #[must_use]
6714    pub fn with_start_time(start_time: Instant) -> Self {
6715        Self {
6716            start_time,
6717            first_byte_time: None,
6718            custom_metrics: Vec::new(),
6719        }
6720    }
6721
6722    /// Mark the time when the first byte of the response was sent.
6723    pub fn mark_first_byte(&mut self) {
6724        self.first_byte_time = Some(Instant::now());
6725    }
6726
6727    /// Add a custom metric (e.g., database query time).
6728    pub fn add_metric(&mut self, name: impl Into<String>, duration_ms: f64) {
6729        self.custom_metrics.push((name.into(), duration_ms, None));
6730    }
6731
6732    /// Add a custom metric with a description.
6733    pub fn add_metric_with_desc(
6734        &mut self,
6735        name: impl Into<String>,
6736        duration_ms: f64,
6737        desc: impl Into<String>,
6738    ) {
6739        self.custom_metrics
6740            .push((name.into(), duration_ms, Some(desc.into())));
6741    }
6742
6743    /// Get the total elapsed time in milliseconds.
6744    #[must_use]
6745    pub fn total_ms(&self) -> f64 {
6746        self.start_time.elapsed().as_secs_f64() * 1000.0
6747    }
6748
6749    /// Get the time-to-first-byte in milliseconds (if available).
6750    #[must_use]
6751    pub fn ttfb_ms(&self) -> Option<f64> {
6752        self.first_byte_time
6753            .map(|t| t.duration_since(self.start_time).as_secs_f64() * 1000.0)
6754    }
6755
6756    /// Build a Server-Timing header from the collected metrics.
6757    #[must_use]
6758    pub fn to_server_timing(&self) -> ServerTimingBuilder {
6759        let mut builder = ServerTimingBuilder::new().add_with_desc(
6760            "total",
6761            self.total_ms(),
6762            "Total request time",
6763        );
6764
6765        if let Some(ttfb) = self.ttfb_ms() {
6766            builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6767        }
6768
6769        for (name, duration, desc) in &self.custom_metrics {
6770            match desc {
6771                Some(d) => builder = builder.add_with_desc(name, *duration, d),
6772                None => builder = builder.add(name, *duration),
6773            }
6774        }
6775
6776        builder
6777    }
6778}
6779
6780impl Default for TimingMetrics {
6781    fn default() -> Self {
6782        Self::new()
6783    }
6784}
6785
6786/// Configuration for the timing metrics middleware.
6787#[derive(Debug, Clone)]
6788#[allow(clippy::struct_excessive_bools)]
6789pub struct TimingMetricsConfig {
6790    /// Whether to add the Server-Timing header.
6791    pub add_server_timing_header: bool,
6792    /// Whether to add the X-Response-Time header.
6793    pub add_response_time_header: bool,
6794    /// Custom header name for response time (default: "X-Response-Time").
6795    pub response_time_header_name: String,
6796    /// Whether to include custom metrics from handlers.
6797    pub include_custom_metrics: bool,
6798    /// Whether to include TTFB in the Server-Timing header.
6799    pub include_ttfb: bool,
6800}
6801
6802impl Default for TimingMetricsConfig {
6803    fn default() -> Self {
6804        Self {
6805            add_server_timing_header: true,
6806            add_response_time_header: true,
6807            response_time_header_name: "X-Response-Time".to_string(),
6808            include_custom_metrics: true,
6809            include_ttfb: true,
6810        }
6811    }
6812}
6813
6814impl TimingMetricsConfig {
6815    /// Create a new config with default settings.
6816    #[must_use]
6817    pub fn new() -> Self {
6818        Self::default()
6819    }
6820
6821    /// Enable or disable Server-Timing header.
6822    #[must_use]
6823    pub fn server_timing(mut self, enabled: bool) -> Self {
6824        self.add_server_timing_header = enabled;
6825        self
6826    }
6827
6828    /// Enable or disable X-Response-Time header.
6829    #[must_use]
6830    pub fn response_time(mut self, enabled: bool) -> Self {
6831        self.add_response_time_header = enabled;
6832        self
6833    }
6834
6835    /// Set a custom response time header name.
6836    #[must_use]
6837    pub fn response_time_header(mut self, name: impl Into<String>) -> Self {
6838        self.response_time_header_name = name.into();
6839        self
6840    }
6841
6842    /// Enable or disable custom metrics.
6843    #[must_use]
6844    pub fn custom_metrics(mut self, enabled: bool) -> Self {
6845        self.include_custom_metrics = enabled;
6846        self
6847    }
6848
6849    /// Enable or disable TTFB tracking.
6850    #[must_use]
6851    pub fn ttfb(mut self, enabled: bool) -> Self {
6852        self.include_ttfb = enabled;
6853        self
6854    }
6855
6856    /// Create a production-safe config (minimal headers).
6857    #[must_use]
6858    pub fn production() -> Self {
6859        Self {
6860            add_server_timing_header: false,
6861            add_response_time_header: true,
6862            response_time_header_name: "X-Response-Time".to_string(),
6863            include_custom_metrics: false,
6864            include_ttfb: false,
6865        }
6866    }
6867
6868    /// Create a development config (all timing info exposed).
6869    #[must_use]
6870    pub fn development() -> Self {
6871        Self::default()
6872    }
6873}
6874
6875/// Middleware that collects and exposes timing metrics.
6876///
6877/// This middleware:
6878/// 1. Records the request start time
6879/// 2. Injects `TimingMetrics` into request extensions for handlers to use
6880/// 3. Adds timing headers to the response
6881///
6882/// # Example
6883///
6884/// ```ignore
6885/// let timing = TimingMetricsMiddleware::new();
6886/// // Or with custom config:
6887/// let timing = TimingMetricsMiddleware::with_config(
6888///     TimingMetricsConfig::production()
6889/// );
6890///
6891/// middleware_stack.push(timing);
6892/// ```
6893#[derive(Debug, Clone)]
6894pub struct TimingMetricsMiddleware {
6895    config: TimingMetricsConfig,
6896}
6897
6898impl TimingMetricsMiddleware {
6899    /// Create a new timing metrics middleware with default config.
6900    #[must_use]
6901    pub fn new() -> Self {
6902        Self {
6903            config: TimingMetricsConfig::default(),
6904        }
6905    }
6906
6907    /// Create with a custom configuration.
6908    #[must_use]
6909    pub fn with_config(config: TimingMetricsConfig) -> Self {
6910        Self { config }
6911    }
6912
6913    /// Create a production-safe instance (minimal headers).
6914    #[must_use]
6915    pub fn production() -> Self {
6916        Self {
6917            config: TimingMetricsConfig::production(),
6918        }
6919    }
6920
6921    /// Create a development instance (all timing info exposed).
6922    #[must_use]
6923    pub fn development() -> Self {
6924        Self {
6925            config: TimingMetricsConfig::development(),
6926        }
6927    }
6928}
6929
6930impl Default for TimingMetricsMiddleware {
6931    fn default() -> Self {
6932        Self::new()
6933    }
6934}
6935
6936impl Middleware for TimingMetricsMiddleware {
6937    fn before<'a>(
6938        &'a self,
6939        _ctx: &'a RequestContext,
6940        req: &'a mut Request,
6941    ) -> BoxFuture<'a, ControlFlow> {
6942        // Store timing metrics in request extensions
6943        req.insert_extension(TimingMetrics::new());
6944        Box::pin(async { ControlFlow::Continue })
6945    }
6946
6947    fn after<'a>(
6948        &'a self,
6949        _ctx: &'a RequestContext,
6950        req: &'a Request,
6951        response: Response,
6952    ) -> BoxFuture<'a, Response> {
6953        let config = self.config.clone();
6954
6955        Box::pin(async move {
6956            let mut resp = response;
6957
6958            // Get timing metrics from extensions
6959            let metrics = req.get_extension::<TimingMetrics>();
6960
6961            match metrics {
6962                Some(metrics) => {
6963                    // Add X-Response-Time header
6964                    if config.add_response_time_header {
6965                        let timing = format!("{:.3}ms", metrics.total_ms());
6966                        resp = resp.header(&config.response_time_header_name, timing.into_bytes());
6967                    }
6968
6969                    // Add Server-Timing header
6970                    if config.add_server_timing_header {
6971                        let mut builder = ServerTimingBuilder::new().add_with_desc(
6972                            "total",
6973                            metrics.total_ms(),
6974                            "Total request time",
6975                        );
6976
6977                        // Add TTFB if available and enabled
6978                        if config.include_ttfb {
6979                            if let Some(ttfb) = metrics.ttfb_ms() {
6980                                builder = builder.add_with_desc("ttfb", ttfb, "Time to first byte");
6981                            }
6982                        }
6983
6984                        // Add custom metrics if enabled
6985                        if config.include_custom_metrics {
6986                            for (name, duration, desc) in &metrics.custom_metrics {
6987                                match desc {
6988                                    Some(d) => builder = builder.add_with_desc(name, *duration, d),
6989                                    None => builder = builder.add(name, *duration),
6990                                }
6991                            }
6992                        }
6993
6994                        let header_value = builder.build();
6995                        resp = resp.header("Server-Timing", header_value.into_bytes());
6996                    }
6997                }
6998                None => {
6999                    // No timing metrics in extensions - add basic timing
7000                    // This shouldn't happen if middleware is properly registered
7001                    if config.add_response_time_header {
7002                        resp = resp.header(&config.response_time_header_name, b"0.000ms".to_vec());
7003                    }
7004                }
7005            }
7006
7007            resp
7008        })
7009    }
7010
7011    fn name(&self) -> &'static str {
7012        "TimingMetrics"
7013    }
7014}
7015
7016/// Simple histogram bucket for collecting timing distributions.
7017///
7018/// Useful for aggregating timing data across many requests.
7019#[derive(Debug, Clone)]
7020pub struct TimingHistogramBucket {
7021    /// Upper bound for this bucket (milliseconds).
7022    pub le: f64,
7023    /// Count of observations in this bucket.
7024    pub count: u64,
7025}
7026
7027/// A histogram for collecting timing distributions.
7028///
7029/// This provides Prometheus-style histogram buckets for aggregating
7030/// timing data across many requests.
7031///
7032/// # Example
7033///
7034/// ```ignore
7035/// let mut histogram = TimingHistogram::with_buckets(vec![
7036///     1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0
7037/// ]);
7038///
7039/// histogram.observe(42.5);  // 42.5ms response time
7040/// histogram.observe(150.0);
7041///
7042/// let buckets = histogram.buckets();
7043/// let avg = histogram.mean();
7044/// ```
7045#[derive(Debug, Clone)]
7046pub struct TimingHistogram {
7047    /// Bucket upper bounds in milliseconds.
7048    bucket_bounds: Vec<f64>,
7049    /// Count per bucket.
7050    bucket_counts: Vec<u64>,
7051    /// Sum of all observed values.
7052    sum: f64,
7053    /// Total count of observations.
7054    count: u64,
7055}
7056
7057impl TimingHistogram {
7058    /// Create a histogram with the given bucket upper bounds.
7059    ///
7060    /// Bounds should be sorted in ascending order.
7061    #[must_use]
7062    pub fn with_buckets(bucket_bounds: Vec<f64>) -> Self {
7063        let bucket_counts = vec![0; bucket_bounds.len()];
7064        Self {
7065            bucket_bounds,
7066            bucket_counts,
7067            sum: 0.0,
7068            count: 0,
7069        }
7070    }
7071
7072    /// Create a histogram with default HTTP latency buckets.
7073    ///
7074    /// Buckets: 1ms, 5ms, 10ms, 25ms, 50ms, 100ms, 250ms, 500ms, 1s, 2.5s, 5s, 10s
7075    #[must_use]
7076    pub fn http_latency() -> Self {
7077        Self::with_buckets(vec![
7078            1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0, 2500.0, 5000.0, 10000.0,
7079        ])
7080    }
7081
7082    /// Record an observation.
7083    pub fn observe(&mut self, value_ms: f64) {
7084        self.sum += value_ms;
7085        self.count += 1;
7086
7087        // Increment bucket counts (cumulative)
7088        for (i, bound) in self.bucket_bounds.iter().enumerate() {
7089            if value_ms <= *bound {
7090                self.bucket_counts[i] += 1;
7091            }
7092        }
7093    }
7094
7095    /// Get the total count of observations.
7096    #[must_use]
7097    pub fn count(&self) -> u64 {
7098        self.count
7099    }
7100
7101    /// Get the sum of all observed values.
7102    #[must_use]
7103    pub fn sum(&self) -> f64 {
7104        self.sum
7105    }
7106
7107    /// Get the mean value.
7108    #[must_use]
7109    pub fn mean(&self) -> f64 {
7110        if self.count == 0 {
7111            0.0
7112        } else {
7113            #[allow(clippy::cast_precision_loss)]
7114            {
7115                self.sum / self.count as f64
7116            }
7117        }
7118    }
7119
7120    /// Get the bucket data.
7121    #[must_use]
7122    pub fn buckets(&self) -> Vec<TimingHistogramBucket> {
7123        self.bucket_bounds
7124            .iter()
7125            .zip(&self.bucket_counts)
7126            .map(|(&le, &count)| TimingHistogramBucket { le, count })
7127            .collect()
7128    }
7129
7130    /// Reset the histogram.
7131    pub fn reset(&mut self) {
7132        self.sum = 0.0;
7133        self.count = 0;
7134        for count in &mut self.bucket_counts {
7135            *count = 0;
7136        }
7137    }
7138}
7139
7140impl Default for TimingHistogram {
7141    fn default() -> Self {
7142        Self::http_latency()
7143    }
7144}
7145
7146// ===========================================================================
7147// End Response Timing Metrics Collection
7148// ===========================================================================
7149
7150#[cfg(test)]
7151mod timing_metrics_tests {
7152    use super::*;
7153    use crate::request::Method;
7154    use crate::response::StatusCode;
7155
7156    fn test_context() -> RequestContext {
7157        RequestContext::new(asupersync::Cx::for_testing(), 1)
7158    }
7159
7160    fn test_request() -> Request {
7161        Request::new(Method::Get, "/test")
7162    }
7163
7164    fn run_middleware_before(mw: &impl Middleware, req: &mut Request) -> ControlFlow {
7165        let ctx = test_context();
7166        futures_executor::block_on(mw.before(&ctx, req))
7167    }
7168
7169    fn run_middleware_after(mw: &impl Middleware, req: &Request, resp: Response) -> Response {
7170        let ctx = test_context();
7171        futures_executor::block_on(mw.after(&ctx, req, resp))
7172    }
7173
7174    #[test]
7175    fn server_timing_entry_basic() {
7176        let entry = ServerTimingEntry::new("db", 42.5);
7177        assert_eq!(entry.to_header_value(), "db;dur=42.500");
7178    }
7179
7180    #[test]
7181    fn server_timing_entry_with_description() {
7182        let entry = ServerTimingEntry::new("db", 42.5).with_description("Database query");
7183        assert_eq!(
7184            entry.to_header_value(),
7185            "db;dur=42.500;desc=\"Database query\""
7186        );
7187    }
7188
7189    #[test]
7190    fn server_timing_builder_single_entry() {
7191        let timing = ServerTimingBuilder::new().add("total", 150.0).build();
7192        assert_eq!(timing, "total;dur=150.000");
7193    }
7194
7195    #[test]
7196    fn server_timing_builder_multiple_entries() {
7197        let timing = ServerTimingBuilder::new()
7198            .add("total", 150.0)
7199            .add_with_desc("db", 42.0, "Database")
7200            .add("cache", 5.0)
7201            .build();
7202
7203        assert!(timing.contains("total;dur=150.000"));
7204        assert!(timing.contains("db;dur=42.000;desc=\"Database\""));
7205        assert!(timing.contains("cache;dur=5.000"));
7206        assert!(timing.contains(", ")); // Multiple entries separated by comma
7207    }
7208
7209    #[test]
7210    fn server_timing_builder_empty() {
7211        let builder = ServerTimingBuilder::new();
7212        assert!(builder.is_empty());
7213        assert_eq!(builder.len(), 0);
7214        assert_eq!(builder.build(), "");
7215    }
7216
7217    #[test]
7218    fn timing_metrics_basic() {
7219        let metrics = TimingMetrics::new();
7220        std::thread::sleep(std::time::Duration::from_millis(5));
7221
7222        let total = metrics.total_ms();
7223        assert!(total >= 5.0, "Total should be at least 5ms");
7224        assert!(metrics.ttfb_ms().is_none(), "TTFB should not be set");
7225    }
7226
7227    #[test]
7228    fn timing_metrics_custom_metrics() {
7229        let mut metrics = TimingMetrics::new();
7230        metrics.add_metric("db", 42.5);
7231        metrics.add_metric_with_desc("cache", 5.0, "Cache lookup");
7232
7233        let timing = metrics.to_server_timing();
7234        assert_eq!(timing.len(), 3); // total + 2 custom
7235
7236        let header = timing.build();
7237        assert!(header.contains("total"));
7238        assert!(header.contains("db;dur=42.500"));
7239        assert!(header.contains("cache;dur=5.000;desc=\"Cache lookup\""));
7240    }
7241
7242    #[test]
7243    fn timing_metrics_ttfb() {
7244        let mut metrics = TimingMetrics::new();
7245        std::thread::sleep(std::time::Duration::from_millis(5));
7246        metrics.mark_first_byte();
7247
7248        let ttfb = metrics.ttfb_ms().unwrap();
7249        assert!(ttfb >= 5.0, "TTFB should be at least 5ms");
7250    }
7251
7252    #[test]
7253    fn timing_metrics_config_default() {
7254        let config = TimingMetricsConfig::default();
7255        assert!(config.add_server_timing_header);
7256        assert!(config.add_response_time_header);
7257        assert!(config.include_custom_metrics);
7258        assert!(config.include_ttfb);
7259    }
7260
7261    #[test]
7262    fn timing_metrics_config_production() {
7263        let config = TimingMetricsConfig::production();
7264        assert!(!config.add_server_timing_header);
7265        assert!(config.add_response_time_header);
7266        assert!(!config.include_custom_metrics);
7267    }
7268
7269    #[test]
7270    fn timing_middleware_adds_metrics_to_request() {
7271        let mw = TimingMetricsMiddleware::new();
7272        let mut req = test_request();
7273
7274        // Before should insert TimingMetrics
7275        let result = run_middleware_before(&mw, &mut req);
7276        assert!(result.is_continue());
7277
7278        let metrics = req.get_extension::<TimingMetrics>();
7279        assert!(metrics.is_some(), "TimingMetrics should be in extensions");
7280    }
7281
7282    #[test]
7283    fn timing_middleware_adds_response_time_header() {
7284        let mw = TimingMetricsMiddleware::new();
7285        let mut req = test_request();
7286
7287        // Run before to insert TimingMetrics
7288        run_middleware_before(&mw, &mut req);
7289
7290        let resp = Response::with_status(StatusCode::OK);
7291        let result = run_middleware_after(&mw, &req, resp);
7292
7293        let has_timing = result
7294            .headers()
7295            .iter()
7296            .any(|(name, _)| name == "X-Response-Time");
7297        assert!(has_timing, "Should have X-Response-Time header");
7298    }
7299
7300    #[test]
7301    fn timing_middleware_adds_server_timing_header() {
7302        let mw = TimingMetricsMiddleware::new();
7303        let mut req = test_request();
7304
7305        run_middleware_before(&mw, &mut req);
7306
7307        let resp = Response::with_status(StatusCode::OK);
7308        let result = run_middleware_after(&mw, &req, resp);
7309
7310        let server_timing = result
7311            .headers()
7312            .iter()
7313            .find(|(name, _)| name == "Server-Timing")
7314            .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7315
7316        assert!(server_timing.is_some(), "Should have Server-Timing header");
7317        let header = server_timing.unwrap();
7318        assert!(header.contains("total"), "Should have total timing");
7319    }
7320
7321    #[test]
7322    fn timing_middleware_production_mode() {
7323        let mw = TimingMetricsMiddleware::production();
7324        let mut req = test_request();
7325
7326        run_middleware_before(&mw, &mut req);
7327
7328        let resp = Response::with_status(StatusCode::OK);
7329        let result = run_middleware_after(&mw, &req, resp);
7330
7331        // Should have X-Response-Time
7332        let has_response_time = result
7333            .headers()
7334            .iter()
7335            .any(|(name, _)| name == "X-Response-Time");
7336        assert!(has_response_time);
7337
7338        // Should NOT have Server-Timing
7339        let has_server_timing = result
7340            .headers()
7341            .iter()
7342            .any(|(name, _)| name == "Server-Timing");
7343        assert!(!has_server_timing);
7344    }
7345
7346    #[test]
7347    #[allow(clippy::float_cmp)]
7348    fn timing_histogram_basic() {
7349        let mut histogram = TimingHistogram::http_latency();
7350        assert_eq!(histogram.count(), 0);
7351        assert_eq!(histogram.sum(), 0.0);
7352
7353        histogram.observe(42.0);
7354        histogram.observe(150.0);
7355        histogram.observe(5.0);
7356
7357        assert_eq!(histogram.count(), 3);
7358        assert_eq!(histogram.sum(), 197.0);
7359        assert!((histogram.mean() - 65.666).abs() < 0.01);
7360    }
7361
7362    #[test]
7363    fn timing_histogram_buckets() {
7364        let mut histogram = TimingHistogram::with_buckets(vec![10.0, 50.0, 100.0]);
7365
7366        histogram.observe(5.0); // Falls in 10 bucket
7367        histogram.observe(25.0); // Falls in 50 bucket
7368        histogram.observe(75.0); // Falls in 100 bucket
7369        histogram.observe(150.0); // Above all buckets
7370
7371        let buckets = histogram.buckets();
7372        assert_eq!(buckets.len(), 3);
7373
7374        // Buckets are cumulative
7375        assert_eq!(buckets[0].count, 1); // <= 10: 1
7376        assert_eq!(buckets[1].count, 2); // <= 50: 2
7377        assert_eq!(buckets[2].count, 3); // <= 100: 3
7378    }
7379
7380    #[test]
7381    #[allow(clippy::float_cmp)]
7382    fn timing_histogram_reset() {
7383        let mut histogram = TimingHistogram::http_latency();
7384        histogram.observe(100.0);
7385        histogram.observe(200.0);
7386
7387        assert_eq!(histogram.count(), 2);
7388
7389        histogram.reset();
7390
7391        assert_eq!(histogram.count(), 0);
7392        assert_eq!(histogram.sum(), 0.0);
7393    }
7394}
7395
7396#[cfg(test)]
7397mod response_interceptor_tests {
7398    use super::*;
7399    use crate::request::Method;
7400    use crate::response::StatusCode;
7401
7402    fn test_context() -> RequestContext {
7403        RequestContext::new(asupersync::Cx::for_testing(), 1)
7404    }
7405
7406    fn test_request() -> Request {
7407        Request::new(Method::Get, "/test")
7408    }
7409
7410    fn run_interceptor<I: ResponseInterceptor>(
7411        interceptor: &I,
7412        req: &Request,
7413        resp: Response,
7414    ) -> Response {
7415        let ctx = test_context();
7416        let start_time = Instant::now();
7417        let interceptor_ctx = ResponseInterceptorContext::new(req, &ctx, start_time);
7418        futures_executor::block_on(interceptor.intercept(&interceptor_ctx, resp))
7419    }
7420
7421    #[test]
7422    fn timing_interceptor_adds_header() {
7423        let interceptor = TimingInterceptor::new();
7424        let req = test_request();
7425        let resp = Response::with_status(StatusCode::OK);
7426
7427        let result = run_interceptor(&interceptor, &req, resp);
7428
7429        let has_timing = result
7430            .headers()
7431            .iter()
7432            .any(|(name, _)| name == "X-Response-Time");
7433        assert!(has_timing, "Should have X-Response-Time header");
7434    }
7435
7436    #[test]
7437    fn timing_interceptor_with_server_timing() {
7438        let interceptor = TimingInterceptor::new().with_server_timing("app");
7439        let req = test_request();
7440        let resp = Response::with_status(StatusCode::OK);
7441
7442        let result = run_interceptor(&interceptor, &req, resp);
7443
7444        let has_server_timing = result
7445            .headers()
7446            .iter()
7447            .any(|(name, _)| name == "Server-Timing");
7448        assert!(has_server_timing, "Should have Server-Timing header");
7449    }
7450
7451    #[test]
7452    fn timing_interceptor_custom_header_name() {
7453        let interceptor = TimingInterceptor::new().header_name("X-Custom-Time");
7454        let req = test_request();
7455        let resp = Response::with_status(StatusCode::OK);
7456
7457        let result = run_interceptor(&interceptor, &req, resp);
7458
7459        let has_custom = result
7460            .headers()
7461            .iter()
7462            .any(|(name, _)| name == "X-Custom-Time");
7463        assert!(has_custom, "Should have X-Custom-Time header");
7464    }
7465
7466    #[test]
7467    fn debug_info_interceptor_adds_headers() {
7468        let interceptor = DebugInfoInterceptor::new();
7469        let req = test_request();
7470        let resp = Response::with_status(StatusCode::OK);
7471
7472        let result = run_interceptor(&interceptor, &req, resp);
7473
7474        let has_path = result
7475            .headers()
7476            .iter()
7477            .any(|(name, _)| name == "X-Debug-Path");
7478        let has_method = result
7479            .headers()
7480            .iter()
7481            .any(|(name, _)| name == "X-Debug-Method");
7482        let has_timing = result
7483            .headers()
7484            .iter()
7485            .any(|(name, _)| name == "X-Debug-Handler-Time");
7486
7487        assert!(has_path, "Should have X-Debug-Path header");
7488        assert!(has_method, "Should have X-Debug-Method header");
7489        assert!(has_timing, "Should have X-Debug-Handler-Time header");
7490    }
7491
7492    #[test]
7493    fn debug_info_interceptor_custom_prefix() {
7494        let interceptor = DebugInfoInterceptor::new().header_prefix("X-Trace-");
7495        let req = test_request();
7496        let resp = Response::with_status(StatusCode::OK);
7497
7498        let result = run_interceptor(&interceptor, &req, resp);
7499
7500        let has_trace_path = result
7501            .headers()
7502            .iter()
7503            .any(|(name, _)| name == "X-Trace-Path");
7504        assert!(has_trace_path, "Should have X-Trace-Path header");
7505    }
7506
7507    #[test]
7508    fn debug_info_interceptor_selective_options() {
7509        let interceptor = DebugInfoInterceptor::new()
7510            .include_path(true)
7511            .include_method(false)
7512            .include_timing(false)
7513            .include_request_id(false);
7514        let req = test_request();
7515        let resp = Response::with_status(StatusCode::OK);
7516
7517        let result = run_interceptor(&interceptor, &req, resp);
7518
7519        let has_path = result
7520            .headers()
7521            .iter()
7522            .any(|(name, _)| name == "X-Debug-Path");
7523        let has_method = result
7524            .headers()
7525            .iter()
7526            .any(|(name, _)| name == "X-Debug-Method");
7527
7528        assert!(has_path, "Should have X-Debug-Path header");
7529        assert!(!has_method, "Should NOT have X-Debug-Method header");
7530    }
7531
7532    #[test]
7533    fn header_transform_adds_headers() {
7534        let interceptor = HeaderTransformInterceptor::new()
7535            .add("X-Powered-By", b"fastapi_rust".to_vec())
7536            .add("X-Version", b"1.0".to_vec());
7537        let req = test_request();
7538        let resp = Response::with_status(StatusCode::OK);
7539
7540        let result = run_interceptor(&interceptor, &req, resp);
7541
7542        let has_powered_by = result
7543            .headers()
7544            .iter()
7545            .any(|(name, _)| name == "X-Powered-By");
7546        let has_version = result.headers().iter().any(|(name, _)| name == "X-Version");
7547
7548        assert!(has_powered_by, "Should have X-Powered-By header");
7549        assert!(has_version, "Should have X-Version header");
7550    }
7551
7552    #[test]
7553    fn response_body_transform_modifies_body() {
7554        let transformer = ResponseBodyTransform::new(|body| {
7555            let mut result = b"[".to_vec();
7556            result.extend_from_slice(&body);
7557            result.extend_from_slice(b"]");
7558            result
7559        });
7560        let req = test_request();
7561        let resp = Response::with_status(StatusCode::OK)
7562            .body(crate::response::ResponseBody::Bytes(b"hello".to_vec()));
7563
7564        let result = run_interceptor(&transformer, &req, resp);
7565
7566        match result.body_ref() {
7567            crate::response::ResponseBody::Bytes(b) => {
7568                assert_eq!(b, b"[hello]");
7569            }
7570            _ => panic!("Expected bytes body"),
7571        }
7572    }
7573
7574    #[test]
7575    fn response_body_transform_with_content_type_filter() {
7576        let transformer =
7577            ResponseBodyTransform::new(|_| b"transformed".to_vec()).for_content_type("text/plain");
7578        let req = test_request();
7579
7580        // JSON response should NOT be transformed
7581        let json_resp = Response::with_status(StatusCode::OK)
7582            .header("content-type", b"application/json".to_vec())
7583            .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7584
7585        let result = run_interceptor(&transformer, &req, json_resp);
7586
7587        match result.body_ref() {
7588            crate::response::ResponseBody::Bytes(b) => {
7589                assert_eq!(b, b"original", "JSON should not be transformed");
7590            }
7591            _ => panic!("Expected bytes body"),
7592        }
7593
7594        // Plain text response SHOULD be transformed
7595        let text_resp = Response::with_status(StatusCode::OK)
7596            .header("content-type", b"text/plain".to_vec())
7597            .body(crate::response::ResponseBody::Bytes(b"original".to_vec()));
7598
7599        let result = run_interceptor(&transformer, &req, text_resp);
7600
7601        match result.body_ref() {
7602            crate::response::ResponseBody::Bytes(b) => {
7603                assert_eq!(b, b"transformed", "Text should be transformed");
7604            }
7605            _ => panic!("Expected bytes body"),
7606        }
7607    }
7608
7609    #[test]
7610    fn error_response_transformer_hides_details() {
7611        let transformer = ErrorResponseTransformer::new()
7612            .hide_details_for_status(StatusCode::INTERNAL_SERVER_ERROR)
7613            .with_replacement_body(b"An error occurred");
7614
7615        let req = test_request();
7616
7617        // 500 response should be transformed
7618        let error_resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR).body(
7619            crate::response::ResponseBody::Bytes(b"Sensitive error details".to_vec()),
7620        );
7621
7622        let result = run_interceptor(&transformer, &req, error_resp);
7623
7624        match result.body_ref() {
7625            crate::response::ResponseBody::Bytes(b) => {
7626                assert_eq!(b, b"An error occurred");
7627            }
7628            _ => panic!("Expected bytes body"),
7629        }
7630
7631        // 200 response should NOT be transformed
7632        let ok_resp = Response::with_status(StatusCode::OK)
7633            .body(crate::response::ResponseBody::Bytes(b"Success".to_vec()));
7634
7635        let result = run_interceptor(&transformer, &req, ok_resp);
7636
7637        match result.body_ref() {
7638            crate::response::ResponseBody::Bytes(b) => {
7639                assert_eq!(b, b"Success");
7640            }
7641            _ => panic!("Expected bytes body"),
7642        }
7643    }
7644
7645    #[test]
7646    fn response_interceptor_stack_chains_interceptors() {
7647        let mut stack = ResponseInterceptorStack::new();
7648        stack.push(TimingInterceptor::new());
7649        stack.push(HeaderTransformInterceptor::new().add("X-Extra", b"value".to_vec()));
7650
7651        let req = test_request();
7652        let resp = Response::with_status(StatusCode::OK);
7653
7654        let ctx = test_context();
7655        let start_time = Instant::now();
7656        let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7657        let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7658
7659        let has_timing = result
7660            .headers()
7661            .iter()
7662            .any(|(name, _)| name == "X-Response-Time");
7663        let has_extra = result.headers().iter().any(|(name, _)| name == "X-Extra");
7664
7665        assert!(
7666            has_timing,
7667            "Should have timing header from first interceptor"
7668        );
7669        assert!(
7670            has_extra,
7671            "Should have extra header from second interceptor"
7672        );
7673    }
7674
7675    #[test]
7676    fn response_interceptor_stack_empty_is_noop() {
7677        let stack = ResponseInterceptorStack::new();
7678        assert!(stack.is_empty());
7679        assert_eq!(stack.len(), 0);
7680
7681        let req = test_request();
7682        let resp = Response::with_status(StatusCode::OK)
7683            .body(crate::response::ResponseBody::Bytes(b"unchanged".to_vec()));
7684
7685        let ctx = test_context();
7686        let start_time = Instant::now();
7687        let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7688        let result = futures_executor::block_on(stack.process(&interceptor_ctx, resp));
7689
7690        match result.body_ref() {
7691            crate::response::ResponseBody::Bytes(b) => {
7692                assert_eq!(b, b"unchanged");
7693            }
7694            _ => panic!("Expected bytes body"),
7695        }
7696    }
7697
7698    #[test]
7699    fn interceptor_context_provides_timing() {
7700        let ctx = test_context();
7701        let req = test_request();
7702        let start_time = Instant::now();
7703        std::thread::sleep(std::time::Duration::from_millis(5));
7704
7705        let interceptor_ctx = ResponseInterceptorContext::new(&req, &ctx, start_time);
7706
7707        assert!(
7708            interceptor_ctx.elapsed_ms() >= 5,
7709            "Elapsed time should be at least 5ms"
7710        );
7711        assert!(interceptor_ctx.elapsed().as_millis() >= 5);
7712    }
7713
7714    #[test]
7715    fn conditional_interceptor_applies_conditionally() {
7716        // Only add header if response is 200 OK
7717        let inner = HeaderTransformInterceptor::new().add("X-Success", b"true".to_vec());
7718        let conditional =
7719            ConditionalInterceptor::new(inner, |_ctx, resp| resp.status().as_u16() == 200);
7720
7721        let req = test_request();
7722
7723        // 200 response should get the header
7724        let ok_resp = Response::with_status(StatusCode::OK);
7725        let result = run_interceptor(&conditional, &req, ok_resp);
7726        let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7727        assert!(has_success, "200 response should get X-Success header");
7728
7729        // 404 response should NOT get the header
7730        let not_found = Response::with_status(StatusCode::NOT_FOUND);
7731        let result = run_interceptor(&conditional, &req, not_found);
7732        let has_success = result.headers().iter().any(|(name, _)| name == "X-Success");
7733        assert!(!has_success, "404 response should NOT get X-Success header");
7734    }
7735}
7736
7737#[cfg(test)]
7738mod cache_control_tests {
7739    use super::*;
7740    use crate::request::Method;
7741    use crate::response::StatusCode;
7742
7743    fn test_context() -> RequestContext {
7744        RequestContext::new(asupersync::Cx::for_testing(), 1)
7745    }
7746
7747    fn run_after(mw: &CacheControlMiddleware, req: &Request, resp: Response) -> Response {
7748        let ctx = test_context();
7749        let fut = mw.after(&ctx, req, resp);
7750        futures_executor::block_on(fut)
7751    }
7752
7753    #[test]
7754    fn cache_directive_as_str_works() {
7755        assert_eq!(CacheDirective::Public.as_str(), "public");
7756        assert_eq!(CacheDirective::Private.as_str(), "private");
7757        assert_eq!(CacheDirective::NoStore.as_str(), "no-store");
7758        assert_eq!(CacheDirective::NoCache.as_str(), "no-cache");
7759        assert_eq!(CacheDirective::MustRevalidate.as_str(), "must-revalidate");
7760        assert_eq!(CacheDirective::Immutable.as_str(), "immutable");
7761    }
7762
7763    #[test]
7764    fn cache_control_builder_basic() {
7765        let cc = CacheControlBuilder::new()
7766            .public()
7767            .max_age_secs(3600)
7768            .build();
7769        assert!(cc.contains("public"));
7770        assert!(cc.contains("max-age=3600"));
7771    }
7772
7773    #[test]
7774    fn cache_control_builder_complex() {
7775        let cc = CacheControlBuilder::new()
7776            .public()
7777            .max_age_secs(60)
7778            .s_maxage_secs(3600)
7779            .stale_while_revalidate_secs(86400)
7780            .build();
7781        assert!(cc.contains("public"));
7782        assert!(cc.contains("max-age=60"));
7783        assert!(cc.contains("s-maxage=3600"));
7784        assert!(cc.contains("stale-while-revalidate=86400"));
7785    }
7786
7787    #[test]
7788    fn cache_control_builder_no_cache() {
7789        let cc = CacheControlBuilder::new()
7790            .no_store()
7791            .no_cache()
7792            .must_revalidate()
7793            .build();
7794        assert!(cc.contains("no-store"));
7795        assert!(cc.contains("no-cache"));
7796        assert!(cc.contains("must-revalidate"));
7797    }
7798
7799    #[test]
7800    fn cache_preset_no_cache() {
7801        let value = CachePreset::NoCache.to_header_value();
7802        assert!(value.contains("no-store"));
7803        assert!(value.contains("no-cache"));
7804        assert!(value.contains("must-revalidate"));
7805    }
7806
7807    #[test]
7808    fn cache_preset_immutable() {
7809        let value = CachePreset::Immutable.to_header_value();
7810        assert!(value.contains("public"));
7811        assert!(value.contains("max-age=31536000"));
7812        assert!(value.contains("immutable"));
7813    }
7814
7815    #[test]
7816    fn cache_preset_static_assets() {
7817        let value = CachePreset::StaticAssets.to_header_value();
7818        assert!(value.contains("public"));
7819        assert!(value.contains("max-age=86400"));
7820    }
7821
7822    #[test]
7823    fn middleware_adds_cache_control_header() {
7824        let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7825        let req = Request::new(Method::Get, "/api/test");
7826        let resp = Response::with_status(StatusCode::OK);
7827
7828        let result = run_after(&mw, &req, resp);
7829        let headers = result.headers();
7830        let cc_header = headers
7831            .iter()
7832            .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7833        assert!(
7834            cc_header.is_some(),
7835            "Cache-Control header should be present"
7836        );
7837        let (_, value) = cc_header.unwrap();
7838        let value_str = String::from_utf8_lossy(value);
7839        assert!(value_str.contains("public"));
7840        assert!(value_str.contains("max-age=3600"));
7841    }
7842
7843    #[test]
7844    fn middleware_skips_post_requests() {
7845        let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7846        let req = Request::new(Method::Post, "/api/test");
7847        let resp = Response::with_status(StatusCode::OK);
7848
7849        let result = run_after(&mw, &req, resp);
7850        let headers = result.headers();
7851        let cc_header = headers
7852            .iter()
7853            .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7854        assert!(
7855            cc_header.is_none(),
7856            "Cache-Control should not be added for POST"
7857        );
7858    }
7859
7860    #[test]
7861    fn middleware_skips_error_responses() {
7862        let mw = CacheControlMiddleware::with_preset(CachePreset::PublicOneHour);
7863        let req = Request::new(Method::Get, "/api/test");
7864        let resp = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
7865
7866        let result = run_after(&mw, &req, resp);
7867        let headers = result.headers();
7868        let cc_header = headers
7869            .iter()
7870            .find(|(name, _)| name.eq_ignore_ascii_case("cache-control"));
7871        assert!(
7872            cc_header.is_none(),
7873            "Cache-Control should not be added for error responses"
7874        );
7875    }
7876
7877    #[test]
7878    fn middleware_with_vary_header() {
7879        let mw = CacheControlMiddleware::with_config(
7880            CacheControlConfig::from_preset(CachePreset::PublicOneHour)
7881                .vary("Accept-Encoding")
7882                .vary("Accept-Language"),
7883        );
7884        let req = Request::new(Method::Get, "/api/test");
7885        let resp = Response::with_status(StatusCode::OK);
7886
7887        let result = run_after(&mw, &req, resp);
7888        let headers = result.headers();
7889        let vary_header = headers
7890            .iter()
7891            .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
7892        assert!(vary_header.is_some(), "Vary header should be present");
7893        let (_, value) = vary_header.unwrap();
7894        let value_str = String::from_utf8_lossy(value);
7895        assert!(value_str.contains("Accept-Encoding"));
7896        assert!(value_str.contains("Accept-Language"));
7897    }
7898
7899    #[test]
7900    fn middleware_preserves_existing_cache_control() {
7901        let mw = CacheControlMiddleware::with_config(
7902            CacheControlConfig::from_preset(CachePreset::PublicOneHour).preserve_existing(true),
7903        );
7904        let req = Request::new(Method::Get, "/api/test");
7905        let resp =
7906            Response::with_status(StatusCode::OK).header("Cache-Control", b"max-age=60".to_vec());
7907
7908        let result = run_after(&mw, &req, resp);
7909        let headers = result.headers();
7910        let cc_headers: Vec<_> = headers
7911            .iter()
7912            .filter(|(name, _)| name.eq_ignore_ascii_case("cache-control"))
7913            .collect();
7914        // Should only have the original header, not add a new one
7915        assert_eq!(cc_headers.len(), 1);
7916        let (_, value) = cc_headers[0];
7917        let value_str = String::from_utf8_lossy(value);
7918        assert_eq!(value_str, "max-age=60");
7919    }
7920
7921    #[test]
7922    fn path_pattern_matching_exact() {
7923        assert!(path_matches_pattern("/api/users", "/api/users"));
7924        assert!(!path_matches_pattern("/api/users", "/api/items"));
7925    }
7926
7927    #[test]
7928    fn path_pattern_matching_wildcard() {
7929        assert!(path_matches_pattern("/api/users/123", "/api/users/*"));
7930        assert!(path_matches_pattern("/static/css/style.css", "/static/*"));
7931        assert!(path_matches_pattern("/anything", "*"));
7932    }
7933
7934    #[test]
7935    fn date_formatting_works() {
7936        // Test that format_http_date doesn't panic and produces valid format
7937        let now = std::time::SystemTime::now();
7938        let formatted = format_http_date(now);
7939        // Should contain GMT
7940        assert!(formatted.ends_with(" GMT"));
7941        // Should have day name
7942        let days = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"];
7943        assert!(days.iter().any(|d| formatted.starts_with(d)));
7944    }
7945
7946    #[test]
7947    fn leap_year_detection() {
7948        assert!(!is_leap_year(1900)); // Divisible by 100 but not 400
7949        assert!(is_leap_year(2000)); // Divisible by 400
7950        assert!(is_leap_year(2024)); // Divisible by 4 but not 100
7951        assert!(!is_leap_year(2023)); // Not divisible by 4
7952    }
7953}
7954
7955// ===========================================================================
7956// TRACE Rejection Middleware Tests
7957// ===========================================================================
7958
7959#[cfg(test)]
7960mod trace_rejection_tests {
7961    use super::*;
7962    use crate::request::Method;
7963    use crate::response::StatusCode;
7964
7965    fn test_context() -> RequestContext {
7966        RequestContext::new(asupersync::Cx::for_testing(), 1)
7967    }
7968
7969    fn run_before(mw: &TraceRejectionMiddleware, req: &mut Request) -> ControlFlow {
7970        let ctx = test_context();
7971        let fut = mw.before(&ctx, req);
7972        futures_executor::block_on(fut)
7973    }
7974
7975    fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
7976        headers
7977            .iter()
7978            .find(|(n, _)| n.eq_ignore_ascii_case(name))
7979            .map(|(_, v)| v.as_slice())
7980    }
7981
7982    #[test]
7983    fn trace_request_rejected() {
7984        let mw = TraceRejectionMiddleware::new();
7985        let mut req = Request::new(Method::Trace, "/");
7986
7987        let result = run_before(&mw, &mut req);
7988
7989        match result {
7990            ControlFlow::Break(response) => {
7991                assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
7992            }
7993            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
7994        }
7995    }
7996
7997    #[test]
7998    fn trace_request_with_path() {
7999        let mw = TraceRejectionMiddleware::new();
8000        let mut req = Request::new(Method::Trace, "/api/users/123");
8001
8002        let result = run_before(&mw, &mut req);
8003
8004        match result {
8005            ControlFlow::Break(response) => {
8006                assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
8007            }
8008            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8009        }
8010    }
8011
8012    #[test]
8013    fn get_request_allowed() {
8014        let mw = TraceRejectionMiddleware::new();
8015        let mut req = Request::new(Method::Get, "/");
8016
8017        let result = run_before(&mw, &mut req);
8018
8019        match result {
8020            ControlFlow::Continue => {} // Expected
8021            ControlFlow::Break(_) => panic!("GET request should be allowed"),
8022        }
8023    }
8024
8025    #[test]
8026    fn post_request_allowed() {
8027        let mw = TraceRejectionMiddleware::new();
8028        let mut req = Request::new(Method::Post, "/api/users");
8029
8030        let result = run_before(&mw, &mut req);
8031
8032        match result {
8033            ControlFlow::Continue => {} // Expected
8034            ControlFlow::Break(_) => panic!("POST request should be allowed"),
8035        }
8036    }
8037
8038    #[test]
8039    fn put_request_allowed() {
8040        let mw = TraceRejectionMiddleware::new();
8041        let mut req = Request::new(Method::Put, "/api/users/1");
8042
8043        let result = run_before(&mw, &mut req);
8044
8045        match result {
8046            ControlFlow::Continue => {} // Expected
8047            ControlFlow::Break(_) => panic!("PUT request should be allowed"),
8048        }
8049    }
8050
8051    #[test]
8052    fn delete_request_allowed() {
8053        let mw = TraceRejectionMiddleware::new();
8054        let mut req = Request::new(Method::Delete, "/api/users/1");
8055
8056        let result = run_before(&mw, &mut req);
8057
8058        match result {
8059            ControlFlow::Continue => {} // Expected
8060            ControlFlow::Break(_) => panic!("DELETE request should be allowed"),
8061        }
8062    }
8063
8064    #[test]
8065    fn patch_request_allowed() {
8066        let mw = TraceRejectionMiddleware::new();
8067        let mut req = Request::new(Method::Patch, "/api/users/1");
8068
8069        let result = run_before(&mw, &mut req);
8070
8071        match result {
8072            ControlFlow::Continue => {} // Expected
8073            ControlFlow::Break(_) => panic!("PATCH request should be allowed"),
8074        }
8075    }
8076
8077    #[test]
8078    fn options_request_allowed() {
8079        let mw = TraceRejectionMiddleware::new();
8080        let mut req = Request::new(Method::Options, "/api/users");
8081
8082        let result = run_before(&mw, &mut req);
8083
8084        match result {
8085            ControlFlow::Continue => {} // Expected
8086            ControlFlow::Break(_) => panic!("OPTIONS request should be allowed"),
8087        }
8088    }
8089
8090    #[test]
8091    fn head_request_allowed() {
8092        let mw = TraceRejectionMiddleware::new();
8093        let mut req = Request::new(Method::Head, "/");
8094
8095        let result = run_before(&mw, &mut req);
8096
8097        match result {
8098            ControlFlow::Continue => {} // Expected
8099            ControlFlow::Break(_) => panic!("HEAD request should be allowed"),
8100        }
8101    }
8102
8103    #[test]
8104    fn response_includes_allow_header() {
8105        let mw = TraceRejectionMiddleware::new();
8106        let mut req = Request::new(Method::Trace, "/");
8107
8108        let result = run_before(&mw, &mut req);
8109
8110        match result {
8111            ControlFlow::Break(response) => {
8112                let allow_header = find_header(response.headers(), "Allow");
8113                assert!(
8114                    allow_header.is_some(),
8115                    "Response should include Allow header"
8116                );
8117            }
8118            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8119        }
8120    }
8121
8122    #[test]
8123    fn response_has_json_content_type() {
8124        let mw = TraceRejectionMiddleware::new();
8125        let mut req = Request::new(Method::Trace, "/");
8126
8127        let result = run_before(&mw, &mut req);
8128
8129        match result {
8130            ControlFlow::Break(response) => {
8131                let ct_header = find_header(response.headers(), "Content-Type");
8132                assert_eq!(ct_header, Some(b"application/json".as_slice()));
8133            }
8134            ControlFlow::Continue => panic!("TRACE request should have been rejected"),
8135        }
8136    }
8137
8138    #[test]
8139    fn default_enables_logging() {
8140        let mw = TraceRejectionMiddleware::new();
8141        assert!(mw.log_attempts);
8142    }
8143
8144    #[test]
8145    fn log_attempts_can_be_disabled() {
8146        let mw = TraceRejectionMiddleware::new().log_attempts(false);
8147        assert!(!mw.log_attempts);
8148    }
8149
8150    #[test]
8151    fn middleware_name() {
8152        let mw = TraceRejectionMiddleware::new();
8153        assert_eq!(mw.name(), "TraceRejection");
8154    }
8155
8156    #[test]
8157    fn default_impl() {
8158        let mw = TraceRejectionMiddleware::default();
8159        assert!(mw.log_attempts);
8160    }
8161}
8162
8163// ===========================================================================
8164// End TRACE Rejection Middleware Tests
8165// ===========================================================================
8166
8167// ===========================================================================
8168// HTTPS Redirect Middleware Tests
8169// ===========================================================================
8170
8171#[cfg(test)]
8172mod https_redirect_tests {
8173    use super::*;
8174    use crate::request::Method;
8175    use crate::response::StatusCode;
8176
8177    fn test_context() -> RequestContext {
8178        RequestContext::new(asupersync::Cx::for_testing(), 1)
8179    }
8180
8181    fn run_before(mw: &HttpsRedirectMiddleware, req: &mut Request) -> ControlFlow {
8182        let ctx = test_context();
8183        let fut = mw.before(&ctx, req);
8184        futures_executor::block_on(fut)
8185    }
8186
8187    fn run_after(mw: &HttpsRedirectMiddleware, req: &Request, resp: Response) -> Response {
8188        let ctx = test_context();
8189        let fut = mw.after(&ctx, req, resp);
8190        futures_executor::block_on(fut)
8191    }
8192
8193    fn find_header<'a>(headers: &'a [(String, Vec<u8>)], name: &str) -> Option<&'a [u8]> {
8194        headers
8195            .iter()
8196            .find(|(n, _)| n.eq_ignore_ascii_case(name))
8197            .map(|(_, v)| v.as_slice())
8198    }
8199
8200    #[test]
8201    fn http_request_redirected() {
8202        let mw = HttpsRedirectMiddleware::new();
8203        let mut req = Request::new(Method::Get, "/");
8204        req.headers_mut().insert("Host", b"example.com".to_vec());
8205
8206        let result = run_before(&mw, &mut req);
8207
8208        match result {
8209            ControlFlow::Break(response) => {
8210                assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY);
8211                let location = find_header(response.headers(), "Location");
8212                assert_eq!(location, Some(b"https://example.com/".as_slice()));
8213            }
8214            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8215        }
8216    }
8217
8218    #[test]
8219    fn http_request_with_path_and_query() {
8220        let mw = HttpsRedirectMiddleware::new();
8221        let mut req = Request::new(Method::Get, "/api/users?page=1");
8222        req.headers_mut().insert("Host", b"example.com".to_vec());
8223
8224        let result = run_before(&mw, &mut req);
8225
8226        match result {
8227            ControlFlow::Break(response) => {
8228                let location = find_header(response.headers(), "Location");
8229                assert_eq!(
8230                    location,
8231                    Some(b"https://example.com/api/users?page=1".as_slice())
8232                );
8233            }
8234            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8235        }
8236    }
8237
8238    #[test]
8239    fn https_request_not_redirected() {
8240        let mw = HttpsRedirectMiddleware::new();
8241        let mut req = Request::new(Method::Get, "/");
8242        req.headers_mut().insert("Host", b"example.com".to_vec());
8243        req.headers_mut()
8244            .insert("X-Forwarded-Proto", b"https".to_vec());
8245
8246        let result = run_before(&mw, &mut req);
8247
8248        match result {
8249            ControlFlow::Continue => {} // Expected
8250            ControlFlow::Break(_) => panic!("HTTPS request should not be redirected"),
8251        }
8252    }
8253
8254    #[test]
8255    fn x_forwarded_ssl_recognized() {
8256        let mw = HttpsRedirectMiddleware::new();
8257        let mut req = Request::new(Method::Get, "/");
8258        req.headers_mut().insert("Host", b"example.com".to_vec());
8259        req.headers_mut().insert("X-Forwarded-Ssl", b"on".to_vec());
8260
8261        let result = run_before(&mw, &mut req);
8262
8263        match result {
8264            ControlFlow::Continue => {} // Expected
8265            ControlFlow::Break(_) => panic!("Request with X-Forwarded-Ssl=on should not redirect"),
8266        }
8267    }
8268
8269    #[test]
8270    fn excluded_path_not_redirected() {
8271        let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8272        let mut req = Request::new(Method::Get, "/health");
8273        req.headers_mut().insert("Host", b"example.com".to_vec());
8274
8275        let result = run_before(&mw, &mut req);
8276
8277        match result {
8278            ControlFlow::Continue => {} // Expected
8279            ControlFlow::Break(_) => panic!("Excluded path should not be redirected"),
8280        }
8281    }
8282
8283    #[test]
8284    fn excluded_path_prefix_matches() {
8285        let mw = HttpsRedirectMiddleware::new().exclude_path("/health");
8286        let mut req = Request::new(Method::Get, "/health/live");
8287        req.headers_mut().insert("Host", b"example.com".to_vec());
8288
8289        let result = run_before(&mw, &mut req);
8290
8291        match result {
8292            ControlFlow::Continue => {} // Expected
8293            ControlFlow::Break(_) => panic!("Path with excluded prefix should not be redirected"),
8294        }
8295    }
8296
8297    #[test]
8298    fn temporary_redirect_option() {
8299        let mw = HttpsRedirectMiddleware::new().permanent_redirect(false);
8300        let mut req = Request::new(Method::Get, "/");
8301        req.headers_mut().insert("Host", b"example.com".to_vec());
8302
8303        let result = run_before(&mw, &mut req);
8304
8305        match result {
8306            ControlFlow::Break(response) => {
8307                assert_eq!(response.status(), StatusCode::TEMPORARY_REDIRECT);
8308            }
8309            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8310        }
8311    }
8312
8313    #[test]
8314    fn redirect_disabled() {
8315        let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8316        let mut req = Request::new(Method::Get, "/");
8317        req.headers_mut().insert("Host", b"example.com".to_vec());
8318
8319        let result = run_before(&mw, &mut req);
8320
8321        match result {
8322            ControlFlow::Continue => {} // Expected
8323            ControlFlow::Break(_) => panic!("Redirects are disabled, should continue"),
8324        }
8325    }
8326
8327    #[test]
8328    fn hsts_header_on_https_response() {
8329        let mw = HttpsRedirectMiddleware::new();
8330        let mut req = Request::new(Method::Get, "/");
8331        req.headers_mut()
8332            .insert("X-Forwarded-Proto", b"https".to_vec());
8333
8334        let response = Response::with_status(StatusCode::OK);
8335        let result = run_after(&mw, &req, response);
8336
8337        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8338        assert!(
8339            hsts.is_some(),
8340            "HSTS header should be present on HTTPS response"
8341        );
8342        let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8343        assert!(hsts_str.contains("max-age=31536000"));
8344    }
8345
8346    #[test]
8347    fn hsts_header_not_on_http_response() {
8348        let mw = HttpsRedirectMiddleware::new().redirect_enabled(false);
8349        let req = Request::new(Method::Get, "/");
8350        // No X-Forwarded-Proto, so this is HTTP
8351
8352        let response = Response::with_status(StatusCode::OK);
8353        let result = run_after(&mw, &req, response);
8354
8355        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8356        assert!(hsts.is_none(), "HSTS header should not be on HTTP response");
8357    }
8358
8359    #[test]
8360    fn hsts_with_include_subdomains() {
8361        let mw = HttpsRedirectMiddleware::new().include_subdomains(true);
8362        let mut req = Request::new(Method::Get, "/");
8363        req.headers_mut()
8364            .insert("X-Forwarded-Proto", b"https".to_vec());
8365
8366        let response = Response::with_status(StatusCode::OK);
8367        let result = run_after(&mw, &req, response);
8368
8369        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8370        let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8371        assert!(hsts_str.contains("includeSubDomains"));
8372    }
8373
8374    #[test]
8375    fn hsts_with_preload() {
8376        let mw = HttpsRedirectMiddleware::new().preload(true);
8377        let mut req = Request::new(Method::Get, "/");
8378        req.headers_mut()
8379            .insert("X-Forwarded-Proto", b"https".to_vec());
8380
8381        let response = Response::with_status(StatusCode::OK);
8382        let result = run_after(&mw, &req, response);
8383
8384        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8385        let hsts_str = String::from_utf8_lossy(hsts.unwrap());
8386        assert!(hsts_str.contains("preload"));
8387    }
8388
8389    #[test]
8390    fn hsts_disabled_with_zero_max_age() {
8391        let mw = HttpsRedirectMiddleware::new().hsts_max_age_secs(0);
8392        let mut req = Request::new(Method::Get, "/");
8393        req.headers_mut()
8394            .insert("X-Forwarded-Proto", b"https".to_vec());
8395
8396        let response = Response::with_status(StatusCode::OK);
8397        let result = run_after(&mw, &req, response);
8398
8399        let hsts = find_header(result.headers(), "Strict-Transport-Security");
8400        assert!(hsts.is_none(), "HSTS should be disabled with max-age=0");
8401    }
8402
8403    #[test]
8404    fn custom_https_port() {
8405        let mw = HttpsRedirectMiddleware::new().https_port(8443);
8406        let mut req = Request::new(Method::Get, "/");
8407        req.headers_mut().insert("Host", b"example.com".to_vec());
8408
8409        let result = run_before(&mw, &mut req);
8410
8411        match result {
8412            ControlFlow::Break(response) => {
8413                let location = find_header(response.headers(), "Location");
8414                assert_eq!(location, Some(b"https://example.com:8443/".as_slice()));
8415            }
8416            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8417        }
8418    }
8419
8420    #[test]
8421    fn host_with_port_stripped() {
8422        let mw = HttpsRedirectMiddleware::new();
8423        let mut req = Request::new(Method::Get, "/");
8424        req.headers_mut()
8425            .insert("Host", b"example.com:8080".to_vec());
8426
8427        let result = run_before(&mw, &mut req);
8428
8429        match result {
8430            ControlFlow::Break(response) => {
8431                let location = find_header(response.headers(), "Location");
8432                // Port should be stripped from host, using default 443
8433                assert_eq!(location, Some(b"https://example.com/".as_slice()));
8434            }
8435            ControlFlow::Continue => panic!("HTTP request should be redirected"),
8436        }
8437    }
8438
8439    #[test]
8440    fn middleware_name() {
8441        let mw = HttpsRedirectMiddleware::new();
8442        assert_eq!(mw.name(), "HttpsRedirect");
8443    }
8444
8445    #[test]
8446    fn default_impl() {
8447        let mw = HttpsRedirectMiddleware::default();
8448        assert!(mw.config.redirect_enabled);
8449        assert!(mw.config.permanent_redirect);
8450        assert_eq!(mw.config.hsts_max_age_secs, 31_536_000);
8451    }
8452
8453    #[test]
8454    fn config_builder() {
8455        let mw = HttpsRedirectMiddleware::new()
8456            .redirect_enabled(false)
8457            .permanent_redirect(false)
8458            .hsts_max_age_secs(86400)
8459            .include_subdomains(true)
8460            .preload(true)
8461            .https_port(8443);
8462
8463        assert!(!mw.config.redirect_enabled);
8464        assert!(!mw.config.permanent_redirect);
8465        assert_eq!(mw.config.hsts_max_age_secs, 86400);
8466        assert!(mw.config.hsts_include_subdomains);
8467        assert!(mw.config.hsts_preload);
8468        assert_eq!(mw.config.https_port, 8443);
8469    }
8470
8471    #[test]
8472    fn exclude_paths_method() {
8473        let mw = HttpsRedirectMiddleware::new()
8474            .exclude_paths(vec!["/health".to_string(), "/ready".to_string()]);
8475
8476        assert_eq!(mw.config.exclude_paths.len(), 2);
8477        assert!(mw.config.exclude_paths.contains(&"/health".to_string()));
8478        assert!(mw.config.exclude_paths.contains(&"/ready".to_string()));
8479    }
8480}
8481
8482// ===========================================================================
8483// End HTTPS Redirect Middleware Tests
8484// ===========================================================================
8485
8486// ===========================================================================
8487// End ETag Middleware
8488// ===========================================================================
8489
8490#[cfg(test)]
8491mod tests {
8492    use super::*;
8493    use crate::response::{ResponseBody, StatusCode};
8494
8495    // Test middleware that adds a header
8496    #[allow(dead_code)]
8497    struct AddHeaderMiddleware {
8498        name: &'static str,
8499        value: &'static [u8],
8500    }
8501
8502    impl Middleware for AddHeaderMiddleware {
8503        fn after<'a>(
8504            &'a self,
8505            _ctx: &'a RequestContext,
8506            _req: &'a Request,
8507            response: Response,
8508        ) -> BoxFuture<'a, Response> {
8509            Box::pin(async move { response.header(self.name, self.value.to_vec()) })
8510        }
8511    }
8512
8513    // Test middleware that short-circuits
8514    #[allow(dead_code)]
8515    struct BlockingMiddleware;
8516
8517    impl Middleware for BlockingMiddleware {
8518        fn before<'a>(
8519            &'a self,
8520            _ctx: &'a RequestContext,
8521            _req: &'a mut Request,
8522        ) -> BoxFuture<'a, ControlFlow> {
8523            Box::pin(async {
8524                ControlFlow::Break(
8525                    Response::with_status(StatusCode::FORBIDDEN)
8526                        .body(ResponseBody::Bytes(b"blocked".to_vec())),
8527                )
8528            })
8529        }
8530    }
8531
8532    // Test middleware that tracks calls
8533    #[allow(dead_code)]
8534    struct TrackingMiddleware {
8535        before_count: std::sync::atomic::AtomicUsize,
8536        after_count: std::sync::atomic::AtomicUsize,
8537    }
8538
8539    #[allow(dead_code)]
8540    impl TrackingMiddleware {
8541        fn new() -> Self {
8542            Self {
8543                before_count: std::sync::atomic::AtomicUsize::new(0),
8544                after_count: std::sync::atomic::AtomicUsize::new(0),
8545            }
8546        }
8547
8548        fn before_count(&self) -> usize {
8549            self.before_count.load(std::sync::atomic::Ordering::SeqCst)
8550        }
8551
8552        fn after_count(&self) -> usize {
8553            self.after_count.load(std::sync::atomic::Ordering::SeqCst)
8554        }
8555    }
8556
8557    impl Middleware for TrackingMiddleware {
8558        fn before<'a>(
8559            &'a self,
8560            _ctx: &'a RequestContext,
8561            _req: &'a mut Request,
8562        ) -> BoxFuture<'a, ControlFlow> {
8563            self.before_count
8564                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8565            Box::pin(async { ControlFlow::Continue })
8566        }
8567
8568        fn after<'a>(
8569            &'a self,
8570            _ctx: &'a RequestContext,
8571            _req: &'a Request,
8572            response: Response,
8573        ) -> BoxFuture<'a, Response> {
8574            self.after_count
8575                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
8576            Box::pin(async move { response })
8577        }
8578    }
8579
8580    #[test]
8581    fn control_flow_variants() {
8582        let cont = ControlFlow::Continue;
8583        assert!(cont.is_continue());
8584        assert!(!cont.is_break());
8585
8586        let brk = ControlFlow::Break(Response::ok());
8587        assert!(!brk.is_continue());
8588        assert!(brk.is_break());
8589    }
8590
8591    #[test]
8592    fn middleware_stack_empty() {
8593        let stack = MiddlewareStack::new();
8594        assert!(stack.is_empty());
8595        assert_eq!(stack.len(), 0);
8596    }
8597
8598    #[test]
8599    fn middleware_stack_push() {
8600        let mut stack = MiddlewareStack::new();
8601        stack.push(NoopMiddleware);
8602        stack.push(NoopMiddleware);
8603        assert_eq!(stack.len(), 2);
8604        assert!(!stack.is_empty());
8605    }
8606
8607    #[test]
8608    fn noop_middleware_name() {
8609        let mw = NoopMiddleware;
8610        assert_eq!(mw.name(), "Noop");
8611    }
8612
8613    #[test]
8614    fn logging_redacts_sensitive_headers() {
8615        let mut headers = crate::request::Headers::new();
8616        headers.insert("Authorization", b"secret".to_vec());
8617        headers.insert("X-Request-Id", b"abc123".to_vec());
8618
8619        let redacted = super::default_redacted_headers();
8620        let formatted = super::format_headers(headers.iter(), &redacted);
8621
8622        assert!(formatted.contains("authorization=<redacted>"));
8623        assert!(formatted.contains("x-request-id=abc123"));
8624    }
8625
8626    #[test]
8627    fn logging_body_truncation() {
8628        let body = b"abcdef";
8629        let preview = super::format_bytes(body, 4);
8630        assert_eq!(preview, "abcd...");
8631
8632        let preview_full = super::format_bytes(body, 10);
8633        assert_eq!(preview_full, "abcdef");
8634    }
8635
8636    fn test_context() -> RequestContext {
8637        let cx = asupersync::Cx::for_testing();
8638        RequestContext::new(cx, 1)
8639    }
8640
8641    fn header_value(response: &Response, name: &str) -> Option<String> {
8642        response
8643            .headers()
8644            .iter()
8645            .find(|(n, _)| n.eq_ignore_ascii_case(name))
8646            .and_then(|(_, v)| std::str::from_utf8(v).ok())
8647            .map(ToString::to_string)
8648    }
8649
8650    #[test]
8651    fn cors_exact_origin_allows() {
8652        let cors = Cors::new().allow_origin("https://example.com");
8653        let ctx = test_context();
8654        let mut req = Request::new(crate::request::Method::Get, "/");
8655        req.headers_mut()
8656            .insert("origin", b"https://example.com".to_vec());
8657
8658        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8659        assert!(matches!(result, ControlFlow::Continue));
8660
8661        let response = Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()));
8662        let response = futures_executor::block_on(cors.after(&ctx, &req, response));
8663
8664        assert_eq!(
8665            header_value(&response, "access-control-allow-origin"),
8666            Some("https://example.com".to_string())
8667        );
8668        assert_eq!(header_value(&response, "vary"), Some("Origin".to_string()));
8669    }
8670
8671    #[test]
8672    fn cors_wildcard_origin_allows() {
8673        let cors = Cors::new().allow_origin_wildcard("https://*.example.com");
8674        let ctx = test_context();
8675        let mut req = Request::new(crate::request::Method::Get, "/");
8676        req.headers_mut()
8677            .insert("origin", b"https://api.example.com".to_vec());
8678
8679        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8680        assert!(matches!(result, ControlFlow::Continue));
8681    }
8682
8683    #[test]
8684    fn cors_regex_origin_allows() {
8685        let cors = Cors::new().allow_origin_regex(r"^https://.*\.example\.com$");
8686        let ctx = test_context();
8687        let mut req = Request::new(crate::request::Method::Get, "/");
8688        req.headers_mut()
8689            .insert("origin", b"https://svc.example.com".to_vec());
8690
8691        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8692        assert!(matches!(result, ControlFlow::Continue));
8693    }
8694
8695    #[test]
8696    fn cors_preflight_handled() {
8697        let cors = Cors::new()
8698            .allow_any_origin()
8699            .allow_headers(["x-test", "content-type"])
8700            .max_age(600);
8701        let ctx = test_context();
8702        let mut req = Request::new(crate::request::Method::Options, "/");
8703        req.headers_mut()
8704            .insert("origin", b"https://example.com".to_vec());
8705        req.headers_mut()
8706            .insert("access-control-request-method", b"POST".to_vec());
8707        req.headers_mut().insert(
8708            "access-control-request-headers",
8709            b"x-test, content-type".to_vec(),
8710        );
8711
8712        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8713        let ControlFlow::Break(response) = result else {
8714            panic!("expected preflight break");
8715        };
8716
8717        assert_eq!(response.status().as_u16(), 204);
8718        assert_eq!(
8719            header_value(&response, "access-control-allow-origin"),
8720            Some("*".to_string())
8721        );
8722        assert_eq!(
8723            header_value(&response, "access-control-allow-methods"),
8724            Some("GET, POST, PUT, PATCH, DELETE, OPTIONS, HEAD".to_string())
8725        );
8726        assert_eq!(
8727            header_value(&response, "access-control-allow-headers"),
8728            Some("x-test, content-type".to_string())
8729        );
8730        assert_eq!(
8731            header_value(&response, "access-control-max-age"),
8732            Some("600".to_string())
8733        );
8734    }
8735
8736    #[test]
8737    fn cors_credentials_echo_origin() {
8738        let cors = Cors::new().allow_any_origin().allow_credentials(true);
8739        let ctx = test_context();
8740        let mut req = Request::new(crate::request::Method::Get, "/");
8741        req.headers_mut()
8742            .insert("origin", b"https://example.com".to_vec());
8743
8744        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8745        assert!(matches!(result, ControlFlow::Continue));
8746
8747        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8748        assert_eq!(
8749            header_value(&response, "access-control-allow-origin"),
8750            Some("https://example.com".to_string())
8751        );
8752        assert_eq!(
8753            header_value(&response, "access-control-allow-credentials"),
8754            Some("true".to_string())
8755        );
8756    }
8757
8758    // CORS Spec Compliance Tests (bd-l1qe)
8759    // According to the Fetch Standard, when credentials mode is true,
8760    // the Access-Control-Allow-Origin header MUST NOT be "*".
8761
8762    #[test]
8763    fn cors_spec_compliance_credentials_never_wildcard_origin() {
8764        // When credentials are enabled, Access-Control-Allow-Origin
8765        // must echo the specific origin, never "*"
8766        let cors = Cors::new().allow_any_origin().allow_credentials(true);
8767        let ctx = test_context();
8768
8769        // Test with various origins
8770        for origin in &[
8771            "https://example.com",
8772            "https://api.example.com",
8773            "http://localhost:3000",
8774        ] {
8775            let mut req = Request::new(crate::request::Method::Get, "/");
8776            req.headers_mut()
8777                .insert("origin", origin.as_bytes().to_vec());
8778
8779            futures_executor::block_on(cors.before(&ctx, &mut req));
8780            let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8781
8782            let allow_origin = header_value(&response, "access-control-allow-origin");
8783            assert_eq!(
8784                allow_origin,
8785                Some((*origin).to_string()),
8786                "With credentials enabled, Access-Control-Allow-Origin must echo '{}', not '*'",
8787                origin
8788            );
8789            assert_ne!(
8790                allow_origin,
8791                Some("*".to_string()),
8792                "CORS spec violation: credentials + wildcard origin is forbidden"
8793            );
8794        }
8795    }
8796
8797    #[test]
8798    fn cors_spec_compliance_preflight_with_credentials() {
8799        // Preflight response with credentials should also echo origin, not "*"
8800        let cors = Cors::new()
8801            .allow_any_origin()
8802            .allow_credentials(true)
8803            .allow_headers(["content-type", "x-custom-header"]);
8804        let ctx = test_context();
8805
8806        let mut req = Request::new(crate::request::Method::Options, "/");
8807        req.headers_mut()
8808            .insert("origin", b"https://example.com".to_vec());
8809        req.headers_mut()
8810            .insert("access-control-request-method", b"POST".to_vec());
8811        req.headers_mut()
8812            .insert("access-control-request-headers", b"content-type".to_vec());
8813
8814        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8815        let ControlFlow::Break(response) = result else {
8816            panic!("expected preflight break");
8817        };
8818
8819        // Verify Access-Control-Allow-Origin is NOT "*" with credentials
8820        let allow_origin = header_value(&response, "access-control-allow-origin");
8821        assert_eq!(allow_origin, Some("https://example.com".to_string()));
8822        assert_ne!(
8823            allow_origin,
8824            Some("*".to_string()),
8825            "CORS spec violation: preflight with credentials must not use wildcard origin"
8826        );
8827
8828        // Verify credentials header is set
8829        assert_eq!(
8830            header_value(&response, "access-control-allow-credentials"),
8831            Some("true".to_string())
8832        );
8833    }
8834
8835    #[test]
8836    fn cors_spec_without_credentials_allows_wildcard() {
8837        // When credentials are NOT enabled, "*" is allowed for Access-Control-Allow-Origin
8838        let cors = Cors::new().allow_any_origin();
8839        let ctx = test_context();
8840        let mut req = Request::new(crate::request::Method::Get, "/");
8841        req.headers_mut()
8842            .insert("origin", b"https://example.com".to_vec());
8843
8844        futures_executor::block_on(cors.before(&ctx, &mut req));
8845        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8846
8847        // Without credentials, wildcard IS allowed
8848        assert_eq!(
8849            header_value(&response, "access-control-allow-origin"),
8850            Some("*".to_string())
8851        );
8852        // Should NOT have credentials header
8853        assert!(header_value(&response, "access-control-allow-credentials").is_none());
8854    }
8855
8856    #[test]
8857    fn cors_disallowed_preflight_forbidden() {
8858        let cors = Cors::new().allow_origin("https://good.example");
8859        let ctx = test_context();
8860        let mut req = Request::new(crate::request::Method::Options, "/");
8861        req.headers_mut()
8862            .insert("origin", b"https://evil.example".to_vec());
8863        req.headers_mut()
8864            .insert("access-control-request-method", b"GET".to_vec());
8865
8866        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8867        let ControlFlow::Break(response) = result else {
8868            panic!("expected forbidden preflight");
8869        };
8870        assert_eq!(response.status().as_u16(), 403);
8871    }
8872
8873    #[test]
8874    fn cors_simple_request_disallowed_origin_no_headers() {
8875        // Non-preflight request from disallowed origin should proceed but not get CORS headers
8876        let cors = Cors::new().allow_origin("https://good.example");
8877        let ctx = test_context();
8878        let mut req = Request::new(crate::request::Method::Get, "/");
8879        req.headers_mut()
8880            .insert("origin", b"https://evil.example".to_vec());
8881
8882        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8883        // Simple requests proceed (browser will block based on missing headers)
8884        assert!(matches!(result, ControlFlow::Continue));
8885
8886        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8887        // No CORS headers should be added for disallowed origin
8888        assert!(header_value(&response, "access-control-allow-origin").is_none());
8889    }
8890
8891    #[test]
8892    fn cors_expose_headers_configuration() {
8893        let cors = Cors::new()
8894            .allow_any_origin()
8895            .expose_headers(["x-custom-header", "x-another-header"]);
8896        let ctx = test_context();
8897        let mut req = Request::new(crate::request::Method::Get, "/");
8898        req.headers_mut()
8899            .insert("origin", b"https://example.com".to_vec());
8900
8901        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8902        assert!(matches!(result, ControlFlow::Continue));
8903
8904        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8905        assert_eq!(
8906            header_value(&response, "access-control-expose-headers"),
8907            Some("x-custom-header, x-another-header".to_string())
8908        );
8909    }
8910
8911    #[test]
8912    fn cors_any_origin_sets_wildcard() {
8913        let cors = Cors::new().allow_any_origin();
8914        let ctx = test_context();
8915        let mut req = Request::new(crate::request::Method::Get, "/");
8916        req.headers_mut()
8917            .insert("origin", b"https://any-site.com".to_vec());
8918
8919        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8920        assert!(matches!(result, ControlFlow::Continue));
8921
8922        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8923        assert_eq!(
8924            header_value(&response, "access-control-allow-origin"),
8925            Some("*".to_string())
8926        );
8927    }
8928
8929    #[test]
8930    fn cors_config_allows_method_override() {
8931        // Test that allow_methods overrides defaults
8932        let cors = Cors::new()
8933            .allow_any_origin()
8934            .allow_methods([crate::request::Method::Get, crate::request::Method::Post]);
8935        let ctx = test_context();
8936        let mut req = Request::new(crate::request::Method::Options, "/");
8937        req.headers_mut()
8938            .insert("origin", b"https://example.com".to_vec());
8939        req.headers_mut()
8940            .insert("access-control-request-method", b"POST".to_vec());
8941
8942        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8943        let ControlFlow::Break(response) = result else {
8944            panic!("expected preflight break");
8945        };
8946        assert_eq!(
8947            header_value(&response, "access-control-allow-methods"),
8948            Some("GET, POST".to_string())
8949        );
8950    }
8951
8952    #[test]
8953    fn cors_no_origin_header_skips_cors() {
8954        // Request without Origin header should not get CORS headers
8955        let cors = Cors::new().allow_any_origin();
8956        let ctx = test_context();
8957        let mut req = Request::new(crate::request::Method::Get, "/");
8958
8959        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8960        assert!(matches!(result, ControlFlow::Continue));
8961
8962        let response = futures_executor::block_on(cors.after(&ctx, &req, Response::ok()));
8963        assert!(header_value(&response, "access-control-allow-origin").is_none());
8964    }
8965
8966    #[test]
8967    fn cors_middleware_name() {
8968        let cors = Cors::new();
8969        assert_eq!(cors.name(), "Cors");
8970    }
8971
8972    #[test]
8973    fn cors_empty_allowed_headers_does_not_reflect_request_headers() {
8974        // When allowed_headers is empty (default), the CORS middleware should
8975        // NOT reflect the client's Access-Control-Request-Headers back. That
8976        // would effectively allow arbitrary headers — a security risk.
8977        let cors = Cors::new().allow_any_origin(); // default: allowed_headers = []
8978        let ctx = test_context();
8979        let mut req = Request::new(crate::request::Method::Options, "/api");
8980        req.headers_mut()
8981            .insert("origin", b"https://example.com".to_vec());
8982        req.headers_mut()
8983            .insert("access-control-request-method", b"GET".to_vec());
8984        req.headers_mut().insert(
8985            "access-control-request-headers",
8986            b"x-evil-custom, authorization".to_vec(),
8987        );
8988
8989        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
8990        if let ControlFlow::Break(response) = result {
8991            // Preflight response should NOT have access-control-allow-headers
8992            // when no allowed_headers are configured.
8993            assert_eq!(
8994                header_value(&response, "access-control-allow-headers"),
8995                None,
8996                "Empty allowed_headers must not reflect request headers"
8997            );
8998        } else {
8999            panic!("Preflight should have been handled (Break)");
9000        }
9001    }
9002
9003    #[test]
9004    fn cors_explicit_allowed_headers_returned_in_preflight() {
9005        let cors = Cors::new()
9006            .allow_any_origin()
9007            .allow_headers(["x-token", "content-type"]);
9008        let ctx = test_context();
9009        let mut req = Request::new(crate::request::Method::Options, "/api");
9010        req.headers_mut()
9011            .insert("origin", b"https://example.com".to_vec());
9012        req.headers_mut()
9013            .insert("access-control-request-method", b"POST".to_vec());
9014
9015        let result = futures_executor::block_on(cors.before(&ctx, &mut req));
9016        if let ControlFlow::Break(response) = result {
9017            let headers_val = header_value(&response, "access-control-allow-headers");
9018            assert!(headers_val.is_some());
9019            let val = headers_val.unwrap();
9020            assert!(val.contains("x-token"));
9021            assert!(val.contains("content-type"));
9022        } else {
9023            panic!("Preflight should have been handled (Break)");
9024        }
9025    }
9026
9027    // =========================================================================
9028    // Request ID Middleware tests
9029    // =========================================================================
9030
9031    #[test]
9032    fn request_id_generates_unique_ids() {
9033        let id1 = RequestId::generate();
9034        let id2 = RequestId::generate();
9035        let id3 = RequestId::generate();
9036
9037        assert_ne!(id1, id2);
9038        assert_ne!(id2, id3);
9039        assert_ne!(id1, id3);
9040
9041        // IDs should be non-empty
9042        assert!(!id1.as_str().is_empty());
9043        assert!(!id2.as_str().is_empty());
9044        assert!(!id3.as_str().is_empty());
9045    }
9046
9047    #[test]
9048    fn request_id_display() {
9049        let id = RequestId::new("test-request-123");
9050        assert_eq!(format!("{}", id), "test-request-123");
9051    }
9052
9053    #[test]
9054    fn request_id_from_string() {
9055        let id: RequestId = "my-id".into();
9056        assert_eq!(id.as_str(), "my-id");
9057
9058        let id2: RequestId = String::from("my-id-2").into();
9059        assert_eq!(id2.as_str(), "my-id-2");
9060    }
9061
9062    #[test]
9063    fn request_id_config_defaults() {
9064        let config = RequestIdConfig::default();
9065        assert_eq!(config.header_name, "x-request-id");
9066        assert!(config.accept_from_client);
9067        assert!(config.add_to_response);
9068        assert_eq!(config.max_client_id_length, 128);
9069    }
9070
9071    #[test]
9072    fn request_id_config_builder() {
9073        let config = RequestIdConfig::new()
9074            .header_name("X-Trace-ID")
9075            .accept_from_client(false)
9076            .add_to_response(false)
9077            .max_client_id_length(64);
9078
9079        assert_eq!(config.header_name, "X-Trace-ID");
9080        assert!(!config.accept_from_client);
9081        assert!(!config.add_to_response);
9082        assert_eq!(config.max_client_id_length, 64);
9083    }
9084
9085    #[test]
9086    fn request_id_middleware_generates_id() {
9087        let middleware = RequestIdMiddleware::new();
9088        let ctx = test_context();
9089        let mut req = Request::new(crate::request::Method::Get, "/");
9090
9091        let result = futures_executor::block_on(middleware.before(&ctx, &mut req));
9092        assert!(matches!(result, ControlFlow::Continue));
9093
9094        let stored_id = req.get_extension::<RequestId>();
9095        assert!(stored_id.is_some());
9096        assert!(!stored_id.unwrap().as_str().is_empty());
9097    }
9098
9099    #[test]
9100    fn request_id_middleware_accepts_client_id() {
9101        let middleware = RequestIdMiddleware::new();
9102        let ctx = test_context();
9103        let mut req = Request::new(crate::request::Method::Get, "/");
9104        req.headers_mut()
9105            .insert("x-request-id", b"client-provided-id-123".to_vec());
9106
9107        futures_executor::block_on(middleware.before(&ctx, &mut req));
9108
9109        let stored_id = req.get_extension::<RequestId>().unwrap();
9110        assert_eq!(stored_id.as_str(), "client-provided-id-123");
9111    }
9112
9113    #[test]
9114    fn request_id_middleware_rejects_invalid_client_id() {
9115        let middleware = RequestIdMiddleware::new();
9116        let ctx = test_context();
9117
9118        // Test with invalid characters
9119        let mut req = Request::new(crate::request::Method::Get, "/");
9120        req.headers_mut()
9121            .insert("x-request-id", b"invalid<script>id".to_vec());
9122
9123        futures_executor::block_on(middleware.before(&ctx, &mut req));
9124
9125        let stored_id = req.get_extension::<RequestId>().unwrap();
9126        // Should have generated a new ID instead of using the invalid one
9127        assert_ne!(stored_id.as_str(), "invalid<script>id");
9128    }
9129
9130    #[test]
9131    fn request_id_middleware_rejects_too_long_client_id() {
9132        let config = RequestIdConfig::new().max_client_id_length(10);
9133        let middleware = RequestIdMiddleware::with_config(config);
9134        let ctx = test_context();
9135
9136        let mut req = Request::new(crate::request::Method::Get, "/");
9137        req.headers_mut()
9138            .insert("x-request-id", b"this-id-is-way-too-long".to_vec());
9139
9140        futures_executor::block_on(middleware.before(&ctx, &mut req));
9141
9142        let stored_id = req.get_extension::<RequestId>().unwrap();
9143        // Should have generated a new ID instead of using the too-long one
9144        assert_ne!(stored_id.as_str(), "this-id-is-way-too-long");
9145    }
9146
9147    #[test]
9148    fn request_id_middleware_adds_to_response() {
9149        let middleware = RequestIdMiddleware::new();
9150        let ctx = test_context();
9151        let mut req = Request::new(crate::request::Method::Get, "/");
9152
9153        futures_executor::block_on(middleware.before(&ctx, &mut req));
9154        let stored_id = req.get_extension::<RequestId>().unwrap().clone();
9155
9156        let response = Response::ok();
9157        let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9158
9159        let header = header_value(&response, "x-request-id");
9160        assert_eq!(header, Some(stored_id.0));
9161    }
9162
9163    #[test]
9164    fn request_id_middleware_respects_add_to_response_false() {
9165        let config = RequestIdConfig::new().add_to_response(false);
9166        let middleware = RequestIdMiddleware::with_config(config);
9167        let ctx = test_context();
9168        let mut req = Request::new(crate::request::Method::Get, "/");
9169
9170        futures_executor::block_on(middleware.before(&ctx, &mut req));
9171
9172        let response = Response::ok();
9173        let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9174
9175        let header = header_value(&response, "x-request-id");
9176        assert!(header.is_none());
9177    }
9178
9179    #[test]
9180    fn request_id_middleware_respects_accept_from_client_false() {
9181        let config = RequestIdConfig::new().accept_from_client(false);
9182        let middleware = RequestIdMiddleware::with_config(config);
9183        let ctx = test_context();
9184        let mut req = Request::new(crate::request::Method::Get, "/");
9185        req.headers_mut()
9186            .insert("x-request-id", b"client-id".to_vec());
9187
9188        futures_executor::block_on(middleware.before(&ctx, &mut req));
9189
9190        let stored_id = req.get_extension::<RequestId>().unwrap();
9191        // Should ignore client ID and generate new one
9192        assert_ne!(stored_id.as_str(), "client-id");
9193    }
9194
9195    #[test]
9196    fn request_id_middleware_custom_header_name() {
9197        let config = RequestIdConfig::new().header_name("X-Trace-ID");
9198        let middleware = RequestIdMiddleware::with_config(config);
9199        let ctx = test_context();
9200        let mut req = Request::new(crate::request::Method::Get, "/");
9201        req.headers_mut()
9202            .insert("X-Trace-ID", b"trace-123".to_vec());
9203
9204        futures_executor::block_on(middleware.before(&ctx, &mut req));
9205
9206        let stored_id = req.get_extension::<RequestId>().unwrap();
9207        assert_eq!(stored_id.as_str(), "trace-123");
9208
9209        let response = Response::ok();
9210        let response = futures_executor::block_on(middleware.after(&ctx, &req, response));
9211
9212        let header = header_value(&response, "X-Trace-ID");
9213        assert_eq!(header, Some("trace-123".to_string()));
9214    }
9215
9216    #[test]
9217    fn is_valid_request_id_accepts_valid() {
9218        assert!(super::is_valid_request_id("abc123"));
9219        assert!(super::is_valid_request_id("request-id-123"));
9220        assert!(super::is_valid_request_id("request_id_123"));
9221        assert!(super::is_valid_request_id("request.id.123"));
9222        assert!(super::is_valid_request_id("ABC123"));
9223        assert!(super::is_valid_request_id("a-b_c.D"));
9224    }
9225
9226    #[test]
9227    fn is_valid_request_id_rejects_invalid() {
9228        assert!(!super::is_valid_request_id(""));
9229        assert!(!super::is_valid_request_id("id with spaces"));
9230        assert!(!super::is_valid_request_id("id<script>"));
9231        assert!(!super::is_valid_request_id("id\nwith\nnewlines"));
9232        assert!(!super::is_valid_request_id("id;with;semicolons"));
9233        assert!(!super::is_valid_request_id("id/with/slashes"));
9234    }
9235
9236    #[test]
9237    fn request_id_middleware_name() {
9238        let middleware = RequestIdMiddleware::new();
9239        assert_eq!(middleware.name(), "RequestId");
9240    }
9241
9242    // =========================================================================
9243    // Middleware Stack Execution Order Tests
9244    // =========================================================================
9245
9246    /// Test middleware that records when its before/after hooks run
9247    struct OrderTrackingMiddleware {
9248        id: &'static str,
9249        log: Arc<std::sync::Mutex<Vec<String>>>,
9250    }
9251
9252    impl OrderTrackingMiddleware {
9253        fn new(id: &'static str, log: Arc<std::sync::Mutex<Vec<String>>>) -> Self {
9254            Self { id, log }
9255        }
9256    }
9257
9258    impl Middleware for OrderTrackingMiddleware {
9259        fn before<'a>(
9260            &'a self,
9261            _ctx: &'a RequestContext,
9262            _req: &'a mut Request,
9263        ) -> BoxFuture<'a, ControlFlow> {
9264            self.log.lock().unwrap().push(format!("{}.before", self.id));
9265            Box::pin(async { ControlFlow::Continue })
9266        }
9267
9268        fn after<'a>(
9269            &'a self,
9270            _ctx: &'a RequestContext,
9271            _req: &'a Request,
9272            response: Response,
9273        ) -> BoxFuture<'a, Response> {
9274            self.log.lock().unwrap().push(format!("{}.after", self.id));
9275            Box::pin(async move { response })
9276        }
9277    }
9278
9279    /// Test middleware that short-circuits with a configurable condition
9280    struct ConditionalBreakMiddleware {
9281        id: &'static str,
9282        should_break: bool,
9283        log: Arc<std::sync::Mutex<Vec<String>>>,
9284    }
9285
9286    impl ConditionalBreakMiddleware {
9287        fn new(
9288            id: &'static str,
9289            should_break: bool,
9290            log: Arc<std::sync::Mutex<Vec<String>>>,
9291        ) -> Self {
9292            Self {
9293                id,
9294                should_break,
9295                log,
9296            }
9297        }
9298    }
9299
9300    impl Middleware for ConditionalBreakMiddleware {
9301        fn before<'a>(
9302            &'a self,
9303            _ctx: &'a RequestContext,
9304            _req: &'a mut Request,
9305        ) -> BoxFuture<'a, ControlFlow> {
9306            self.log.lock().unwrap().push(format!("{}.before", self.id));
9307            let should_break = self.should_break;
9308            Box::pin(async move {
9309                if should_break {
9310                    ControlFlow::Break(
9311                        Response::with_status(StatusCode::FORBIDDEN)
9312                            .body(ResponseBody::Bytes(b"blocked".to_vec())),
9313                    )
9314                } else {
9315                    ControlFlow::Continue
9316                }
9317            })
9318        }
9319
9320        fn after<'a>(
9321            &'a self,
9322            _ctx: &'a RequestContext,
9323            _req: &'a Request,
9324            response: Response,
9325        ) -> BoxFuture<'a, Response> {
9326            self.log.lock().unwrap().push(format!("{}.after", self.id));
9327            Box::pin(async move { response })
9328        }
9329    }
9330
9331    /// Simple test handler that returns 200 OK
9332    struct OkHandler;
9333
9334    impl Handler for OkHandler {
9335        fn call<'a>(
9336            &'a self,
9337            _ctx: &'a RequestContext,
9338            _req: &'a mut Request,
9339        ) -> BoxFuture<'a, Response> {
9340            Box::pin(async move { Response::ok().body(ResponseBody::Bytes(b"handler".to_vec())) })
9341        }
9342    }
9343
9344    /// Handler that checks for a header injected by middleware.
9345    struct CheckHeaderHandler;
9346
9347    impl Handler for CheckHeaderHandler {
9348        fn call<'a>(
9349            &'a self,
9350            _ctx: &'a RequestContext,
9351            req: &'a mut Request,
9352        ) -> BoxFuture<'a, Response> {
9353            let has_header = req.headers().get("X-Modified-By").is_some();
9354            Box::pin(async move {
9355                if has_header {
9356                    Response::ok().body(ResponseBody::Bytes(b"header-present".to_vec()))
9357                } else {
9358                    Response::with_status(StatusCode::BAD_REQUEST)
9359                }
9360            })
9361        }
9362    }
9363
9364    /// Handler that returns an error status.
9365    struct ErrorHandler;
9366
9367    impl Handler for ErrorHandler {
9368        fn call<'a>(
9369            &'a self,
9370            _ctx: &'a RequestContext,
9371            _req: &'a mut Request,
9372        ) -> BoxFuture<'a, Response> {
9373            Box::pin(async move { Response::with_status(StatusCode::INTERNAL_SERVER_ERROR) })
9374        }
9375    }
9376
9377    #[test]
9378    fn middleware_stack_executes_in_correct_order() {
9379        // Verify the "onion" model: before hooks run first-to-last,
9380        // after hooks run last-to-first
9381        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9382
9383        let mut stack = MiddlewareStack::new();
9384        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9385        stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9386        stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9387
9388        let ctx = test_context();
9389        let mut req = Request::new(crate::request::Method::Get, "/");
9390
9391        futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9392
9393        let calls = log.lock().unwrap().clone();
9394        assert_eq!(
9395            calls,
9396            vec![
9397                "mw1.before",
9398                "mw2.before",
9399                "mw3.before",
9400                "mw3.after",
9401                "mw2.after",
9402                "mw1.after",
9403            ]
9404        );
9405    }
9406
9407    #[test]
9408    fn middleware_stack_short_circuit_skips_later_middleware() {
9409        // When middleware 2 breaks, middleware 3's before should NOT run
9410        // But middleware 1 and 2's after hooks should still run
9411        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9412
9413        let mut stack = MiddlewareStack::new();
9414        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9415        stack.push(ConditionalBreakMiddleware::new("mw2", true, log.clone()));
9416        stack.push(OrderTrackingMiddleware::new("mw3", log.clone()));
9417
9418        let ctx = test_context();
9419        let mut req = Request::new(crate::request::Method::Get, "/");
9420
9421        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9422
9423        // Should get 403 from the break
9424        assert_eq!(response.status().as_u16(), 403);
9425
9426        let calls = log.lock().unwrap().clone();
9427        assert_eq!(
9428            calls,
9429            vec![
9430                "mw1.before",
9431                "mw2.before",
9432                // mw3.before NOT called because mw2 broke
9433                // mw2.after NOT called because it was the one that broke (ran_before_count = 1)
9434                "mw1.after",
9435            ]
9436        );
9437    }
9438
9439    #[test]
9440    fn middleware_stack_first_middleware_breaks() {
9441        // When the first middleware breaks, no other middleware should run
9442        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9443
9444        let mut stack = MiddlewareStack::new();
9445        stack.push(ConditionalBreakMiddleware::new("mw1", true, log.clone()));
9446        stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9447
9448        let ctx = test_context();
9449        let mut req = Request::new(crate::request::Method::Get, "/");
9450
9451        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9452
9453        assert_eq!(response.status().as_u16(), 403);
9454
9455        let calls = log.lock().unwrap().clone();
9456        assert_eq!(calls, vec!["mw1.before"]);
9457        // No after hooks because ran_before_count = 0
9458    }
9459
9460    #[test]
9461    fn middleware_stack_last_middleware_breaks() {
9462        // When the last middleware breaks, all previous after hooks should run
9463        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9464
9465        let mut stack = MiddlewareStack::new();
9466        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9467        stack.push(OrderTrackingMiddleware::new("mw2", log.clone()));
9468        stack.push(ConditionalBreakMiddleware::new("mw3", true, log.clone()));
9469
9470        let ctx = test_context();
9471        let mut req = Request::new(crate::request::Method::Get, "/");
9472
9473        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9474
9475        assert_eq!(response.status().as_u16(), 403);
9476
9477        let calls = log.lock().unwrap().clone();
9478        assert_eq!(
9479            calls,
9480            vec![
9481                "mw1.before",
9482                "mw2.before",
9483                "mw3.before",
9484                // mw3 broke, so only mw1 and mw2 after hooks run
9485                "mw2.after",
9486                "mw1.after",
9487            ]
9488        );
9489    }
9490
9491    #[test]
9492    fn middleware_stack_empty_executes_handler_directly() {
9493        let stack = MiddlewareStack::new();
9494        let ctx = test_context();
9495        let mut req = Request::new(crate::request::Method::Get, "/");
9496
9497        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9498
9499        assert_eq!(response.status().as_u16(), 200);
9500    }
9501
9502    #[test]
9503    fn middleware_stack_with_capacity() {
9504        let stack = MiddlewareStack::with_capacity(10);
9505        assert!(stack.is_empty());
9506        assert_eq!(stack.len(), 0);
9507    }
9508
9509    #[test]
9510    fn middleware_stack_push_arc() {
9511        let mut stack = MiddlewareStack::new();
9512        let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
9513        stack.push_arc(mw);
9514        assert_eq!(stack.len(), 1);
9515    }
9516
9517    // =========================================================================
9518    // AddResponseHeader Middleware Tests
9519    // =========================================================================
9520
9521    #[test]
9522    fn add_response_header_adds_header() {
9523        let mw = AddResponseHeader::new("X-Custom", b"custom-value".to_vec());
9524        let ctx = test_context();
9525        let req = Request::new(crate::request::Method::Get, "/");
9526
9527        let response = Response::ok();
9528        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9529
9530        assert_eq!(
9531            header_value(&response, "X-Custom"),
9532            Some("custom-value".to_string())
9533        );
9534    }
9535
9536    #[test]
9537    fn add_response_header_preserves_existing_headers() {
9538        let mw = AddResponseHeader::new("X-New", b"new".to_vec());
9539        let ctx = test_context();
9540        let req = Request::new(crate::request::Method::Get, "/");
9541
9542        let response = Response::ok().header("X-Existing", b"existing".to_vec());
9543        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9544
9545        assert_eq!(
9546            header_value(&response, "X-Existing"),
9547            Some("existing".to_string())
9548        );
9549        assert_eq!(header_value(&response, "X-New"), Some("new".to_string()));
9550    }
9551
9552    #[test]
9553    fn add_response_header_name() {
9554        let mw = AddResponseHeader::new("X-Test", b"test".to_vec());
9555        assert_eq!(mw.name(), "AddResponseHeader");
9556    }
9557
9558    // =========================================================================
9559    // RequireHeader Middleware Tests
9560    // =========================================================================
9561
9562    #[test]
9563    fn require_header_allows_with_header() {
9564        let mw = RequireHeader::new("X-Api-Key");
9565        let ctx = test_context();
9566        let mut req = Request::new(crate::request::Method::Get, "/");
9567        req.headers_mut()
9568            .insert("X-Api-Key", b"secret-key".to_vec());
9569
9570        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9571        assert!(matches!(result, ControlFlow::Continue));
9572    }
9573
9574    #[test]
9575    fn require_header_blocks_without_header() {
9576        let mw = RequireHeader::new("X-Api-Key");
9577        let ctx = test_context();
9578        let mut req = Request::new(crate::request::Method::Get, "/");
9579
9580        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9581
9582        match result {
9583            ControlFlow::Break(response) => {
9584                assert_eq!(response.status().as_u16(), 400);
9585            }
9586            ControlFlow::Continue => panic!("Expected Break, got Continue"),
9587        }
9588    }
9589
9590    #[test]
9591    fn require_header_name() {
9592        let mw = RequireHeader::new("X-Test");
9593        assert_eq!(mw.name(), "RequireHeader");
9594    }
9595
9596    // =========================================================================
9597    // PathPrefixFilter Middleware Tests
9598    // =========================================================================
9599
9600    #[test]
9601    fn path_prefix_filter_allows_matching_path() {
9602        let mw = PathPrefixFilter::new("/api");
9603        let ctx = test_context();
9604        let mut req = Request::new(crate::request::Method::Get, "/api/users");
9605
9606        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9607        assert!(matches!(result, ControlFlow::Continue));
9608    }
9609
9610    #[test]
9611    fn path_prefix_filter_allows_exact_prefix() {
9612        let mw = PathPrefixFilter::new("/api");
9613        let ctx = test_context();
9614        let mut req = Request::new(crate::request::Method::Get, "/api");
9615
9616        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9617        assert!(matches!(result, ControlFlow::Continue));
9618    }
9619
9620    #[test]
9621    fn path_prefix_filter_blocks_non_matching_path() {
9622        let mw = PathPrefixFilter::new("/api");
9623        let ctx = test_context();
9624        let mut req = Request::new(crate::request::Method::Get, "/admin/users");
9625
9626        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9627
9628        match result {
9629            ControlFlow::Break(response) => {
9630                assert_eq!(response.status().as_u16(), 404);
9631            }
9632            ControlFlow::Continue => panic!("Expected Break, got Continue"),
9633        }
9634    }
9635
9636    #[test]
9637    fn path_prefix_filter_name() {
9638        let mw = PathPrefixFilter::new("/api");
9639        assert_eq!(mw.name(), "PathPrefixFilter");
9640    }
9641
9642    // =========================================================================
9643    // ConditionalStatus Middleware Tests
9644    // =========================================================================
9645
9646    #[test]
9647    fn conditional_status_applies_true_status() {
9648        let mw = ConditionalStatus::new(
9649            |req| req.path() == "/health",
9650            StatusCode::OK,
9651            StatusCode::NOT_FOUND,
9652        );
9653        let ctx = test_context();
9654        let req = Request::new(crate::request::Method::Get, "/health");
9655        let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9656
9657        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9658        assert_eq!(response.status().as_u16(), 200);
9659    }
9660
9661    #[test]
9662    fn conditional_status_applies_false_status() {
9663        let mw = ConditionalStatus::new(
9664            |req| req.path() == "/health",
9665            StatusCode::OK,
9666            StatusCode::NOT_FOUND,
9667        );
9668        let ctx = test_context();
9669        let req = Request::new(crate::request::Method::Get, "/other");
9670        let response = Response::with_status(StatusCode::INTERNAL_SERVER_ERROR);
9671
9672        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
9673        assert_eq!(response.status().as_u16(), 404);
9674    }
9675
9676    #[test]
9677    fn conditional_status_name() {
9678        let mw = ConditionalStatus::new(|_| true, StatusCode::OK, StatusCode::NOT_FOUND);
9679        assert_eq!(mw.name(), "ConditionalStatus");
9680    }
9681
9682    // =========================================================================
9683    // Layer and Layered Tests
9684    // =========================================================================
9685
9686    #[derive(Clone)]
9687    struct LayerTestMiddleware {
9688        prefix: String,
9689    }
9690
9691    impl LayerTestMiddleware {
9692        fn new(prefix: impl Into<String>) -> Self {
9693            Self {
9694                prefix: prefix.into(),
9695            }
9696        }
9697    }
9698
9699    impl Middleware for LayerTestMiddleware {
9700        fn after<'a>(
9701            &'a self,
9702            _ctx: &'a RequestContext,
9703            _req: &'a Request,
9704            response: Response,
9705        ) -> BoxFuture<'a, Response> {
9706            let prefix = self.prefix.clone();
9707            Box::pin(async move { response.header("X-Layer", prefix.into_bytes()) })
9708        }
9709    }
9710
9711    #[test]
9712    fn layer_wraps_handler() {
9713        let layer = Layer::new(LayerTestMiddleware::new("wrapped"));
9714        let wrapped = layer.wrap(OkHandler);
9715
9716        let ctx = test_context();
9717        let mut req = Request::new(crate::request::Method::Get, "/");
9718
9719        let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9720
9721        assert_eq!(response.status().as_u16(), 200);
9722        assert_eq!(
9723            header_value(&response, "X-Layer"),
9724            Some("wrapped".to_string())
9725        );
9726    }
9727
9728    #[test]
9729    fn layered_handles_break() {
9730        #[derive(Clone)]
9731        struct BreakingMiddleware;
9732
9733        impl Middleware for BreakingMiddleware {
9734            fn before<'a>(
9735                &'a self,
9736                _ctx: &'a RequestContext,
9737                _req: &'a mut Request,
9738            ) -> BoxFuture<'a, ControlFlow> {
9739                Box::pin(async {
9740                    ControlFlow::Break(Response::with_status(StatusCode::UNAUTHORIZED))
9741                })
9742            }
9743
9744            fn after<'a>(
9745                &'a self,
9746                _ctx: &'a RequestContext,
9747                _req: &'a Request,
9748                response: Response,
9749            ) -> BoxFuture<'a, Response> {
9750                Box::pin(async move { response.header("X-After", b"ran".to_vec()) })
9751            }
9752        }
9753
9754        let layer = Layer::new(BreakingMiddleware);
9755        let wrapped = layer.wrap(OkHandler);
9756
9757        let ctx = test_context();
9758        let mut req = Request::new(crate::request::Method::Get, "/");
9759
9760        let response = futures_executor::block_on(wrapped.call(&ctx, &mut req));
9761
9762        // Should get 401 from break
9763        assert_eq!(response.status().as_u16(), 401);
9764        // After hook should still run
9765        assert_eq!(header_value(&response, "X-After"), Some("ran".to_string()));
9766    }
9767
9768    // =========================================================================
9769    // RequestResponseLogger Tests
9770    // =========================================================================
9771
9772    #[test]
9773    fn request_response_logger_default() {
9774        let logger = RequestResponseLogger::default();
9775        assert!(logger.log_request_headers);
9776        assert!(logger.log_response_headers);
9777        assert!(!logger.log_body);
9778        assert_eq!(logger.max_body_bytes, 1024);
9779    }
9780
9781    #[test]
9782    fn request_response_logger_builder() {
9783        let logger = RequestResponseLogger::new()
9784            .log_request_headers(false)
9785            .log_response_headers(false)
9786            .log_body(true)
9787            .max_body_bytes(2048)
9788            .redact_header("x-secret");
9789
9790        assert!(!logger.log_request_headers);
9791        assert!(!logger.log_response_headers);
9792        assert!(logger.log_body);
9793        assert_eq!(logger.max_body_bytes, 2048);
9794        assert!(logger.redact_headers.contains("x-secret"));
9795    }
9796
9797    #[test]
9798    fn request_response_logger_name() {
9799        let logger = RequestResponseLogger::new();
9800        assert_eq!(logger.name(), "RequestResponseLogger");
9801    }
9802
9803    // =========================================================================
9804    // Integration Tests with Handlers
9805    // =========================================================================
9806
9807    #[test]
9808    fn middleware_stack_modifies_request_for_handler() {
9809        /// Middleware that adds a header that the handler can see
9810        struct RequestModifier;
9811
9812        impl Middleware for RequestModifier {
9813            fn before<'a>(
9814                &'a self,
9815                _ctx: &'a RequestContext,
9816                req: &'a mut Request,
9817            ) -> BoxFuture<'a, ControlFlow> {
9818                req.headers_mut()
9819                    .insert("X-Modified-By", b"middleware".to_vec());
9820                Box::pin(async { ControlFlow::Continue })
9821            }
9822        }
9823
9824        let mut stack = MiddlewareStack::new();
9825        stack.push(RequestModifier);
9826
9827        let ctx = test_context();
9828        let mut req = Request::new(crate::request::Method::Get, "/");
9829
9830        let response =
9831            futures_executor::block_on(stack.execute(&CheckHeaderHandler, &ctx, &mut req));
9832
9833        assert_eq!(response.status().as_u16(), 200);
9834    }
9835
9836    #[test]
9837    fn middleware_stack_multiple_response_modifications() {
9838        let mut stack = MiddlewareStack::new();
9839        stack.push(AddResponseHeader::new("X-First", b"1".to_vec()));
9840        stack.push(AddResponseHeader::new("X-Second", b"2".to_vec()));
9841        stack.push(AddResponseHeader::new("X-Third", b"3".to_vec()));
9842
9843        let ctx = test_context();
9844        let mut req = Request::new(crate::request::Method::Get, "/");
9845
9846        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9847
9848        // All headers should be present (after hooks run in reverse)
9849        assert_eq!(header_value(&response, "X-First"), Some("1".to_string()));
9850        assert_eq!(header_value(&response, "X-Second"), Some("2".to_string()));
9851        assert_eq!(header_value(&response, "X-Third"), Some("3".to_string()));
9852    }
9853
9854    #[test]
9855    fn middleware_stack_handler_receives_response_after_break() {
9856        // Verify that when middleware breaks, the response body is from the break
9857        let mut stack = MiddlewareStack::new();
9858        stack.push(ConditionalBreakMiddleware::new(
9859            "breaker",
9860            true,
9861            Arc::new(std::sync::Mutex::new(Vec::new())),
9862        ));
9863
9864        let ctx = test_context();
9865        let mut req = Request::new(crate::request::Method::Get, "/");
9866
9867        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9868
9869        assert_eq!(response.status().as_u16(), 403);
9870        // Body should be from the breaking middleware, not the handler
9871        match response.body_ref() {
9872            ResponseBody::Bytes(b) => assert_eq!(b, b"blocked"),
9873            _ => panic!("Expected Bytes body"),
9874        }
9875    }
9876
9877    // =========================================================================
9878    // Error Propagation Tests
9879    // =========================================================================
9880
9881    #[test]
9882    fn middleware_after_can_change_status() {
9883        struct StatusChanger;
9884
9885        impl Middleware for StatusChanger {
9886            fn after<'a>(
9887                &'a self,
9888                _ctx: &'a RequestContext,
9889                _req: &'a Request,
9890                _response: Response,
9891            ) -> BoxFuture<'a, Response> {
9892                Box::pin(async { Response::with_status(StatusCode::SERVICE_UNAVAILABLE) })
9893            }
9894        }
9895
9896        let mut stack = MiddlewareStack::new();
9897        stack.push(StatusChanger);
9898
9899        let ctx = test_context();
9900        let mut req = Request::new(crate::request::Method::Get, "/");
9901
9902        let response = futures_executor::block_on(stack.execute(&OkHandler, &ctx, &mut req));
9903
9904        // Should be changed by after hook
9905        assert_eq!(response.status().as_u16(), 503);
9906    }
9907
9908    #[test]
9909    fn middleware_after_runs_even_on_error_status() {
9910        let log = Arc::new(std::sync::Mutex::new(Vec::new()));
9911        let mut stack = MiddlewareStack::new();
9912        stack.push(OrderTrackingMiddleware::new("mw1", log.clone()));
9913
9914        let ctx = test_context();
9915        let mut req = Request::new(crate::request::Method::Get, "/");
9916
9917        let response = futures_executor::block_on(stack.execute(&ErrorHandler, &ctx, &mut req));
9918
9919        assert_eq!(response.status().as_u16(), 500);
9920
9921        let calls = log.lock().unwrap().clone();
9922        // After should run even when handler returns error status
9923        assert_eq!(calls, vec!["mw1.before", "mw1.after"]);
9924    }
9925
9926    // =========================================================================
9927    // Wildcard and Regex Matching Tests
9928    // =========================================================================
9929
9930    #[test]
9931    fn wildcard_match_simple() {
9932        assert!(super::wildcard_match("*.example.com", "api.example.com"));
9933        assert!(super::wildcard_match("*.example.com", "www.example.com"));
9934        assert!(!super::wildcard_match("*.example.com", "example.com"));
9935    }
9936
9937    #[test]
9938    fn wildcard_match_suffix_pattern() {
9939        // Wildcard at start with fixed suffix - primary use case for CORS
9940        assert!(super::wildcard_match("*.txt", "file.txt"));
9941        assert!(super::wildcard_match("*.txt", "document.txt"));
9942        assert!(!super::wildcard_match("*.txt", "file.doc"));
9943        assert!(super::wildcard_match("*-suffix", "any-suffix"));
9944    }
9945
9946    #[test]
9947    fn wildcard_match_no_wildcard() {
9948        assert!(super::wildcard_match("exact", "exact"));
9949        assert!(!super::wildcard_match("exact", "different"));
9950    }
9951
9952    #[test]
9953    fn regex_match_anchored() {
9954        assert!(super::regex_match("^hello$", "hello"));
9955        assert!(!super::regex_match("^hello$", "hello world"));
9956        assert!(!super::regex_match("^hello$", "say hello"));
9957    }
9958
9959    #[test]
9960    fn regex_match_dot_wildcard() {
9961        assert!(super::regex_match("h.llo", "hello"));
9962        assert!(super::regex_match("h.llo", "hallo"));
9963    }
9964
9965    #[test]
9966    fn regex_match_star() {
9967        assert!(super::regex_match("hel*o", "hello"));
9968        assert!(super::regex_match("hel*o", "helo"));
9969        assert!(super::regex_match("hel*o", "hellllllo"));
9970    }
9971
9972    // =========================================================================
9973    // Middleware Trait Default Implementation Tests
9974    // =========================================================================
9975
9976    #[test]
9977    fn middleware_default_before_continues() {
9978        struct DefaultBefore;
9979        impl Middleware for DefaultBefore {}
9980
9981        let mw = DefaultBefore;
9982        let ctx = test_context();
9983        let mut req = Request::new(crate::request::Method::Get, "/");
9984
9985        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
9986        assert!(matches!(result, ControlFlow::Continue));
9987    }
9988
9989    #[test]
9990    fn middleware_default_after_passes_through() {
9991        struct DefaultAfter;
9992        impl Middleware for DefaultAfter {}
9993
9994        let mw = DefaultAfter;
9995        let ctx = test_context();
9996        let req = Request::new(crate::request::Method::Get, "/");
9997        let response = Response::with_status(StatusCode::CREATED);
9998
9999        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10000        assert_eq!(result.status().as_u16(), 201);
10001    }
10002
10003    #[test]
10004    fn middleware_default_name_is_type_name() {
10005        struct MyCustomMiddleware;
10006        impl Middleware for MyCustomMiddleware {}
10007
10008        let mw = MyCustomMiddleware;
10009        assert!(mw.name().contains("MyCustomMiddleware"));
10010    }
10011
10012    // =========================================================================
10013    // Security Headers Middleware Tests
10014    // =========================================================================
10015
10016    #[test]
10017    fn security_headers_default_config() {
10018        let config = SecurityHeadersConfig::default();
10019        assert_eq!(config.x_content_type_options, Some("nosniff"));
10020        assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
10021        assert_eq!(config.x_xss_protection, Some("0"));
10022        assert!(config.content_security_policy.is_none());
10023        assert!(config.hsts.is_none());
10024        assert_eq!(
10025            config.referrer_policy,
10026            Some(ReferrerPolicy::StrictOriginWhenCrossOrigin)
10027        );
10028        assert!(config.permissions_policy.is_none());
10029    }
10030
10031    #[test]
10032    fn security_headers_none_config() {
10033        let config = SecurityHeadersConfig::none();
10034        assert!(config.x_content_type_options.is_none());
10035        assert!(config.x_frame_options.is_none());
10036        assert!(config.x_xss_protection.is_none());
10037        assert!(config.content_security_policy.is_none());
10038        assert!(config.hsts.is_none());
10039        assert!(config.referrer_policy.is_none());
10040        assert!(config.permissions_policy.is_none());
10041    }
10042
10043    #[test]
10044    fn security_headers_strict_config() {
10045        let config = SecurityHeadersConfig::strict();
10046        assert_eq!(config.x_content_type_options, Some("nosniff"));
10047        assert_eq!(config.x_frame_options, Some(XFrameOptions::Deny));
10048        assert_eq!(
10049            config.content_security_policy,
10050            Some("default-src 'self'".to_string())
10051        );
10052        assert_eq!(config.hsts, Some((31536000, true, false)));
10053        assert_eq!(config.referrer_policy, Some(ReferrerPolicy::NoReferrer));
10054        assert!(config.permissions_policy.is_some());
10055    }
10056
10057    #[test]
10058    fn security_headers_config_builder() {
10059        let config = SecurityHeadersConfig::new()
10060            .x_frame_options(Some(XFrameOptions::SameOrigin))
10061            .content_security_policy("default-src 'self'")
10062            .hsts(86400, false, false)
10063            .referrer_policy(Some(ReferrerPolicy::Origin));
10064
10065        assert_eq!(config.x_frame_options, Some(XFrameOptions::SameOrigin));
10066        assert_eq!(
10067            config.content_security_policy,
10068            Some("default-src 'self'".to_string())
10069        );
10070        assert_eq!(config.hsts, Some((86400, false, false)));
10071        assert_eq!(config.referrer_policy, Some(ReferrerPolicy::Origin));
10072    }
10073
10074    #[test]
10075    fn security_headers_hsts_value_format() {
10076        // Basic HSTS
10077        let config = SecurityHeadersConfig::none().hsts(3600, false, false);
10078        assert_eq!(config.build_hsts_value(), Some("max-age=3600".to_string()));
10079
10080        // With includeSubDomains
10081        let config = SecurityHeadersConfig::none().hsts(3600, true, false);
10082        assert_eq!(
10083            config.build_hsts_value(),
10084            Some("max-age=3600; includeSubDomains".to_string())
10085        );
10086
10087        // With preload
10088        let config = SecurityHeadersConfig::none().hsts(3600, false, true);
10089        assert_eq!(
10090            config.build_hsts_value(),
10091            Some("max-age=3600; preload".to_string())
10092        );
10093
10094        // With both
10095        let config = SecurityHeadersConfig::none().hsts(3600, true, true);
10096        assert_eq!(
10097            config.build_hsts_value(),
10098            Some("max-age=3600; includeSubDomains; preload".to_string())
10099        );
10100    }
10101
10102    #[test]
10103    fn security_headers_middleware_adds_default_headers() {
10104        let mw = SecurityHeaders::new();
10105        let ctx = test_context();
10106        let req = Request::new(crate::request::Method::Get, "/");
10107        let response = Response::ok();
10108
10109        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10110
10111        // Check that default headers are present
10112        assert!(header_value(&result, "X-Content-Type-Options").is_some());
10113        assert!(header_value(&result, "X-Frame-Options").is_some());
10114        assert!(header_value(&result, "X-XSS-Protection").is_some());
10115        assert!(header_value(&result, "Referrer-Policy").is_some());
10116
10117        // Check that optional headers are NOT present by default
10118        assert!(header_value(&result, "Content-Security-Policy").is_none());
10119        assert!(header_value(&result, "Strict-Transport-Security").is_none());
10120        assert!(header_value(&result, "Permissions-Policy").is_none());
10121    }
10122
10123    #[test]
10124    fn security_headers_middleware_with_csp() {
10125        let config = SecurityHeadersConfig::new()
10126            .content_security_policy("default-src 'self'; script-src 'self' 'unsafe-inline'");
10127        let mw = SecurityHeaders::with_config(config);
10128        let ctx = test_context();
10129        let req = Request::new(crate::request::Method::Get, "/");
10130        let response = Response::ok();
10131
10132        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10133
10134        let csp = header_value(&result, "Content-Security-Policy");
10135        assert!(csp.is_some());
10136        assert_eq!(
10137            csp.unwrap(),
10138            "default-src 'self'; script-src 'self' 'unsafe-inline'"
10139        );
10140    }
10141
10142    #[test]
10143    fn security_headers_middleware_with_hsts() {
10144        let config = SecurityHeadersConfig::new().hsts(31536000, true, false);
10145        let mw = SecurityHeaders::with_config(config);
10146        let ctx = test_context();
10147        let req = Request::new(crate::request::Method::Get, "/");
10148        let response = Response::ok();
10149
10150        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10151
10152        let hsts = header_value(&result, "Strict-Transport-Security");
10153        assert!(hsts.is_some());
10154        assert_eq!(hsts.unwrap(), "max-age=31536000; includeSubDomains");
10155    }
10156
10157    #[test]
10158    fn security_headers_middleware_name() {
10159        let mw = SecurityHeaders::new();
10160        assert_eq!(mw.name(), "SecurityHeaders");
10161    }
10162
10163    #[test]
10164    fn x_frame_options_values() {
10165        assert_eq!(XFrameOptions::Deny.as_bytes(), b"DENY");
10166        assert_eq!(XFrameOptions::SameOrigin.as_bytes(), b"SAMEORIGIN");
10167    }
10168
10169    #[test]
10170    fn referrer_policy_values() {
10171        assert_eq!(ReferrerPolicy::NoReferrer.as_bytes(), b"no-referrer");
10172        assert_eq!(
10173            ReferrerPolicy::NoReferrerWhenDowngrade.as_bytes(),
10174            b"no-referrer-when-downgrade"
10175        );
10176        assert_eq!(ReferrerPolicy::Origin.as_bytes(), b"origin");
10177        assert_eq!(
10178            ReferrerPolicy::OriginWhenCrossOrigin.as_bytes(),
10179            b"origin-when-cross-origin"
10180        );
10181        assert_eq!(ReferrerPolicy::SameOrigin.as_bytes(), b"same-origin");
10182        assert_eq!(ReferrerPolicy::StrictOrigin.as_bytes(), b"strict-origin");
10183        assert_eq!(
10184            ReferrerPolicy::StrictOriginWhenCrossOrigin.as_bytes(),
10185            b"strict-origin-when-cross-origin"
10186        );
10187        assert_eq!(ReferrerPolicy::UnsafeUrl.as_bytes(), b"unsafe-url");
10188    }
10189
10190    #[test]
10191    fn security_headers_strict_preset() {
10192        let mw = SecurityHeaders::strict();
10193        let ctx = test_context();
10194        let req = Request::new(crate::request::Method::Get, "/");
10195        let response = Response::ok();
10196
10197        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
10198
10199        // All headers should be present with strict config
10200        assert!(header_value(&result, "X-Content-Type-Options").is_some());
10201        assert!(header_value(&result, "X-Frame-Options").is_some());
10202        assert!(header_value(&result, "Content-Security-Policy").is_some());
10203        assert!(header_value(&result, "Strict-Transport-Security").is_some());
10204        assert!(header_value(&result, "Referrer-Policy").is_some());
10205        assert!(header_value(&result, "Permissions-Policy").is_some());
10206    }
10207
10208    #[test]
10209    fn security_headers_config_clearing_methods() {
10210        let config = SecurityHeadersConfig::strict()
10211            .no_content_security_policy()
10212            .no_hsts()
10213            .no_permissions_policy();
10214
10215        assert!(config.content_security_policy.is_none());
10216        assert!(config.hsts.is_none());
10217        assert!(config.permissions_policy.is_none());
10218    }
10219
10220    // =========================================================================
10221    // CSRF Middleware Tests
10222    // =========================================================================
10223
10224    #[test]
10225    fn csrf_token_generate_produces_unique_tokens() {
10226        let token1 = CsrfToken::generate();
10227        let token2 = CsrfToken::generate();
10228        assert_ne!(token1, token2);
10229        assert!(!token1.as_str().is_empty());
10230        assert!(!token2.as_str().is_empty());
10231    }
10232
10233    #[test]
10234    fn csrf_token_display() {
10235        let token = CsrfToken::new("test-token-123");
10236        assert_eq!(format!("{}", token), "test-token-123");
10237    }
10238
10239    #[test]
10240    fn csrf_config_defaults() {
10241        let config = CsrfConfig::default();
10242        assert_eq!(config.cookie_name, "csrf_token");
10243        assert_eq!(config.header_name, "x-csrf-token");
10244        assert_eq!(config.mode, CsrfMode::DoubleSubmit);
10245        assert!(!config.rotate_token);
10246        assert!(config.production);
10247        assert!(config.error_message.is_none());
10248    }
10249
10250    #[test]
10251    fn csrf_config_builder() {
10252        let config = CsrfConfig::new()
10253            .cookie_name("XSRF-TOKEN")
10254            .header_name("X-XSRF-Token")
10255            .mode(CsrfMode::HeaderOnly)
10256            .rotate_token(true)
10257            .production(false)
10258            .error_message("Custom CSRF error");
10259
10260        assert_eq!(config.cookie_name, "XSRF-TOKEN");
10261        assert_eq!(config.header_name, "X-XSRF-Token");
10262        assert_eq!(config.mode, CsrfMode::HeaderOnly);
10263        assert!(config.rotate_token);
10264        assert!(!config.production);
10265        assert_eq!(config.error_message, Some("Custom CSRF error".to_string()));
10266    }
10267
10268    #[test]
10269    fn csrf_middleware_allows_get_without_token() {
10270        let csrf = CsrfMiddleware::new();
10271        let ctx = test_context();
10272        let mut req = Request::new(crate::request::Method::Get, "/");
10273
10274        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10275        assert!(result.is_continue());
10276        // Token should be generated and stored
10277        assert!(req.get_extension::<CsrfToken>().is_some());
10278    }
10279
10280    #[test]
10281    fn csrf_middleware_allows_head_without_token() {
10282        let csrf = CsrfMiddleware::new();
10283        let ctx = test_context();
10284        let mut req = Request::new(crate::request::Method::Head, "/");
10285
10286        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10287        assert!(result.is_continue());
10288    }
10289
10290    #[test]
10291    fn csrf_middleware_allows_options_without_token() {
10292        let csrf = CsrfMiddleware::new();
10293        let ctx = test_context();
10294        let mut req = Request::new(crate::request::Method::Options, "/");
10295
10296        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10297        assert!(result.is_continue());
10298    }
10299
10300    #[test]
10301    fn csrf_middleware_blocks_post_without_token() {
10302        let csrf = CsrfMiddleware::new();
10303        let ctx = test_context();
10304        let mut req = Request::new(crate::request::Method::Post, "/");
10305
10306        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10307        assert!(result.is_break());
10308
10309        if let ControlFlow::Break(response) = result {
10310            assert_eq!(response.status(), StatusCode::FORBIDDEN);
10311        }
10312    }
10313
10314    #[test]
10315    fn csrf_middleware_blocks_put_without_token() {
10316        let csrf = CsrfMiddleware::new();
10317        let ctx = test_context();
10318        let mut req = Request::new(crate::request::Method::Put, "/");
10319
10320        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10321        assert!(result.is_break());
10322    }
10323
10324    #[test]
10325    fn csrf_middleware_blocks_delete_without_token() {
10326        let csrf = CsrfMiddleware::new();
10327        let ctx = test_context();
10328        let mut req = Request::new(crate::request::Method::Delete, "/");
10329
10330        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10331        assert!(result.is_break());
10332    }
10333
10334    #[test]
10335    fn csrf_middleware_blocks_patch_without_token() {
10336        let csrf = CsrfMiddleware::new();
10337        let ctx = test_context();
10338        let mut req = Request::new(crate::request::Method::Patch, "/");
10339
10340        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10341        assert!(result.is_break());
10342    }
10343
10344    #[test]
10345    fn csrf_middleware_allows_post_with_matching_tokens() {
10346        let csrf = CsrfMiddleware::new();
10347        let ctx = test_context();
10348        let mut req = Request::new(crate::request::Method::Post, "/");
10349
10350        // Set matching cookie and header
10351        let token = "valid-csrf-token-12345";
10352        req.headers_mut()
10353            .insert("cookie", format!("csrf_token={}", token).into_bytes());
10354        req.headers_mut()
10355            .insert("x-csrf-token", token.as_bytes().to_vec());
10356
10357        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10358        assert!(result.is_continue());
10359
10360        // Token should be stored in extensions
10361        let stored_token = req.get_extension::<CsrfToken>().unwrap();
10362        assert_eq!(stored_token.as_str(), token);
10363    }
10364
10365    #[test]
10366    fn csrf_middleware_blocks_post_with_mismatched_tokens() {
10367        let csrf = CsrfMiddleware::new();
10368        let ctx = test_context();
10369        let mut req = Request::new(crate::request::Method::Post, "/");
10370
10371        // Set mismatched cookie and header
10372        req.headers_mut()
10373            .insert("cookie", b"csrf_token=token-in-cookie".to_vec());
10374        req.headers_mut()
10375            .insert("x-csrf-token", b"different-token".to_vec());
10376
10377        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10378        assert!(result.is_break());
10379
10380        if let ControlFlow::Break(response) = result {
10381            assert_eq!(response.status(), StatusCode::FORBIDDEN);
10382        }
10383    }
10384
10385    #[test]
10386    fn csrf_middleware_blocks_post_with_header_only_in_double_submit_mode() {
10387        let csrf = CsrfMiddleware::new();
10388        let ctx = test_context();
10389        let mut req = Request::new(crate::request::Method::Post, "/");
10390
10391        // Only header, no cookie
10392        req.headers_mut()
10393            .insert("x-csrf-token", b"some-token".to_vec());
10394
10395        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10396        assert!(result.is_break());
10397    }
10398
10399    #[test]
10400    fn csrf_middleware_blocks_post_with_cookie_only_in_double_submit_mode() {
10401        let csrf = CsrfMiddleware::new();
10402        let ctx = test_context();
10403        let mut req = Request::new(crate::request::Method::Post, "/");
10404
10405        // Only cookie, no header
10406        req.headers_mut()
10407            .insert("cookie", b"csrf_token=some-token".to_vec());
10408
10409        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10410        assert!(result.is_break());
10411    }
10412
10413    #[test]
10414    fn csrf_middleware_header_only_mode_accepts_header_token() {
10415        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10416        let ctx = test_context();
10417        let mut req = Request::new(crate::request::Method::Post, "/");
10418
10419        req.headers_mut()
10420            .insert("x-csrf-token", b"valid-token".to_vec());
10421
10422        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10423        assert!(result.is_continue());
10424    }
10425
10426    #[test]
10427    fn csrf_middleware_header_only_mode_rejects_empty_header() {
10428        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10429        let ctx = test_context();
10430        let mut req = Request::new(crate::request::Method::Post, "/");
10431
10432        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10433
10434        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10435        assert!(result.is_break());
10436    }
10437
10438    #[test]
10439    fn csrf_middleware_sets_cookie_on_get() {
10440        let csrf = CsrfMiddleware::new();
10441        let ctx = test_context();
10442        let mut req = Request::new(crate::request::Method::Get, "/");
10443
10444        // Run before to generate token
10445        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10446
10447        // Run after to set cookie
10448        let response = Response::ok();
10449        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10450
10451        // Check Set-Cookie header
10452        let cookie_value = header_value(&result, "set-cookie");
10453        assert!(cookie_value.is_some());
10454
10455        let cookie_value = cookie_value.unwrap();
10456        assert!(cookie_value.starts_with("csrf_token="));
10457        assert!(cookie_value.contains("SameSite=Strict"));
10458        assert!(cookie_value.contains("Secure")); // Production mode
10459    }
10460
10461    #[test]
10462    fn csrf_middleware_no_secure_in_dev_mode() {
10463        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(false));
10464        let ctx = test_context();
10465        let mut req = Request::new(crate::request::Method::Get, "/");
10466
10467        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10468
10469        let response = Response::ok();
10470        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10471
10472        let cookie_value = header_value(&result, "set-cookie").unwrap();
10473        assert!(!cookie_value.contains("Secure")); // No Secure in dev mode
10474    }
10475
10476    #[test]
10477    fn csrf_middleware_does_not_set_cookie_if_already_present() {
10478        let csrf = CsrfMiddleware::new();
10479        let ctx = test_context();
10480        let mut req = Request::new(crate::request::Method::Get, "/");
10481
10482        // Cookie already present
10483        req.headers_mut()
10484            .insert("cookie", b"csrf_token=existing-token".to_vec());
10485
10486        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10487
10488        let response = Response::ok();
10489        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10490
10491        // Should not set a new cookie
10492        assert!(header_value(&result, "set-cookie").is_none());
10493    }
10494
10495    #[test]
10496    fn csrf_middleware_rotates_token_when_configured() {
10497        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
10498        let ctx = test_context();
10499        let mut req = Request::new(crate::request::Method::Get, "/");
10500
10501        // Cookie already present
10502        req.headers_mut()
10503            .insert("cookie", b"csrf_token=old-token".to_vec());
10504
10505        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10506
10507        let response = Response::ok();
10508        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10509
10510        // Should set a new cookie even though one exists
10511        assert!(header_value(&result, "set-cookie").is_some());
10512    }
10513
10514    #[test]
10515    fn csrf_middleware_custom_header_name() {
10516        let csrf = CsrfMiddleware::with_config(
10517            CsrfConfig::new()
10518                .header_name("X-XSRF-Token")
10519                .cookie_name("XSRF-TOKEN"),
10520        );
10521        let ctx = test_context();
10522        let mut req = Request::new(crate::request::Method::Post, "/");
10523
10524        let token = "custom-token-value";
10525        req.headers_mut()
10526            .insert("cookie", format!("XSRF-TOKEN={}", token).into_bytes());
10527        req.headers_mut()
10528            .insert("x-xsrf-token", token.as_bytes().to_vec());
10529
10530        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10531        assert!(result.is_continue());
10532    }
10533
10534    #[test]
10535    fn csrf_middleware_error_response_is_json() {
10536        let csrf = CsrfMiddleware::new();
10537        let ctx = test_context();
10538        let mut req = Request::new(crate::request::Method::Post, "/");
10539
10540        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10541
10542        if let ControlFlow::Break(response) = result {
10543            let content_type = header_value(&response, "content-type");
10544            assert_eq!(content_type, Some("application/json".to_string()));
10545
10546            // Check body contains proper error structure
10547            if let ResponseBody::Bytes(body) = response.body_ref() {
10548                let body_str = std::str::from_utf8(body).unwrap();
10549                assert!(body_str.contains("csrf_error"));
10550                assert!(body_str.contains("x-csrf-token"));
10551            } else {
10552                panic!("Expected Bytes body");
10553            }
10554        } else {
10555            panic!("Expected Break");
10556        }
10557    }
10558
10559    #[test]
10560    fn csrf_middleware_custom_error_message() {
10561        let csrf = CsrfMiddleware::with_config(
10562            CsrfConfig::new().error_message("Access denied: invalid security token"),
10563        );
10564        let ctx = test_context();
10565        let mut req = Request::new(crate::request::Method::Post, "/");
10566
10567        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10568
10569        if let ControlFlow::Break(response) = result {
10570            if let ResponseBody::Bytes(body) = response.body_ref() {
10571                let body_str = std::str::from_utf8(body).unwrap();
10572                assert!(body_str.contains("Access denied: invalid security token"));
10573            }
10574        }
10575    }
10576
10577    #[test]
10578    fn csrf_middleware_name() {
10579        let csrf = CsrfMiddleware::new();
10580        assert_eq!(csrf.name(), "CSRF");
10581    }
10582
10583    #[test]
10584    fn csrf_middleware_parses_cookie_with_multiple_cookies() {
10585        let csrf = CsrfMiddleware::new();
10586        let ctx = test_context();
10587        let mut req = Request::new(crate::request::Method::Post, "/");
10588
10589        // Multiple cookies in the header
10590        let token = "the-csrf-token";
10591        req.headers_mut().insert(
10592            "cookie",
10593            format!("session=abc123; csrf_token={}; user=test", token).into_bytes(),
10594        );
10595        req.headers_mut()
10596            .insert("x-csrf-token", token.as_bytes().to_vec());
10597
10598        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10599        assert!(result.is_continue());
10600    }
10601
10602    #[test]
10603    fn csrf_middleware_handles_empty_token_value() {
10604        let csrf = CsrfMiddleware::new();
10605        let ctx = test_context();
10606        let mut req = Request::new(crate::request::Method::Post, "/");
10607
10608        // Empty token values
10609        req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10610        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10611
10612        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10613        assert!(result.is_break()); // Should reject empty tokens
10614    }
10615
10616    // ---- Comprehensive CSRF tests (bd-3v0c) ----
10617
10618    #[test]
10619    fn csrf_token_generate_many_unique() {
10620        // Generate many tokens and verify all are unique
10621        let mut tokens = std::collections::HashSet::new();
10622        for _ in 0..100 {
10623            let token = CsrfToken::generate();
10624            assert!(
10625                tokens.insert(token.0.clone()),
10626                "Duplicate token generated: {}",
10627                token.0
10628            );
10629        }
10630        assert_eq!(tokens.len(), 100);
10631    }
10632
10633    #[test]
10634    fn csrf_token_generate_format_is_hex() {
10635        let token = CsrfToken::generate();
10636        let s = token.as_str();
10637        // Token should be all hex characters, at least 64 chars (32 bytes from urandom)
10638        assert!(
10639            s.len() >= 64,
10640            "Expected at least 64 hex characters, got {} in '{s}'",
10641            s.len()
10642        );
10643        assert!(
10644            s.chars().all(|c| c.is_ascii_hexdigit()),
10645            "Non-hex character in token: {s}"
10646        );
10647    }
10648
10649    #[test]
10650    fn csrf_token_generate_minimum_length() {
10651        let token = CsrfToken::generate();
10652        // 32 bytes from urandom = 64 hex chars
10653        assert!(
10654            token.as_str().len() >= 64,
10655            "Token too short: {} (len={})",
10656            token.as_str(),
10657            token.as_str().len()
10658        );
10659    }
10660
10661    #[test]
10662    fn csrf_token_from_str() {
10663        let token: CsrfToken = "my-token".into();
10664        assert_eq!(token.as_str(), "my-token");
10665        assert_eq!(token.0, "my-token");
10666    }
10667
10668    #[test]
10669    fn csrf_token_clone_eq() {
10670        let t1 = CsrfToken::new("abc");
10671        let t2 = t1.clone();
10672        assert_eq!(t1, t2);
10673        assert_eq!(t1.as_str(), t2.as_str());
10674    }
10675
10676    #[test]
10677    fn csrf_middleware_allows_trace_without_token() {
10678        let csrf = CsrfMiddleware::new();
10679        let ctx = test_context();
10680        let mut req = Request::new(crate::request::Method::Trace, "/");
10681
10682        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10683        assert!(result.is_continue());
10684        // Token should be generated
10685        assert!(req.get_extension::<CsrfToken>().is_some());
10686    }
10687
10688    #[test]
10689    fn csrf_safe_method_generates_token_into_extension() {
10690        let csrf = CsrfMiddleware::new();
10691        let ctx = test_context();
10692
10693        for method in [
10694            crate::request::Method::Get,
10695            crate::request::Method::Head,
10696            crate::request::Method::Options,
10697            crate::request::Method::Trace,
10698        ] {
10699            let mut req = Request::new(method, "/test");
10700            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10701            assert!(result.is_continue());
10702            let token = req.get_extension::<CsrfToken>().expect("token missing");
10703            assert!(!token.as_str().is_empty());
10704        }
10705    }
10706
10707    #[test]
10708    fn csrf_safe_method_preserves_existing_cookie_token() {
10709        let csrf = CsrfMiddleware::new();
10710        let ctx = test_context();
10711        let mut req = Request::new(crate::request::Method::Get, "/");
10712        req.headers_mut()
10713            .insert("cookie", b"csrf_token=my-existing-token".to_vec());
10714
10715        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10716
10717        // Extension should contain the existing cookie token, not a new one
10718        let token = req.get_extension::<CsrfToken>().unwrap();
10719        assert_eq!(token.as_str(), "my-existing-token");
10720    }
10721
10722    #[test]
10723    fn csrf_valid_post_stores_token_in_extension() {
10724        let csrf = CsrfMiddleware::new();
10725        let ctx = test_context();
10726        let mut req = Request::new(crate::request::Method::Post, "/submit");
10727
10728        let tk = "valid-token-xyz";
10729        req.headers_mut()
10730            .insert("cookie", format!("csrf_token={}", tk).into_bytes());
10731        req.headers_mut()
10732            .insert("x-csrf-token", tk.as_bytes().to_vec());
10733
10734        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10735        assert!(result.is_continue());
10736        let stored = req.get_extension::<CsrfToken>().unwrap();
10737        assert_eq!(stored.as_str(), tk);
10738    }
10739
10740    #[test]
10741    fn csrf_double_submit_both_empty_strings_rejected() {
10742        let csrf = CsrfMiddleware::new();
10743        let ctx = test_context();
10744        let mut req = Request::new(crate::request::Method::Post, "/");
10745
10746        // Both cookie and header have empty string values
10747        req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10748        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10749
10750        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10751        assert!(result.is_break());
10752    }
10753
10754    #[test]
10755    fn csrf_double_submit_matching_empty_rejected() {
10756        // Even if both are technically "equal" (empty), should reject
10757        let csrf = CsrfMiddleware::new();
10758        let ctx = test_context();
10759        let mut req = Request::new(crate::request::Method::Post, "/");
10760
10761        req.headers_mut().insert("cookie", b"csrf_token=".to_vec());
10762        req.headers_mut().insert("x-csrf-token", b"".to_vec());
10763
10764        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10765        assert!(
10766            result.is_break(),
10767            "Empty matching tokens should be rejected"
10768        );
10769    }
10770
10771    #[test]
10772    fn csrf_header_only_mode_does_not_need_cookie() {
10773        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10774        let ctx = test_context();
10775        let mut req = Request::new(crate::request::Method::Post, "/");
10776
10777        // Header only, no cookie
10778        req.headers_mut()
10779            .insert("x-csrf-token", b"header-only-token".to_vec());
10780
10781        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10782        assert!(result.is_continue());
10783        let token = req.get_extension::<CsrfToken>().unwrap();
10784        assert_eq!(token.as_str(), "header-only-token");
10785    }
10786
10787    #[test]
10788    fn csrf_header_only_mode_ignores_mismatched_cookie() {
10789        // In HeaderOnly mode, the cookie value is irrelevant
10790        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10791        let ctx = test_context();
10792        let mut req = Request::new(crate::request::Method::Post, "/");
10793
10794        req.headers_mut()
10795            .insert("cookie", b"csrf_token=different-value".to_vec());
10796        req.headers_mut()
10797            .insert("x-csrf-token", b"header-value".to_vec());
10798
10799        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10800        assert!(result.is_continue(), "HeaderOnly should ignore cookie");
10801    }
10802
10803    #[test]
10804    fn csrf_header_only_mode_rejects_no_header() {
10805        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10806        let ctx = test_context();
10807        let mut req = Request::new(crate::request::Method::Post, "/");
10808        // No header at all
10809        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10810        assert!(result.is_break());
10811    }
10812
10813    #[test]
10814    fn csrf_header_only_error_message_mentions_header() {
10815        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().mode(CsrfMode::HeaderOnly));
10816        let ctx = test_context();
10817        let mut req = Request::new(crate::request::Method::Post, "/");
10818
10819        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
10820        if let ControlFlow::Break(response) = result {
10821            if let ResponseBody::Bytes(body) = response.body_ref() {
10822                let body_str = std::str::from_utf8(body).unwrap();
10823                assert!(
10824                    body_str.contains("missing in header"),
10825                    "Expected 'missing in header' in: {}",
10826                    body_str
10827                );
10828            }
10829        } else {
10830            panic!("Expected Break");
10831        }
10832    }
10833
10834    #[test]
10835    fn csrf_mismatch_error_differs_from_missing_error() {
10836        let csrf = CsrfMiddleware::new();
10837        let ctx = test_context();
10838
10839        // Missing: no header or cookie
10840        let mut req_missing = Request::new(crate::request::Method::Post, "/");
10841        let missing_result = futures_executor::block_on(csrf.before(&ctx, &mut req_missing));
10842        let missing_body = match missing_result {
10843            ControlFlow::Break(r) => match r.body_ref() {
10844                ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10845                ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10846            },
10847            ControlFlow::Continue => panic!("Expected Break"),
10848        };
10849
10850        // Mismatch: both present but different
10851        let mut req_mismatch = Request::new(crate::request::Method::Post, "/");
10852        req_mismatch
10853            .headers_mut()
10854            .insert("cookie", b"csrf_token=aaa".to_vec());
10855        req_mismatch
10856            .headers_mut()
10857            .insert("x-csrf-token", b"bbb".to_vec());
10858        let mismatch_result = futures_executor::block_on(csrf.before(&ctx, &mut req_mismatch));
10859        let mismatch_body = match mismatch_result {
10860            ControlFlow::Break(r) => match r.body_ref() {
10861                ResponseBody::Bytes(b) => std::str::from_utf8(b).unwrap().to_string(),
10862                ResponseBody::Empty | ResponseBody::Stream(_) => panic!("Expected Bytes"),
10863            },
10864            ControlFlow::Continue => panic!("Expected Break"),
10865        };
10866
10867        // Error messages should differ
10868        assert_ne!(
10869            missing_body, mismatch_body,
10870            "Missing vs mismatch should have different error messages"
10871        );
10872        assert!(missing_body.contains("missing"));
10873        assert!(mismatch_body.contains("mismatch"));
10874    }
10875
10876    #[test]
10877    fn csrf_cookie_not_httponly() {
10878        // CSRF cookies MUST be readable by JavaScript (no HttpOnly)
10879        let csrf = CsrfMiddleware::new();
10880        let ctx = test_context();
10881        let mut req = Request::new(crate::request::Method::Get, "/");
10882
10883        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10884        let response = Response::ok();
10885        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10886
10887        let cookie_value = header_value(&result, "set-cookie").unwrap();
10888        assert!(
10889            !cookie_value.to_lowercase().contains("httponly"),
10890            "CSRF cookie must NOT be HttpOnly (needs JS access), got: {}",
10891            cookie_value
10892        );
10893    }
10894
10895    #[test]
10896    fn csrf_cookie_has_path_slash() {
10897        let csrf = CsrfMiddleware::new();
10898        let ctx = test_context();
10899        let mut req = Request::new(crate::request::Method::Get, "/");
10900
10901        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10902        let response = Response::ok();
10903        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10904
10905        let cookie_value = header_value(&result, "set-cookie").unwrap();
10906        assert!(
10907            cookie_value.contains("Path=/"),
10908            "Cookie should have Path=/, got: {}",
10909            cookie_value
10910        );
10911    }
10912
10913    #[test]
10914    fn csrf_cookie_has_samesite_strict() {
10915        let csrf = CsrfMiddleware::new();
10916        let ctx = test_context();
10917        let mut req = Request::new(crate::request::Method::Get, "/");
10918
10919        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10920        let response = Response::ok();
10921        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10922
10923        let cookie_value = header_value(&result, "set-cookie").unwrap();
10924        assert!(
10925            cookie_value.contains("SameSite=Strict"),
10926            "Cookie should have SameSite=Strict, got: {}",
10927            cookie_value
10928        );
10929    }
10930
10931    #[test]
10932    fn csrf_production_mode_sets_secure_flag() {
10933        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().production(true));
10934        let ctx = test_context();
10935        let mut req = Request::new(crate::request::Method::Get, "/");
10936
10937        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10938        let response = Response::ok();
10939        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10940
10941        let cookie_value = header_value(&result, "set-cookie").unwrap();
10942        assert!(
10943            cookie_value.contains("Secure"),
10944            "Production cookie must have Secure flag, got: {}",
10945            cookie_value
10946        );
10947    }
10948
10949    #[test]
10950    fn csrf_no_set_cookie_on_post_response() {
10951        // Set-Cookie should only be added for safe methods, not POST
10952        let csrf = CsrfMiddleware::new();
10953        let ctx = test_context();
10954        let mut req = Request::new(crate::request::Method::Post, "/");
10955
10956        let token = "valid-token";
10957        req.headers_mut()
10958            .insert("cookie", format!("csrf_token={}", token).into_bytes());
10959        req.headers_mut()
10960            .insert("x-csrf-token", token.as_bytes().to_vec());
10961
10962        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10963        let response = Response::ok();
10964        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10965
10966        assert!(
10967            header_value(&result, "set-cookie").is_none(),
10968            "POST response should not set CSRF cookie"
10969        );
10970    }
10971
10972    #[test]
10973    fn csrf_head_method_sets_cookie() {
10974        let csrf = CsrfMiddleware::new();
10975        let ctx = test_context();
10976        let mut req = Request::new(crate::request::Method::Head, "/");
10977
10978        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10979        let response = Response::ok();
10980        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10981
10982        assert!(
10983            header_value(&result, "set-cookie").is_some(),
10984            "HEAD response should set CSRF cookie"
10985        );
10986    }
10987
10988    #[test]
10989    fn csrf_options_method_sets_cookie() {
10990        let csrf = CsrfMiddleware::new();
10991        let ctx = test_context();
10992        let mut req = Request::new(crate::request::Method::Options, "/");
10993
10994        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
10995        let response = Response::ok();
10996        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
10997
10998        assert!(
10999            header_value(&result, "set-cookie").is_some(),
11000            "OPTIONS response should set CSRF cookie"
11001        );
11002    }
11003
11004    #[test]
11005    fn csrf_rotation_produces_different_token_in_cookie() {
11006        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(true));
11007        let ctx = test_context();
11008        let mut req = Request::new(crate::request::Method::Get, "/");
11009
11010        let old_token = "old-token-value";
11011        req.headers_mut()
11012            .insert("cookie", format!("csrf_token={}", old_token).into_bytes());
11013
11014        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
11015        let response = Response::ok();
11016        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
11017
11018        let cookie_value = header_value(&result, "set-cookie").unwrap();
11019        // When rotation is enabled, old token is reused from cookie parse, but
11020        // the cookie IS set (which the before phase stored in extension).
11021        // The existing token from cookie is used, so cookie_value will contain old_token.
11022        // This verifies the Set-Cookie is emitted even with an existing cookie.
11023        assert!(cookie_value.starts_with("csrf_token="));
11024    }
11025
11026    #[test]
11027    fn csrf_no_rotation_skips_set_cookie_when_present() {
11028        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().rotate_token(false));
11029        let ctx = test_context();
11030        let mut req = Request::new(crate::request::Method::Get, "/");
11031
11032        req.headers_mut()
11033            .insert("cookie", b"csrf_token=existing".to_vec());
11034
11035        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
11036        let response = Response::ok();
11037        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
11038
11039        assert!(
11040            header_value(&result, "set-cookie").is_none(),
11041            "Without rotation, should not re-set existing cookie"
11042        );
11043    }
11044
11045    #[test]
11046    fn csrf_custom_cookie_name_in_set_cookie_response() {
11047        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().cookie_name("XSRF-TOKEN"));
11048        let ctx = test_context();
11049        let mut req = Request::new(crate::request::Method::Get, "/");
11050
11051        let _ = futures_executor::block_on(csrf.before(&ctx, &mut req));
11052        let response = Response::ok();
11053        let result = futures_executor::block_on(csrf.after(&ctx, &req, response));
11054
11055        let cookie_value = header_value(&result, "set-cookie").unwrap();
11056        assert!(
11057            cookie_value.starts_with("XSRF-TOKEN="),
11058            "Custom cookie name should appear in Set-Cookie, got: {}",
11059            cookie_value
11060        );
11061    }
11062
11063    #[test]
11064    fn csrf_custom_header_name_validated() {
11065        let csrf = CsrfMiddleware::with_config(
11066            CsrfConfig::new()
11067                .header_name("X-Custom-CSRF")
11068                .cookie_name("my_csrf"),
11069        );
11070        let ctx = test_context();
11071        let mut req = Request::new(crate::request::Method::Post, "/");
11072
11073        let token = "custom-tok";
11074        req.headers_mut()
11075            .insert("cookie", format!("my_csrf={}", token).into_bytes());
11076        req.headers_mut()
11077            .insert("x-custom-csrf", token.as_bytes().to_vec());
11078
11079        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11080        assert!(result.is_continue());
11081    }
11082
11083    #[test]
11084    fn csrf_custom_header_name_wrong_header_rejected() {
11085        let csrf = CsrfMiddleware::with_config(CsrfConfig::new().header_name("X-Custom-CSRF"));
11086        let ctx = test_context();
11087        let mut req = Request::new(crate::request::Method::Post, "/");
11088
11089        let token = "some-token";
11090        req.headers_mut()
11091            .insert("cookie", format!("csrf_token={}", token).into_bytes());
11092        // Using default header name instead of custom one
11093        req.headers_mut()
11094            .insert("x-csrf-token", token.as_bytes().to_vec());
11095
11096        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11097        assert!(result.is_break(), "Wrong header name should be rejected");
11098    }
11099
11100    #[test]
11101    fn csrf_cookie_parsing_multiple_cookies_picks_correct() {
11102        let csrf = CsrfMiddleware::new();
11103        let ctx = test_context();
11104        let mut req = Request::new(crate::request::Method::Post, "/");
11105
11106        let token = "correct-csrf";
11107        req.headers_mut().insert(
11108            "cookie",
11109            format!("session=abc; other=xyz; csrf_token={}; tracking=123", token).into_bytes(),
11110        );
11111        req.headers_mut()
11112            .insert("x-csrf-token", token.as_bytes().to_vec());
11113
11114        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11115        assert!(result.is_continue());
11116    }
11117
11118    #[test]
11119    fn csrf_cookie_parsing_spaces_around_semicolons() {
11120        let csrf = CsrfMiddleware::new();
11121        let ctx = test_context();
11122        let mut req = Request::new(crate::request::Method::Post, "/");
11123
11124        let token = "spaced-token";
11125        req.headers_mut().insert(
11126            "cookie",
11127            format!("session=abc ;  csrf_token={}  ; other=xyz", token).into_bytes(),
11128        );
11129        req.headers_mut()
11130            .insert("x-csrf-token", token.as_bytes().to_vec());
11131
11132        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11133        assert!(result.is_continue());
11134    }
11135
11136    #[test]
11137    fn csrf_error_response_status_is_403() {
11138        let csrf = CsrfMiddleware::new();
11139        let ctx = test_context();
11140
11141        // Test all state-changing methods return 403
11142        for method in [
11143            crate::request::Method::Post,
11144            crate::request::Method::Put,
11145            crate::request::Method::Delete,
11146            crate::request::Method::Patch,
11147        ] {
11148            let mut req = Request::new(method, "/");
11149            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11150            match result {
11151                ControlFlow::Break(response) => {
11152                    assert_eq!(
11153                        response.status(),
11154                        StatusCode::FORBIDDEN,
11155                        "Expected 403 for {:?}",
11156                        method
11157                    );
11158                }
11159                ControlFlow::Continue => panic!("Expected Break for {:?}", method),
11160            }
11161        }
11162    }
11163
11164    #[test]
11165    fn csrf_error_body_json_structure() {
11166        let csrf = CsrfMiddleware::new();
11167        let ctx = test_context();
11168        let mut req = Request::new(crate::request::Method::Post, "/");
11169
11170        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11171        if let ControlFlow::Break(response) = result {
11172            if let ResponseBody::Bytes(body) = response.body_ref() {
11173                let body_str = std::str::from_utf8(body).unwrap();
11174                // Verify JSON structure
11175                let parsed: serde_json::Value = serde_json::from_str(body_str)
11176                    .unwrap_or_else(|e| panic!("Invalid JSON: {}: {}", body_str, e));
11177                assert!(parsed["detail"].is_array());
11178                let detail = &parsed["detail"][0];
11179                assert_eq!(detail["type"], "csrf_error");
11180                assert!(detail["loc"].is_array());
11181                assert_eq!(detail["loc"][0], "header");
11182                assert_eq!(detail["loc"][1], "x-csrf-token");
11183                assert!(detail["msg"].is_string());
11184            } else {
11185                panic!("Expected Bytes body");
11186            }
11187        } else {
11188            panic!("Expected Break");
11189        }
11190    }
11191
11192    #[test]
11193    fn csrf_default_trait() {
11194        let csrf = CsrfMiddleware::default();
11195        assert_eq!(csrf.name(), "CSRF");
11196        // Should behave identically to new()
11197        let ctx = test_context();
11198        let mut req = Request::new(crate::request::Method::Get, "/");
11199        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11200        assert!(result.is_continue());
11201    }
11202
11203    #[test]
11204    fn csrf_mode_default_is_double_submit() {
11205        assert_eq!(CsrfMode::default(), CsrfMode::DoubleSubmit);
11206    }
11207
11208    #[test]
11209    fn csrf_double_submit_both_present_same_non_empty_passes() {
11210        // Explicit test of the core double-submit pattern
11211        let csrf = CsrfMiddleware::new();
11212        let ctx = test_context();
11213
11214        let token = "a1b2c3d4e5f6";
11215        let mut req = Request::new(crate::request::Method::Delete, "/resource/1");
11216        req.headers_mut()
11217            .insert("cookie", format!("csrf_token={}", token).into_bytes());
11218        req.headers_mut()
11219            .insert("x-csrf-token", token.as_bytes().to_vec());
11220
11221        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11222        assert!(result.is_continue());
11223    }
11224
11225    #[test]
11226    fn csrf_double_submit_case_sensitive() {
11227        // Token comparison should be case-sensitive
11228        let csrf = CsrfMiddleware::new();
11229        let ctx = test_context();
11230        let mut req = Request::new(crate::request::Method::Post, "/");
11231
11232        req.headers_mut()
11233            .insert("cookie", b"csrf_token=AbCdEf".to_vec());
11234        req.headers_mut().insert("x-csrf-token", b"abcdef".to_vec());
11235
11236        let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11237        assert!(
11238            result.is_break(),
11239            "Token comparison should be case-sensitive"
11240        );
11241    }
11242
11243    #[test]
11244    fn csrf_token_cookie_extractor_reads_csrf_cookie() {
11245        // Test that CsrfTokenCookie works as a cookie name marker
11246        use crate::extract::{CookieName, CsrfTokenCookie};
11247        assert_eq!(CsrfTokenCookie::NAME, "csrf_token");
11248    }
11249
11250    #[test]
11251    fn csrf_make_set_cookie_header_value_production() {
11252        let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", true);
11253        let s = std::str::from_utf8(&value).unwrap();
11254        assert!(s.contains("csrf_token=tok123"));
11255        assert!(s.contains("Path=/"));
11256        assert!(s.contains("SameSite=Strict"));
11257        assert!(s.contains("Secure"));
11258        assert!(!s.to_lowercase().contains("httponly"));
11259    }
11260
11261    #[test]
11262    fn csrf_make_set_cookie_header_value_development() {
11263        let value = CsrfMiddleware::make_set_cookie_header_value("csrf_token", "tok123", false);
11264        let s = std::str::from_utf8(&value).unwrap();
11265        assert!(s.contains("csrf_token=tok123"));
11266        assert!(s.contains("Path=/"));
11267        assert!(s.contains("SameSite=Strict"));
11268        assert!(!s.contains("Secure"));
11269    }
11270
11271    #[test]
11272    fn csrf_before_after_full_cycle_get_then_post() {
11273        // Simulate a full CSRF flow: GET sets cookie, POST uses it
11274        let csrf = CsrfMiddleware::new();
11275        let ctx = test_context();
11276
11277        // Step 1: GET request - generates token and sets cookie
11278        let mut get_req = Request::new(crate::request::Method::Get, "/form");
11279        let _ = futures_executor::block_on(csrf.before(&ctx, &mut get_req));
11280        let get_response = Response::ok();
11281        let get_result = futures_executor::block_on(csrf.after(&ctx, &get_req, get_response));
11282
11283        let set_cookie = header_value(&get_result, "set-cookie").expect("GET should set cookie");
11284        // Extract token value from "csrf_token=<value>; Path=/; ..."
11285        let token_value = set_cookie
11286            .strip_prefix("csrf_token=")
11287            .unwrap()
11288            .split(';')
11289            .next()
11290            .unwrap();
11291        assert!(!token_value.is_empty());
11292
11293        // Step 2: POST request - uses the token from cookie + header
11294        let mut post_req = Request::new(crate::request::Method::Post, "/form");
11295        post_req
11296            .headers_mut()
11297            .insert("cookie", format!("csrf_token={}", token_value).into_bytes());
11298        post_req
11299            .headers_mut()
11300            .insert("x-csrf-token", token_value.as_bytes().to_vec());
11301
11302        let result = futures_executor::block_on(csrf.before(&ctx, &mut post_req));
11303        assert!(result.is_continue(), "POST with valid token should pass");
11304    }
11305
11306    #[test]
11307    fn csrf_all_state_changing_methods_require_token() {
11308        let csrf = CsrfMiddleware::new();
11309        let ctx = test_context();
11310
11311        for method in [
11312            crate::request::Method::Post,
11313            crate::request::Method::Put,
11314            crate::request::Method::Delete,
11315            crate::request::Method::Patch,
11316        ] {
11317            let mut req = Request::new(method, "/resource");
11318            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11319            assert!(
11320                result.is_break(),
11321                "{:?} without token should be rejected",
11322                method
11323            );
11324        }
11325    }
11326
11327    #[test]
11328    fn csrf_all_safe_methods_pass_without_token() {
11329        let csrf = CsrfMiddleware::new();
11330        let ctx = test_context();
11331
11332        for method in [
11333            crate::request::Method::Get,
11334            crate::request::Method::Head,
11335            crate::request::Method::Options,
11336            crate::request::Method::Trace,
11337        ] {
11338            let mut req = Request::new(method, "/resource");
11339            let result = futures_executor::block_on(csrf.before(&ctx, &mut req));
11340            assert!(
11341                result.is_continue(),
11342                "{:?} should be allowed without token",
11343                method
11344            );
11345        }
11346    }
11347
11348    // =========================================================================
11349    // Middleware Stack Ordering Tests (Onion Model)
11350    // =========================================================================
11351
11352    /// Middleware that records execution order to a shared Vec.
11353    /// Used to verify the onion model (before in order, after in reverse).
11354    #[derive(Clone)]
11355    struct OrderRecordingMiddleware {
11356        id: &'static str,
11357        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11358    }
11359
11360    impl OrderRecordingMiddleware {
11361        fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11362            Self { id, log }
11363        }
11364    }
11365
11366    impl Middleware for OrderRecordingMiddleware {
11367        fn before<'a>(
11368            &'a self,
11369            _ctx: &'a RequestContext,
11370            _req: &'a mut Request,
11371        ) -> BoxFuture<'a, ControlFlow> {
11372            let id = self.id;
11373            let log = self.log.clone();
11374            Box::pin(async move {
11375                log.lock().unwrap().push(format!("{id}:before"));
11376                ControlFlow::Continue
11377            })
11378        }
11379
11380        fn after<'a>(
11381            &'a self,
11382            _ctx: &'a RequestContext,
11383            _req: &'a Request,
11384            response: Response,
11385        ) -> BoxFuture<'a, Response> {
11386            let id = self.id;
11387            let log = self.log.clone();
11388            Box::pin(async move {
11389                log.lock().unwrap().push(format!("{id}:after"));
11390                response
11391            })
11392        }
11393
11394        fn name(&self) -> &'static str {
11395            "OrderRecording"
11396        }
11397    }
11398
11399    /// Middleware that short-circuits in its before hook.
11400    struct ShortCircuitMiddleware {
11401        id: &'static str,
11402        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11403    }
11404
11405    impl ShortCircuitMiddleware {
11406        fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11407            Self { id, log }
11408        }
11409    }
11410
11411    impl Middleware for ShortCircuitMiddleware {
11412        fn before<'a>(
11413            &'a self,
11414            _ctx: &'a RequestContext,
11415            _req: &'a mut Request,
11416        ) -> BoxFuture<'a, ControlFlow> {
11417            let id = self.id;
11418            let log = self.log.clone();
11419            Box::pin(async move {
11420                log.lock().unwrap().push(format!("{id}:before:break"));
11421                ControlFlow::Break(
11422                    Response::with_status(StatusCode::FORBIDDEN)
11423                        .body(ResponseBody::Bytes(b"short-circuited".to_vec())),
11424                )
11425            })
11426        }
11427
11428        fn after<'a>(
11429            &'a self,
11430            _ctx: &'a RequestContext,
11431            _req: &'a Request,
11432            response: Response,
11433        ) -> BoxFuture<'a, Response> {
11434            let id = self.id;
11435            let log = self.log.clone();
11436            Box::pin(async move {
11437                log.lock().unwrap().push(format!("{id}:after"));
11438                response
11439            })
11440        }
11441
11442        fn name(&self) -> &'static str {
11443            "ShortCircuit"
11444        }
11445    }
11446
11447    /// Simple handler that records when it runs.
11448    struct RecordingHandler {
11449        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11450    }
11451
11452    impl RecordingHandler {
11453        fn new(log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11454            Self { log }
11455        }
11456    }
11457
11458    impl Handler for RecordingHandler {
11459        fn call<'a>(
11460            &'a self,
11461            _ctx: &'a RequestContext,
11462            _req: &'a mut Request,
11463        ) -> BoxFuture<'a, Response> {
11464            let log = self.log.clone();
11465            Box::pin(async move {
11466                log.lock().unwrap().push("handler".to_string());
11467                Response::ok().body(ResponseBody::Bytes(b"ok".to_vec()))
11468            })
11469        }
11470    }
11471
11472    #[test]
11473    fn middleware_stack_three_middleware_onion_order() {
11474        // Test that three middleware follow the onion model:
11475        // Before hooks run in order: 1 -> 2 -> 3
11476        // After hooks run in reverse: 3 -> 2 -> 1
11477        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11478
11479        let mut stack = MiddlewareStack::new();
11480        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11481        stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11482        stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11483
11484        let handler = RecordingHandler::new(log.clone());
11485        let ctx = test_context();
11486        let mut req = Request::new(crate::request::Method::Get, "/");
11487
11488        let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11489
11490        let execution_log = log.lock().unwrap().clone();
11491        assert_eq!(
11492            execution_log,
11493            vec![
11494                "mw1:before",
11495                "mw2:before",
11496                "mw3:before",
11497                "handler",
11498                "mw3:after",
11499                "mw2:after",
11500                "mw1:after",
11501            ]
11502        );
11503    }
11504
11505    #[test]
11506    fn middleware_stack_short_circuit_runs_prior_after_hooks() {
11507        // When middleware 2 short-circuits:
11508        // - mw1:before runs (returns Continue, count=1)
11509        // - mw2:before short-circuits (returns Break, count stays at 1)
11510        // - mw3:before does NOT run
11511        // - handler does NOT run
11512        // - Only middleware that successfully completed before (mw1) have after run
11513        // - mw1:after runs
11514        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11515
11516        let mut stack = MiddlewareStack::new();
11517        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11518        stack.push(ShortCircuitMiddleware::new("mw2", log.clone()));
11519        stack.push(OrderRecordingMiddleware::new("mw3", log.clone()));
11520
11521        let handler = RecordingHandler::new(log.clone());
11522        let ctx = test_context();
11523        let mut req = Request::new(crate::request::Method::Get, "/");
11524
11525        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11526
11527        // Should return the short-circuit response
11528        assert_eq!(response.status().as_u16(), 403);
11529
11530        let execution_log = log.lock().unwrap().clone();
11531        // Note: mw2's after hook does NOT run because it didn't return Continue
11532        // Only middleware that successfully completed before (returned Continue) have after run
11533        assert_eq!(
11534            execution_log,
11535            vec!["mw1:before", "mw2:before:break", "mw1:after",]
11536        );
11537    }
11538
11539    #[test]
11540    fn middleware_stack_first_middleware_short_circuits() {
11541        // When the first middleware short-circuits:
11542        // - mw1:before short-circuits (returns Break, count=0)
11543        // - No after hooks run (count=0)
11544        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11545
11546        let mut stack = MiddlewareStack::new();
11547        stack.push(ShortCircuitMiddleware::new("mw1", log.clone()));
11548        stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11549
11550        let handler = RecordingHandler::new(log.clone());
11551        let ctx = test_context();
11552        let mut req = Request::new(crate::request::Method::Get, "/");
11553
11554        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11555        assert_eq!(response.status().as_u16(), 403);
11556
11557        let execution_log = log.lock().unwrap().clone();
11558        // No after hooks run because no middleware returned Continue
11559        assert_eq!(execution_log, vec!["mw1:before:break",]);
11560    }
11561
11562    #[test]
11563    fn middleware_stack_empty_runs_handler_only() {
11564        // Empty stack should just run the handler (onion ordering variant)
11565        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11566
11567        let stack = MiddlewareStack::new();
11568        let handler = RecordingHandler::new(log.clone());
11569        let ctx = test_context();
11570        let mut req = Request::new(crate::request::Method::Get, "/");
11571
11572        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11573        assert_eq!(response.status().as_u16(), 200);
11574
11575        let execution_log = log.lock().unwrap().clone();
11576        assert_eq!(execution_log, vec!["handler"]);
11577    }
11578
11579    #[test]
11580    fn middleware_stack_single_middleware_ordering() {
11581        // Single middleware should have before -> handler -> after
11582        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11583
11584        let mut stack = MiddlewareStack::new();
11585        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11586
11587        let handler = RecordingHandler::new(log.clone());
11588        let ctx = test_context();
11589        let mut req = Request::new(crate::request::Method::Get, "/");
11590
11591        let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11592
11593        let execution_log = log.lock().unwrap().clone();
11594        assert_eq!(execution_log, vec!["mw1:before", "handler", "mw1:after",]);
11595    }
11596
11597    #[test]
11598    fn middleware_stack_five_middleware_onion_order() {
11599        // Test with five middleware for a longer chain
11600        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11601
11602        let mut stack = MiddlewareStack::new();
11603        stack.push(OrderRecordingMiddleware::new("a", log.clone()));
11604        stack.push(OrderRecordingMiddleware::new("b", log.clone()));
11605        stack.push(OrderRecordingMiddleware::new("c", log.clone()));
11606        stack.push(OrderRecordingMiddleware::new("d", log.clone()));
11607        stack.push(OrderRecordingMiddleware::new("e", log.clone()));
11608
11609        let handler = RecordingHandler::new(log.clone());
11610        let ctx = test_context();
11611        let mut req = Request::new(crate::request::Method::Get, "/");
11612
11613        let _response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11614
11615        let execution_log = log.lock().unwrap().clone();
11616        assert_eq!(
11617            execution_log,
11618            vec![
11619                "a:before", "b:before", "c:before", "d:before", "e:before", "handler", "e:after",
11620                "d:after", "c:after", "b:after", "a:after",
11621            ]
11622        );
11623    }
11624
11625    #[test]
11626    fn middleware_stack_short_circuit_at_end_runs_prior_afters() {
11627        // When the last middleware short-circuits:
11628        // - mw1:before runs (Continue, count=1)
11629        // - mw2:before runs (Continue, count=2)
11630        // - mw3:before short-circuits (Break, count stays at 2)
11631        // - handler does NOT run
11632        // - After hooks run for mw1 and mw2 only (they returned Continue)
11633        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11634
11635        let mut stack = MiddlewareStack::new();
11636        stack.push(OrderRecordingMiddleware::new("mw1", log.clone()));
11637        stack.push(OrderRecordingMiddleware::new("mw2", log.clone()));
11638        stack.push(ShortCircuitMiddleware::new("mw3", log.clone()));
11639
11640        let handler = RecordingHandler::new(log.clone());
11641        let ctx = test_context();
11642        let mut req = Request::new(crate::request::Method::Get, "/");
11643
11644        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11645        assert_eq!(response.status().as_u16(), 403);
11646
11647        let execution_log = log.lock().unwrap().clone();
11648        // mw3's after hook does NOT run because it didn't return Continue
11649        assert_eq!(
11650            execution_log,
11651            vec![
11652                "mw1:before",
11653                "mw2:before",
11654                "mw3:before:break",
11655                "mw2:after",
11656                "mw1:after",
11657            ]
11658        );
11659    }
11660
11661    /// Middleware that modifies the request in before and response in after.
11662    struct ModifyingMiddleware {
11663        id: &'static str,
11664        log: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
11665    }
11666
11667    impl ModifyingMiddleware {
11668        fn new(id: &'static str, log: std::sync::Arc<std::sync::Mutex<Vec<String>>>) -> Self {
11669            Self { id, log }
11670        }
11671    }
11672
11673    impl Middleware for ModifyingMiddleware {
11674        fn before<'a>(
11675            &'a self,
11676            _ctx: &'a RequestContext,
11677            req: &'a mut Request,
11678        ) -> BoxFuture<'a, ControlFlow> {
11679            let id = self.id;
11680            let log = self.log.clone();
11681            Box::pin(async move {
11682                // Add a header to track middleware order
11683                req.headers_mut()
11684                    .insert(format!("x-{id}-before"), b"true".to_vec());
11685                log.lock().unwrap().push(format!("{id}:before"));
11686                ControlFlow::Continue
11687            })
11688        }
11689
11690        fn after<'a>(
11691            &'a self,
11692            _ctx: &'a RequestContext,
11693            _req: &'a Request,
11694            response: Response,
11695        ) -> BoxFuture<'a, Response> {
11696            let id = self.id;
11697            let log = self.log.clone();
11698            Box::pin(async move {
11699                log.lock().unwrap().push(format!("{id}:after"));
11700                // Add a header to the response
11701                response.header(format!("x-{id}-after"), b"true".to_vec())
11702            })
11703        }
11704
11705        fn name(&self) -> &'static str {
11706            "Modifying"
11707        }
11708    }
11709
11710    #[test]
11711    fn middleware_stack_modifications_accumulate_correctly() {
11712        // Test that request modifications in before hooks accumulate,
11713        // and response modifications in after hooks accumulate
11714        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11715
11716        let mut stack = MiddlewareStack::new();
11717        stack.push(ModifyingMiddleware::new("mw1", log.clone()));
11718        stack.push(ModifyingMiddleware::new("mw2", log.clone()));
11719        stack.push(ModifyingMiddleware::new("mw3", log.clone()));
11720
11721        let handler = RecordingHandler::new(log.clone());
11722        let ctx = test_context();
11723        let mut req = Request::new(crate::request::Method::Get, "/");
11724
11725        let response = futures_executor::block_on(stack.execute(&handler, &ctx, &mut req));
11726
11727        // Check that all after hooks added their headers
11728        assert!(header_value(&response, "x-mw1-after").is_some());
11729        assert!(header_value(&response, "x-mw2-after").is_some());
11730        assert!(header_value(&response, "x-mw3-after").is_some());
11731
11732        // Check that the request was modified by all before hooks
11733        assert!(req.headers().contains("x-mw1-before"));
11734        assert!(req.headers().contains("x-mw2-before"));
11735        assert!(req.headers().contains("x-mw3-before"));
11736    }
11737
11738    #[test]
11739    fn layer_wrap_maintains_middleware_order() {
11740        // Test that Layer::wrap creates a Layered handler that maintains before->after ordering
11741        let log = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
11742
11743        // Create a layer with our recording middleware
11744        let layer = Layer::new(OrderRecordingMiddleware::new("layer", log.clone()));
11745
11746        // Wrap the recording handler
11747        let handler = RecordingHandler::new(log.clone());
11748        let layered_handler = layer.wrap(handler);
11749
11750        let ctx = test_context();
11751        let mut req = Request::new(crate::request::Method::Get, "/");
11752
11753        // Execute the layered handler directly (not via middleware stack)
11754        let _response = futures_executor::block_on(layered_handler.call(&ctx, &mut req));
11755
11756        let execution_log = log.lock().unwrap().clone();
11757        assert_eq!(
11758            execution_log,
11759            vec!["layer:before", "handler", "layer:after",]
11760        );
11761    }
11762}
11763
11764// ============================================================================
11765// Compression Middleware Tests (requires "compression" feature)
11766// ============================================================================
11767
11768#[cfg(all(test, feature = "compression"))]
11769mod compression_tests {
11770    use super::*;
11771    use crate::request::Method;
11772    use crate::response::ResponseBody;
11773
11774    fn test_context() -> RequestContext {
11775        RequestContext::new(asupersync::Cx::for_testing(), 1)
11776    }
11777
11778    #[test]
11779    fn compression_config_defaults() {
11780        let config = CompressionConfig::default();
11781        assert_eq!(config.min_size, 1024);
11782        assert_eq!(config.level, 6);
11783        assert!(!config.skip_content_types.is_empty());
11784    }
11785
11786    #[test]
11787    fn compression_config_builder() {
11788        let config = CompressionConfig::new().min_size(512).level(9);
11789        assert_eq!(config.min_size, 512);
11790        assert_eq!(config.level, 9);
11791    }
11792
11793    #[test]
11794    fn compression_level_clamped() {
11795        let config = CompressionConfig::new().level(100);
11796        assert_eq!(config.level, 9);
11797
11798        let config = CompressionConfig::new().level(0);
11799        assert_eq!(config.level, 1);
11800    }
11801
11802    #[test]
11803    fn skip_content_type_exact_match() {
11804        let config = CompressionConfig::default();
11805        assert!(config.should_skip_content_type("image/jpeg"));
11806        assert!(config.should_skip_content_type("image/jpeg; charset=utf-8"));
11807        assert!(!config.should_skip_content_type("text/html"));
11808    }
11809
11810    #[test]
11811    fn skip_content_type_prefix_match() {
11812        let config = CompressionConfig::default();
11813        // "video/" prefix should match any video type
11814        assert!(config.should_skip_content_type("video/mp4"));
11815        assert!(config.should_skip_content_type("video/webm"));
11816        assert!(config.should_skip_content_type("audio/mpeg"));
11817    }
11818
11819    #[test]
11820    fn compression_skips_small_responses() {
11821        let middleware = CompressionMiddleware::new();
11822        let ctx = test_context();
11823
11824        // Create request with Accept-Encoding: gzip
11825        let mut req = Request::new(Method::Get, "/");
11826        req.headers_mut()
11827            .insert("accept-encoding", b"gzip".to_vec());
11828
11829        // Create a small response (less than 1024 bytes)
11830        let response = Response::ok()
11831            .header("content-type", b"text/plain".to_vec())
11832            .body(ResponseBody::Bytes(b"Hello, World!".to_vec()));
11833
11834        // Run the after hook
11835        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11836
11837        // Should NOT be compressed (too small)
11838        let has_encoding = result
11839            .headers()
11840            .iter()
11841            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11842        assert!(!has_encoding, "Small response should not be compressed");
11843    }
11844
11845    #[test]
11846    fn compression_works_for_large_responses() {
11847        let config = CompressionConfig::new().min_size(10); // Lower threshold
11848        let middleware = CompressionMiddleware::with_config(config);
11849        let ctx = test_context();
11850
11851        // Create request with Accept-Encoding: gzip
11852        let mut req = Request::new(Method::Get, "/");
11853        req.headers_mut()
11854            .insert("accept-encoding", b"gzip".to_vec());
11855
11856        // Create a response with repetitive content (compresses well)
11857        let body = "Hello, World! ".repeat(100);
11858        let original_size = body.len();
11859
11860        let response = Response::ok()
11861            .header("content-type", b"text/plain".to_vec())
11862            .body(ResponseBody::Bytes(body.into_bytes()));
11863
11864        // Run the after hook
11865        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11866
11867        // Should be compressed
11868        let encoding = result
11869            .headers()
11870            .iter()
11871            .find(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11872        assert!(encoding.is_some(), "Large response should be compressed");
11873
11874        let (_, value) = encoding.unwrap();
11875        assert_eq!(value, b"gzip");
11876
11877        // Check Vary header
11878        let vary = result
11879            .headers()
11880            .iter()
11881            .find(|(name, _)| name.eq_ignore_ascii_case("vary"));
11882        assert!(vary.is_some(), "Should have Vary header");
11883
11884        // Verify compressed size is smaller
11885        if let ResponseBody::Bytes(compressed) = result.body_ref() {
11886            assert!(
11887                compressed.len() < original_size,
11888                "Compressed size should be smaller"
11889            );
11890        } else {
11891            panic!("Expected Bytes body");
11892        }
11893    }
11894
11895    #[test]
11896    fn compression_skips_without_accept_encoding() {
11897        let config = CompressionConfig::new().min_size(10);
11898        let middleware = CompressionMiddleware::with_config(config);
11899        let ctx = test_context();
11900
11901        // Create request WITHOUT Accept-Encoding
11902        let req = Request::new(Method::Get, "/");
11903
11904        let body = "Hello, World! ".repeat(100);
11905        let response = Response::ok()
11906            .header("content-type", b"text/plain".to_vec())
11907            .body(ResponseBody::Bytes(body.into_bytes()));
11908
11909        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11910
11911        // Should NOT be compressed (no Accept-Encoding)
11912        let has_encoding = result
11913            .headers()
11914            .iter()
11915            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11916        assert!(!has_encoding, "Should not compress without Accept-Encoding");
11917    }
11918
11919    #[test]
11920    fn compression_skips_already_compressed_content() {
11921        let config = CompressionConfig::new().min_size(10);
11922        let middleware = CompressionMiddleware::with_config(config);
11923        let ctx = test_context();
11924
11925        // Create request with Accept-Encoding: gzip
11926        let mut req = Request::new(Method::Get, "/");
11927        req.headers_mut()
11928            .insert("accept-encoding", b"gzip".to_vec());
11929
11930        // Create response with already-compressed content type
11931        let body = "Some image data".repeat(100);
11932        let response = Response::ok()
11933            .header("content-type", b"image/jpeg".to_vec())
11934            .body(ResponseBody::Bytes(body.into_bytes()));
11935
11936        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11937
11938        // Should NOT be compressed (image/jpeg is already compressed)
11939        let has_encoding = result
11940            .headers()
11941            .iter()
11942            .any(|(name, _)| name.eq_ignore_ascii_case("content-encoding"));
11943        assert!(
11944            !has_encoding,
11945            "Should not compress already-compressed content types"
11946        );
11947    }
11948
11949    #[test]
11950    fn compression_skips_if_already_has_content_encoding() {
11951        let config = CompressionConfig::new().min_size(10);
11952        let middleware = CompressionMiddleware::with_config(config);
11953        let ctx = test_context();
11954
11955        // Create request with Accept-Encoding: gzip
11956        let mut req = Request::new(Method::Get, "/");
11957        req.headers_mut()
11958            .insert("accept-encoding", b"gzip".to_vec());
11959
11960        // Create response that already has Content-Encoding
11961        let body = "Hello, World! ".repeat(100);
11962        let response = Response::ok()
11963            .header("content-type", b"text/plain".to_vec())
11964            .header("content-encoding", b"br".to_vec())
11965            .body(ResponseBody::Bytes(body.into_bytes()));
11966
11967        let result = futures_executor::block_on(middleware.after(&ctx, &req, response));
11968
11969        // Should NOT double-compress
11970        let encodings: Vec<_> = result
11971            .headers()
11972            .iter()
11973            .filter(|(name, _)| name.eq_ignore_ascii_case("content-encoding"))
11974            .collect();
11975
11976        // Should still have exactly one Content-Encoding header (the original br)
11977        assert_eq!(encodings.len(), 1);
11978        assert_eq!(encodings[0].1, b"br");
11979    }
11980
11981    #[test]
11982    fn accepts_gzip_parses_header_correctly() {
11983        // Test various Accept-Encoding header formats
11984
11985        // Simple gzip
11986        let mut req = Request::new(Method::Get, "/");
11987        req.headers_mut()
11988            .insert("accept-encoding", b"gzip".to_vec());
11989        assert!(CompressionMiddleware::accepts_gzip(&req));
11990
11991        // Multiple encodings
11992        let mut req = Request::new(Method::Get, "/");
11993        req.headers_mut()
11994            .insert("accept-encoding", b"deflate, gzip, br".to_vec());
11995        assert!(CompressionMiddleware::accepts_gzip(&req));
11996
11997        // With quality values
11998        let mut req = Request::new(Method::Get, "/");
11999        req.headers_mut()
12000            .insert("accept-encoding", b"gzip;q=1.0, identity;q=0.5".to_vec());
12001        assert!(CompressionMiddleware::accepts_gzip(&req));
12002
12003        // Wildcard
12004        let mut req = Request::new(Method::Get, "/");
12005        req.headers_mut().insert("accept-encoding", b"*".to_vec());
12006        assert!(CompressionMiddleware::accepts_gzip(&req));
12007
12008        // No gzip
12009        let mut req = Request::new(Method::Get, "/");
12010        req.headers_mut()
12011            .insert("accept-encoding", b"deflate, br".to_vec());
12012        assert!(!CompressionMiddleware::accepts_gzip(&req));
12013
12014        // No header
12015        let req_no_header = Request::new(Method::Get, "/");
12016        assert!(!CompressionMiddleware::accepts_gzip(&req_no_header));
12017    }
12018
12019    #[test]
12020    fn compression_middleware_name() {
12021        let middleware = CompressionMiddleware::new();
12022        assert_eq!(middleware.name(), "Compression");
12023    }
12024}
12025
12026// ============================================================================
12027// Request Inspection Middleware Tests
12028// ============================================================================
12029
12030#[cfg(test)]
12031mod request_inspection_tests {
12032    use super::*;
12033    use crate::request::Method;
12034    use crate::response::ResponseBody;
12035
12036    fn test_context() -> RequestContext {
12037        RequestContext::new(asupersync::Cx::for_testing(), 1)
12038    }
12039
12040    #[test]
12041    fn inspection_middleware_default_creates_normal_verbosity() {
12042        let mw = RequestInspectionMiddleware::new();
12043        assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
12044        assert_eq!(mw.slow_threshold_ms, 1000);
12045        assert_eq!(mw.max_body_preview, 2048);
12046        assert_eq!(mw.name(), "RequestInspection");
12047    }
12048
12049    #[test]
12050    fn inspection_middleware_builder_methods() {
12051        let mw = RequestInspectionMiddleware::new()
12052            .verbosity(InspectionVerbosity::Verbose)
12053            .slow_threshold_ms(500)
12054            .max_body_preview(4096)
12055            .log_config(LogConfig::development())
12056            .redact_header("x-api-key");
12057
12058        assert_eq!(mw.verbosity, InspectionVerbosity::Verbose);
12059        assert_eq!(mw.slow_threshold_ms, 500);
12060        assert_eq!(mw.max_body_preview, 4096);
12061        assert!(mw.redact_headers.contains("x-api-key"));
12062        // Default redacted headers should still be present
12063        assert!(mw.redact_headers.contains("authorization"));
12064        assert!(mw.redact_headers.contains("cookie"));
12065    }
12066
12067    #[test]
12068    fn inspection_before_continues_processing() {
12069        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12070        let ctx = test_context();
12071        let mut req = Request::new(Method::Post, "/api/users");
12072
12073        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12074        assert!(result.is_continue());
12075    }
12076
12077    #[test]
12078    fn inspection_after_returns_response_unchanged() {
12079        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12080        let ctx = test_context();
12081        let mut req = Request::new(Method::Get, "/health");
12082
12083        // Run before to set the InspectionStart extension
12084        let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12085
12086        let response = Response::ok().body(ResponseBody::Bytes(b"OK".to_vec()));
12087
12088        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
12089        assert_eq!(result.status().as_u16(), 200);
12090        assert_eq!(result.body_ref().len(), 2);
12091    }
12092
12093    #[test]
12094    fn inspection_stores_start_extension() {
12095        let mw = RequestInspectionMiddleware::new();
12096        let ctx = test_context();
12097        let mut req = Request::new(Method::Get, "/");
12098
12099        let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12100
12101        // Verify the InspectionStart extension was set
12102        assert!(req.get_extension::<InspectionStart>().is_some());
12103    }
12104
12105    #[test]
12106    fn inspection_all_verbosity_levels_continue() {
12107        for verbosity in [
12108            InspectionVerbosity::Minimal,
12109            InspectionVerbosity::Normal,
12110            InspectionVerbosity::Verbose,
12111        ] {
12112            let mw = RequestInspectionMiddleware::new().verbosity(verbosity);
12113            let ctx = test_context();
12114            let mut req = Request::new(Method::Get, "/test");
12115            req.headers_mut()
12116                .insert("content-type", b"text/plain".to_vec());
12117
12118            let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12119            assert!(
12120                result.is_continue(),
12121                "Verbosity {verbosity:?} should continue"
12122            );
12123        }
12124    }
12125
12126    #[test]
12127    fn inspection_verbose_with_json_body() {
12128        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
12129        let ctx = test_context();
12130        let body = br#"{"name":"Alice","age":30}"#;
12131        let mut req = Request::new(Method::Post, "/api/users");
12132        req.headers_mut()
12133            .insert("content-type", b"application/json".to_vec());
12134        req.set_body(Body::Bytes(body.to_vec()));
12135
12136        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12137        assert!(result.is_continue());
12138    }
12139
12140    #[test]
12141    fn inspection_verbose_after_with_json_response() {
12142        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Verbose);
12143        let ctx = test_context();
12144        let mut req = Request::new(Method::Get, "/api/users/1");
12145
12146        let _ = futures_executor::block_on(mw.before(&ctx, &mut req));
12147
12148        let response = Response::ok()
12149            .header("content-type", b"application/json".to_vec())
12150            .body(ResponseBody::Bytes(br#"{"id":1,"name":"Alice"}"#.to_vec()));
12151
12152        let result = futures_executor::block_on(mw.after(&ctx, &req, response));
12153        assert_eq!(result.status().as_u16(), 200);
12154    }
12155
12156    #[test]
12157    fn inspection_redacts_sensitive_headers() {
12158        let mw = RequestInspectionMiddleware::new();
12159
12160        // Verify default redacted headers are present
12161        assert!(mw.redact_headers.contains("authorization"));
12162        assert!(mw.redact_headers.contains("proxy-authorization"));
12163        assert!(mw.redact_headers.contains("cookie"));
12164        assert!(mw.redact_headers.contains("set-cookie"));
12165    }
12166
12167    #[test]
12168    fn inspection_format_headers_redacts() {
12169        let mw = RequestInspectionMiddleware::new().redact_header("x-secret");
12170
12171        let headers = vec![
12172            ("content-type", b"text/plain".as_slice()),
12173            ("x-secret", b"my-secret-value".as_slice()),
12174            ("x-normal", b"visible".as_slice()),
12175        ];
12176
12177        let output = mw.format_inspection_headers(headers.into_iter());
12178        assert!(output.contains("content-type: text/plain"));
12179        assert!(output.contains("x-secret: [REDACTED]"));
12180        assert!(output.contains("x-normal: visible"));
12181        assert!(!output.contains("my-secret-value"));
12182    }
12183
12184    #[test]
12185    fn inspection_format_body_preview_truncates() {
12186        let mw = RequestInspectionMiddleware::new().max_body_preview(10);
12187
12188        let body = b"Hello, World! This is a long body.";
12189        let result = mw.format_body_preview(body, None);
12190        assert!(result.is_some());
12191        let text = result.unwrap();
12192        assert!(text.ends_with("..."));
12193        assert!(text.len() <= 15); // 10 chars + "..."
12194    }
12195
12196    #[test]
12197    fn inspection_format_body_preview_empty() {
12198        let mw = RequestInspectionMiddleware::new();
12199        assert!(mw.format_body_preview(b"", None).is_none());
12200    }
12201
12202    #[test]
12203    fn inspection_format_body_preview_zero_max() {
12204        let mw = RequestInspectionMiddleware::new().max_body_preview(0);
12205        assert!(mw.format_body_preview(b"hello", None).is_none());
12206    }
12207
12208    #[test]
12209    fn inspection_format_body_preview_json_pretty() {
12210        let mw = RequestInspectionMiddleware::new();
12211        let body = br#"{"key":"value","num":42}"#;
12212        let ct = b"application/json".as_slice();
12213        let result = mw.format_body_preview(body, Some(ct));
12214        assert!(result.is_some());
12215        let text = result.unwrap();
12216        // Pretty-printed JSON should contain newlines
12217        assert!(text.contains('\n'));
12218        assert!(text.contains("\"key\": \"value\""));
12219    }
12220
12221    #[test]
12222    fn inspection_format_body_preview_non_json() {
12223        let mw = RequestInspectionMiddleware::new();
12224        let body = b"Hello, World!";
12225        let ct = b"text/plain".as_slice();
12226        let result = mw.format_body_preview(body, Some(ct));
12227        assert_eq!(result.unwrap(), "Hello, World!");
12228    }
12229
12230    #[test]
12231    fn inspection_format_body_preview_binary() {
12232        let mw = RequestInspectionMiddleware::new();
12233        let body: &[u8] = &[0xFF, 0xFE, 0xFD, 0x00];
12234        let result = mw.format_body_preview(body, None);
12235        assert!(result.is_some());
12236        assert!(result.unwrap().contains("binary"));
12237    }
12238
12239    #[test]
12240    fn try_pretty_json_valid_object() {
12241        let result = try_pretty_json(r#"{"a":"b","c":1}"#);
12242        assert!(result.is_some());
12243        let pretty = result.unwrap();
12244        assert!(pretty.contains('\n'));
12245        assert!(pretty.contains("  \"a\": \"b\""));
12246    }
12247
12248    #[test]
12249    fn try_pretty_json_valid_array() {
12250        let result = try_pretty_json(r"[1,2,3]");
12251        assert!(result.is_some());
12252        let pretty = result.unwrap();
12253        assert!(pretty.contains('\n'));
12254    }
12255
12256    #[test]
12257    fn try_pretty_json_empty_object() {
12258        let result = try_pretty_json("{}");
12259        assert!(result.is_some());
12260        assert_eq!(result.unwrap(), "{}");
12261    }
12262
12263    #[test]
12264    fn try_pretty_json_empty_array() {
12265        let result = try_pretty_json("[]");
12266        assert!(result.is_some());
12267        assert_eq!(result.unwrap(), "[]");
12268    }
12269
12270    #[test]
12271    fn try_pretty_json_not_json() {
12272        assert!(try_pretty_json("hello world").is_none());
12273        assert!(try_pretty_json("12345").is_none());
12274    }
12275
12276    #[test]
12277    fn try_pretty_json_nested() {
12278        let input = r#"{"user":{"name":"Alice","roles":["admin","user"]}}"#;
12279        let result = try_pretty_json(input);
12280        assert!(result.is_some());
12281        let pretty = result.unwrap();
12282        assert!(pretty.contains("\"user\":"));
12283        assert!(pretty.contains("\"name\": \"Alice\""));
12284        assert!(pretty.contains("\"roles\":"));
12285    }
12286
12287    #[test]
12288    fn try_pretty_json_with_escapes() {
12289        let input = r#"{"msg":"hello \"world\""}"#;
12290        let result = try_pretty_json(input);
12291        assert!(result.is_some());
12292        let pretty = result.unwrap();
12293        assert!(pretty.contains(r#"\"world\""#));
12294    }
12295
12296    #[test]
12297    fn inspection_name() {
12298        let mw = RequestInspectionMiddleware::new();
12299        assert_eq!(mw.name(), "RequestInspection");
12300    }
12301
12302    #[test]
12303    fn inspection_default_via_default_trait() {
12304        let mw = RequestInspectionMiddleware::default();
12305        assert_eq!(mw.verbosity, InspectionVerbosity::Normal);
12306        assert_eq!(mw.slow_threshold_ms, 1000);
12307    }
12308
12309    #[test]
12310    fn inspection_with_query_string() {
12311        let mw = RequestInspectionMiddleware::new().verbosity(InspectionVerbosity::Minimal);
12312        let ctx = test_context();
12313        let mut req = Request::new(Method::Get, "/search");
12314        req.set_query(Some("q=rust&page=1".to_string()));
12315
12316        let result = futures_executor::block_on(mw.before(&ctx, &mut req));
12317        assert!(result.is_continue());
12318    }
12319
12320    #[test]
12321    fn inspection_response_body_stream() {
12322        let mw = RequestInspectionMiddleware::new();
12323        let result = mw.format_response_preview(&ResponseBody::Empty, None);
12324        assert!(result.is_none());
12325    }
12326}
12327
12328// ============================================================================
12329// Rate Limiting Middleware Tests
12330// ============================================================================
12331
12332#[cfg(test)]
12333mod rate_limit_tests {
12334    use super::*;
12335    use crate::request::Method;
12336    use crate::response::{ResponseBody, StatusCode};
12337    use std::time::Duration;
12338
12339    fn test_context() -> RequestContext {
12340        RequestContext::new(asupersync::Cx::for_testing(), 1)
12341    }
12342
12343    fn run_rate_limit_before(mw: &RateLimitMiddleware, req: &mut Request) -> ControlFlow {
12344        let ctx = test_context();
12345        let fut = mw.before(&ctx, req);
12346        futures_executor::block_on(fut)
12347    }
12348
12349    fn run_rate_limit_after(mw: &RateLimitMiddleware, req: &Request, resp: Response) -> Response {
12350        let ctx = test_context();
12351        let fut = mw.after(&ctx, req, resp);
12352        futures_executor::block_on(fut)
12353    }
12354
12355    #[test]
12356    fn rate_limit_default_allows_requests() {
12357        let mw = RateLimitMiddleware::new();
12358        let mut req = Request::new(Method::Get, "/api/test");
12359        req.headers_mut()
12360            .insert("x-forwarded-for", b"192.168.1.1".to_vec());
12361
12362        let result = run_rate_limit_before(&mw, &mut req);
12363        assert!(result.is_continue(), "first request should be allowed");
12364    }
12365
12366    #[test]
12367    fn rate_limit_fixed_window_blocks_after_limit() {
12368        let mw = RateLimitMiddleware::builder()
12369            .requests(3)
12370            .per(Duration::from_secs(60))
12371            .algorithm(RateLimitAlgorithm::FixedWindow)
12372            .key_extractor(IpKeyExtractor)
12373            .build();
12374
12375        for i in 0..3 {
12376            let mut req = Request::new(Method::Get, "/api/test");
12377            req.headers_mut()
12378                .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12379            let result = run_rate_limit_before(&mw, &mut req);
12380            assert!(
12381                result.is_continue(),
12382                "request {i} should be allowed within limit"
12383            );
12384        }
12385
12386        // Fourth request should be blocked
12387        let mut req = Request::new(Method::Get, "/api/test");
12388        req.headers_mut()
12389            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12390        let result = run_rate_limit_before(&mw, &mut req);
12391        assert!(result.is_break(), "fourth request should be blocked");
12392
12393        // Verify 429 status
12394        if let ControlFlow::Break(resp) = result {
12395            assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12396        }
12397    }
12398
12399    #[test]
12400    fn rate_limit_different_keys_independent() {
12401        let mw = RateLimitMiddleware::builder()
12402            .requests(2)
12403            .per(Duration::from_secs(60))
12404            .algorithm(RateLimitAlgorithm::FixedWindow)
12405            .key_extractor(IpKeyExtractor)
12406            .build();
12407
12408        // Two requests from IP A
12409        for _ in 0..2 {
12410            let mut req = Request::new(Method::Get, "/");
12411            req.headers_mut()
12412                .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12413            assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12414        }
12415
12416        // IP A is now exhausted
12417        let mut req = Request::new(Method::Get, "/");
12418        req.headers_mut()
12419            .insert("x-forwarded-for", b"1.1.1.1".to_vec());
12420        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12421
12422        // IP B should still be fine
12423        let mut req = Request::new(Method::Get, "/");
12424        req.headers_mut()
12425            .insert("x-forwarded-for", b"2.2.2.2".to_vec());
12426        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12427    }
12428
12429    #[test]
12430    fn rate_limit_token_bucket_allows_burst() {
12431        let mw = RateLimitMiddleware::builder()
12432            .requests(5)
12433            .per(Duration::from_secs(60))
12434            .algorithm(RateLimitAlgorithm::TokenBucket)
12435            .key_extractor(IpKeyExtractor)
12436            .build();
12437
12438        // Should allow 5 rapid requests (full bucket)
12439        for i in 0..5 {
12440            let mut req = Request::new(Method::Get, "/");
12441            req.headers_mut()
12442                .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12443            let result = run_rate_limit_before(&mw, &mut req);
12444            assert!(result.is_continue(), "burst request {i} should be allowed");
12445        }
12446
12447        // 6th request should be blocked (bucket empty)
12448        let mut req = Request::new(Method::Get, "/");
12449        req.headers_mut()
12450            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12451        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12452    }
12453
12454    #[test]
12455    fn rate_limit_sliding_window_basic() {
12456        let mw = RateLimitMiddleware::builder()
12457            .requests(3)
12458            .per(Duration::from_secs(60))
12459            .algorithm(RateLimitAlgorithm::SlidingWindow)
12460            .key_extractor(IpKeyExtractor)
12461            .build();
12462
12463        for i in 0..3 {
12464            let mut req = Request::new(Method::Get, "/");
12465            req.headers_mut()
12466                .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12467            assert!(
12468                run_rate_limit_before(&mw, &mut req).is_continue(),
12469                "sliding window request {i} should be allowed"
12470            );
12471        }
12472
12473        // Should block once limit reached
12474        let mut req = Request::new(Method::Get, "/");
12475        req.headers_mut()
12476            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12477        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12478    }
12479
12480    #[test]
12481    fn rate_limit_header_key_extractor() {
12482        let mw = RateLimitMiddleware::builder()
12483            .requests(2)
12484            .per(Duration::from_secs(60))
12485            .algorithm(RateLimitAlgorithm::FixedWindow)
12486            .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12487            .build();
12488
12489        // Two requests with same API key
12490        for _ in 0..2 {
12491            let mut req = Request::new(Method::Get, "/");
12492            req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12493            assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12494        }
12495
12496        // Same key blocked
12497        let mut req = Request::new(Method::Get, "/");
12498        req.headers_mut().insert("x-api-key", b"key-abc".to_vec());
12499        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12500
12501        // Different key still allowed
12502        let mut req = Request::new(Method::Get, "/");
12503        req.headers_mut().insert("x-api-key", b"key-xyz".to_vec());
12504        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12505    }
12506
12507    #[test]
12508    fn rate_limit_path_key_extractor() {
12509        let mw = RateLimitMiddleware::builder()
12510            .requests(1)
12511            .per(Duration::from_secs(60))
12512            .algorithm(RateLimitAlgorithm::FixedWindow)
12513            .key_extractor(PathKeyExtractor)
12514            .build();
12515
12516        let mut req = Request::new(Method::Get, "/api/a");
12517        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12518
12519        // Same path is blocked
12520        let mut req = Request::new(Method::Get, "/api/a");
12521        assert!(run_rate_limit_before(&mw, &mut req).is_break());
12522
12523        // Different path is allowed
12524        let mut req = Request::new(Method::Get, "/api/b");
12525        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12526    }
12527
12528    #[test]
12529    fn rate_limit_no_key_skips_limiting() {
12530        let mw = RateLimitMiddleware::builder()
12531            .requests(1)
12532            .per(Duration::from_secs(60))
12533            .algorithm(RateLimitAlgorithm::FixedWindow)
12534            .key_extractor(HeaderKeyExtractor::new("x-api-key"))
12535            .build();
12536
12537        // Request without the header — no key extracted, should pass
12538        let mut req = Request::new(Method::Get, "/");
12539        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12540
12541        // Still passes even with many requests (no key = no limiting)
12542        for _ in 0..10 {
12543            let mut req = Request::new(Method::Get, "/");
12544            assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12545        }
12546    }
12547
12548    #[test]
12549    fn rate_limit_response_headers_on_success() {
12550        let mw = RateLimitMiddleware::builder()
12551            .requests(10)
12552            .per(Duration::from_secs(60))
12553            .algorithm(RateLimitAlgorithm::FixedWindow)
12554            .key_extractor(IpKeyExtractor)
12555            .build();
12556
12557        let mut req = Request::new(Method::Get, "/");
12558        req.headers_mut()
12559            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12560        let cf = run_rate_limit_before(&mw, &mut req);
12561        assert!(cf.is_continue());
12562
12563        let resp = Response::with_status(StatusCode::OK);
12564        let resp = run_rate_limit_after(&mw, &req, resp);
12565
12566        // Verify rate limit headers are present
12567        let headers = resp.headers();
12568        let has_limit = headers
12569            .iter()
12570            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12571        let has_remaining = headers
12572            .iter()
12573            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-remaining"));
12574        let has_reset = headers
12575            .iter()
12576            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-reset"));
12577
12578        assert!(has_limit, "should have X-RateLimit-Limit header");
12579        assert!(has_remaining, "should have X-RateLimit-Remaining header");
12580        assert!(has_reset, "should have X-RateLimit-Reset header");
12581
12582        // Check limit value
12583        let limit_val = headers
12584            .iter()
12585            .find(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"))
12586            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12587            .unwrap();
12588        assert_eq!(limit_val, "10");
12589    }
12590
12591    #[test]
12592    fn rate_limit_429_response_has_retry_after() {
12593        let mw = RateLimitMiddleware::builder()
12594            .requests(1)
12595            .per(Duration::from_secs(60))
12596            .algorithm(RateLimitAlgorithm::FixedWindow)
12597            .key_extractor(IpKeyExtractor)
12598            .build();
12599
12600        // Consume the single allowed request
12601        let mut req = Request::new(Method::Get, "/");
12602        req.headers_mut()
12603            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12604        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12605
12606        // Second request should be blocked with 429
12607        let mut req = Request::new(Method::Get, "/");
12608        req.headers_mut()
12609            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12610        let result = run_rate_limit_before(&mw, &mut req);
12611
12612        if let ControlFlow::Break(resp) = result {
12613            assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
12614
12615            // Should have Retry-After header
12616            let has_retry = resp
12617                .headers()
12618                .iter()
12619                .any(|(n, _)| n.eq_ignore_ascii_case("retry-after"));
12620            assert!(has_retry, "429 response should have Retry-After header");
12621
12622            // Should have JSON body
12623            let has_ct = resp
12624                .headers()
12625                .iter()
12626                .any(|(n, v)| n.eq_ignore_ascii_case("content-type") && v == b"application/json");
12627            assert!(has_ct, "429 response should have JSON content type");
12628        } else {
12629            panic!("expected Break(429)");
12630        }
12631    }
12632
12633    #[test]
12634    fn rate_limit_no_headers_when_disabled() {
12635        let mw = RateLimitMiddleware::builder()
12636            .requests(10)
12637            .per(Duration::from_secs(60))
12638            .algorithm(RateLimitAlgorithm::FixedWindow)
12639            .key_extractor(IpKeyExtractor)
12640            .include_headers(false)
12641            .build();
12642
12643        let mut req = Request::new(Method::Get, "/");
12644        req.headers_mut()
12645            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12646        assert!(run_rate_limit_before(&mw, &mut req).is_continue());
12647
12648        let resp = Response::with_status(StatusCode::OK);
12649        let resp = run_rate_limit_after(&mw, &req, resp);
12650
12651        let has_limit = resp
12652            .headers()
12653            .iter()
12654            .any(|(n, _)| n.eq_ignore_ascii_case("x-ratelimit-limit"));
12655        assert!(
12656            !has_limit,
12657            "should NOT have rate limit headers when disabled"
12658        );
12659    }
12660
12661    #[test]
12662    fn rate_limit_custom_retry_message() {
12663        let mw = RateLimitMiddleware::builder()
12664            .requests(1)
12665            .per(Duration::from_secs(60))
12666            .algorithm(RateLimitAlgorithm::FixedWindow)
12667            .key_extractor(IpKeyExtractor)
12668            .retry_message("Slow down, partner!")
12669            .build();
12670
12671        // Exhaust limit
12672        let mut req = Request::new(Method::Get, "/");
12673        req.headers_mut()
12674            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12675        run_rate_limit_before(&mw, &mut req);
12676
12677        // Check custom message in 429 body
12678        let mut req = Request::new(Method::Get, "/");
12679        req.headers_mut()
12680            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12681        if let ControlFlow::Break(resp) = run_rate_limit_before(&mw, &mut req) {
12682            if let ResponseBody::Bytes(body) = resp.body_ref() {
12683                let body_str = std::str::from_utf8(body).unwrap();
12684                assert!(
12685                    body_str.contains("Slow down, partner!"),
12686                    "expected custom message in body, got: {body_str}"
12687                );
12688            } else {
12689                panic!("expected Bytes body");
12690            }
12691        } else {
12692            panic!("expected Break(429)");
12693        }
12694    }
12695
12696    #[test]
12697    fn rate_limit_ip_extractor_x_forwarded_for() {
12698        let extractor = IpKeyExtractor;
12699        let mut req = Request::new(Method::Get, "/");
12700        req.headers_mut()
12701            .insert("x-forwarded-for", b"1.2.3.4, 5.6.7.8".to_vec());
12702        assert_eq!(extractor.extract_key(&req), Some("1.2.3.4".to_string()));
12703    }
12704
12705    #[test]
12706    fn rate_limit_ip_extractor_x_real_ip() {
12707        let extractor = IpKeyExtractor;
12708        let mut req = Request::new(Method::Get, "/");
12709        req.headers_mut().insert("x-real-ip", b"9.8.7.6".to_vec());
12710        assert_eq!(extractor.extract_key(&req), Some("9.8.7.6".to_string()));
12711    }
12712
12713    #[test]
12714    fn rate_limit_ip_extractor_fallback() {
12715        let extractor = IpKeyExtractor;
12716        let req = Request::new(Method::Get, "/");
12717        assert_eq!(extractor.extract_key(&req), Some("unknown".to_string()));
12718    }
12719
12720    // Tests for secure ConnectedIpKeyExtractor (bd-u9gw)
12721    #[test]
12722    fn connected_ip_extractor_with_remote_addr() {
12723        use std::net::{IpAddr, Ipv4Addr};
12724
12725        let extractor = ConnectedIpKeyExtractor;
12726        let mut req = Request::new(Method::Get, "/");
12727        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100))));
12728
12729        assert_eq!(
12730            extractor.extract_key(&req),
12731            Some("192.168.1.100".to_string())
12732        );
12733    }
12734
12735    #[test]
12736    fn connected_ip_extractor_without_remote_addr() {
12737        let extractor = ConnectedIpKeyExtractor;
12738        let req = Request::new(Method::Get, "/");
12739
12740        // Should return None when no RemoteAddr is set
12741        assert_eq!(extractor.extract_key(&req), None);
12742    }
12743
12744    #[test]
12745    fn connected_ip_extractor_ignores_headers() {
12746        use std::net::{IpAddr, Ipv4Addr};
12747
12748        let extractor = ConnectedIpKeyExtractor;
12749        let mut req = Request::new(Method::Get, "/");
12750        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12751        // Add spoofed header - should be ignored
12752        req.headers_mut()
12753            .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12754
12755        // Should use RemoteAddr, not the header
12756        assert_eq!(extractor.extract_key(&req), Some("10.0.0.1".to_string()));
12757    }
12758
12759    // Tests for TrustedProxyIpKeyExtractor (bd-u9gw)
12760    #[test]
12761    fn trusted_proxy_extractor_from_trusted_proxy() {
12762        use std::net::{IpAddr, Ipv4Addr};
12763
12764        let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12765
12766        let mut req = Request::new(Method::Get, "/");
12767        // Request came from trusted proxy 10.0.0.1
12768        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
12769        // Proxy set X-Forwarded-For with real client IP
12770        req.headers_mut()
12771            .insert("x-forwarded-for", b"203.0.113.50".to_vec());
12772
12773        // Should trust the header and extract client IP
12774        assert_eq!(
12775            extractor.extract_key(&req),
12776            Some("203.0.113.50".to_string())
12777        );
12778    }
12779
12780    #[test]
12781    fn trusted_proxy_extractor_from_untrusted_direct() {
12782        use std::net::{IpAddr, Ipv4Addr};
12783
12784        let extractor = TrustedProxyIpKeyExtractor::new().trust_cidr("10.0.0.0/8");
12785
12786        let mut req = Request::new(Method::Get, "/");
12787        // Request came directly from client (not a trusted proxy)
12788        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50))));
12789        // Client tries to spoof X-Forwarded-For
12790        req.headers_mut()
12791            .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12792
12793        // Should ignore header and use RemoteAddr
12794        assert_eq!(
12795            extractor.extract_key(&req),
12796            Some("203.0.113.50".to_string())
12797        );
12798    }
12799
12800    #[test]
12801    fn trusted_proxy_extractor_no_remote_addr() {
12802        let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12803
12804        let mut req = Request::new(Method::Get, "/");
12805        // No RemoteAddr set - should return None (safer than guessing)
12806        req.headers_mut()
12807            .insert("x-forwarded-for", b"1.2.3.4".to_vec());
12808
12809        assert_eq!(extractor.extract_key(&req), None);
12810    }
12811
12812    #[test]
12813    fn trusted_proxy_extractor_loopback_ipv4() {
12814        use std::net::{IpAddr, Ipv4Addr};
12815
12816        let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12817
12818        let mut req = Request::new(Method::Get, "/");
12819        req.insert_extension(RemoteAddr(IpAddr::V4(Ipv4Addr::LOCALHOST)));
12820        req.headers_mut()
12821            .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12822
12823        assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12824    }
12825
12826    #[test]
12827    fn trusted_proxy_extractor_loopback_ipv6() {
12828        use std::net::{IpAddr, Ipv6Addr};
12829
12830        let extractor = TrustedProxyIpKeyExtractor::new().trust_loopback();
12831
12832        let mut req = Request::new(Method::Get, "/");
12833        req.insert_extension(RemoteAddr(IpAddr::V6(Ipv6Addr::LOCALHOST)));
12834        req.headers_mut()
12835            .insert("x-forwarded-for", b"8.8.8.8".to_vec());
12836
12837        assert_eq!(extractor.extract_key(&req), Some("8.8.8.8".to_string()));
12838    }
12839
12840    #[test]
12841    fn cidr_parsing() {
12842        // Valid CIDRs
12843        assert!(parse_cidr("10.0.0.0/8").is_some());
12844        assert!(parse_cidr("192.168.1.0/24").is_some());
12845        assert!(parse_cidr("0.0.0.0/0").is_some());
12846        assert!(parse_cidr("::1/128").is_some());
12847        assert!(parse_cidr("::/0").is_some());
12848
12849        // Invalid CIDRs
12850        assert!(parse_cidr("10.0.0.0/33").is_none()); // Prefix too large for IPv4
12851        assert!(parse_cidr("invalid").is_none());
12852        assert!(parse_cidr("10.0.0.0").is_none()); // Missing prefix
12853    }
12854
12855    #[test]
12856    fn ip_in_cidr_matching() {
12857        use std::net::{IpAddr, Ipv4Addr};
12858
12859        let cidr_10 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0));
12860
12861        // In range
12862        assert!(ip_in_cidr(
12863            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
12864            cidr_10,
12865            8
12866        ));
12867        assert!(ip_in_cidr(
12868            IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)),
12869            cidr_10,
12870            8
12871        ));
12872
12873        // Out of range
12874        assert!(!ip_in_cidr(
12875            IpAddr::V4(Ipv4Addr::new(11, 0, 0, 1)),
12876            cidr_10,
12877            8
12878        ));
12879        assert!(!ip_in_cidr(
12880            IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
12881            cidr_10,
12882            8
12883        ));
12884    }
12885
12886    #[test]
12887    fn rate_limit_composite_key_extractor() {
12888        let extractor =
12889            CompositeKeyExtractor::new(vec![Box::new(IpKeyExtractor), Box::new(PathKeyExtractor)]);
12890
12891        let mut req = Request::new(Method::Get, "/api/users");
12892        req.headers_mut()
12893            .insert("x-forwarded-for", b"10.0.0.1".to_vec());
12894
12895        let key = extractor.extract_key(&req);
12896        assert_eq!(key, Some("10.0.0.1:/api/users".to_string()));
12897    }
12898
12899    #[test]
12900    fn rate_limit_builder_defaults() {
12901        let mw = RateLimitMiddleware::builder().build();
12902        assert_eq!(mw.config.max_requests, 100);
12903        assert_eq!(mw.config.window, Duration::from_secs(60));
12904        assert_eq!(mw.config.algorithm, RateLimitAlgorithm::TokenBucket);
12905        assert!(mw.config.include_headers);
12906    }
12907
12908    #[test]
12909    fn rate_limit_builder_per_minute() {
12910        let mw = RateLimitMiddleware::builder()
12911            .requests(50)
12912            .per_minute(2)
12913            .algorithm(RateLimitAlgorithm::SlidingWindow)
12914            .build();
12915        assert_eq!(mw.config.max_requests, 50);
12916        assert_eq!(mw.config.window, Duration::from_secs(120));
12917        assert_eq!(mw.config.algorithm, RateLimitAlgorithm::SlidingWindow);
12918    }
12919
12920    #[test]
12921    fn rate_limit_builder_per_hour() {
12922        let mw = RateLimitMiddleware::builder()
12923            .requests(1000)
12924            .per_hour(1)
12925            .build();
12926        assert_eq!(mw.config.window, Duration::from_secs(3600));
12927    }
12928
12929    #[test]
12930    fn rate_limit_middleware_name() {
12931        let mw = RateLimitMiddleware::new();
12932        assert_eq!(mw.name(), "RateLimit");
12933    }
12934
12935    #[test]
12936    fn rate_limit_default_via_default_trait() {
12937        let mw = RateLimitMiddleware::default();
12938        assert_eq!(mw.config.max_requests, 100);
12939    }
12940
12941    // ========================================================================
12942    // ETag Middleware Tests
12943    // ========================================================================
12944
12945    #[test]
12946    fn etag_middleware_generates_etag_for_get() {
12947        let mw = ETagMiddleware::new();
12948        let ctx = test_context();
12949        let req = Request::new(crate::request::Method::Get, "/resource");
12950
12951        // Create response with body
12952        let response = Response::ok()
12953            .header("content-type", b"application/json".to_vec())
12954            .body(ResponseBody::Bytes(br#"{"status":"ok"}"#.to_vec()));
12955
12956        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
12957
12958        // Should have ETag header
12959        let etag = response
12960            .headers()
12961            .iter()
12962            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
12963        assert!(etag.is_some(), "Response should have ETag header");
12964
12965        // ETag should be a quoted hex string
12966        let etag_value = std::str::from_utf8(&etag.unwrap().1).unwrap();
12967        assert!(etag_value.starts_with('"'), "ETag should start with quote");
12968        assert!(etag_value.ends_with('"'), "ETag should end with quote");
12969    }
12970
12971    #[test]
12972    fn etag_middleware_returns_304_on_match() {
12973        let mw = ETagMiddleware::new();
12974        let ctx = test_context();
12975
12976        // First request to get the ETag
12977        let req1 = Request::new(crate::request::Method::Get, "/resource");
12978        let body = br#"{"status":"ok"}"#.to_vec();
12979        let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
12980        let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
12981
12982        let etag = response1
12983            .headers()
12984            .iter()
12985            .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
12986            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
12987            .unwrap();
12988
12989        // Second request with If-None-Match header
12990        let mut req2 = Request::new(crate::request::Method::Get, "/resource");
12991        req2.headers_mut()
12992            .insert("if-none-match", etag.as_bytes().to_vec());
12993
12994        let response2 = Response::ok().body(ResponseBody::Bytes(body));
12995        let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
12996
12997        // Should return 304 Not Modified
12998        assert_eq!(response2.status().as_u16(), 304);
12999        assert!(response2.body_ref().is_empty());
13000    }
13001
13002    #[test]
13003    fn etag_middleware_returns_full_response_on_mismatch() {
13004        let mw = ETagMiddleware::new();
13005        let ctx = test_context();
13006
13007        let mut req = Request::new(crate::request::Method::Get, "/resource");
13008        req.headers_mut()
13009            .insert("if-none-match", b"\"old-etag\"".to_vec());
13010
13011        let body = br#"{"status":"updated"}"#.to_vec();
13012        let response = Response::ok().body(ResponseBody::Bytes(body.clone()));
13013        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13014
13015        // Should return 200 OK with body
13016        assert_eq!(response.status().as_u16(), 200);
13017        assert!(!response.body_ref().is_empty());
13018    }
13019
13020    #[test]
13021    fn etag_middleware_weak_etag_generation() {
13022        let config = ETagConfig::new().weak(true);
13023        let mw = ETagMiddleware::with_config(config);
13024        let ctx = test_context();
13025        let req = Request::new(crate::request::Method::Get, "/resource");
13026
13027        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13028        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13029
13030        let etag = response
13031            .headers()
13032            .iter()
13033            .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
13034            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
13035            .unwrap();
13036
13037        assert!(etag.starts_with("W/"), "Weak ETag should start with W/");
13038    }
13039
13040    #[test]
13041    fn etag_middleware_skips_post_requests() {
13042        let mw = ETagMiddleware::new();
13043        let ctx = test_context();
13044        let req = Request::new(crate::request::Method::Post, "/resource");
13045
13046        let response = Response::ok().body(ResponseBody::Bytes(b"created".to_vec()));
13047        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13048
13049        // POST should not get ETag
13050        let etag = response
13051            .headers()
13052            .iter()
13053            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13054        assert!(etag.is_none(), "POST should not have ETag");
13055    }
13056
13057    #[test]
13058    fn etag_middleware_handles_head_requests() {
13059        let mw = ETagMiddleware::new();
13060        let ctx = test_context();
13061        let req = Request::new(crate::request::Method::Head, "/resource");
13062
13063        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13064        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13065
13066        // HEAD should get ETag
13067        let etag = response
13068            .headers()
13069            .iter()
13070            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13071        assert!(etag.is_some(), "HEAD should have ETag");
13072    }
13073
13074    #[test]
13075    fn etag_middleware_disabled_mode() {
13076        let config = ETagConfig::new().mode(ETagMode::Disabled);
13077        let mw = ETagMiddleware::with_config(config);
13078        let ctx = test_context();
13079        let req = Request::new(crate::request::Method::Get, "/resource");
13080
13081        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13082        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13083
13084        // Should not have ETag when disabled
13085        let etag = response
13086            .headers()
13087            .iter()
13088            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13089        assert!(etag.is_none(), "Disabled mode should not add ETag");
13090    }
13091
13092    #[test]
13093    fn etag_middleware_min_size_filter() {
13094        let config = ETagConfig::new().min_size(1000);
13095        let mw = ETagMiddleware::with_config(config);
13096        let ctx = test_context();
13097        let req = Request::new(crate::request::Method::Get, "/resource");
13098
13099        // Small body below min_size
13100        let response = Response::ok().body(ResponseBody::Bytes(b"small".to_vec()));
13101        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13102
13103        // Should not have ETag for small body
13104        let etag = response
13105            .headers()
13106            .iter()
13107            .find(|(name, _)| name.eq_ignore_ascii_case("etag"));
13108        assert!(etag.is_none(), "Small body should not get ETag");
13109    }
13110
13111    #[test]
13112    fn etag_middleware_preserves_existing_etag() {
13113        let config = ETagConfig::new().mode(ETagMode::Manual);
13114        let mw = ETagMiddleware::with_config(config);
13115        let ctx = test_context();
13116
13117        // First request to set up cached ETag
13118        let mut req = Request::new(crate::request::Method::Get, "/resource");
13119        req.headers_mut()
13120            .insert("if-none-match", b"\"custom-etag\"".to_vec());
13121
13122        // Response with pre-set ETag matching the request
13123        let response = Response::ok()
13124            .header("etag", b"\"custom-etag\"".to_vec())
13125            .body(ResponseBody::Bytes(b"data".to_vec()));
13126        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13127
13128        // Should return 304 since custom ETag matches
13129        assert_eq!(response.status().as_u16(), 304);
13130    }
13131
13132    #[test]
13133    fn etag_middleware_wildcard_if_none_match() {
13134        let mw = ETagMiddleware::new();
13135        let ctx = test_context();
13136        let mut req = Request::new(crate::request::Method::Get, "/resource");
13137        req.headers_mut().insert("if-none-match", b"*".to_vec());
13138
13139        let response = Response::ok().body(ResponseBody::Bytes(b"data".to_vec()));
13140        let response = futures_executor::block_on(mw.after(&ctx, &req, response));
13141
13142        // Wildcard should match any ETag
13143        assert_eq!(response.status().as_u16(), 304);
13144    }
13145
13146    #[test]
13147    fn etag_middleware_weak_comparison_matches() {
13148        let mw = ETagMiddleware::new();
13149        let ctx = test_context();
13150
13151        // Get the strong ETag
13152        let req1 = Request::new(crate::request::Method::Get, "/resource");
13153        let body = b"test data".to_vec();
13154        let response1 = Response::ok().body(ResponseBody::Bytes(body.clone()));
13155        let response1 = futures_executor::block_on(mw.after(&ctx, &req1, response1));
13156
13157        let etag = response1
13158            .headers()
13159            .iter()
13160            .find(|(name, _)| name.eq_ignore_ascii_case("etag"))
13161            .map(|(_, v)| std::str::from_utf8(v).unwrap().to_string())
13162            .unwrap();
13163
13164        // Send request with weak version of the same ETag
13165        let mut req2 = Request::new(crate::request::Method::Get, "/resource");
13166        let weak_etag = format!("W/{}", etag);
13167        req2.headers_mut()
13168            .insert("if-none-match", weak_etag.as_bytes().to_vec());
13169
13170        let response2 = Response::ok().body(ResponseBody::Bytes(body));
13171        let response2 = futures_executor::block_on(mw.after(&ctx, &req2, response2));
13172
13173        // Weak comparison should match
13174        assert_eq!(response2.status().as_u16(), 304);
13175    }
13176
13177    #[test]
13178    fn etag_middleware_name() {
13179        let mw = ETagMiddleware::new();
13180        assert_eq!(mw.name(), "ETagMiddleware");
13181    }
13182
13183    #[test]
13184    fn etag_config_builder() {
13185        let config = ETagConfig::new()
13186            .mode(ETagMode::Auto)
13187            .weak(true)
13188            .min_size(512);
13189
13190        assert_eq!(config.mode, ETagMode::Auto);
13191        assert!(config.weak);
13192        assert_eq!(config.min_size, 512);
13193    }
13194
13195    #[test]
13196    fn etag_generates_consistent_hash() {
13197        // Same data should produce same ETag
13198        let etag1 = ETagMiddleware::generate_etag(b"hello world", false);
13199        let etag2 = ETagMiddleware::generate_etag(b"hello world", false);
13200        assert_eq!(etag1, etag2);
13201
13202        // Different data should produce different ETag
13203        let etag3 = ETagMiddleware::generate_etag(b"hello world!", false);
13204        assert_ne!(etag1, etag3);
13205    }
13206}