Skip to main content

ranvier_guard/
lib.rs

1//! # ranvier-guard — HTTP Security/Policy Guard Nodes
2//!
3//! Guard nodes are typed Transition nodes that enforce security and policy
4//! constraints as visible, traceable pipeline steps — replacing hidden Tower
5//! middleware layers.
6//!
7//! Each Guard reads context from the Bus (e.g., request headers, client IP)
8//! and either passes the input through unchanged or returns a Fault.
9//!
10//! ## Available Guards
11//!
12//! | Guard | Purpose | Bus Read | Bus Write |
13//! |-------|---------|----------|-----------|
14//! | [`CorsGuard`] | Origin validation + CORS headers | `RequestOrigin` | `CorsHeaders` |
15//! | [`RateLimitGuard`] | Per-client token-bucket rate limiting | `ClientIdentity` | — |
16//! | [`SecurityHeadersGuard`] | Standard security response headers | — | `SecurityHeaders` |
17//! | [`IpFilterGuard`] | Allow/deny-list IP filtering | `ClientIp` | — |
18//! | [`AccessLogGuard`] | Structured access logging | `AccessLogRequest` | `AccessLogEntry` |
19//!
20//! ## Example
21//!
22//! ```rust,ignore
23//! use ranvier_guard::*;
24//!
25//! Axon::simple::<String>("api")
26//!     .then(AccessLogGuard::new())
27//!     .then(CorsGuard::new(CorsConfig::default()))
28//!     .then(SecurityHeadersGuard::new(SecurityPolicy::default()))
29//!     .then(business_logic)
30//! ```
31
32use async_trait::async_trait;
33use ranvier_core::iam::{enforce_policy, IamIdentity, IamPolicy};
34use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
35use serde::{Deserialize, Serialize};
36use std::collections::HashSet;
37use std::marker::PhantomData;
38use std::sync::Arc;
39use std::time::Instant;
40use subtle::ConstantTimeEq;
41use tokio::sync::Mutex;
42
43// ---------------------------------------------------------------------------
44// CorsGuard
45// ---------------------------------------------------------------------------
46
47/// CORS guard configuration.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct CorsConfig {
50    pub allowed_origins: Vec<String>,
51    pub allowed_methods: Vec<String>,
52    pub allowed_headers: Vec<String>,
53    pub max_age_seconds: u64,
54    pub allow_credentials: bool,
55}
56
57impl Default for CorsConfig {
58    fn default() -> Self {
59        Self {
60            allowed_origins: vec!["*".to_string()],
61            allowed_methods: vec![
62                "GET".into(),
63                "POST".into(),
64                "PUT".into(),
65                "DELETE".into(),
66                "OPTIONS".into(),
67            ],
68            allowed_headers: vec!["Content-Type".into(), "Authorization".into()],
69            max_age_seconds: 86400,
70            allow_credentials: false,
71        }
72    }
73}
74
75impl CorsConfig {
76    pub fn new() -> Self {
77        Self::default()
78    }
79
80    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
81        self.allowed_origins.push(origin.into());
82        self
83    }
84}
85
86/// Bus-injectable type representing the request origin header.
87#[derive(Debug, Clone)]
88pub struct RequestOrigin(pub String);
89
90/// CORS guard Transition — validates the request origin against allowed origins.
91///
92/// Reads `RequestOrigin` from the Bus. If the origin is not allowed, returns Fault.
93/// Writes CORS response headers to the Bus as `CorsHeaders`.
94#[derive(Debug, Clone)]
95pub struct CorsGuard<T> {
96    config: CorsConfig,
97    _marker: PhantomData<T>,
98}
99
100impl<T> CorsGuard<T> {
101    pub fn new(config: CorsConfig) -> Self {
102        Self {
103            config,
104            _marker: PhantomData,
105        }
106    }
107
108    /// Create a fully permissive CORS guard for development and testing.
109    ///
110    /// Allows all origins (`*`), all standard HTTP methods, and common headers.
111    /// A `tracing::warn` is emitted to remind that this should not be used in production.
112    ///
113    /// # Example
114    ///
115    /// ```rust,ignore
116    /// Ranvier::http()
117    ///     .guard(CorsGuard::<()>::permissive())
118    /// ```
119    pub fn permissive() -> Self {
120        tracing::warn!("CorsGuard::permissive() — all origins allowed; do not use in production");
121        Self {
122            config: CorsConfig {
123                allowed_origins: vec!["*".to_string()],
124                allowed_methods: vec![
125                    "GET".into(), "POST".into(), "PUT".into(), "DELETE".into(),
126                    "PATCH".into(), "OPTIONS".into(), "HEAD".into(),
127                ],
128                allowed_headers: vec![
129                    "Content-Type".into(), "Authorization".into(), "Accept".into(),
130                    "Origin".into(), "X-Requested-With".into(),
131                ],
132                max_age_seconds: 86400,
133                allow_credentials: false,
134            },
135            _marker: PhantomData,
136        }
137    }
138
139    /// Returns a reference to the CORS configuration.
140    pub fn cors_config(&self) -> &CorsConfig {
141        &self.config
142    }
143}
144
145/// CORS headers to be applied to the response.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CorsHeaders {
148    pub access_control_allow_origin: String,
149    pub access_control_allow_methods: String,
150    pub access_control_allow_headers: String,
151    pub access_control_max_age: String,
152}
153
154#[async_trait]
155impl<T> Transition<T, T> for CorsGuard<T>
156where
157    T: Send + Sync + 'static,
158{
159    type Error = String;
160    type Resources = ();
161
162    async fn run(
163        &self,
164        input: T,
165        _resources: &Self::Resources,
166        bus: &mut Bus,
167    ) -> Outcome<T, Self::Error> {
168        let origin = bus
169            .read::<RequestOrigin>()
170            .map(|o| o.0.clone())
171            .unwrap_or_default();
172
173        let allowed = self.config.allowed_origins.contains(&"*".to_string())
174            || self.config.allowed_origins.contains(&origin);
175
176        if !allowed && !origin.is_empty() {
177            return Outcome::fault(format!("CORS: origin '{}' not allowed", origin));
178        }
179
180        let allow_origin = if self.config.allowed_origins.contains(&"*".to_string()) {
181            "*".to_string()
182        } else {
183            origin
184        };
185
186        bus.insert(CorsHeaders {
187            access_control_allow_origin: allow_origin,
188            access_control_allow_methods: self.config.allowed_methods.join(", "),
189            access_control_allow_headers: self.config.allowed_headers.join(", "),
190            access_control_max_age: self.config.max_age_seconds.to_string(),
191        });
192
193        Outcome::next(input)
194    }
195}
196
197// ---------------------------------------------------------------------------
198// RateLimitGuard
199// ---------------------------------------------------------------------------
200
201/// Bus-injectable type representing the client identity for rate limiting.
202#[derive(Debug, Clone, Hash, PartialEq, Eq)]
203pub struct ClientIdentity(pub String);
204
205/// Rate limit error with retry-after information.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RateLimitError {
208    pub message: String,
209    pub retry_after_ms: u64,
210}
211
212impl std::fmt::Display for RateLimitError {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        write!(f, "{} (retry after {}ms)", self.message, self.retry_after_ms)
215    }
216}
217
218/// Simple token-bucket rate limiter state.
219struct RateBucket {
220    tokens: f64,
221    last_refill: Instant,
222}
223
224/// Rate limit guard — enforces per-client request rate limits.
225///
226/// Reads `ClientIdentity` from the Bus. Uses a token-bucket algorithm.
227///
228/// Stale buckets are automatically pruned when `bucket_ttl` is set.
229/// The TTL check runs lazily on each request (no background task required).
230pub struct RateLimitGuard<T> {
231    max_requests: u64,
232    window_ms: u64,
233    buckets: Arc<Mutex<std::collections::HashMap<String, RateBucket>>>,
234    /// If > 0, buckets idle longer than this (in ms) are removed on next access.
235    bucket_ttl_ms: u64,
236    _marker: PhantomData<T>,
237}
238
239impl<T> RateLimitGuard<T> {
240    pub fn new(max_requests: u64, window_ms: u64) -> Self {
241        Self {
242            max_requests,
243            window_ms,
244            buckets: Arc::new(Mutex::new(std::collections::HashMap::new())),
245            bucket_ttl_ms: 0,
246            _marker: PhantomData,
247        }
248    }
249
250    /// Set a TTL for idle buckets. Buckets not accessed within this duration
251    /// are lazily pruned on subsequent requests.
252    ///
253    /// Default: no TTL (buckets persist forever).
254    pub fn with_bucket_ttl(mut self, ttl: std::time::Duration) -> Self {
255        self.bucket_ttl_ms = ttl.as_millis() as u64;
256        self
257    }
258
259    /// Returns the maximum requests per window.
260    pub fn max_requests(&self) -> u64 {
261        self.max_requests
262    }
263
264    /// Returns the window duration in milliseconds.
265    pub fn window_ms(&self) -> u64 {
266        self.window_ms
267    }
268
269    /// Returns the configured bucket TTL in milliseconds (0 = disabled).
270    pub fn bucket_ttl_ms(&self) -> u64 {
271        self.bucket_ttl_ms
272    }
273}
274
275impl<T> Clone for RateLimitGuard<T> {
276    fn clone(&self) -> Self {
277        Self {
278            max_requests: self.max_requests,
279            window_ms: self.window_ms,
280            buckets: self.buckets.clone(),
281            bucket_ttl_ms: self.bucket_ttl_ms,
282            _marker: PhantomData,
283        }
284    }
285}
286
287impl<T> std::fmt::Debug for RateLimitGuard<T> {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        f.debug_struct("RateLimitGuard")
290            .field("max_requests", &self.max_requests)
291            .field("window_ms", &self.window_ms)
292            .field("bucket_ttl_ms", &self.bucket_ttl_ms)
293            .finish()
294    }
295}
296
297#[async_trait]
298impl<T> Transition<T, T> for RateLimitGuard<T>
299where
300    T: Send + Sync + 'static,
301{
302    type Error = String;
303    type Resources = ();
304
305    async fn run(
306        &self,
307        input: T,
308        _resources: &Self::Resources,
309        bus: &mut Bus,
310    ) -> Outcome<T, Self::Error> {
311        let client_id = bus
312            .read::<ClientIdentity>()
313            .map(|c| c.0.clone())
314            .unwrap_or_else(|| "anonymous".to_string());
315
316        let mut buckets = self.buckets.lock().await;
317        let now = Instant::now();
318
319        // Lazy prune: remove stale buckets that haven't been accessed within the TTL
320        if self.bucket_ttl_ms > 0 {
321            let ttl = std::time::Duration::from_millis(self.bucket_ttl_ms);
322            buckets.retain(|_, b| now.duration_since(b.last_refill) < ttl);
323        }
324
325        let rate = self.max_requests as f64 / self.window_ms as f64 * 1000.0;
326
327        let bucket = buckets.entry(client_id).or_insert(RateBucket {
328            tokens: self.max_requests as f64,
329            last_refill: now,
330        });
331
332        // Refill tokens based on elapsed time
333        let elapsed_ms = now.duration_since(bucket.last_refill).as_millis() as f64;
334        bucket.tokens = (bucket.tokens + elapsed_ms * rate / 1000.0).min(self.max_requests as f64);
335        bucket.last_refill = now;
336
337        if bucket.tokens >= 1.0 {
338            bucket.tokens -= 1.0;
339            Outcome::next(input)
340        } else {
341            let retry_after = ((1.0 - bucket.tokens) / rate * 1000.0) as u64;
342            Outcome::fault(format!(
343                "Rate limit exceeded. Retry after {}ms",
344                retry_after
345            ))
346        }
347    }
348}
349
350// ---------------------------------------------------------------------------
351// SecurityHeadersGuard
352// ---------------------------------------------------------------------------
353
354/// Security policy configuration for HTTP response headers.
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct SecurityPolicy {
357    pub x_frame_options: String,
358    pub x_content_type_options: String,
359    pub strict_transport_security: String,
360    pub content_security_policy: Option<String>,
361    pub x_xss_protection: String,
362    pub referrer_policy: String,
363}
364
365impl Default for SecurityPolicy {
366    fn default() -> Self {
367        Self {
368            x_frame_options: "DENY".to_string(),
369            x_content_type_options: "nosniff".to_string(),
370            strict_transport_security: "max-age=31536000; includeSubDomains".to_string(),
371            content_security_policy: None,
372            x_xss_protection: "1; mode=block".to_string(),
373            referrer_policy: "strict-origin-when-cross-origin".to_string(),
374        }
375    }
376}
377
378impl SecurityPolicy {
379    pub fn new() -> Self {
380        Self::default()
381    }
382
383    pub fn with_csp(mut self, csp: impl Into<String>) -> Self {
384        self.content_security_policy = Some(csp.into());
385        self
386    }
387}
388
389/// Security headers stored in the Bus for the HTTP layer to apply.
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct SecurityHeaders(pub SecurityPolicy);
392
393/// Security headers guard — injects standard security headers into the Bus.
394#[derive(Debug, Clone)]
395pub struct SecurityHeadersGuard<T> {
396    policy: SecurityPolicy,
397    _marker: PhantomData<T>,
398}
399
400impl<T> SecurityHeadersGuard<T> {
401    pub fn new(policy: SecurityPolicy) -> Self {
402        Self {
403            policy,
404            _marker: PhantomData,
405        }
406    }
407
408    /// Returns a reference to the security policy.
409    pub fn policy(&self) -> &SecurityPolicy {
410        &self.policy
411    }
412}
413
414#[async_trait]
415impl<T> Transition<T, T> for SecurityHeadersGuard<T>
416where
417    T: Send + Sync + 'static,
418{
419    type Error = String;
420    type Resources = ();
421
422    async fn run(
423        &self,
424        input: T,
425        _resources: &Self::Resources,
426        bus: &mut Bus,
427    ) -> Outcome<T, Self::Error> {
428        bus.insert(SecurityHeaders(self.policy.clone()));
429        Outcome::next(input)
430    }
431}
432
433// ---------------------------------------------------------------------------
434// IpFilterGuard
435// ---------------------------------------------------------------------------
436
437/// Bus-injectable type representing the client IP address.
438#[derive(Debug, Clone)]
439pub struct ClientIp(pub String);
440
441/// A set of trusted proxy IPs for safe X-Forwarded-For extraction.
442///
443/// When the direct connection comes from a trusted proxy, the rightmost
444/// non-trusted IP in the X-Forwarded-For chain is used as the client IP.
445/// If the direct connection is NOT from a trusted proxy, X-Forwarded-For
446/// is ignored and the direct connection IP is used instead.
447///
448/// ## Example
449///
450/// ```rust,ignore
451/// use ranvier_guard::TrustedProxies;
452///
453/// let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
454///
455/// // Direct IP from trusted proxy, XFF has client chain
456/// let ip = proxies.extract("203.0.113.50, 10.0.0.1", "10.0.0.2");
457/// assert_eq!(ip, "203.0.113.50");
458///
459/// // Direct IP is NOT a trusted proxy → ignore XFF
460/// let ip = proxies.extract("spoofed-ip", "198.51.100.7");
461/// assert_eq!(ip, "198.51.100.7");
462/// ```
463#[derive(Debug, Clone)]
464pub struct TrustedProxies {
465    proxies: HashSet<String>,
466}
467
468impl TrustedProxies {
469    /// Create a new TrustedProxies set.
470    pub fn new(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
471        Self {
472            proxies: ips.into_iter().map(|s| s.into()).collect(),
473        }
474    }
475
476    /// Extract the real client IP from X-Forwarded-For header and direct connection IP.
477    ///
478    /// If the `direct_ip` is NOT in the trusted set, XFF is ignored (anti-spoofing).
479    /// Otherwise, walks the XFF chain right-to-left and returns the first non-trusted IP.
480    pub fn extract(&self, xff_header: &str, direct_ip: &str) -> String {
481        // If direct connection is not from a trusted proxy, don't trust XFF
482        if !self.proxies.contains(direct_ip) {
483            return direct_ip.to_string();
484        }
485
486        // Walk the XFF chain right-to-left, skip trusted proxies
487        let parts: Vec<&str> = xff_header.split(',').map(|s| s.trim()).collect();
488        for ip in parts.iter().rev() {
489            if !ip.is_empty() && !self.proxies.contains(*ip) {
490                return ip.to_string();
491            }
492        }
493
494        // Fallback: all IPs in XFF are trusted proxies, use direct IP
495        direct_ip.to_string()
496    }
497
498    /// Check if the given IP is a trusted proxy.
499    pub fn is_trusted(&self, ip: &str) -> bool {
500        self.proxies.contains(ip)
501    }
502}
503
504/// IP filter mode.
505#[derive(Debug, Clone)]
506pub enum IpFilterMode {
507    /// Only allow IPs in the set.
508    AllowList(HashSet<String>),
509    /// Block IPs in the set.
510    DenyList(HashSet<String>),
511}
512
513/// IP filter guard — allows or denies requests based on client IP.
514///
515/// Reads `ClientIp` from the Bus.
516#[derive(Debug, Clone)]
517pub struct IpFilterGuard<T> {
518    mode: IpFilterMode,
519    _marker: PhantomData<T>,
520}
521
522impl<T> IpFilterGuard<T> {
523    pub fn allow_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
524        Self {
525            mode: IpFilterMode::AllowList(ips.into_iter().map(|s| s.into()).collect()),
526            _marker: PhantomData,
527        }
528    }
529
530    pub fn deny_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
531        Self {
532            mode: IpFilterMode::DenyList(ips.into_iter().map(|s| s.into()).collect()),
533            _marker: PhantomData,
534        }
535    }
536
537    /// Clone the guard configuration as `IpFilterGuard<()>` for type-erased execution.
538    pub fn clone_as_unit(&self) -> IpFilterGuard<()> {
539        IpFilterGuard {
540            mode: self.mode.clone(),
541            _marker: PhantomData,
542        }
543    }
544}
545
546#[async_trait]
547impl<T> Transition<T, T> for IpFilterGuard<T>
548where
549    T: Send + Sync + 'static,
550{
551    type Error = String;
552    type Resources = ();
553
554    async fn run(
555        &self,
556        input: T,
557        _resources: &Self::Resources,
558        bus: &mut Bus,
559    ) -> Outcome<T, Self::Error> {
560        let client_ip = bus
561            .read::<ClientIp>()
562            .map(|ip| ip.0.clone())
563            .unwrap_or_default();
564
565        match &self.mode {
566            IpFilterMode::AllowList(allowed) => {
567                if allowed.contains(&client_ip) {
568                    Outcome::next(input)
569                } else {
570                    Outcome::fault(format!("IP '{}' not in allow list", client_ip))
571                }
572            }
573            IpFilterMode::DenyList(denied) => {
574                if denied.contains(&client_ip) {
575                    Outcome::fault(format!("IP '{}' is denied", client_ip))
576                } else {
577                    Outcome::next(input)
578                }
579            }
580        }
581    }
582}
583
584// ---------------------------------------------------------------------------
585// AccessLogGuard
586// ---------------------------------------------------------------------------
587
588/// Request metadata injected into the Bus before `AccessLogGuard` runs.
589///
590/// Typically set by an HTTP extractor or middleware before the guard.
591#[derive(Debug, Clone, Serialize, Deserialize)]
592pub struct AccessLogRequest {
593    pub method: String,
594    pub path: String,
595}
596
597/// Access log entry written to the Bus by `AccessLogGuard`.
598///
599/// Downstream nodes can read this to inspect what was logged.
600#[derive(Debug, Clone, Serialize, Deserialize)]
601pub struct AccessLogEntry {
602    pub method: String,
603    pub path: String,
604    pub timestamp_ms: u64,
605}
606
607/// HTTP access log guard — logs request metadata and writes an [`AccessLogEntry`]
608/// to the Bus.
609///
610/// This is a **pass-through** guard: it always returns `Outcome::next(input)`.
611/// It never faults — if no [`AccessLogRequest`] is in the Bus, it logs an empty
612/// entry.
613///
614/// # Example
615///
616/// ```ignore
617/// Axon::new("api")
618///     .then(AccessLogGuard::new()
619///         .redact_paths(vec!["/auth/login".into()]))
620///     .then(CorsGuard::default())
621///     .then(business_logic)
622/// ```
623#[derive(Debug, Clone)]
624pub struct AccessLogGuard<T> {
625    redact_paths: Vec<String>,
626    _marker: PhantomData<T>,
627}
628
629impl<T> AccessLogGuard<T> {
630    /// Create a new `AccessLogGuard` with default settings.
631    pub fn new() -> Self {
632        Self {
633            redact_paths: Vec::new(),
634            _marker: PhantomData,
635        }
636    }
637
638    /// Paths whose entries will have the path replaced with `"[redacted]"`.
639    ///
640    /// Use this for sensitive endpoints (e.g., login, token refresh) where
641    /// logging the path itself might leak information.
642    pub fn redact_paths(mut self, paths: Vec<String>) -> Self {
643        self.redact_paths = paths;
644        self
645    }
646
647    /// Clone the guard configuration as `AccessLogGuard<()>` for type-erased execution.
648    pub fn clone_as_unit(&self) -> AccessLogGuard<()> {
649        AccessLogGuard {
650            redact_paths: self.redact_paths.clone(),
651            _marker: PhantomData,
652        }
653    }
654}
655
656impl<T> Default for AccessLogGuard<T> {
657    fn default() -> Self {
658        Self::new()
659    }
660}
661
662#[async_trait]
663impl<T> Transition<T, T> for AccessLogGuard<T>
664where
665    T: Send + Sync + 'static,
666{
667    type Error = String;
668    type Resources = ();
669
670    async fn run(
671        &self,
672        input: T,
673        _resources: &Self::Resources,
674        bus: &mut Bus,
675    ) -> Outcome<T, Self::Error> {
676        let req = bus.read::<AccessLogRequest>().cloned();
677        let (method, raw_path) = match &req {
678            Some(r) => (r.method.clone(), r.path.clone()),
679            None => (String::new(), String::new()),
680        };
681
682        let display_path = if self.redact_paths.iter().any(|p| p == &raw_path) {
683            "[redacted]".to_string()
684        } else {
685            raw_path
686        };
687
688        let now_ms = std::time::SystemTime::now()
689            .duration_since(std::time::UNIX_EPOCH)
690            .unwrap_or_default()
691            .as_millis() as u64;
692
693        tracing::info!(method = %method, path = %display_path, "access");
694
695        bus.insert(AccessLogEntry {
696            method,
697            path: display_path,
698            timestamp_ms: now_ms,
699        });
700
701        Outcome::next(input)
702    }
703}
704
705// ---------------------------------------------------------------------------
706// CompressionGuard
707// ---------------------------------------------------------------------------
708
709/// Supported compression encodings.
710#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
711pub enum CompressionEncoding {
712    Gzip,
713    Brotli,
714    Zstd,
715    Identity,
716}
717
718impl CompressionEncoding {
719    /// HTTP `Content-Encoding` header value.
720    pub fn as_str(&self) -> &'static str {
721        match self {
722            Self::Gzip => "gzip",
723            Self::Brotli => "br",
724            Self::Zstd => "zstd",
725            Self::Identity => "identity",
726        }
727    }
728}
729
730/// Bus-injectable type representing the client's `Accept-Encoding` header.
731#[derive(Debug, Clone)]
732pub struct AcceptEncoding(pub String);
733
734/// Compression configuration written to the Bus after encoding negotiation.
735///
736/// The HTTP response layer reads this to decide whether and how to compress
737/// the response body.
738#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct CompressionConfig {
740    pub encoding: CompressionEncoding,
741    pub min_body_size: usize,
742}
743
744/// Compression guard — negotiates response encoding from `Accept-Encoding`.
745///
746/// Reads [`AcceptEncoding`] from the Bus and selects the best encoding
747/// based on the configured preference order. Writes [`CompressionConfig`]
748/// to the Bus for the HTTP layer to apply.
749///
750/// # Example
751///
752/// ```rust,ignore
753/// Ranvier::http()
754///     .guard(CompressionGuard::new().prefer_brotli())
755///     .get("/api/data", circuit)
756/// ```
757#[derive(Debug, Clone)]
758pub struct CompressionGuard<T> {
759    preferred: Vec<CompressionEncoding>,
760    min_body_size: usize,
761    _marker: PhantomData<T>,
762}
763
764impl<T> CompressionGuard<T> {
765    /// Create with default preference order: gzip > identity.
766    pub fn new() -> Self {
767        Self {
768            preferred: vec![CompressionEncoding::Gzip, CompressionEncoding::Identity],
769            min_body_size: 256,
770            _marker: PhantomData,
771        }
772    }
773
774    /// Set preference order to brotli > gzip > identity.
775    pub fn prefer_brotli(mut self) -> Self {
776        self.preferred = vec![
777            CompressionEncoding::Brotli,
778            CompressionEncoding::Gzip,
779            CompressionEncoding::Identity,
780        ];
781        self
782    }
783
784    /// Set minimum body size for compression (default: 256 bytes).
785    pub fn with_min_body_size(mut self, size: usize) -> Self {
786        self.min_body_size = size;
787        self
788    }
789
790    /// Returns the minimum body size threshold.
791    pub fn min_body_size(&self) -> usize {
792        self.min_body_size
793    }
794
795    /// Returns the preference order.
796    pub fn preferred_encodings(&self) -> &[CompressionEncoding] {
797        &self.preferred
798    }
799}
800
801impl<T> Default for CompressionGuard<T> {
802    fn default() -> Self {
803        Self::new()
804    }
805}
806
807/// Parse `Accept-Encoding` header value into a set of supported encodings.
808fn parse_accept_encoding(header: &str) -> HashSet<String> {
809    header
810        .split(',')
811        .map(|s| {
812            s.split(';')
813                .next()
814                .unwrap_or("")
815                .trim()
816                .to_lowercase()
817        })
818        .filter(|s| !s.is_empty())
819        .collect()
820}
821
822#[async_trait]
823impl<T> Transition<T, T> for CompressionGuard<T>
824where
825    T: Send + Sync + 'static,
826{
827    type Error = String;
828    type Resources = ();
829
830    async fn run(
831        &self,
832        input: T,
833        _resources: &Self::Resources,
834        bus: &mut Bus,
835    ) -> Outcome<T, Self::Error> {
836        let accepted = bus
837            .read::<AcceptEncoding>()
838            .map(|ae| parse_accept_encoding(&ae.0))
839            .unwrap_or_default();
840
841        // Negotiate: pick first preferred encoding the client accepts
842        let selected = if accepted.is_empty() || accepted.contains("*") {
843            self.preferred.first().copied().unwrap_or(CompressionEncoding::Identity)
844        } else {
845            self.preferred
846                .iter()
847                .find(|enc| accepted.contains(enc.as_str()))
848                .copied()
849                .unwrap_or(CompressionEncoding::Identity)
850        };
851
852        bus.insert(CompressionConfig {
853            encoding: selected,
854            min_body_size: self.min_body_size,
855        });
856
857        Outcome::next(input)
858    }
859}
860
861// ---------------------------------------------------------------------------
862// RequestSizeLimitGuard
863// ---------------------------------------------------------------------------
864
865/// Bus-injectable type representing the request's `Content-Length` header value.
866#[derive(Debug, Clone)]
867pub struct ContentLength(pub u64);
868
869/// Request body size limit guard — rejects requests exceeding the configured
870/// maximum `Content-Length`.
871///
872/// Reads [`ContentLength`] from the Bus. If the value exceeds the limit,
873/// returns a Fault with "413 Payload Too Large".
874///
875/// # Example
876///
877/// ```rust,ignore
878/// Ranvier::http()
879///     .guard(RequestSizeLimitGuard::max_2mb())
880///     .post("/api/upload", upload_circuit)
881/// ```
882#[derive(Debug, Clone)]
883pub struct RequestSizeLimitGuard<T> {
884    max_bytes: u64,
885    _marker: PhantomData<T>,
886}
887
888impl<T> RequestSizeLimitGuard<T> {
889    /// Create with a custom byte limit.
890    pub fn new(max_bytes: u64) -> Self {
891        Self {
892            max_bytes,
893            _marker: PhantomData,
894        }
895    }
896
897    /// 2 MB limit.
898    pub fn max_2mb() -> Self {
899        Self::new(2 * 1024 * 1024)
900    }
901
902    /// 10 MB limit.
903    pub fn max_10mb() -> Self {
904        Self::new(10 * 1024 * 1024)
905    }
906
907    /// Returns the configured maximum bytes.
908    pub fn max_bytes(&self) -> u64 {
909        self.max_bytes
910    }
911}
912
913#[async_trait]
914impl<T> Transition<T, T> for RequestSizeLimitGuard<T>
915where
916    T: Send + Sync + 'static,
917{
918    type Error = String;
919    type Resources = ();
920
921    async fn run(
922        &self,
923        input: T,
924        _resources: &Self::Resources,
925        bus: &mut Bus,
926    ) -> Outcome<T, Self::Error> {
927        if let Some(len) = bus.read::<ContentLength>() {
928            if len.0 > self.max_bytes {
929                return Outcome::fault(format!(
930                    "413 Payload Too Large: {} bytes exceeds limit of {} bytes",
931                    len.0, self.max_bytes
932                ));
933            }
934        }
935        Outcome::next(input)
936    }
937}
938
939// ---------------------------------------------------------------------------
940// RequestIdGuard
941// ---------------------------------------------------------------------------
942
943/// Bus type representing a unique request identifier.
944///
945/// Propagated from the `X-Request-Id` header or generated as UUID v4.
946/// The HTTP response layer reflects this back in the `X-Request-Id` response header.
947#[derive(Debug, Clone, Serialize, Deserialize)]
948pub struct RequestId(pub String);
949
950/// Request ID guard — ensures every request has a unique identifier.
951///
952/// If `RequestId` is already in the Bus (from a client-provided `X-Request-Id`
953/// header), it is preserved. Otherwise a UUID v4 is generated.
954///
955/// # Example
956///
957/// ```rust,ignore
958/// Ranvier::http()
959///     .guard(RequestIdGuard::new())
960///     .get("/api/data", circuit)
961/// ```
962#[derive(Debug, Clone)]
963pub struct RequestIdGuard<T> {
964    _marker: PhantomData<T>,
965}
966
967impl<T> RequestIdGuard<T> {
968    pub fn new() -> Self {
969        Self {
970            _marker: PhantomData,
971        }
972    }
973}
974
975impl<T> Default for RequestIdGuard<T> {
976    fn default() -> Self {
977        Self::new()
978    }
979}
980
981#[async_trait]
982impl<T> Transition<T, T> for RequestIdGuard<T>
983where
984    T: Send + Sync + 'static,
985{
986    type Error = String;
987    type Resources = ();
988
989    async fn run(
990        &self,
991        input: T,
992        _resources: &Self::Resources,
993        bus: &mut Bus,
994    ) -> Outcome<T, Self::Error> {
995        // Generate a UUID v4 if no RequestId was injected by the HTTP layer
996        if bus.read::<RequestId>().is_none() {
997            bus.insert(RequestId(uuid::Uuid::new_v4().to_string()));
998        }
999
1000        // Integrate with tracing: record request_id on current span
1001        if let Some(rid) = bus.read::<RequestId>() {
1002            tracing::debug!(request_id = %rid.0, "request id assigned");
1003        }
1004
1005        Outcome::next(input)
1006    }
1007}
1008
1009// ---------------------------------------------------------------------------
1010// AuthGuard
1011// ---------------------------------------------------------------------------
1012
1013/// Bus-injectable type representing the raw `Authorization` header value.
1014#[derive(Debug, Clone)]
1015pub struct AuthorizationHeader(pub String);
1016
1017/// Authentication strategy for [`AuthGuard`].
1018pub enum AuthStrategy {
1019    /// Bearer token authentication.
1020    ///
1021    /// Compares the request's `Authorization: Bearer <token>` against a set
1022    /// of valid tokens using constant-time comparison to prevent timing attacks.
1023    Bearer {
1024        tokens: Vec<String>,
1025    },
1026
1027    /// API key authentication from a custom header.
1028    ///
1029    /// Validates the value of `header_name` against a set of valid keys
1030    /// using constant-time comparison.
1031    ApiKey {
1032        header_name: String,
1033        valid_keys: Vec<String>,
1034    },
1035
1036    /// Custom authentication via a validator function.
1037    ///
1038    /// The function receives the raw `Authorization` header value and returns
1039    /// either a verified [`IamIdentity`] or an error message.
1040    Custom {
1041        validator: Arc<dyn Fn(&str) -> Result<IamIdentity, String> + Send + Sync + 'static>,
1042    },
1043}
1044
1045impl Clone for AuthStrategy {
1046    fn clone(&self) -> Self {
1047        match self {
1048            Self::Bearer { tokens } => Self::Bearer {
1049                tokens: tokens.clone(),
1050            },
1051            Self::ApiKey {
1052                header_name,
1053                valid_keys,
1054            } => Self::ApiKey {
1055                header_name: header_name.clone(),
1056                valid_keys: valid_keys.clone(),
1057            },
1058            Self::Custom { validator } => Self::Custom {
1059                validator: validator.clone(),
1060            },
1061        }
1062    }
1063}
1064
1065impl std::fmt::Debug for AuthStrategy {
1066    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1067        match self {
1068            Self::Bearer { tokens } => f
1069                .debug_struct("Bearer")
1070                .field("token_count", &tokens.len())
1071                .finish(),
1072            Self::ApiKey { header_name, valid_keys } => f
1073                .debug_struct("ApiKey")
1074                .field("header_name", header_name)
1075                .field("key_count", &valid_keys.len())
1076                .finish(),
1077            Self::Custom { .. } => f.debug_struct("Custom").finish(),
1078        }
1079    }
1080}
1081
1082/// Authentication guard — validates credentials and injects [`IamIdentity`]
1083/// into the Bus.
1084///
1085/// Supports Bearer token, API key, and custom authentication strategies.
1086/// Uses constant-time comparison (`subtle::ConstantTimeEq`) for Bearer and
1087/// API key validation to prevent timing attacks.
1088///
1089/// # Examples
1090///
1091/// ```rust,ignore
1092/// // Bearer token auth
1093/// Ranvier::http()
1094///     .guard(AuthGuard::bearer(vec!["secret-token".into()]))
1095///     .get("/api/protected", circuit)
1096///
1097/// // With role requirement
1098/// Ranvier::http()
1099///     .guard(AuthGuard::bearer(vec!["admin-token".into()])
1100///         .with_policy(IamPolicy::RequireRole("admin".into())))
1101///     .get("/api/admin", circuit)
1102/// ```
1103pub struct AuthGuard<T> {
1104    strategy: AuthStrategy,
1105    policy: IamPolicy,
1106    _marker: PhantomData<T>,
1107}
1108
1109impl<T> AuthGuard<T> {
1110    /// Create with a specific strategy and no policy enforcement.
1111    pub fn new(strategy: AuthStrategy) -> Self {
1112        Self {
1113            strategy,
1114            policy: IamPolicy::None,
1115            _marker: PhantomData,
1116        }
1117    }
1118
1119    /// Bearer token authentication.
1120    pub fn bearer(tokens: Vec<String>) -> Self {
1121        Self::new(AuthStrategy::Bearer { tokens })
1122    }
1123
1124    /// API key authentication from a custom header.
1125    pub fn api_key(header_name: impl Into<String>, valid_keys: Vec<String>) -> Self {
1126        Self::new(AuthStrategy::ApiKey {
1127            header_name: header_name.into(),
1128            valid_keys,
1129        })
1130    }
1131
1132    /// Custom validator function.
1133    pub fn custom(
1134        validator: impl Fn(&str) -> Result<IamIdentity, String> + Send + Sync + 'static,
1135    ) -> Self {
1136        Self::new(AuthStrategy::Custom {
1137            validator: Arc::new(validator),
1138        })
1139    }
1140
1141    /// Set the IAM policy to enforce after successful authentication.
1142    pub fn with_policy(mut self, policy: IamPolicy) -> Self {
1143        self.policy = policy;
1144        self
1145    }
1146
1147    /// Returns the authentication strategy.
1148    pub fn strategy(&self) -> &AuthStrategy {
1149        &self.strategy
1150    }
1151
1152    /// Returns the IAM policy.
1153    pub fn iam_policy(&self) -> &IamPolicy {
1154        &self.policy
1155    }
1156}
1157
1158impl<T> Clone for AuthGuard<T> {
1159    fn clone(&self) -> Self {
1160        Self {
1161            strategy: self.strategy.clone(),
1162            policy: self.policy.clone(),
1163            _marker: PhantomData,
1164        }
1165    }
1166}
1167
1168impl<T> std::fmt::Debug for AuthGuard<T> {
1169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1170        f.debug_struct("AuthGuard")
1171            .field("strategy", &self.strategy)
1172            .field("policy", &self.policy)
1173            .finish()
1174    }
1175}
1176
1177/// Constant-time comparison of two byte slices.
1178fn ct_eq(a: &[u8], b: &[u8]) -> bool {
1179    a.len() == b.len() && a.ct_eq(b).into()
1180}
1181
1182#[async_trait]
1183impl<T> Transition<T, T> for AuthGuard<T>
1184where
1185    T: Send + Sync + 'static,
1186{
1187    type Error = String;
1188    type Resources = ();
1189
1190    async fn run(
1191        &self,
1192        input: T,
1193        _resources: &Self::Resources,
1194        bus: &mut Bus,
1195    ) -> Outcome<T, Self::Error> {
1196        let auth_value = bus.read::<AuthorizationHeader>().map(|h| h.0.clone());
1197
1198        let identity = match &self.strategy {
1199            AuthStrategy::Bearer { tokens } => {
1200                let Some(auth) = auth_value else {
1201                    return Outcome::fault(
1202                        "401 Unauthorized: missing Authorization header".to_string(),
1203                    );
1204                };
1205                let Some(token) = auth.strip_prefix("Bearer ") else {
1206                    return Outcome::fault(
1207                        "401 Unauthorized: expected Bearer scheme".to_string(),
1208                    );
1209                };
1210                let token = token.trim();
1211                let matched = tokens
1212                    .iter()
1213                    .any(|valid| ct_eq(token.as_bytes(), valid.as_bytes()));
1214                if !matched {
1215                    return Outcome::fault(
1216                        "401 Unauthorized: invalid bearer token".to_string(),
1217                    );
1218                }
1219                IamIdentity::new("bearer-authenticated")
1220            }
1221            AuthStrategy::ApiKey { valid_keys, .. } => {
1222                let Some(key) = auth_value else {
1223                    return Outcome::fault("401 Unauthorized: missing API key".to_string());
1224                };
1225                let matched = valid_keys
1226                    .iter()
1227                    .any(|valid| ct_eq(key.as_bytes(), valid.as_bytes()));
1228                if !matched {
1229                    return Outcome::fault("401 Unauthorized: invalid API key".to_string());
1230                }
1231                IamIdentity::new("apikey-authenticated")
1232            }
1233            AuthStrategy::Custom { validator } => {
1234                let raw = auth_value.unwrap_or_default();
1235                match validator(&raw) {
1236                    Ok(identity) => identity,
1237                    Err(msg) => {
1238                        return Outcome::fault(format!("401 Unauthorized: {}", msg));
1239                    }
1240                }
1241            }
1242        };
1243
1244        // Enforce IAM policy
1245        if let Err(e) = enforce_policy(&self.policy, &identity) {
1246            return Outcome::fault(format!("403 Forbidden: {}", e));
1247        }
1248
1249        bus.insert(identity);
1250        Outcome::next(input)
1251    }
1252}
1253
1254// ---------------------------------------------------------------------------
1255// ContentTypeGuard
1256// ---------------------------------------------------------------------------
1257
1258/// Bus-injectable type representing the request's `Content-Type` header value.
1259#[derive(Debug, Clone)]
1260pub struct RequestContentType(pub String);
1261
1262/// Content-Type validation guard — rejects requests with unsupported media types.
1263///
1264/// Reads [`RequestContentType`] from the Bus. If the content type does not
1265/// match any of the allowed types, returns a Fault with "415 Unsupported Media Type".
1266///
1267/// Useful as a per-route guard: apply to POST/PUT/PATCH endpoints while
1268/// leaving GET/DELETE endpoints unrestricted.
1269///
1270/// # Example
1271///
1272/// ```rust,ignore
1273/// Ranvier::http()
1274///     .post_with_guards("/api/data", circuit, guards![
1275///         ContentTypeGuard::json(),
1276///     ])
1277/// ```
1278#[derive(Debug, Clone)]
1279pub struct ContentTypeGuard<T> {
1280    allowed_types: Vec<String>,
1281    _marker: PhantomData<T>,
1282}
1283
1284impl<T> ContentTypeGuard<T> {
1285    /// Create with specific allowed content types.
1286    pub fn new(allowed_types: Vec<String>) -> Self {
1287        Self {
1288            allowed_types,
1289            _marker: PhantomData,
1290        }
1291    }
1292
1293    /// Accept only `application/json`.
1294    pub fn json() -> Self {
1295        Self::new(vec!["application/json".into()])
1296    }
1297
1298    /// Accept only `application/x-www-form-urlencoded`.
1299    pub fn form() -> Self {
1300        Self::new(vec!["application/x-www-form-urlencoded".into()])
1301    }
1302
1303    /// Accept specific content types.
1304    pub fn accept(types: impl IntoIterator<Item = impl Into<String>>) -> Self {
1305        Self::new(types.into_iter().map(|t| t.into()).collect())
1306    }
1307
1308    /// Returns the allowed content types.
1309    pub fn allowed_types(&self) -> &[String] {
1310        &self.allowed_types
1311    }
1312}
1313
1314#[async_trait]
1315impl<T> Transition<T, T> for ContentTypeGuard<T>
1316where
1317    T: Send + Sync + 'static,
1318{
1319    type Error = String;
1320    type Resources = ();
1321
1322    async fn run(
1323        &self,
1324        input: T,
1325        _resources: &Self::Resources,
1326        bus: &mut Bus,
1327    ) -> Outcome<T, Self::Error> {
1328        let content_type = bus.read::<RequestContentType>().map(|ct| ct.0.clone());
1329
1330        // If no Content-Type header, allow (GET/DELETE may not have body)
1331        let Some(ct) = content_type else {
1332            return Outcome::next(input);
1333        };
1334
1335        // Compare the media type portion (before any ;charset=... parameters)
1336        let media_type = ct.split(';').next().unwrap_or("").trim().to_lowercase();
1337        let matched = self
1338            .allowed_types
1339            .iter()
1340            .any(|allowed| allowed.to_lowercase() == media_type);
1341
1342        if matched {
1343            Outcome::next(input)
1344        } else {
1345            Outcome::fault(format!(
1346                "415 Unsupported Media Type: expected one of [{}], got '{}'",
1347                self.allowed_types.join(", "),
1348                media_type,
1349            ))
1350        }
1351    }
1352}
1353
1354// ---------------------------------------------------------------------------
1355// TimeoutGuard
1356// ---------------------------------------------------------------------------
1357
1358/// Deadline for the current request pipeline.
1359///
1360/// Written to the Bus by [`TimeoutGuard`]. The HTTP ingress layer reads this
1361/// to enforce the deadline by wrapping circuit execution with
1362/// `tokio::time::timeout()`.
1363///
1364/// Complements `Axon::then_with_timeout()`: TimeoutGuard sets the global
1365/// pipeline deadline, while `then_with_timeout()` adds per-node timeouts.
1366#[derive(Debug, Clone)]
1367pub struct TimeoutDeadline {
1368    created_at: std::time::Instant,
1369    timeout: std::time::Duration,
1370}
1371
1372impl TimeoutDeadline {
1373    /// Create a new deadline starting from now.
1374    pub fn new(timeout: std::time::Duration) -> Self {
1375        Self {
1376            created_at: std::time::Instant::now(),
1377            timeout,
1378        }
1379    }
1380
1381    /// Returns the remaining time until the deadline.
1382    pub fn remaining(&self) -> std::time::Duration {
1383        self.timeout.saturating_sub(self.created_at.elapsed())
1384    }
1385
1386    /// Returns true if the deadline has passed.
1387    pub fn is_expired(&self) -> bool {
1388        self.created_at.elapsed() >= self.timeout
1389    }
1390
1391    /// Returns the configured timeout duration.
1392    pub fn duration(&self) -> std::time::Duration {
1393        self.timeout
1394    }
1395}
1396
1397/// Pipeline timeout guard — sets a [`TimeoutDeadline`] in the Bus.
1398///
1399/// This is a **pass-through** guard that writes the deadline. The HTTP
1400/// ingress layer enforces it by wrapping circuit execution with
1401/// `tokio::time::timeout()`.
1402///
1403/// # Relationship with `Axon::then_with_timeout()`
1404///
1405/// - `TimeoutGuard`: outer boundary — limits total pipeline duration
1406/// - `Axon::then_with_timeout()`: inner granularity — limits a single node
1407///
1408/// # Example
1409///
1410/// ```rust,ignore
1411/// use std::time::Duration;
1412///
1413/// Ranvier::http()
1414///     .guard(TimeoutGuard::new(Duration::from_secs(30)))
1415///     .post("/api/slow", slow_circuit)
1416/// ```
1417#[derive(Debug, Clone)]
1418pub struct TimeoutGuard<T> {
1419    timeout: std::time::Duration,
1420    _marker: PhantomData<T>,
1421}
1422
1423impl<T> TimeoutGuard<T> {
1424    /// Create with a specific timeout.
1425    pub fn new(timeout: std::time::Duration) -> Self {
1426        Self {
1427            timeout,
1428            _marker: PhantomData,
1429        }
1430    }
1431
1432    /// 5-second timeout.
1433    pub fn secs_5() -> Self {
1434        Self::new(std::time::Duration::from_secs(5))
1435    }
1436
1437    /// 30-second timeout.
1438    pub fn secs_30() -> Self {
1439        Self::new(std::time::Duration::from_secs(30))
1440    }
1441
1442    /// 60-second timeout.
1443    pub fn secs_60() -> Self {
1444        Self::new(std::time::Duration::from_secs(60))
1445    }
1446
1447    /// Returns the configured timeout.
1448    pub fn timeout(&self) -> std::time::Duration {
1449        self.timeout
1450    }
1451}
1452
1453#[async_trait]
1454impl<T> Transition<T, T> for TimeoutGuard<T>
1455where
1456    T: Send + Sync + 'static,
1457{
1458    type Error = String;
1459    type Resources = ();
1460
1461    async fn run(
1462        &self,
1463        input: T,
1464        _resources: &Self::Resources,
1465        bus: &mut Bus,
1466    ) -> Outcome<T, Self::Error> {
1467        bus.insert(TimeoutDeadline::new(self.timeout));
1468        Outcome::next(input)
1469    }
1470}
1471
1472// ---------------------------------------------------------------------------
1473// IdempotencyGuard
1474// ---------------------------------------------------------------------------
1475
1476/// Bus-injectable type representing the `Idempotency-Key` header value.
1477#[derive(Debug, Clone, Serialize, Deserialize)]
1478pub struct IdempotencyKey(pub String);
1479
1480/// Cached response from a previous idempotent request.
1481///
1482/// When found in the Bus after guard execution, the HTTP ingress skips
1483/// circuit execution and returns the cached response directly with an
1484/// `Idempotency-Replayed: true` header.
1485#[derive(Debug, Clone)]
1486pub struct IdempotencyCachedResponse {
1487    pub body: Vec<u8>,
1488}
1489
1490/// TTL-based in-memory cache entry for idempotency.
1491struct IdempotencyCacheEntry {
1492    body: Vec<u8>,
1493    expires_at: std::time::Instant,
1494}
1495
1496/// Shared TTL-based in-memory cache for idempotency key tracking.
1497#[derive(Clone)]
1498pub struct IdempotencyCache {
1499    inner: Arc<std::sync::Mutex<std::collections::HashMap<String, IdempotencyCacheEntry>>>,
1500    ttl: std::time::Duration,
1501}
1502
1503impl IdempotencyCache {
1504    /// Create a new cache with the given TTL.
1505    pub fn new(ttl: std::time::Duration) -> Self {
1506        Self {
1507            inner: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1508            ttl,
1509        }
1510    }
1511
1512    /// Look up a cached response body by key. Returns `None` if not found or expired.
1513    pub fn get(&self, key: &str) -> Option<Vec<u8>> {
1514        let mut cache = self.inner.lock().ok()?;
1515        let now = std::time::Instant::now();
1516        if let Some(entry) = cache.get(key) {
1517            if entry.expires_at > now {
1518                return Some(entry.body.clone());
1519            }
1520            cache.remove(key);
1521        }
1522        None
1523    }
1524
1525    /// Insert a response body into the cache.
1526    pub fn insert(&self, key: String, body: Vec<u8>) {
1527        if let Ok(mut cache) = self.inner.lock() {
1528            let now = std::time::Instant::now();
1529            // Lazy cleanup: remove a few expired entries on insert
1530            let expired: Vec<String> = cache
1531                .iter()
1532                .filter(|(_, e)| e.expires_at <= now)
1533                .take(5)
1534                .map(|(k, _)| k.clone())
1535                .collect();
1536            for k in expired {
1537                cache.remove(&k);
1538            }
1539            cache.insert(
1540                key,
1541                IdempotencyCacheEntry {
1542                    body,
1543                    expires_at: now + self.ttl,
1544                },
1545            );
1546        }
1547    }
1548
1549    /// Returns the configured TTL.
1550    pub fn ttl(&self) -> std::time::Duration {
1551        self.ttl
1552    }
1553}
1554
1555impl std::fmt::Debug for IdempotencyCache {
1556    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1557        f.debug_struct("IdempotencyCache")
1558            .field("ttl", &self.ttl)
1559            .finish()
1560    }
1561}
1562
1563/// Idempotency guard — prevents duplicate request processing via an
1564/// in-memory TTL cache.
1565///
1566/// Reads [`IdempotencyKey`] from the Bus (extracted from the `Idempotency-Key`
1567/// HTTP header). On cache hit, writes [`IdempotencyCachedResponse`] to the
1568/// Bus to signal the ingress layer to skip circuit execution.
1569///
1570/// On cache miss, the HTTP ingress layer caches the response body after
1571/// circuit execution via `ResponseBodyTransformFn`.
1572///
1573/// # Example
1574///
1575/// ```rust,ignore
1576/// use std::time::Duration;
1577///
1578/// Ranvier::http()
1579///     .post_with_guards("/api/orders", order_circuit, guards![
1580///         ContentTypeGuard::json(),
1581///         IdempotencyGuard::new(Duration::from_secs(300)),
1582///     ])
1583/// ```
1584pub struct IdempotencyGuard<T> {
1585    cache: IdempotencyCache,
1586    _marker: PhantomData<T>,
1587}
1588
1589impl<T> IdempotencyGuard<T> {
1590    /// Create with a specific TTL for cached entries.
1591    pub fn new(ttl: std::time::Duration) -> Self {
1592        Self {
1593            cache: IdempotencyCache::new(ttl),
1594            _marker: PhantomData,
1595        }
1596    }
1597
1598    /// 5-minute TTL (default for most APIs).
1599    pub fn ttl_5min() -> Self {
1600        Self::new(std::time::Duration::from_secs(300))
1601    }
1602
1603    /// Returns the configured TTL.
1604    pub fn ttl(&self) -> std::time::Duration {
1605        self.cache.ttl()
1606    }
1607
1608    /// Returns a reference to the internal cache.
1609    pub fn cache(&self) -> &IdempotencyCache {
1610        &self.cache
1611    }
1612
1613    /// Clone the guard configuration as `IdempotencyGuard<()>` for type-erased execution.
1614    pub fn clone_as_unit(&self) -> IdempotencyGuard<()> {
1615        IdempotencyGuard {
1616            cache: self.cache.clone(),
1617            _marker: PhantomData,
1618        }
1619    }
1620}
1621
1622impl<T> Clone for IdempotencyGuard<T> {
1623    fn clone(&self) -> Self {
1624        Self {
1625            cache: self.cache.clone(),
1626            _marker: PhantomData,
1627        }
1628    }
1629}
1630
1631impl<T> std::fmt::Debug for IdempotencyGuard<T> {
1632    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1633        f.debug_struct("IdempotencyGuard")
1634            .field("ttl", &self.cache.ttl())
1635            .finish()
1636    }
1637}
1638
1639#[async_trait]
1640impl<T> Transition<T, T> for IdempotencyGuard<T>
1641where
1642    T: Send + Sync + 'static,
1643{
1644    type Error = String;
1645    type Resources = ();
1646
1647    async fn run(
1648        &self,
1649        input: T,
1650        _resources: &Self::Resources,
1651        bus: &mut Bus,
1652    ) -> Outcome<T, Self::Error> {
1653        let Some(key) = bus.read::<IdempotencyKey>().map(|k| k.0.clone()) else {
1654            return Outcome::next(input);
1655        };
1656
1657        if let Some(body) = self.cache.get(&key) {
1658            bus.insert(IdempotencyCachedResponse { body });
1659            tracing::debug!(idempotency_key = %key, "idempotency cache hit");
1660        } else {
1661            tracing::debug!(idempotency_key = %key, "idempotency cache miss");
1662        }
1663
1664        Outcome::next(input)
1665    }
1666}
1667
1668// ===========================================================================
1669// Tier 3 Guards (feature-gated: `advanced`)
1670// ===========================================================================
1671
1672#[cfg(feature = "advanced")]
1673mod advanced_guards;
1674
1675#[cfg(feature = "advanced")]
1676pub use advanced_guards::*;
1677
1678#[cfg(feature = "distributed")]
1679mod distributed;
1680
1681#[cfg(feature = "distributed")]
1682pub use distributed::DistributedRateLimitGuard;
1683
1684// ---------------------------------------------------------------------------
1685// Prelude
1686// ---------------------------------------------------------------------------
1687
1688pub mod prelude {
1689    pub use crate::{
1690        AcceptEncoding, AccessLogEntry, AccessLogGuard, AccessLogRequest, AuthGuard,
1691        AuthStrategy, AuthorizationHeader, ClientIdentity, ClientIp, CompressionConfig,
1692        CompressionEncoding, CompressionGuard, ContentLength, ContentTypeGuard, CorsConfig,
1693        CorsGuard, CorsHeaders, IdempotencyCache, IdempotencyCachedResponse, IdempotencyGuard,
1694        IdempotencyKey, IpFilterGuard, RateLimitGuard, RequestContentType, RequestId,
1695        RequestIdGuard, RequestOrigin, RequestSizeLimitGuard, SecurityHeaders,
1696        SecurityHeadersGuard, SecurityPolicy, TimeoutDeadline, TimeoutGuard,
1697    };
1698
1699    #[cfg(feature = "advanced")]
1700    pub use crate::advanced_guards::{
1701        ConditionalRequestGuard, DecompressionGuard, ETag, IfModifiedSince, IfNoneMatch,
1702        LastModified, RedirectGuard, RedirectRule, RequestBody,
1703    };
1704
1705    #[cfg(feature = "distributed")]
1706    pub use crate::distributed::DistributedRateLimitGuard;
1707}
1708
1709// ---------------------------------------------------------------------------
1710// Tests
1711// ---------------------------------------------------------------------------
1712
1713#[cfg(test)]
1714mod tests {
1715    use super::*;
1716
1717    #[tokio::test]
1718    async fn cors_guard_allows_wildcard() {
1719        let guard = CorsGuard::<String>::new(CorsConfig::default());
1720        let mut bus = Bus::new();
1721        bus.insert(RequestOrigin("https://example.com".into()));
1722        let result = guard.run("hello".into(), &(), &mut bus).await;
1723        assert!(matches!(result, Outcome::Next(_)));
1724        assert!(bus.read::<CorsHeaders>().is_some());
1725    }
1726
1727    #[tokio::test]
1728    async fn cors_guard_rejects_disallowed_origin() {
1729        let config = CorsConfig {
1730            allowed_origins: vec!["https://trusted.com".into()],
1731            ..Default::default()
1732        };
1733        let guard = CorsGuard::<String>::new(config);
1734        let mut bus = Bus::new();
1735        bus.insert(RequestOrigin("https://evil.com".into()));
1736        let result = guard.run("hello".into(), &(), &mut bus).await;
1737        assert!(matches!(result, Outcome::Fault(_)));
1738    }
1739
1740    #[tokio::test]
1741    async fn rate_limit_allows_within_budget() {
1742        let guard = RateLimitGuard::<String>::new(10, 1000);
1743        let mut bus = Bus::new();
1744        bus.insert(ClientIdentity("user1".into()));
1745        let result = guard.run("ok".into(), &(), &mut bus).await;
1746        assert!(matches!(result, Outcome::Next(_)));
1747    }
1748
1749    #[tokio::test]
1750    async fn rate_limit_exhausts_budget() {
1751        let guard = RateLimitGuard::<String>::new(2, 60000);
1752        let mut bus = Bus::new();
1753        bus.insert(ClientIdentity("user1".into()));
1754
1755        // Use up the budget
1756        let _ = guard.run("1".into(), &(), &mut bus).await;
1757        let _ = guard.run("2".into(), &(), &mut bus).await;
1758        let result = guard.run("3".into(), &(), &mut bus).await;
1759        assert!(matches!(result, Outcome::Fault(_)));
1760    }
1761
1762    #[tokio::test]
1763    async fn security_headers_injects_policy() {
1764        let guard = SecurityHeadersGuard::<String>::new(SecurityPolicy::default());
1765        let mut bus = Bus::new();
1766        let result = guard.run("ok".into(), &(), &mut bus).await;
1767        assert!(matches!(result, Outcome::Next(_)));
1768        let headers = bus.read::<SecurityHeaders>().unwrap();
1769        assert_eq!(headers.0.x_frame_options, "DENY");
1770    }
1771
1772    #[tokio::test]
1773    async fn ip_filter_allow_list_permits() {
1774        let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
1775        let mut bus = Bus::new();
1776        bus.insert(ClientIp("10.0.0.1".into()));
1777        let result = guard.run("ok".into(), &(), &mut bus).await;
1778        assert!(matches!(result, Outcome::Next(_)));
1779    }
1780
1781    #[tokio::test]
1782    async fn ip_filter_allow_list_denies() {
1783        let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
1784        let mut bus = Bus::new();
1785        bus.insert(ClientIp("192.168.1.1".into()));
1786        let result = guard.run("ok".into(), &(), &mut bus).await;
1787        assert!(matches!(result, Outcome::Fault(_)));
1788    }
1789
1790    #[tokio::test]
1791    async fn ip_filter_deny_list_blocks() {
1792        let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
1793        let mut bus = Bus::new();
1794        bus.insert(ClientIp("10.0.0.1".into()));
1795        let result = guard.run("ok".into(), &(), &mut bus).await;
1796        assert!(matches!(result, Outcome::Fault(_)));
1797    }
1798
1799    #[tokio::test]
1800    async fn ip_filter_deny_list_allows() {
1801        let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
1802        let mut bus = Bus::new();
1803        bus.insert(ClientIp("192.168.1.1".into()));
1804        let result = guard.run("ok".into(), &(), &mut bus).await;
1805        assert!(matches!(result, Outcome::Next(_)));
1806    }
1807
1808    // --- AccessLogGuard tests ---
1809
1810    #[tokio::test]
1811    async fn access_log_guard_passes_input_through() {
1812        let guard = AccessLogGuard::<String>::new();
1813        let mut bus = Bus::new();
1814        bus.insert(AccessLogRequest {
1815            method: "GET".into(),
1816            path: "/users".into(),
1817        });
1818        let result = guard.run("payload".into(), &(), &mut bus).await;
1819        assert!(matches!(result, Outcome::Next(ref v) if v == "payload"));
1820    }
1821
1822    #[tokio::test]
1823    async fn access_log_guard_writes_entry_to_bus() {
1824        let guard = AccessLogGuard::<String>::new();
1825        let mut bus = Bus::new();
1826        bus.insert(AccessLogRequest {
1827            method: "POST".into(),
1828            path: "/api/orders".into(),
1829        });
1830        let _result = guard.run("ok".into(), &(), &mut bus).await;
1831        let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
1832        assert_eq!(entry.method, "POST");
1833        assert_eq!(entry.path, "/api/orders");
1834    }
1835
1836    #[tokio::test]
1837    async fn access_log_guard_redacts_paths() {
1838        let guard = AccessLogGuard::<String>::new().redact_paths(vec!["/auth/login".into()]);
1839        let mut bus = Bus::new();
1840        bus.insert(AccessLogRequest {
1841            method: "POST".into(),
1842            path: "/auth/login".into(),
1843        });
1844        let _result = guard.run("ok".into(), &(), &mut bus).await;
1845        let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
1846        assert_eq!(entry.path, "[redacted]");
1847    }
1848
1849    #[tokio::test]
1850    async fn access_log_guard_works_without_request_in_bus() {
1851        let guard = AccessLogGuard::<String>::new();
1852        let mut bus = Bus::new();
1853        let result = guard.run("ok".into(), &(), &mut bus).await;
1854        assert!(matches!(result, Outcome::Next(_)));
1855        let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
1856        assert_eq!(entry.method, "");
1857        assert_eq!(entry.path, "");
1858    }
1859
1860    #[tokio::test]
1861    async fn access_log_guard_default_works() {
1862        let guard = AccessLogGuard::<String>::default();
1863        let mut bus = Bus::new();
1864        bus.insert(AccessLogRequest {
1865            method: "DELETE".into(),
1866            path: "/api/v1/users/42".into(),
1867        });
1868        let result = guard.run("ok".into(), &(), &mut bus).await;
1869        assert!(matches!(result, Outcome::Next(_)));
1870    }
1871
1872    #[tokio::test]
1873    async fn access_log_guard_entry_has_timestamp() {
1874        let guard = AccessLogGuard::<String>::new();
1875        let mut bus = Bus::new();
1876        bus.insert(AccessLogRequest {
1877            method: "GET".into(),
1878            path: "/".into(),
1879        });
1880        let _result = guard.run("ok".into(), &(), &mut bus).await;
1881        let entry = bus.read::<AccessLogEntry>().unwrap();
1882        // Timestamp should be non-zero (milliseconds since epoch)
1883        assert!(entry.timestamp_ms > 1_700_000_000_000);
1884    }
1885
1886    #[tokio::test]
1887    async fn access_log_guard_works_with_integer_type() {
1888        let guard = AccessLogGuard::<i32>::new();
1889        let mut bus = Bus::new();
1890        bus.insert(AccessLogRequest {
1891            method: "PUT".into(),
1892            path: "/count".into(),
1893        });
1894        let result = guard.run(42, &(), &mut bus).await;
1895        assert!(matches!(result, Outcome::Next(42)));
1896    }
1897
1898    #[tokio::test]
1899    async fn access_log_guard_non_redacted_path_preserved() {
1900        let guard = AccessLogGuard::<String>::new()
1901            .redact_paths(vec!["/auth/login".into()]);
1902        let mut bus = Bus::new();
1903        bus.insert(AccessLogRequest {
1904            method: "GET".into(),
1905            path: "/api/public".into(),
1906        });
1907        let _result = guard.run("ok".into(), &(), &mut bus).await;
1908        let entry = bus.read::<AccessLogEntry>().unwrap();
1909        assert_eq!(entry.path, "/api/public");
1910    }
1911
1912    // --- CompressionGuard tests ---
1913
1914    #[tokio::test]
1915    async fn compression_guard_negotiates_gzip() {
1916        let guard = CompressionGuard::<String>::new();
1917        let mut bus = Bus::new();
1918        bus.insert(AcceptEncoding("gzip, deflate".into()));
1919        let result = guard.run("ok".into(), &(), &mut bus).await;
1920        assert!(matches!(result, Outcome::Next(_)));
1921        let config = bus.read::<CompressionConfig>().unwrap();
1922        assert_eq!(config.encoding, CompressionEncoding::Gzip);
1923    }
1924
1925    #[tokio::test]
1926    async fn compression_guard_prefer_brotli() {
1927        let guard = CompressionGuard::<String>::new().prefer_brotli();
1928        let mut bus = Bus::new();
1929        bus.insert(AcceptEncoding("gzip, br, zstd".into()));
1930        let result = guard.run("ok".into(), &(), &mut bus).await;
1931        assert!(matches!(result, Outcome::Next(_)));
1932        let config = bus.read::<CompressionConfig>().unwrap();
1933        assert_eq!(config.encoding, CompressionEncoding::Brotli);
1934    }
1935
1936    #[tokio::test]
1937    async fn compression_guard_falls_back_to_identity() {
1938        let guard = CompressionGuard::<String>::new();
1939        let mut bus = Bus::new();
1940        bus.insert(AcceptEncoding("deflate".into()));
1941        let result = guard.run("ok".into(), &(), &mut bus).await;
1942        assert!(matches!(result, Outcome::Next(_)));
1943        let config = bus.read::<CompressionConfig>().unwrap();
1944        assert_eq!(config.encoding, CompressionEncoding::Identity);
1945    }
1946
1947    #[tokio::test]
1948    async fn compression_guard_wildcard_accept() {
1949        let guard = CompressionGuard::<String>::new();
1950        let mut bus = Bus::new();
1951        bus.insert(AcceptEncoding("*".into()));
1952        let result = guard.run("ok".into(), &(), &mut bus).await;
1953        assert!(matches!(result, Outcome::Next(_)));
1954        let config = bus.read::<CompressionConfig>().unwrap();
1955        assert_eq!(config.encoding, CompressionEncoding::Gzip);
1956    }
1957
1958    #[tokio::test]
1959    async fn compression_guard_min_body_size() {
1960        let guard = CompressionGuard::<String>::new().with_min_body_size(1024);
1961        let mut bus = Bus::new();
1962        bus.insert(AcceptEncoding("gzip".into()));
1963        let _ = guard.run("ok".into(), &(), &mut bus).await;
1964        let config = bus.read::<CompressionConfig>().unwrap();
1965        assert_eq!(config.min_body_size, 1024);
1966    }
1967
1968    // --- RequestSizeLimitGuard tests ---
1969
1970    #[tokio::test]
1971    async fn size_limit_allows_within_limit() {
1972        let guard = RequestSizeLimitGuard::<String>::max_2mb();
1973        let mut bus = Bus::new();
1974        bus.insert(ContentLength(1024));
1975        let result = guard.run("ok".into(), &(), &mut bus).await;
1976        assert!(matches!(result, Outcome::Next(_)));
1977    }
1978
1979    #[tokio::test]
1980    async fn size_limit_rejects_over_limit() {
1981        let guard = RequestSizeLimitGuard::<String>::new(1000);
1982        let mut bus = Bus::new();
1983        bus.insert(ContentLength(2000));
1984        let result = guard.run("ok".into(), &(), &mut bus).await;
1985        assert!(matches!(result, Outcome::Fault(ref e) if e.contains("413")));
1986    }
1987
1988    #[tokio::test]
1989    async fn size_limit_passes_without_content_length() {
1990        let guard = RequestSizeLimitGuard::<String>::new(100);
1991        let mut bus = Bus::new();
1992        let result = guard.run("ok".into(), &(), &mut bus).await;
1993        assert!(matches!(result, Outcome::Next(_)));
1994    }
1995
1996    #[tokio::test]
1997    async fn size_limit_convenience_constructors() {
1998        let guard_2mb = RequestSizeLimitGuard::<()>::max_2mb();
1999        assert_eq!(guard_2mb.max_bytes(), 2 * 1024 * 1024);
2000
2001        let guard_10mb = RequestSizeLimitGuard::<()>::max_10mb();
2002        assert_eq!(guard_10mb.max_bytes(), 10 * 1024 * 1024);
2003    }
2004
2005    // --- RequestIdGuard tests ---
2006
2007    #[tokio::test]
2008    async fn request_id_generates_uuid() {
2009        let guard = RequestIdGuard::<String>::new();
2010        let mut bus = Bus::new();
2011        let result = guard.run("ok".into(), &(), &mut bus).await;
2012        assert!(matches!(result, Outcome::Next(_)));
2013        let rid = bus.read::<RequestId>().expect("request id should be in bus");
2014        assert_eq!(rid.0.len(), 36); // UUID v4 format
2015    }
2016
2017    #[tokio::test]
2018    async fn request_id_preserves_existing() {
2019        let guard = RequestIdGuard::<String>::new();
2020        let mut bus = Bus::new();
2021        bus.insert(RequestId("custom-id-123".into()));
2022        let _ = guard.run("ok".into(), &(), &mut bus).await;
2023        let rid = bus.read::<RequestId>().unwrap();
2024        assert_eq!(rid.0, "custom-id-123");
2025    }
2026
2027    // --- AuthGuard tests ---
2028
2029    #[tokio::test]
2030    async fn auth_bearer_success() {
2031        let guard = AuthGuard::<String>::bearer(vec!["secret-token".into()]);
2032        let mut bus = Bus::new();
2033        bus.insert(AuthorizationHeader("Bearer secret-token".into()));
2034        let result = guard.run("ok".into(), &(), &mut bus).await;
2035        assert!(matches!(result, Outcome::Next(_)));
2036        let identity = bus.read::<IamIdentity>().expect("identity should be in bus");
2037        assert_eq!(identity.subject, "bearer-authenticated");
2038    }
2039
2040    #[tokio::test]
2041    async fn auth_bearer_invalid_token() {
2042        let guard = AuthGuard::<String>::bearer(vec!["secret-token".into()]);
2043        let mut bus = Bus::new();
2044        bus.insert(AuthorizationHeader("Bearer wrong-token".into()));
2045        let result = guard.run("ok".into(), &(), &mut bus).await;
2046        assert!(matches!(result, Outcome::Fault(ref e) if e.contains("401")));
2047    }
2048
2049    #[tokio::test]
2050    async fn auth_bearer_missing_header() {
2051        let guard = AuthGuard::<String>::bearer(vec!["token".into()]);
2052        let mut bus = Bus::new();
2053        let result = guard.run("ok".into(), &(), &mut bus).await;
2054        assert!(matches!(result, Outcome::Fault(ref e) if e.contains("401")));
2055    }
2056
2057    #[tokio::test]
2058    async fn auth_apikey_success() {
2059        let guard = AuthGuard::<String>::api_key("X-Api-Key", vec!["my-api-key".into()]);
2060        let mut bus = Bus::new();
2061        bus.insert(AuthorizationHeader("my-api-key".into()));
2062        let result = guard.run("ok".into(), &(), &mut bus).await;
2063        assert!(matches!(result, Outcome::Next(_)));
2064    }
2065
2066    #[tokio::test]
2067    async fn auth_apikey_invalid() {
2068        let guard = AuthGuard::<String>::api_key("X-Api-Key", vec!["valid-key".into()]);
2069        let mut bus = Bus::new();
2070        bus.insert(AuthorizationHeader("invalid-key".into()));
2071        let result = guard.run("ok".into(), &(), &mut bus).await;
2072        assert!(matches!(result, Outcome::Fault(ref e) if e.contains("401")));
2073    }
2074
2075    #[tokio::test]
2076    async fn auth_custom_validator() {
2077        let guard = AuthGuard::<String>::custom(|token| {
2078            if token == "Bearer magic" {
2079                Ok(IamIdentity::new("custom-user").with_role("admin"))
2080            } else {
2081                Err("bad token".into())
2082            }
2083        });
2084        let mut bus = Bus::new();
2085        bus.insert(AuthorizationHeader("Bearer magic".into()));
2086        let result = guard.run("ok".into(), &(), &mut bus).await;
2087        assert!(matches!(result, Outcome::Next(_)));
2088        let id = bus.read::<IamIdentity>().unwrap();
2089        assert!(id.has_role("admin"));
2090    }
2091
2092    #[tokio::test]
2093    async fn auth_policy_enforcement_role() {
2094        let guard = AuthGuard::<String>::bearer(vec!["token".into()])
2095            .with_policy(IamPolicy::RequireRole("admin".into()));
2096        let mut bus = Bus::new();
2097        bus.insert(AuthorizationHeader("Bearer token".into()));
2098        let result = guard.run("ok".into(), &(), &mut bus).await;
2099        // Bearer-authenticated identity has no roles → policy fails
2100        assert!(matches!(result, Outcome::Fault(ref e) if e.contains("403")));
2101    }
2102
2103    #[tokio::test]
2104    async fn auth_timing_safe_comparison() {
2105        // Ensure different-length tokens don't short-circuit
2106        let guard = AuthGuard::<String>::bearer(vec!["short".into()]);
2107        let mut bus = Bus::new();
2108        bus.insert(AuthorizationHeader("Bearer a-very-long-different-token".into()));
2109        let result = guard.run("ok".into(), &(), &mut bus).await;
2110        assert!(matches!(result, Outcome::Fault(_)));
2111    }
2112
2113    // --- ContentTypeGuard tests ---
2114
2115    #[tokio::test]
2116    async fn content_type_json_match() {
2117        let guard = ContentTypeGuard::<String>::json();
2118        let mut bus = Bus::new();
2119        bus.insert(RequestContentType("application/json".into()));
2120        let result = guard.run("ok".into(), &(), &mut bus).await;
2121        assert!(matches!(result, Outcome::Next(_)));
2122    }
2123
2124    #[tokio::test]
2125    async fn content_type_json_with_charset() {
2126        let guard = ContentTypeGuard::<String>::json();
2127        let mut bus = Bus::new();
2128        bus.insert(RequestContentType("application/json; charset=utf-8".into()));
2129        let result = guard.run("ok".into(), &(), &mut bus).await;
2130        assert!(matches!(result, Outcome::Next(_)));
2131    }
2132
2133    #[tokio::test]
2134    async fn content_type_mismatch() {
2135        let guard = ContentTypeGuard::<String>::json();
2136        let mut bus = Bus::new();
2137        bus.insert(RequestContentType("text/plain".into()));
2138        let result = guard.run("ok".into(), &(), &mut bus).await;
2139        assert!(matches!(result, Outcome::Fault(ref e) if e.contains("415")));
2140    }
2141
2142    #[tokio::test]
2143    async fn content_type_no_header_allows() {
2144        let guard = ContentTypeGuard::<String>::json();
2145        let mut bus = Bus::new();
2146        let result = guard.run("ok".into(), &(), &mut bus).await;
2147        assert!(matches!(result, Outcome::Next(_)));
2148    }
2149
2150    #[tokio::test]
2151    async fn content_type_form() {
2152        let guard = ContentTypeGuard::<String>::form();
2153        let mut bus = Bus::new();
2154        bus.insert(RequestContentType("application/x-www-form-urlencoded".into()));
2155        let result = guard.run("ok".into(), &(), &mut bus).await;
2156        assert!(matches!(result, Outcome::Next(_)));
2157    }
2158
2159    #[tokio::test]
2160    async fn content_type_accept_multiple() {
2161        let guard = ContentTypeGuard::<String>::accept(["application/json", "text/xml"]);
2162        let mut bus = Bus::new();
2163        bus.insert(RequestContentType("text/xml".into()));
2164        let result = guard.run("ok".into(), &(), &mut bus).await;
2165        assert!(matches!(result, Outcome::Next(_)));
2166    }
2167
2168    // --- TimeoutGuard tests ---
2169
2170    #[tokio::test]
2171    async fn timeout_sets_deadline() {
2172        let guard = TimeoutGuard::<String>::secs_30();
2173        let mut bus = Bus::new();
2174        let result = guard.run("ok".into(), &(), &mut bus).await;
2175        assert!(matches!(result, Outcome::Next(_)));
2176        let deadline = bus.read::<TimeoutDeadline>().expect("deadline should be in bus");
2177        assert!(!deadline.is_expired());
2178        assert!(deadline.remaining().as_secs() >= 29);
2179    }
2180
2181    #[tokio::test]
2182    async fn timeout_convenience_constructors() {
2183        assert_eq!(TimeoutGuard::<()>::secs_5().timeout().as_secs(), 5);
2184        assert_eq!(TimeoutGuard::<()>::secs_30().timeout().as_secs(), 30);
2185        assert_eq!(TimeoutGuard::<()>::secs_60().timeout().as_secs(), 60);
2186    }
2187
2188    // --- IdempotencyGuard tests ---
2189
2190    #[tokio::test]
2191    async fn idempotency_no_key_passes_through() {
2192        let guard = IdempotencyGuard::<String>::ttl_5min();
2193        let mut bus = Bus::new();
2194        let result = guard.run("ok".into(), &(), &mut bus).await;
2195        assert!(matches!(result, Outcome::Next(_)));
2196        assert!(bus.read::<IdempotencyCachedResponse>().is_none());
2197    }
2198
2199    #[tokio::test]
2200    async fn idempotency_cache_miss() {
2201        let guard = IdempotencyGuard::<String>::ttl_5min();
2202        let mut bus = Bus::new();
2203        bus.insert(IdempotencyKey("key-1".into()));
2204        let result = guard.run("ok".into(), &(), &mut bus).await;
2205        assert!(matches!(result, Outcome::Next(_)));
2206        assert!(bus.read::<IdempotencyCachedResponse>().is_none());
2207    }
2208
2209    #[tokio::test]
2210    async fn idempotency_cache_hit() {
2211        let guard = IdempotencyGuard::<String>::ttl_5min();
2212        // Pre-populate cache
2213        guard.cache().insert("key-1".into(), b"cached-body".to_vec());
2214
2215        let mut bus = Bus::new();
2216        bus.insert(IdempotencyKey("key-1".into()));
2217        let result = guard.run("ok".into(), &(), &mut bus).await;
2218        assert!(matches!(result, Outcome::Next(_)));
2219        let cached = bus.read::<IdempotencyCachedResponse>().expect("cached response");
2220        assert_eq!(cached.body, b"cached-body");
2221    }
2222
2223    #[tokio::test]
2224    async fn idempotency_cache_shared_across_clones() {
2225        let guard1 = IdempotencyGuard::<String>::ttl_5min();
2226        let guard2 = guard1.clone();
2227        guard1.cache().insert("shared-key".into(), b"data".to_vec());
2228        assert!(guard2.cache().get("shared-key").is_some());
2229    }
2230
2231    #[tokio::test]
2232    async fn idempotency_expired_entry_treated_as_miss() {
2233        let guard = IdempotencyGuard::<String>::new(std::time::Duration::from_millis(1));
2234        guard.cache().insert("key-1".into(), b"old".to_vec());
2235        // Wait for expiry
2236        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
2237
2238        let mut bus = Bus::new();
2239        bus.insert(IdempotencyKey("key-1".into()));
2240        let result = guard.run("ok".into(), &(), &mut bus).await;
2241        assert!(matches!(result, Outcome::Next(_)));
2242        assert!(bus.read::<IdempotencyCachedResponse>().is_none());
2243    }
2244
2245    // --- CorsGuard additional tests ---
2246
2247    #[tokio::test]
2248    async fn cors_guard_specific_origin_reflected() {
2249        let config = CorsConfig {
2250            allowed_origins: vec!["https://app.example.com".into()],
2251            ..Default::default()
2252        };
2253        let guard = CorsGuard::<String>::new(config);
2254        let mut bus = Bus::new();
2255        bus.insert(RequestOrigin("https://app.example.com".into()));
2256        let result = guard.run("ok".into(), &(), &mut bus).await;
2257        assert!(matches!(result, Outcome::Next(_)));
2258        let headers = bus.read::<CorsHeaders>().unwrap();
2259        assert_eq!(headers.access_control_allow_origin, "https://app.example.com");
2260    }
2261
2262    #[tokio::test]
2263    async fn cors_guard_no_origin_passes() {
2264        let config = CorsConfig {
2265            allowed_origins: vec!["https://trusted.com".into()],
2266            ..Default::default()
2267        };
2268        let guard = CorsGuard::<String>::new(config);
2269        let mut bus = Bus::new();
2270        // No RequestOrigin in bus — empty origin should pass
2271        let result = guard.run("ok".into(), &(), &mut bus).await;
2272        assert!(matches!(result, Outcome::Next(_)));
2273    }
2274
2275    // --- SecurityHeadersGuard additional tests ---
2276
2277    #[tokio::test]
2278    async fn security_headers_custom_csp() {
2279        let policy = SecurityPolicy::default()
2280            .with_csp("default-src 'self'; script-src 'none'");
2281        let guard = SecurityHeadersGuard::<String>::new(policy);
2282        let mut bus = Bus::new();
2283        let _ = guard.run("ok".into(), &(), &mut bus).await;
2284        let headers = bus.read::<SecurityHeaders>().unwrap();
2285        assert_eq!(
2286            headers.0.content_security_policy.as_deref(),
2287            Some("default-src 'self'; script-src 'none'")
2288        );
2289    }
2290
2291    #[tokio::test]
2292    async fn security_headers_default_no_csp() {
2293        let guard = SecurityHeadersGuard::<String>::new(SecurityPolicy::default());
2294        let mut bus = Bus::new();
2295        let _ = guard.run("ok".into(), &(), &mut bus).await;
2296        let headers = bus.read::<SecurityHeaders>().unwrap();
2297        assert!(headers.0.content_security_policy.is_none());
2298        assert_eq!(headers.0.referrer_policy, "strict-origin-when-cross-origin");
2299    }
2300
2301    // --- TimeoutGuard additional test ---
2302
2303    #[tokio::test]
2304    async fn timeout_custom_duration() {
2305        let guard = TimeoutGuard::<String>::new(std::time::Duration::from_millis(100));
2306        let mut bus = Bus::new();
2307        let _ = guard.run("ok".into(), &(), &mut bus).await;
2308        let deadline = bus.read::<TimeoutDeadline>().unwrap();
2309        assert!(!deadline.is_expired());
2310        // After sleeping past the deadline, it should be expired
2311        tokio::time::sleep(std::time::Duration::from_millis(150)).await;
2312        assert!(deadline.is_expired());
2313    }
2314
2315    // --- RateLimitGuard bucket TTL tests ---
2316
2317    #[tokio::test]
2318    async fn rate_limit_bucket_ttl_prunes_stale_buckets() {
2319        // TTL of 50ms — buckets inactive for 50ms are pruned
2320        let guard = RateLimitGuard::<String>::new(100, 60000)
2321            .with_bucket_ttl(std::time::Duration::from_millis(50));
2322
2323        // Create a bucket for "stale-user"
2324        let mut bus = Bus::new();
2325        bus.insert(ClientIdentity("stale-user".into()));
2326        let _ = guard.run("ok".into(), &(), &mut bus).await;
2327
2328        // Wait for the bucket to become stale
2329        tokio::time::sleep(std::time::Duration::from_millis(80)).await;
2330
2331        // Create a request from a different user — triggers lazy prune
2332        let mut bus2 = Bus::new();
2333        bus2.insert(ClientIdentity("fresh-user".into()));
2334        let _ = guard.run("ok".into(), &(), &mut bus2).await;
2335
2336        // Now "stale-user" bucket should have been pruned.
2337        // Verify by exhausting "stale-user" budget — if pruned, they get fresh tokens.
2338        let guard2 = RateLimitGuard::<String>::new(2, 60000)
2339            .with_bucket_ttl(std::time::Duration::from_millis(50));
2340
2341        let mut bus3 = Bus::new();
2342        bus3.insert(ClientIdentity("user-a".into()));
2343        let _ = guard2.run("1".into(), &(), &mut bus3).await;
2344        let _ = guard2.run("2".into(), &(), &mut bus3).await;
2345        // Budget exhausted
2346        let result = guard2.run("3".into(), &(), &mut bus3).await;
2347        assert!(matches!(result, Outcome::Fault(_)));
2348
2349        // Wait for TTL to expire
2350        tokio::time::sleep(std::time::Duration::from_millis(80)).await;
2351
2352        // Trigger prune with a different user
2353        let mut bus4 = Bus::new();
2354        bus4.insert(ClientIdentity("user-b".into()));
2355        let _ = guard2.run("ok".into(), &(), &mut bus4).await;
2356
2357        // user-a's bucket was pruned, they get a fresh budget
2358        let mut bus5 = Bus::new();
2359        bus5.insert(ClientIdentity("user-a".into()));
2360        let result = guard2.run("retry".into(), &(), &mut bus5).await;
2361        assert!(matches!(result, Outcome::Next(_)));
2362    }
2363
2364    #[tokio::test]
2365    async fn rate_limit_bucket_ttl_zero_disables_pruning() {
2366        // Default TTL is 0 (disabled)
2367        let guard = RateLimitGuard::<String>::new(2, 60000);
2368        assert_eq!(guard.bucket_ttl_ms(), 0);
2369
2370        let mut bus = Bus::new();
2371        bus.insert(ClientIdentity("user".into()));
2372        let _ = guard.run("1".into(), &(), &mut bus).await;
2373        let _ = guard.run("2".into(), &(), &mut bus).await;
2374
2375        // Budget exhausted — even after sleeping, no TTL prune occurs
2376        tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2377
2378        let result = guard.run("3".into(), &(), &mut bus).await;
2379        assert!(matches!(result, Outcome::Fault(_)));
2380    }
2381
2382    #[tokio::test]
2383    async fn rate_limit_with_bucket_ttl_builder() {
2384        let guard = RateLimitGuard::<String>::new(10, 1000)
2385            .with_bucket_ttl(std::time::Duration::from_secs(300));
2386        assert_eq!(guard.bucket_ttl_ms(), 300_000);
2387    }
2388
2389    // --- TrustedProxies tests ---
2390
2391    #[test]
2392    fn trusted_proxies_ignores_xff_from_untrusted_direct() {
2393        let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2394        // Direct IP is NOT trusted — XFF should be ignored
2395        let result = proxies.extract("1.2.3.4, 10.0.0.1", "192.168.1.100");
2396        assert_eq!(result, "192.168.1.100");
2397    }
2398
2399    #[test]
2400    fn trusted_proxies_extracts_rightmost_non_trusted() {
2401        let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2402        // Direct IP is trusted, so walk XFF right-to-left
2403        // XFF: "203.0.113.5, 10.0.0.2" — rightmost non-trusted is 203.0.113.5
2404        let result = proxies.extract("203.0.113.5, 10.0.0.2", "10.0.0.1");
2405        assert_eq!(result, "203.0.113.5");
2406    }
2407
2408    #[test]
2409    fn trusted_proxies_multi_hop_chain() {
2410        let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2", "10.0.0.3"]);
2411        // XFF: "real-client, 10.0.0.3, 10.0.0.2" — all hops after real-client are trusted
2412        let result = proxies.extract("8.8.8.8, 10.0.0.3, 10.0.0.2", "10.0.0.1");
2413        assert_eq!(result, "8.8.8.8");
2414    }
2415
2416    #[test]
2417    fn trusted_proxies_all_xff_trusted_falls_back_to_direct() {
2418        let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2419        // All IPs in XFF are trusted — fall back to direct IP
2420        let result = proxies.extract("10.0.0.2, 10.0.0.1", "10.0.0.1");
2421        assert_eq!(result, "10.0.0.1");
2422    }
2423
2424    #[test]
2425    fn trusted_proxies_empty_xff() {
2426        let proxies = TrustedProxies::new(["10.0.0.1"]);
2427        let result = proxies.extract("", "10.0.0.1");
2428        assert_eq!(result, "10.0.0.1");
2429    }
2430
2431    #[test]
2432    fn trusted_proxies_is_trusted() {
2433        let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2434        assert!(proxies.is_trusted("10.0.0.1"));
2435        assert!(proxies.is_trusted("10.0.0.2"));
2436        assert!(!proxies.is_trusted("192.168.1.1"));
2437    }
2438}