1use async_trait::async_trait;
11use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
12use serde::{Deserialize, Serialize};
13use std::collections::HashSet;
14use std::marker::PhantomData;
15use std::sync::Arc;
16use std::time::Instant;
17use tokio::sync::Mutex;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct CorsConfig {
26 pub allowed_origins: Vec<String>,
27 pub allowed_methods: Vec<String>,
28 pub allowed_headers: Vec<String>,
29 pub max_age_seconds: u64,
30 pub allow_credentials: bool,
31}
32
33impl Default for CorsConfig {
34 fn default() -> Self {
35 Self {
36 allowed_origins: vec!["*".to_string()],
37 allowed_methods: vec![
38 "GET".into(),
39 "POST".into(),
40 "PUT".into(),
41 "DELETE".into(),
42 "OPTIONS".into(),
43 ],
44 allowed_headers: vec!["Content-Type".into(), "Authorization".into()],
45 max_age_seconds: 86400,
46 allow_credentials: false,
47 }
48 }
49}
50
51impl CorsConfig {
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
57 self.allowed_origins.push(origin.into());
58 self
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct RequestOrigin(pub String);
65
66#[derive(Debug, Clone)]
71pub struct CorsGuard<T> {
72 config: CorsConfig,
73 _marker: PhantomData<T>,
74}
75
76impl<T> CorsGuard<T> {
77 pub fn new(config: CorsConfig) -> Self {
78 Self {
79 config,
80 _marker: PhantomData,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct CorsHeaders {
88 pub access_control_allow_origin: String,
89 pub access_control_allow_methods: String,
90 pub access_control_allow_headers: String,
91 pub access_control_max_age: String,
92}
93
94#[async_trait]
95impl<T> Transition<T, T> for CorsGuard<T>
96where
97 T: Send + Sync + 'static,
98{
99 type Error = String;
100 type Resources = ();
101
102 async fn run(
103 &self,
104 input: T,
105 _resources: &Self::Resources,
106 bus: &mut Bus,
107 ) -> Outcome<T, Self::Error> {
108 let origin = bus
109 .read::<RequestOrigin>()
110 .map(|o| o.0.clone())
111 .unwrap_or_default();
112
113 let allowed = self.config.allowed_origins.contains(&"*".to_string())
114 || self.config.allowed_origins.contains(&origin);
115
116 if !allowed && !origin.is_empty() {
117 return Outcome::fault(format!("CORS: origin '{}' not allowed", origin));
118 }
119
120 let allow_origin = if self.config.allowed_origins.contains(&"*".to_string()) {
121 "*".to_string()
122 } else {
123 origin
124 };
125
126 bus.insert(CorsHeaders {
127 access_control_allow_origin: allow_origin,
128 access_control_allow_methods: self.config.allowed_methods.join(", "),
129 access_control_allow_headers: self.config.allowed_headers.join(", "),
130 access_control_max_age: self.config.max_age_seconds.to_string(),
131 });
132
133 Outcome::next(input)
134 }
135}
136
137#[derive(Debug, Clone, Hash, PartialEq, Eq)]
143pub struct ClientIdentity(pub String);
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct RateLimitError {
148 pub message: String,
149 pub retry_after_ms: u64,
150}
151
152impl std::fmt::Display for RateLimitError {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 write!(f, "{} (retry after {}ms)", self.message, self.retry_after_ms)
155 }
156}
157
158struct RateBucket {
160 tokens: f64,
161 last_refill: Instant,
162}
163
164pub struct RateLimitGuard<T> {
168 max_requests: u64,
169 window_ms: u64,
170 buckets: Arc<Mutex<std::collections::HashMap<String, RateBucket>>>,
171 _marker: PhantomData<T>,
172}
173
174impl<T> RateLimitGuard<T> {
175 pub fn new(max_requests: u64, window_ms: u64) -> Self {
176 Self {
177 max_requests,
178 window_ms,
179 buckets: Arc::new(Mutex::new(std::collections::HashMap::new())),
180 _marker: PhantomData,
181 }
182 }
183}
184
185impl<T> Clone for RateLimitGuard<T> {
186 fn clone(&self) -> Self {
187 Self {
188 max_requests: self.max_requests,
189 window_ms: self.window_ms,
190 buckets: self.buckets.clone(),
191 _marker: PhantomData,
192 }
193 }
194}
195
196impl<T> std::fmt::Debug for RateLimitGuard<T> {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 f.debug_struct("RateLimitGuard")
199 .field("max_requests", &self.max_requests)
200 .field("window_ms", &self.window_ms)
201 .finish()
202 }
203}
204
205#[async_trait]
206impl<T> Transition<T, T> for RateLimitGuard<T>
207where
208 T: Send + Sync + 'static,
209{
210 type Error = String;
211 type Resources = ();
212
213 async fn run(
214 &self,
215 input: T,
216 _resources: &Self::Resources,
217 bus: &mut Bus,
218 ) -> Outcome<T, Self::Error> {
219 let client_id = bus
220 .read::<ClientIdentity>()
221 .map(|c| c.0.clone())
222 .unwrap_or_else(|| "anonymous".to_string());
223
224 let mut buckets = self.buckets.lock().await;
225 let now = Instant::now();
226 let rate = self.max_requests as f64 / self.window_ms as f64 * 1000.0;
227
228 let bucket = buckets.entry(client_id).or_insert(RateBucket {
229 tokens: self.max_requests as f64,
230 last_refill: now,
231 });
232
233 let elapsed_ms = now.duration_since(bucket.last_refill).as_millis() as f64;
235 bucket.tokens = (bucket.tokens + elapsed_ms * rate / 1000.0).min(self.max_requests as f64);
236 bucket.last_refill = now;
237
238 if bucket.tokens >= 1.0 {
239 bucket.tokens -= 1.0;
240 Outcome::next(input)
241 } else {
242 let retry_after = ((1.0 - bucket.tokens) / rate * 1000.0) as u64;
243 Outcome::fault(format!(
244 "Rate limit exceeded. Retry after {}ms",
245 retry_after
246 ))
247 }
248 }
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct SecurityPolicy {
258 pub x_frame_options: String,
259 pub x_content_type_options: String,
260 pub strict_transport_security: String,
261 pub content_security_policy: Option<String>,
262 pub x_xss_protection: String,
263 pub referrer_policy: String,
264}
265
266impl Default for SecurityPolicy {
267 fn default() -> Self {
268 Self {
269 x_frame_options: "DENY".to_string(),
270 x_content_type_options: "nosniff".to_string(),
271 strict_transport_security: "max-age=31536000; includeSubDomains".to_string(),
272 content_security_policy: None,
273 x_xss_protection: "1; mode=block".to_string(),
274 referrer_policy: "strict-origin-when-cross-origin".to_string(),
275 }
276 }
277}
278
279impl SecurityPolicy {
280 pub fn new() -> Self {
281 Self::default()
282 }
283
284 pub fn with_csp(mut self, csp: impl Into<String>) -> Self {
285 self.content_security_policy = Some(csp.into());
286 self
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct SecurityHeaders(pub SecurityPolicy);
293
294#[derive(Debug, Clone)]
296pub struct SecurityHeadersGuard<T> {
297 policy: SecurityPolicy,
298 _marker: PhantomData<T>,
299}
300
301impl<T> SecurityHeadersGuard<T> {
302 pub fn new(policy: SecurityPolicy) -> Self {
303 Self {
304 policy,
305 _marker: PhantomData,
306 }
307 }
308}
309
310#[async_trait]
311impl<T> Transition<T, T> for SecurityHeadersGuard<T>
312where
313 T: Send + Sync + 'static,
314{
315 type Error = String;
316 type Resources = ();
317
318 async fn run(
319 &self,
320 input: T,
321 _resources: &Self::Resources,
322 bus: &mut Bus,
323 ) -> Outcome<T, Self::Error> {
324 bus.insert(SecurityHeaders(self.policy.clone()));
325 Outcome::next(input)
326 }
327}
328
329#[derive(Debug, Clone)]
335pub struct ClientIp(pub String);
336
337#[derive(Debug, Clone)]
339pub enum IpFilterMode {
340 AllowList(HashSet<String>),
342 DenyList(HashSet<String>),
344}
345
346#[derive(Debug, Clone)]
350pub struct IpFilterGuard<T> {
351 mode: IpFilterMode,
352 _marker: PhantomData<T>,
353}
354
355impl<T> IpFilterGuard<T> {
356 pub fn allow_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
357 Self {
358 mode: IpFilterMode::AllowList(ips.into_iter().map(|s| s.into()).collect()),
359 _marker: PhantomData,
360 }
361 }
362
363 pub fn deny_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
364 Self {
365 mode: IpFilterMode::DenyList(ips.into_iter().map(|s| s.into()).collect()),
366 _marker: PhantomData,
367 }
368 }
369}
370
371#[async_trait]
372impl<T> Transition<T, T> for IpFilterGuard<T>
373where
374 T: Send + Sync + 'static,
375{
376 type Error = String;
377 type Resources = ();
378
379 async fn run(
380 &self,
381 input: T,
382 _resources: &Self::Resources,
383 bus: &mut Bus,
384 ) -> Outcome<T, Self::Error> {
385 let client_ip = bus
386 .read::<ClientIp>()
387 .map(|ip| ip.0.clone())
388 .unwrap_or_default();
389
390 match &self.mode {
391 IpFilterMode::AllowList(allowed) => {
392 if allowed.contains(&client_ip) {
393 Outcome::next(input)
394 } else {
395 Outcome::fault(format!("IP '{}' not in allow list", client_ip))
396 }
397 }
398 IpFilterMode::DenyList(denied) => {
399 if denied.contains(&client_ip) {
400 Outcome::fault(format!("IP '{}' is denied", client_ip))
401 } else {
402 Outcome::next(input)
403 }
404 }
405 }
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[tokio::test]
414 async fn cors_guard_allows_wildcard() {
415 let guard = CorsGuard::<String>::new(CorsConfig::default());
416 let mut bus = Bus::new();
417 bus.insert(RequestOrigin("https://example.com".into()));
418 let result = guard.run("hello".into(), &(), &mut bus).await;
419 assert!(matches!(result, Outcome::Next(_)));
420 assert!(bus.read::<CorsHeaders>().is_some());
421 }
422
423 #[tokio::test]
424 async fn cors_guard_rejects_disallowed_origin() {
425 let config = CorsConfig {
426 allowed_origins: vec!["https://trusted.com".into()],
427 ..Default::default()
428 };
429 let guard = CorsGuard::<String>::new(config);
430 let mut bus = Bus::new();
431 bus.insert(RequestOrigin("https://evil.com".into()));
432 let result = guard.run("hello".into(), &(), &mut bus).await;
433 assert!(matches!(result, Outcome::Fault(_)));
434 }
435
436 #[tokio::test]
437 async fn rate_limit_allows_within_budget() {
438 let guard = RateLimitGuard::<String>::new(10, 1000);
439 let mut bus = Bus::new();
440 bus.insert(ClientIdentity("user1".into()));
441 let result = guard.run("ok".into(), &(), &mut bus).await;
442 assert!(matches!(result, Outcome::Next(_)));
443 }
444
445 #[tokio::test]
446 async fn rate_limit_exhausts_budget() {
447 let guard = RateLimitGuard::<String>::new(2, 60000);
448 let mut bus = Bus::new();
449 bus.insert(ClientIdentity("user1".into()));
450
451 let _ = guard.run("1".into(), &(), &mut bus).await;
453 let _ = guard.run("2".into(), &(), &mut bus).await;
454 let result = guard.run("3".into(), &(), &mut bus).await;
455 assert!(matches!(result, Outcome::Fault(_)));
456 }
457
458 #[tokio::test]
459 async fn security_headers_injects_policy() {
460 let guard = SecurityHeadersGuard::<String>::new(SecurityPolicy::default());
461 let mut bus = Bus::new();
462 let result = guard.run("ok".into(), &(), &mut bus).await;
463 assert!(matches!(result, Outcome::Next(_)));
464 let headers = bus.read::<SecurityHeaders>().unwrap();
465 assert_eq!(headers.0.x_frame_options, "DENY");
466 }
467
468 #[tokio::test]
469 async fn ip_filter_allow_list_permits() {
470 let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
471 let mut bus = Bus::new();
472 bus.insert(ClientIp("10.0.0.1".into()));
473 let result = guard.run("ok".into(), &(), &mut bus).await;
474 assert!(matches!(result, Outcome::Next(_)));
475 }
476
477 #[tokio::test]
478 async fn ip_filter_allow_list_denies() {
479 let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
480 let mut bus = Bus::new();
481 bus.insert(ClientIp("192.168.1.1".into()));
482 let result = guard.run("ok".into(), &(), &mut bus).await;
483 assert!(matches!(result, Outcome::Fault(_)));
484 }
485
486 #[tokio::test]
487 async fn ip_filter_deny_list_blocks() {
488 let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
489 let mut bus = Bus::new();
490 bus.insert(ClientIp("10.0.0.1".into()));
491 let result = guard.run("ok".into(), &(), &mut bus).await;
492 assert!(matches!(result, Outcome::Fault(_)));
493 }
494
495 #[tokio::test]
496 async fn ip_filter_deny_list_allows() {
497 let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
498 let mut bus = Bus::new();
499 bus.insert(ClientIp("192.168.1.1".into()));
500 let result = guard.run("ok".into(), &(), &mut bus).await;
501 assert!(matches!(result, Outcome::Next(_)));
502 }
503
504 #[tokio::test]
507 async fn access_log_guard_passes_input_through() {
508 let guard = AccessLogGuard::<String>::new();
509 let mut bus = Bus::new();
510 bus.insert(AccessLogRequest {
511 method: "GET".into(),
512 path: "/users".into(),
513 });
514 let result = guard.run("payload".into(), &(), &mut bus).await;
515 assert!(matches!(result, Outcome::Next(ref v) if v == "payload"));
516 }
517
518 #[tokio::test]
519 async fn access_log_guard_writes_entry_to_bus() {
520 let guard = AccessLogGuard::<String>::new();
521 let mut bus = Bus::new();
522 bus.insert(AccessLogRequest {
523 method: "POST".into(),
524 path: "/api/orders".into(),
525 });
526 let _result = guard.run("ok".into(), &(), &mut bus).await;
527 let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
528 assert_eq!(entry.method, "POST");
529 assert_eq!(entry.path, "/api/orders");
530 }
531
532 #[tokio::test]
533 async fn access_log_guard_redacts_paths() {
534 let guard = AccessLogGuard::<String>::new().redact_paths(vec!["/auth/login".into()]);
535 let mut bus = Bus::new();
536 bus.insert(AccessLogRequest {
537 method: "POST".into(),
538 path: "/auth/login".into(),
539 });
540 let _result = guard.run("ok".into(), &(), &mut bus).await;
541 let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
542 assert_eq!(entry.path, "[redacted]");
543 }
544
545 #[tokio::test]
546 async fn access_log_guard_works_without_request_in_bus() {
547 let guard = AccessLogGuard::<String>::new();
548 let mut bus = Bus::new();
549 let result = guard.run("ok".into(), &(), &mut bus).await;
550 assert!(matches!(result, Outcome::Next(_)));
551 let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
552 assert_eq!(entry.method, "");
553 assert_eq!(entry.path, "");
554 }
555
556 #[tokio::test]
557 async fn access_log_guard_default_works() {
558 let guard = AccessLogGuard::<String>::default();
559 let mut bus = Bus::new();
560 bus.insert(AccessLogRequest {
561 method: "DELETE".into(),
562 path: "/api/v1/users/42".into(),
563 });
564 let result = guard.run("ok".into(), &(), &mut bus).await;
565 assert!(matches!(result, Outcome::Next(_)));
566 }
567
568 #[tokio::test]
569 async fn access_log_guard_entry_has_timestamp() {
570 let guard = AccessLogGuard::<String>::new();
571 let mut bus = Bus::new();
572 bus.insert(AccessLogRequest {
573 method: "GET".into(),
574 path: "/".into(),
575 });
576 let _result = guard.run("ok".into(), &(), &mut bus).await;
577 let entry = bus.read::<AccessLogEntry>().unwrap();
578 assert!(entry.timestamp_ms > 1_700_000_000_000);
580 }
581
582 #[tokio::test]
583 async fn access_log_guard_works_with_integer_type() {
584 let guard = AccessLogGuard::<i32>::new();
585 let mut bus = Bus::new();
586 bus.insert(AccessLogRequest {
587 method: "PUT".into(),
588 path: "/count".into(),
589 });
590 let result = guard.run(42, &(), &mut bus).await;
591 assert!(matches!(result, Outcome::Next(42)));
592 }
593
594 #[tokio::test]
595 async fn access_log_guard_non_redacted_path_preserved() {
596 let guard = AccessLogGuard::<String>::new()
597 .redact_paths(vec!["/auth/login".into()]);
598 let mut bus = Bus::new();
599 bus.insert(AccessLogRequest {
600 method: "GET".into(),
601 path: "/api/public".into(),
602 });
603 let _result = guard.run("ok".into(), &(), &mut bus).await;
604 let entry = bus.read::<AccessLogEntry>().unwrap();
605 assert_eq!(entry.path, "/api/public");
606 }
607}
608
609#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct AccessLogRequest {
618 pub method: String,
619 pub path: String,
620}
621
622#[derive(Debug, Clone, Serialize, Deserialize)]
626pub struct AccessLogEntry {
627 pub method: String,
628 pub path: String,
629 pub timestamp_ms: u64,
630}
631
632#[derive(Debug, Clone)]
649pub struct AccessLogGuard<T> {
650 redact_paths: Vec<String>,
651 _marker: PhantomData<T>,
652}
653
654impl<T> AccessLogGuard<T> {
655 pub fn new() -> Self {
657 Self {
658 redact_paths: Vec::new(),
659 _marker: PhantomData,
660 }
661 }
662
663 pub fn redact_paths(mut self, paths: Vec<String>) -> Self {
668 self.redact_paths = paths;
669 self
670 }
671}
672
673impl<T> Default for AccessLogGuard<T> {
674 fn default() -> Self {
675 Self::new()
676 }
677}
678
679#[async_trait]
680impl<T> Transition<T, T> for AccessLogGuard<T>
681where
682 T: Send + Sync + 'static,
683{
684 type Error = String;
685 type Resources = ();
686
687 async fn run(
688 &self,
689 input: T,
690 _resources: &Self::Resources,
691 bus: &mut Bus,
692 ) -> Outcome<T, Self::Error> {
693 let req = bus.read::<AccessLogRequest>().cloned();
694 let (method, raw_path) = match &req {
695 Some(r) => (r.method.clone(), r.path.clone()),
696 None => (String::new(), String::new()),
697 };
698
699 let display_path = if self.redact_paths.iter().any(|p| p == &raw_path) {
700 "[redacted]".to_string()
701 } else {
702 raw_path
703 };
704
705 let now_ms = std::time::SystemTime::now()
706 .duration_since(std::time::UNIX_EPOCH)
707 .unwrap_or_default()
708 .as_millis() as u64;
709
710 tracing::info!(method = %method, path = %display_path, "access");
711
712 bus.insert(AccessLogEntry {
713 method,
714 path: display_path,
715 timestamp_ms: now_ms,
716 });
717
718 Outcome::next(input)
719 }
720}