Skip to main content

ranvier_std/nodes/
guard.rs

1//! Guard nodes — HTTP security/policy Transition nodes replacing Tower middleware.
2//!
3//! These nodes are designed to run early in a Schematic pipeline, enforcing
4//! security policies as visible, traceable Transition steps rather than hidden
5//! middleware layers.
6//!
7//! Each guard reads context from the Bus (e.g., request headers, client IP)
8//! and either passes the input through or returns a Fault.
9
10use async_trait::async_trait;
11use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
12use serde::{Deserialize, Serialize};
13use std::collections::HashSet;
14use std::marker::PhantomData;
15use std::sync::Arc;
16use std::time::Instant;
17use tokio::sync::Mutex;
18
19// ---------------------------------------------------------------------------
20// CorsGuard
21// ---------------------------------------------------------------------------
22
23/// CORS guard configuration.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct CorsConfig {
26    pub allowed_origins: Vec<String>,
27    pub allowed_methods: Vec<String>,
28    pub allowed_headers: Vec<String>,
29    pub max_age_seconds: u64,
30    pub allow_credentials: bool,
31}
32
33impl Default for CorsConfig {
34    fn default() -> Self {
35        Self {
36            allowed_origins: vec!["*".to_string()],
37            allowed_methods: vec![
38                "GET".into(),
39                "POST".into(),
40                "PUT".into(),
41                "DELETE".into(),
42                "OPTIONS".into(),
43            ],
44            allowed_headers: vec!["Content-Type".into(), "Authorization".into()],
45            max_age_seconds: 86400,
46            allow_credentials: false,
47        }
48    }
49}
50
51impl CorsConfig {
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
57        self.allowed_origins.push(origin.into());
58        self
59    }
60}
61
62/// Bus-injectable type representing the request origin header.
63#[derive(Debug, Clone)]
64pub struct RequestOrigin(pub String);
65
66/// CORS guard Transition — validates the request origin against allowed origins.
67///
68/// Reads `RequestOrigin` from the Bus. If the origin is not allowed, returns Fault.
69/// Writes CORS response headers to the Bus as `CorsHeaders`.
70#[derive(Debug, Clone)]
71pub struct CorsGuard<T> {
72    config: CorsConfig,
73    _marker: PhantomData<T>,
74}
75
76impl<T> CorsGuard<T> {
77    pub fn new(config: CorsConfig) -> Self {
78        Self {
79            config,
80            _marker: PhantomData,
81        }
82    }
83}
84
85/// CORS headers to be applied to the response.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct CorsHeaders {
88    pub access_control_allow_origin: String,
89    pub access_control_allow_methods: String,
90    pub access_control_allow_headers: String,
91    pub access_control_max_age: String,
92}
93
94#[async_trait]
95impl<T> Transition<T, T> for CorsGuard<T>
96where
97    T: Send + Sync + 'static,
98{
99    type Error = String;
100    type Resources = ();
101
102    async fn run(
103        &self,
104        input: T,
105        _resources: &Self::Resources,
106        bus: &mut Bus,
107    ) -> Outcome<T, Self::Error> {
108        let origin = bus
109            .read::<RequestOrigin>()
110            .map(|o| o.0.clone())
111            .unwrap_or_default();
112
113        let allowed = self.config.allowed_origins.contains(&"*".to_string())
114            || self.config.allowed_origins.contains(&origin);
115
116        if !allowed && !origin.is_empty() {
117            return Outcome::fault(format!("CORS: origin '{}' not allowed", origin));
118        }
119
120        let allow_origin = if self.config.allowed_origins.contains(&"*".to_string()) {
121            "*".to_string()
122        } else {
123            origin
124        };
125
126        bus.insert(CorsHeaders {
127            access_control_allow_origin: allow_origin,
128            access_control_allow_methods: self.config.allowed_methods.join(", "),
129            access_control_allow_headers: self.config.allowed_headers.join(", "),
130            access_control_max_age: self.config.max_age_seconds.to_string(),
131        });
132
133        Outcome::next(input)
134    }
135}
136
137// ---------------------------------------------------------------------------
138// RateLimitGuard
139// ---------------------------------------------------------------------------
140
141/// Bus-injectable type representing the client identity for rate limiting.
142#[derive(Debug, Clone, Hash, PartialEq, Eq)]
143pub struct ClientIdentity(pub String);
144
145/// Rate limit error with retry-after information.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct RateLimitError {
148    pub message: String,
149    pub retry_after_ms: u64,
150}
151
152impl std::fmt::Display for RateLimitError {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        write!(f, "{} (retry after {}ms)", self.message, self.retry_after_ms)
155    }
156}
157
158/// Simple token-bucket rate limiter state.
159struct RateBucket {
160    tokens: f64,
161    last_refill: Instant,
162}
163
164/// Rate limit guard — enforces per-client request rate limits.
165///
166/// Reads `ClientIdentity` from the Bus. Uses a token-bucket algorithm.
167pub struct RateLimitGuard<T> {
168    max_requests: u64,
169    window_ms: u64,
170    buckets: Arc<Mutex<std::collections::HashMap<String, RateBucket>>>,
171    _marker: PhantomData<T>,
172}
173
174impl<T> RateLimitGuard<T> {
175    pub fn new(max_requests: u64, window_ms: u64) -> Self {
176        Self {
177            max_requests,
178            window_ms,
179            buckets: Arc::new(Mutex::new(std::collections::HashMap::new())),
180            _marker: PhantomData,
181        }
182    }
183}
184
185impl<T> Clone for RateLimitGuard<T> {
186    fn clone(&self) -> Self {
187        Self {
188            max_requests: self.max_requests,
189            window_ms: self.window_ms,
190            buckets: self.buckets.clone(),
191            _marker: PhantomData,
192        }
193    }
194}
195
196impl<T> std::fmt::Debug for RateLimitGuard<T> {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        f.debug_struct("RateLimitGuard")
199            .field("max_requests", &self.max_requests)
200            .field("window_ms", &self.window_ms)
201            .finish()
202    }
203}
204
205#[async_trait]
206impl<T> Transition<T, T> for RateLimitGuard<T>
207where
208    T: Send + Sync + 'static,
209{
210    type Error = String;
211    type Resources = ();
212
213    async fn run(
214        &self,
215        input: T,
216        _resources: &Self::Resources,
217        bus: &mut Bus,
218    ) -> Outcome<T, Self::Error> {
219        let client_id = bus
220            .read::<ClientIdentity>()
221            .map(|c| c.0.clone())
222            .unwrap_or_else(|| "anonymous".to_string());
223
224        let mut buckets = self.buckets.lock().await;
225        let now = Instant::now();
226        let rate = self.max_requests as f64 / self.window_ms as f64 * 1000.0;
227
228        let bucket = buckets.entry(client_id).or_insert(RateBucket {
229            tokens: self.max_requests as f64,
230            last_refill: now,
231        });
232
233        // Refill tokens based on elapsed time
234        let elapsed_ms = now.duration_since(bucket.last_refill).as_millis() as f64;
235        bucket.tokens = (bucket.tokens + elapsed_ms * rate / 1000.0).min(self.max_requests as f64);
236        bucket.last_refill = now;
237
238        if bucket.tokens >= 1.0 {
239            bucket.tokens -= 1.0;
240            Outcome::next(input)
241        } else {
242            let retry_after = ((1.0 - bucket.tokens) / rate * 1000.0) as u64;
243            Outcome::fault(format!(
244                "Rate limit exceeded. Retry after {}ms",
245                retry_after
246            ))
247        }
248    }
249}
250
251// ---------------------------------------------------------------------------
252// SecurityHeadersGuard
253// ---------------------------------------------------------------------------
254
255/// Security policy configuration for HTTP response headers.
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct SecurityPolicy {
258    pub x_frame_options: String,
259    pub x_content_type_options: String,
260    pub strict_transport_security: String,
261    pub content_security_policy: Option<String>,
262    pub x_xss_protection: String,
263    pub referrer_policy: String,
264}
265
266impl Default for SecurityPolicy {
267    fn default() -> Self {
268        Self {
269            x_frame_options: "DENY".to_string(),
270            x_content_type_options: "nosniff".to_string(),
271            strict_transport_security: "max-age=31536000; includeSubDomains".to_string(),
272            content_security_policy: None,
273            x_xss_protection: "1; mode=block".to_string(),
274            referrer_policy: "strict-origin-when-cross-origin".to_string(),
275        }
276    }
277}
278
279impl SecurityPolicy {
280    pub fn new() -> Self {
281        Self::default()
282    }
283
284    pub fn with_csp(mut self, csp: impl Into<String>) -> Self {
285        self.content_security_policy = Some(csp.into());
286        self
287    }
288}
289
290/// Security headers stored in the Bus for the HTTP layer to apply.
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct SecurityHeaders(pub SecurityPolicy);
293
294/// Security headers guard — injects standard security headers into the Bus.
295#[derive(Debug, Clone)]
296pub struct SecurityHeadersGuard<T> {
297    policy: SecurityPolicy,
298    _marker: PhantomData<T>,
299}
300
301impl<T> SecurityHeadersGuard<T> {
302    pub fn new(policy: SecurityPolicy) -> Self {
303        Self {
304            policy,
305            _marker: PhantomData,
306        }
307    }
308}
309
310#[async_trait]
311impl<T> Transition<T, T> for SecurityHeadersGuard<T>
312where
313    T: Send + Sync + 'static,
314{
315    type Error = String;
316    type Resources = ();
317
318    async fn run(
319        &self,
320        input: T,
321        _resources: &Self::Resources,
322        bus: &mut Bus,
323    ) -> Outcome<T, Self::Error> {
324        bus.insert(SecurityHeaders(self.policy.clone()));
325        Outcome::next(input)
326    }
327}
328
329// ---------------------------------------------------------------------------
330// IpFilterGuard
331// ---------------------------------------------------------------------------
332
333/// Bus-injectable type representing the client IP address.
334#[derive(Debug, Clone)]
335pub struct ClientIp(pub String);
336
337/// IP filter mode.
338#[derive(Debug, Clone)]
339pub enum IpFilterMode {
340    /// Only allow IPs in the set.
341    AllowList(HashSet<String>),
342    /// Block IPs in the set.
343    DenyList(HashSet<String>),
344}
345
346/// IP filter guard — allows or denies requests based on client IP.
347///
348/// Reads `ClientIp` from the Bus.
349#[derive(Debug, Clone)]
350pub struct IpFilterGuard<T> {
351    mode: IpFilterMode,
352    _marker: PhantomData<T>,
353}
354
355impl<T> IpFilterGuard<T> {
356    pub fn allow_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
357        Self {
358            mode: IpFilterMode::AllowList(ips.into_iter().map(|s| s.into()).collect()),
359            _marker: PhantomData,
360        }
361    }
362
363    pub fn deny_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
364        Self {
365            mode: IpFilterMode::DenyList(ips.into_iter().map(|s| s.into()).collect()),
366            _marker: PhantomData,
367        }
368    }
369}
370
371#[async_trait]
372impl<T> Transition<T, T> for IpFilterGuard<T>
373where
374    T: Send + Sync + 'static,
375{
376    type Error = String;
377    type Resources = ();
378
379    async fn run(
380        &self,
381        input: T,
382        _resources: &Self::Resources,
383        bus: &mut Bus,
384    ) -> Outcome<T, Self::Error> {
385        let client_ip = bus
386            .read::<ClientIp>()
387            .map(|ip| ip.0.clone())
388            .unwrap_or_default();
389
390        match &self.mode {
391            IpFilterMode::AllowList(allowed) => {
392                if allowed.contains(&client_ip) {
393                    Outcome::next(input)
394                } else {
395                    Outcome::fault(format!("IP '{}' not in allow list", client_ip))
396                }
397            }
398            IpFilterMode::DenyList(denied) => {
399                if denied.contains(&client_ip) {
400                    Outcome::fault(format!("IP '{}' is denied", client_ip))
401                } else {
402                    Outcome::next(input)
403                }
404            }
405        }
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[tokio::test]
414    async fn cors_guard_allows_wildcard() {
415        let guard = CorsGuard::<String>::new(CorsConfig::default());
416        let mut bus = Bus::new();
417        bus.insert(RequestOrigin("https://example.com".into()));
418        let result = guard.run("hello".into(), &(), &mut bus).await;
419        assert!(matches!(result, Outcome::Next(_)));
420        assert!(bus.read::<CorsHeaders>().is_some());
421    }
422
423    #[tokio::test]
424    async fn cors_guard_rejects_disallowed_origin() {
425        let config = CorsConfig {
426            allowed_origins: vec!["https://trusted.com".into()],
427            ..Default::default()
428        };
429        let guard = CorsGuard::<String>::new(config);
430        let mut bus = Bus::new();
431        bus.insert(RequestOrigin("https://evil.com".into()));
432        let result = guard.run("hello".into(), &(), &mut bus).await;
433        assert!(matches!(result, Outcome::Fault(_)));
434    }
435
436    #[tokio::test]
437    async fn rate_limit_allows_within_budget() {
438        let guard = RateLimitGuard::<String>::new(10, 1000);
439        let mut bus = Bus::new();
440        bus.insert(ClientIdentity("user1".into()));
441        let result = guard.run("ok".into(), &(), &mut bus).await;
442        assert!(matches!(result, Outcome::Next(_)));
443    }
444
445    #[tokio::test]
446    async fn rate_limit_exhausts_budget() {
447        let guard = RateLimitGuard::<String>::new(2, 60000);
448        let mut bus = Bus::new();
449        bus.insert(ClientIdentity("user1".into()));
450
451        // Use up the budget
452        let _ = guard.run("1".into(), &(), &mut bus).await;
453        let _ = guard.run("2".into(), &(), &mut bus).await;
454        let result = guard.run("3".into(), &(), &mut bus).await;
455        assert!(matches!(result, Outcome::Fault(_)));
456    }
457
458    #[tokio::test]
459    async fn security_headers_injects_policy() {
460        let guard = SecurityHeadersGuard::<String>::new(SecurityPolicy::default());
461        let mut bus = Bus::new();
462        let result = guard.run("ok".into(), &(), &mut bus).await;
463        assert!(matches!(result, Outcome::Next(_)));
464        let headers = bus.read::<SecurityHeaders>().unwrap();
465        assert_eq!(headers.0.x_frame_options, "DENY");
466    }
467
468    #[tokio::test]
469    async fn ip_filter_allow_list_permits() {
470        let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
471        let mut bus = Bus::new();
472        bus.insert(ClientIp("10.0.0.1".into()));
473        let result = guard.run("ok".into(), &(), &mut bus).await;
474        assert!(matches!(result, Outcome::Next(_)));
475    }
476
477    #[tokio::test]
478    async fn ip_filter_allow_list_denies() {
479        let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
480        let mut bus = Bus::new();
481        bus.insert(ClientIp("192.168.1.1".into()));
482        let result = guard.run("ok".into(), &(), &mut bus).await;
483        assert!(matches!(result, Outcome::Fault(_)));
484    }
485
486    #[tokio::test]
487    async fn ip_filter_deny_list_blocks() {
488        let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
489        let mut bus = Bus::new();
490        bus.insert(ClientIp("10.0.0.1".into()));
491        let result = guard.run("ok".into(), &(), &mut bus).await;
492        assert!(matches!(result, Outcome::Fault(_)));
493    }
494
495    #[tokio::test]
496    async fn ip_filter_deny_list_allows() {
497        let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
498        let mut bus = Bus::new();
499        bus.insert(ClientIp("192.168.1.1".into()));
500        let result = guard.run("ok".into(), &(), &mut bus).await;
501        assert!(matches!(result, Outcome::Next(_)));
502    }
503
504    // --- AccessLogGuard tests ---
505
506    #[tokio::test]
507    async fn access_log_guard_passes_input_through() {
508        let guard = AccessLogGuard::<String>::new();
509        let mut bus = Bus::new();
510        bus.insert(AccessLogRequest {
511            method: "GET".into(),
512            path: "/users".into(),
513        });
514        let result = guard.run("payload".into(), &(), &mut bus).await;
515        assert!(matches!(result, Outcome::Next(ref v) if v == "payload"));
516    }
517
518    #[tokio::test]
519    async fn access_log_guard_writes_entry_to_bus() {
520        let guard = AccessLogGuard::<String>::new();
521        let mut bus = Bus::new();
522        bus.insert(AccessLogRequest {
523            method: "POST".into(),
524            path: "/api/orders".into(),
525        });
526        let _result = guard.run("ok".into(), &(), &mut bus).await;
527        let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
528        assert_eq!(entry.method, "POST");
529        assert_eq!(entry.path, "/api/orders");
530    }
531
532    #[tokio::test]
533    async fn access_log_guard_redacts_paths() {
534        let guard = AccessLogGuard::<String>::new().redact_paths(vec!["/auth/login".into()]);
535        let mut bus = Bus::new();
536        bus.insert(AccessLogRequest {
537            method: "POST".into(),
538            path: "/auth/login".into(),
539        });
540        let _result = guard.run("ok".into(), &(), &mut bus).await;
541        let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
542        assert_eq!(entry.path, "[redacted]");
543    }
544
545    #[tokio::test]
546    async fn access_log_guard_works_without_request_in_bus() {
547        let guard = AccessLogGuard::<String>::new();
548        let mut bus = Bus::new();
549        let result = guard.run("ok".into(), &(), &mut bus).await;
550        assert!(matches!(result, Outcome::Next(_)));
551        let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
552        assert_eq!(entry.method, "");
553        assert_eq!(entry.path, "");
554    }
555
556    #[tokio::test]
557    async fn access_log_guard_default_works() {
558        let guard = AccessLogGuard::<String>::default();
559        let mut bus = Bus::new();
560        bus.insert(AccessLogRequest {
561            method: "DELETE".into(),
562            path: "/api/v1/users/42".into(),
563        });
564        let result = guard.run("ok".into(), &(), &mut bus).await;
565        assert!(matches!(result, Outcome::Next(_)));
566    }
567
568    #[tokio::test]
569    async fn access_log_guard_entry_has_timestamp() {
570        let guard = AccessLogGuard::<String>::new();
571        let mut bus = Bus::new();
572        bus.insert(AccessLogRequest {
573            method: "GET".into(),
574            path: "/".into(),
575        });
576        let _result = guard.run("ok".into(), &(), &mut bus).await;
577        let entry = bus.read::<AccessLogEntry>().unwrap();
578        // Timestamp should be non-zero (milliseconds since epoch)
579        assert!(entry.timestamp_ms > 1_700_000_000_000);
580    }
581
582    #[tokio::test]
583    async fn access_log_guard_works_with_integer_type() {
584        let guard = AccessLogGuard::<i32>::new();
585        let mut bus = Bus::new();
586        bus.insert(AccessLogRequest {
587            method: "PUT".into(),
588            path: "/count".into(),
589        });
590        let result = guard.run(42, &(), &mut bus).await;
591        assert!(matches!(result, Outcome::Next(42)));
592    }
593
594    #[tokio::test]
595    async fn access_log_guard_non_redacted_path_preserved() {
596        let guard = AccessLogGuard::<String>::new()
597            .redact_paths(vec!["/auth/login".into()]);
598        let mut bus = Bus::new();
599        bus.insert(AccessLogRequest {
600            method: "GET".into(),
601            path: "/api/public".into(),
602        });
603        let _result = guard.run("ok".into(), &(), &mut bus).await;
604        let entry = bus.read::<AccessLogEntry>().unwrap();
605        assert_eq!(entry.path, "/api/public");
606    }
607}
608
609// ---------------------------------------------------------------------------
610// AccessLogGuard
611// ---------------------------------------------------------------------------
612
613/// Request metadata injected into the Bus before `AccessLogGuard` runs.
614///
615/// Typically set by an HTTP extractor or middleware before the guard.
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct AccessLogRequest {
618    pub method: String,
619    pub path: String,
620}
621
622/// Access log entry written to the Bus by `AccessLogGuard`.
623///
624/// Downstream nodes can read this to inspect what was logged.
625#[derive(Debug, Clone, Serialize, Deserialize)]
626pub struct AccessLogEntry {
627    pub method: String,
628    pub path: String,
629    pub timestamp_ms: u64,
630}
631
632/// HTTP access log guard — logs request metadata and writes an [`AccessLogEntry`]
633/// to the Bus.
634///
635/// This is a **pass-through** guard: it always returns `Outcome::next(input)`.
636/// It never faults — if no [`AccessLogRequest`] is in the Bus, it logs an empty
637/// entry.
638///
639/// # Example
640///
641/// ```ignore
642/// Axon::new("api")
643///     .then(AccessLogGuard::new()
644///         .redact_paths(vec!["/auth/login".into()]))
645///     .then(CorsGuard::default())
646///     .then(business_logic)
647/// ```
648#[derive(Debug, Clone)]
649pub struct AccessLogGuard<T> {
650    redact_paths: Vec<String>,
651    _marker: PhantomData<T>,
652}
653
654impl<T> AccessLogGuard<T> {
655    /// Create a new `AccessLogGuard` with default settings.
656    pub fn new() -> Self {
657        Self {
658            redact_paths: Vec::new(),
659            _marker: PhantomData,
660        }
661    }
662
663    /// Paths whose entries will have the path replaced with `"[redacted]"`.
664    ///
665    /// Use this for sensitive endpoints (e.g., login, token refresh) where
666    /// logging the path itself might leak information.
667    pub fn redact_paths(mut self, paths: Vec<String>) -> Self {
668        self.redact_paths = paths;
669        self
670    }
671}
672
673impl<T> Default for AccessLogGuard<T> {
674    fn default() -> Self {
675        Self::new()
676    }
677}
678
679#[async_trait]
680impl<T> Transition<T, T> for AccessLogGuard<T>
681where
682    T: Send + Sync + 'static,
683{
684    type Error = String;
685    type Resources = ();
686
687    async fn run(
688        &self,
689        input: T,
690        _resources: &Self::Resources,
691        bus: &mut Bus,
692    ) -> Outcome<T, Self::Error> {
693        let req = bus.read::<AccessLogRequest>().cloned();
694        let (method, raw_path) = match &req {
695            Some(r) => (r.method.clone(), r.path.clone()),
696            None => (String::new(), String::new()),
697        };
698
699        let display_path = if self.redact_paths.iter().any(|p| p == &raw_path) {
700            "[redacted]".to_string()
701        } else {
702            raw_path
703        };
704
705        let now_ms = std::time::SystemTime::now()
706            .duration_since(std::time::UNIX_EPOCH)
707            .unwrap_or_default()
708            .as_millis() as u64;
709
710        tracing::info!(method = %method, path = %display_path, "access");
711
712        bus.insert(AccessLogEntry {
713            method,
714            path: display_path,
715            timestamp_ms: now_ms,
716        });
717
718        Outcome::next(input)
719    }
720}