1use async_trait::async_trait;
33use ranvier_core::iam::{enforce_policy, IamIdentity, IamPolicy};
34use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
35use serde::{Deserialize, Serialize};
36use std::collections::HashSet;
37use std::marker::PhantomData;
38use std::sync::Arc;
39use std::time::Instant;
40use subtle::ConstantTimeEq;
41use tokio::sync::Mutex;
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct CorsConfig {
50 pub allowed_origins: Vec<String>,
51 pub allowed_methods: Vec<String>,
52 pub allowed_headers: Vec<String>,
53 pub max_age_seconds: u64,
54 pub allow_credentials: bool,
55}
56
57impl Default for CorsConfig {
58 fn default() -> Self {
59 Self {
60 allowed_origins: vec!["*".to_string()],
61 allowed_methods: vec![
62 "GET".into(),
63 "POST".into(),
64 "PUT".into(),
65 "DELETE".into(),
66 "OPTIONS".into(),
67 ],
68 allowed_headers: vec!["Content-Type".into(), "Authorization".into()],
69 max_age_seconds: 86400,
70 allow_credentials: false,
71 }
72 }
73}
74
75impl CorsConfig {
76 pub fn new() -> Self {
77 Self::default()
78 }
79
80 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
81 self.allowed_origins.push(origin.into());
82 self
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct RequestOrigin(pub String);
89
90#[derive(Debug, Clone)]
95pub struct CorsGuard<T> {
96 config: CorsConfig,
97 _marker: PhantomData<T>,
98}
99
100impl<T> CorsGuard<T> {
101 pub fn new(config: CorsConfig) -> Self {
102 Self {
103 config,
104 _marker: PhantomData,
105 }
106 }
107
108 pub fn permissive() -> Self {
120 tracing::warn!("CorsGuard::permissive() — all origins allowed; do not use in production");
121 Self {
122 config: CorsConfig {
123 allowed_origins: vec!["*".to_string()],
124 allowed_methods: vec![
125 "GET".into(), "POST".into(), "PUT".into(), "DELETE".into(),
126 "PATCH".into(), "OPTIONS".into(), "HEAD".into(),
127 ],
128 allowed_headers: vec![
129 "Content-Type".into(), "Authorization".into(), "Accept".into(),
130 "Origin".into(), "X-Requested-With".into(),
131 ],
132 max_age_seconds: 86400,
133 allow_credentials: false,
134 },
135 _marker: PhantomData,
136 }
137 }
138
139 pub fn cors_config(&self) -> &CorsConfig {
141 &self.config
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CorsHeaders {
148 pub access_control_allow_origin: String,
149 pub access_control_allow_methods: String,
150 pub access_control_allow_headers: String,
151 pub access_control_max_age: String,
152}
153
154#[async_trait]
155impl<T> Transition<T, T> for CorsGuard<T>
156where
157 T: Send + Sync + 'static,
158{
159 type Error = String;
160 type Resources = ();
161
162 async fn run(
163 &self,
164 input: T,
165 _resources: &Self::Resources,
166 bus: &mut Bus,
167 ) -> Outcome<T, Self::Error> {
168 let origin = bus
169 .read::<RequestOrigin>()
170 .map(|o| o.0.clone())
171 .unwrap_or_default();
172
173 let allowed = self.config.allowed_origins.contains(&"*".to_string())
174 || self.config.allowed_origins.contains(&origin);
175
176 if !allowed && !origin.is_empty() {
177 return Outcome::fault(format!("CORS: origin '{}' not allowed", origin));
178 }
179
180 let allow_origin = if self.config.allowed_origins.contains(&"*".to_string()) {
181 "*".to_string()
182 } else {
183 origin
184 };
185
186 bus.insert(CorsHeaders {
187 access_control_allow_origin: allow_origin,
188 access_control_allow_methods: self.config.allowed_methods.join(", "),
189 access_control_allow_headers: self.config.allowed_headers.join(", "),
190 access_control_max_age: self.config.max_age_seconds.to_string(),
191 });
192
193 Outcome::next(input)
194 }
195}
196
197#[derive(Debug, Clone, Hash, PartialEq, Eq)]
203pub struct ClientIdentity(pub String);
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct RateLimitError {
208 pub message: String,
209 pub retry_after_ms: u64,
210}
211
212impl std::fmt::Display for RateLimitError {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 write!(f, "{} (retry after {}ms)", self.message, self.retry_after_ms)
215 }
216}
217
218struct RateBucket {
220 tokens: f64,
221 last_refill: Instant,
222}
223
224pub struct RateLimitGuard<T> {
231 max_requests: u64,
232 window_ms: u64,
233 buckets: Arc<Mutex<std::collections::HashMap<String, RateBucket>>>,
234 bucket_ttl_ms: u64,
236 _marker: PhantomData<T>,
237}
238
239impl<T> RateLimitGuard<T> {
240 pub fn new(max_requests: u64, window_ms: u64) -> Self {
241 Self {
242 max_requests,
243 window_ms,
244 buckets: Arc::new(Mutex::new(std::collections::HashMap::new())),
245 bucket_ttl_ms: 0,
246 _marker: PhantomData,
247 }
248 }
249
250 pub fn with_bucket_ttl(mut self, ttl: std::time::Duration) -> Self {
255 self.bucket_ttl_ms = ttl.as_millis() as u64;
256 self
257 }
258
259 pub fn max_requests(&self) -> u64 {
261 self.max_requests
262 }
263
264 pub fn window_ms(&self) -> u64 {
266 self.window_ms
267 }
268
269 pub fn bucket_ttl_ms(&self) -> u64 {
271 self.bucket_ttl_ms
272 }
273}
274
275impl<T> Clone for RateLimitGuard<T> {
276 fn clone(&self) -> Self {
277 Self {
278 max_requests: self.max_requests,
279 window_ms: self.window_ms,
280 buckets: self.buckets.clone(),
281 bucket_ttl_ms: self.bucket_ttl_ms,
282 _marker: PhantomData,
283 }
284 }
285}
286
287impl<T> std::fmt::Debug for RateLimitGuard<T> {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 f.debug_struct("RateLimitGuard")
290 .field("max_requests", &self.max_requests)
291 .field("window_ms", &self.window_ms)
292 .field("bucket_ttl_ms", &self.bucket_ttl_ms)
293 .finish()
294 }
295}
296
297#[async_trait]
298impl<T> Transition<T, T> for RateLimitGuard<T>
299where
300 T: Send + Sync + 'static,
301{
302 type Error = String;
303 type Resources = ();
304
305 async fn run(
306 &self,
307 input: T,
308 _resources: &Self::Resources,
309 bus: &mut Bus,
310 ) -> Outcome<T, Self::Error> {
311 let client_id = bus
312 .read::<ClientIdentity>()
313 .map(|c| c.0.clone())
314 .unwrap_or_else(|| "anonymous".to_string());
315
316 let mut buckets = self.buckets.lock().await;
317 let now = Instant::now();
318
319 if self.bucket_ttl_ms > 0 {
321 let ttl = std::time::Duration::from_millis(self.bucket_ttl_ms);
322 buckets.retain(|_, b| now.duration_since(b.last_refill) < ttl);
323 }
324
325 let rate = self.max_requests as f64 / self.window_ms as f64 * 1000.0;
326
327 let bucket = buckets.entry(client_id).or_insert(RateBucket {
328 tokens: self.max_requests as f64,
329 last_refill: now,
330 });
331
332 let elapsed_ms = now.duration_since(bucket.last_refill).as_millis() as f64;
334 bucket.tokens = (bucket.tokens + elapsed_ms * rate / 1000.0).min(self.max_requests as f64);
335 bucket.last_refill = now;
336
337 if bucket.tokens >= 1.0 {
338 bucket.tokens -= 1.0;
339 Outcome::next(input)
340 } else {
341 let retry_after = ((1.0 - bucket.tokens) / rate * 1000.0) as u64;
342 Outcome::fault(format!(
343 "Rate limit exceeded. Retry after {}ms",
344 retry_after
345 ))
346 }
347 }
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct SecurityPolicy {
357 pub x_frame_options: String,
358 pub x_content_type_options: String,
359 pub strict_transport_security: String,
360 pub content_security_policy: Option<String>,
361 pub x_xss_protection: String,
362 pub referrer_policy: String,
363}
364
365impl Default for SecurityPolicy {
366 fn default() -> Self {
367 Self {
368 x_frame_options: "DENY".to_string(),
369 x_content_type_options: "nosniff".to_string(),
370 strict_transport_security: "max-age=31536000; includeSubDomains".to_string(),
371 content_security_policy: None,
372 x_xss_protection: "1; mode=block".to_string(),
373 referrer_policy: "strict-origin-when-cross-origin".to_string(),
374 }
375 }
376}
377
378impl SecurityPolicy {
379 pub fn new() -> Self {
380 Self::default()
381 }
382
383 pub fn with_csp(mut self, csp: impl Into<String>) -> Self {
384 self.content_security_policy = Some(csp.into());
385 self
386 }
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct SecurityHeaders(pub SecurityPolicy);
392
393#[derive(Debug, Clone)]
395pub struct SecurityHeadersGuard<T> {
396 policy: SecurityPolicy,
397 _marker: PhantomData<T>,
398}
399
400impl<T> SecurityHeadersGuard<T> {
401 pub fn new(policy: SecurityPolicy) -> Self {
402 Self {
403 policy,
404 _marker: PhantomData,
405 }
406 }
407
408 pub fn policy(&self) -> &SecurityPolicy {
410 &self.policy
411 }
412}
413
414#[async_trait]
415impl<T> Transition<T, T> for SecurityHeadersGuard<T>
416where
417 T: Send + Sync + 'static,
418{
419 type Error = String;
420 type Resources = ();
421
422 async fn run(
423 &self,
424 input: T,
425 _resources: &Self::Resources,
426 bus: &mut Bus,
427 ) -> Outcome<T, Self::Error> {
428 bus.insert(SecurityHeaders(self.policy.clone()));
429 Outcome::next(input)
430 }
431}
432
433#[derive(Debug, Clone)]
439pub struct ClientIp(pub String);
440
441#[derive(Debug, Clone)]
464pub struct TrustedProxies {
465 proxies: HashSet<String>,
466}
467
468impl TrustedProxies {
469 pub fn new(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
471 Self {
472 proxies: ips.into_iter().map(|s| s.into()).collect(),
473 }
474 }
475
476 pub fn extract(&self, xff_header: &str, direct_ip: &str) -> String {
481 if !self.proxies.contains(direct_ip) {
483 return direct_ip.to_string();
484 }
485
486 let parts: Vec<&str> = xff_header.split(',').map(|s| s.trim()).collect();
488 for ip in parts.iter().rev() {
489 if !ip.is_empty() && !self.proxies.contains(*ip) {
490 return ip.to_string();
491 }
492 }
493
494 direct_ip.to_string()
496 }
497
498 pub fn is_trusted(&self, ip: &str) -> bool {
500 self.proxies.contains(ip)
501 }
502}
503
504#[derive(Debug, Clone)]
506pub enum IpFilterMode {
507 AllowList(HashSet<String>),
509 DenyList(HashSet<String>),
511}
512
513#[derive(Debug, Clone)]
517pub struct IpFilterGuard<T> {
518 mode: IpFilterMode,
519 _marker: PhantomData<T>,
520}
521
522impl<T> IpFilterGuard<T> {
523 pub fn allow_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
524 Self {
525 mode: IpFilterMode::AllowList(ips.into_iter().map(|s| s.into()).collect()),
526 _marker: PhantomData,
527 }
528 }
529
530 pub fn deny_list(ips: impl IntoIterator<Item = impl Into<String>>) -> Self {
531 Self {
532 mode: IpFilterMode::DenyList(ips.into_iter().map(|s| s.into()).collect()),
533 _marker: PhantomData,
534 }
535 }
536
537 pub fn clone_as_unit(&self) -> IpFilterGuard<()> {
539 IpFilterGuard {
540 mode: self.mode.clone(),
541 _marker: PhantomData,
542 }
543 }
544}
545
546#[async_trait]
547impl<T> Transition<T, T> for IpFilterGuard<T>
548where
549 T: Send + Sync + 'static,
550{
551 type Error = String;
552 type Resources = ();
553
554 async fn run(
555 &self,
556 input: T,
557 _resources: &Self::Resources,
558 bus: &mut Bus,
559 ) -> Outcome<T, Self::Error> {
560 let client_ip = bus
561 .read::<ClientIp>()
562 .map(|ip| ip.0.clone())
563 .unwrap_or_default();
564
565 match &self.mode {
566 IpFilterMode::AllowList(allowed) => {
567 if allowed.contains(&client_ip) {
568 Outcome::next(input)
569 } else {
570 Outcome::fault(format!("IP '{}' not in allow list", client_ip))
571 }
572 }
573 IpFilterMode::DenyList(denied) => {
574 if denied.contains(&client_ip) {
575 Outcome::fault(format!("IP '{}' is denied", client_ip))
576 } else {
577 Outcome::next(input)
578 }
579 }
580 }
581 }
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
592pub struct AccessLogRequest {
593 pub method: String,
594 pub path: String,
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
601pub struct AccessLogEntry {
602 pub method: String,
603 pub path: String,
604 pub timestamp_ms: u64,
605}
606
607#[derive(Debug, Clone)]
624pub struct AccessLogGuard<T> {
625 redact_paths: Vec<String>,
626 _marker: PhantomData<T>,
627}
628
629impl<T> AccessLogGuard<T> {
630 pub fn new() -> Self {
632 Self {
633 redact_paths: Vec::new(),
634 _marker: PhantomData,
635 }
636 }
637
638 pub fn redact_paths(mut self, paths: Vec<String>) -> Self {
643 self.redact_paths = paths;
644 self
645 }
646
647 pub fn clone_as_unit(&self) -> AccessLogGuard<()> {
649 AccessLogGuard {
650 redact_paths: self.redact_paths.clone(),
651 _marker: PhantomData,
652 }
653 }
654}
655
656impl<T> Default for AccessLogGuard<T> {
657 fn default() -> Self {
658 Self::new()
659 }
660}
661
662#[async_trait]
663impl<T> Transition<T, T> for AccessLogGuard<T>
664where
665 T: Send + Sync + 'static,
666{
667 type Error = String;
668 type Resources = ();
669
670 async fn run(
671 &self,
672 input: T,
673 _resources: &Self::Resources,
674 bus: &mut Bus,
675 ) -> Outcome<T, Self::Error> {
676 let req = bus.read::<AccessLogRequest>().cloned();
677 let (method, raw_path) = match &req {
678 Some(r) => (r.method.clone(), r.path.clone()),
679 None => (String::new(), String::new()),
680 };
681
682 let display_path = if self.redact_paths.iter().any(|p| p == &raw_path) {
683 "[redacted]".to_string()
684 } else {
685 raw_path
686 };
687
688 let now_ms = std::time::SystemTime::now()
689 .duration_since(std::time::UNIX_EPOCH)
690 .unwrap_or_default()
691 .as_millis() as u64;
692
693 tracing::info!(method = %method, path = %display_path, "access");
694
695 bus.insert(AccessLogEntry {
696 method,
697 path: display_path,
698 timestamp_ms: now_ms,
699 });
700
701 Outcome::next(input)
702 }
703}
704
705#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
711pub enum CompressionEncoding {
712 Gzip,
713 Brotli,
714 Zstd,
715 Identity,
716}
717
718impl CompressionEncoding {
719 pub fn as_str(&self) -> &'static str {
721 match self {
722 Self::Gzip => "gzip",
723 Self::Brotli => "br",
724 Self::Zstd => "zstd",
725 Self::Identity => "identity",
726 }
727 }
728}
729
730#[derive(Debug, Clone)]
732pub struct AcceptEncoding(pub String);
733
734#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct CompressionConfig {
740 pub encoding: CompressionEncoding,
741 pub min_body_size: usize,
742}
743
744#[derive(Debug, Clone)]
758pub struct CompressionGuard<T> {
759 preferred: Vec<CompressionEncoding>,
760 min_body_size: usize,
761 _marker: PhantomData<T>,
762}
763
764impl<T> CompressionGuard<T> {
765 pub fn new() -> Self {
767 Self {
768 preferred: vec![CompressionEncoding::Gzip, CompressionEncoding::Identity],
769 min_body_size: 256,
770 _marker: PhantomData,
771 }
772 }
773
774 pub fn prefer_brotli(mut self) -> Self {
776 self.preferred = vec![
777 CompressionEncoding::Brotli,
778 CompressionEncoding::Gzip,
779 CompressionEncoding::Identity,
780 ];
781 self
782 }
783
784 pub fn with_min_body_size(mut self, size: usize) -> Self {
786 self.min_body_size = size;
787 self
788 }
789
790 pub fn min_body_size(&self) -> usize {
792 self.min_body_size
793 }
794
795 pub fn preferred_encodings(&self) -> &[CompressionEncoding] {
797 &self.preferred
798 }
799}
800
801impl<T> Default for CompressionGuard<T> {
802 fn default() -> Self {
803 Self::new()
804 }
805}
806
807fn parse_accept_encoding(header: &str) -> HashSet<String> {
809 header
810 .split(',')
811 .map(|s| {
812 s.split(';')
813 .next()
814 .unwrap_or("")
815 .trim()
816 .to_lowercase()
817 })
818 .filter(|s| !s.is_empty())
819 .collect()
820}
821
822#[async_trait]
823impl<T> Transition<T, T> for CompressionGuard<T>
824where
825 T: Send + Sync + 'static,
826{
827 type Error = String;
828 type Resources = ();
829
830 async fn run(
831 &self,
832 input: T,
833 _resources: &Self::Resources,
834 bus: &mut Bus,
835 ) -> Outcome<T, Self::Error> {
836 let accepted = bus
837 .read::<AcceptEncoding>()
838 .map(|ae| parse_accept_encoding(&ae.0))
839 .unwrap_or_default();
840
841 let selected = if accepted.is_empty() || accepted.contains("*") {
843 self.preferred.first().copied().unwrap_or(CompressionEncoding::Identity)
844 } else {
845 self.preferred
846 .iter()
847 .find(|enc| accepted.contains(enc.as_str()))
848 .copied()
849 .unwrap_or(CompressionEncoding::Identity)
850 };
851
852 bus.insert(CompressionConfig {
853 encoding: selected,
854 min_body_size: self.min_body_size,
855 });
856
857 Outcome::next(input)
858 }
859}
860
861#[derive(Debug, Clone)]
867pub struct ContentLength(pub u64);
868
869#[derive(Debug, Clone)]
883pub struct RequestSizeLimitGuard<T> {
884 max_bytes: u64,
885 _marker: PhantomData<T>,
886}
887
888impl<T> RequestSizeLimitGuard<T> {
889 pub fn new(max_bytes: u64) -> Self {
891 Self {
892 max_bytes,
893 _marker: PhantomData,
894 }
895 }
896
897 pub fn max_2mb() -> Self {
899 Self::new(2 * 1024 * 1024)
900 }
901
902 pub fn max_10mb() -> Self {
904 Self::new(10 * 1024 * 1024)
905 }
906
907 pub fn max_bytes(&self) -> u64 {
909 self.max_bytes
910 }
911}
912
913#[async_trait]
914impl<T> Transition<T, T> for RequestSizeLimitGuard<T>
915where
916 T: Send + Sync + 'static,
917{
918 type Error = String;
919 type Resources = ();
920
921 async fn run(
922 &self,
923 input: T,
924 _resources: &Self::Resources,
925 bus: &mut Bus,
926 ) -> Outcome<T, Self::Error> {
927 if let Some(len) = bus.read::<ContentLength>() {
928 if len.0 > self.max_bytes {
929 return Outcome::fault(format!(
930 "413 Payload Too Large: {} bytes exceeds limit of {} bytes",
931 len.0, self.max_bytes
932 ));
933 }
934 }
935 Outcome::next(input)
936 }
937}
938
939#[derive(Debug, Clone, Serialize, Deserialize)]
948pub struct RequestId(pub String);
949
950#[derive(Debug, Clone)]
963pub struct RequestIdGuard<T> {
964 _marker: PhantomData<T>,
965}
966
967impl<T> RequestIdGuard<T> {
968 pub fn new() -> Self {
969 Self {
970 _marker: PhantomData,
971 }
972 }
973}
974
975impl<T> Default for RequestIdGuard<T> {
976 fn default() -> Self {
977 Self::new()
978 }
979}
980
981#[async_trait]
982impl<T> Transition<T, T> for RequestIdGuard<T>
983where
984 T: Send + Sync + 'static,
985{
986 type Error = String;
987 type Resources = ();
988
989 async fn run(
990 &self,
991 input: T,
992 _resources: &Self::Resources,
993 bus: &mut Bus,
994 ) -> Outcome<T, Self::Error> {
995 if bus.read::<RequestId>().is_none() {
997 bus.insert(RequestId(uuid::Uuid::new_v4().to_string()));
998 }
999
1000 if let Some(rid) = bus.read::<RequestId>() {
1002 tracing::debug!(request_id = %rid.0, "request id assigned");
1003 }
1004
1005 Outcome::next(input)
1006 }
1007}
1008
1009#[derive(Debug, Clone)]
1015pub struct AuthorizationHeader(pub String);
1016
1017pub enum AuthStrategy {
1019 Bearer {
1024 tokens: Vec<String>,
1025 },
1026
1027 ApiKey {
1032 header_name: String,
1033 valid_keys: Vec<String>,
1034 },
1035
1036 Custom {
1041 validator: Arc<dyn Fn(&str) -> Result<IamIdentity, String> + Send + Sync + 'static>,
1042 },
1043}
1044
1045impl Clone for AuthStrategy {
1046 fn clone(&self) -> Self {
1047 match self {
1048 Self::Bearer { tokens } => Self::Bearer {
1049 tokens: tokens.clone(),
1050 },
1051 Self::ApiKey {
1052 header_name,
1053 valid_keys,
1054 } => Self::ApiKey {
1055 header_name: header_name.clone(),
1056 valid_keys: valid_keys.clone(),
1057 },
1058 Self::Custom { validator } => Self::Custom {
1059 validator: validator.clone(),
1060 },
1061 }
1062 }
1063}
1064
1065impl std::fmt::Debug for AuthStrategy {
1066 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1067 match self {
1068 Self::Bearer { tokens } => f
1069 .debug_struct("Bearer")
1070 .field("token_count", &tokens.len())
1071 .finish(),
1072 Self::ApiKey { header_name, valid_keys } => f
1073 .debug_struct("ApiKey")
1074 .field("header_name", header_name)
1075 .field("key_count", &valid_keys.len())
1076 .finish(),
1077 Self::Custom { .. } => f.debug_struct("Custom").finish(),
1078 }
1079 }
1080}
1081
1082pub struct AuthGuard<T> {
1104 strategy: AuthStrategy,
1105 policy: IamPolicy,
1106 _marker: PhantomData<T>,
1107}
1108
1109impl<T> AuthGuard<T> {
1110 pub fn new(strategy: AuthStrategy) -> Self {
1112 Self {
1113 strategy,
1114 policy: IamPolicy::None,
1115 _marker: PhantomData,
1116 }
1117 }
1118
1119 pub fn bearer(tokens: Vec<String>) -> Self {
1121 Self::new(AuthStrategy::Bearer { tokens })
1122 }
1123
1124 pub fn api_key(header_name: impl Into<String>, valid_keys: Vec<String>) -> Self {
1126 Self::new(AuthStrategy::ApiKey {
1127 header_name: header_name.into(),
1128 valid_keys,
1129 })
1130 }
1131
1132 pub fn custom(
1134 validator: impl Fn(&str) -> Result<IamIdentity, String> + Send + Sync + 'static,
1135 ) -> Self {
1136 Self::new(AuthStrategy::Custom {
1137 validator: Arc::new(validator),
1138 })
1139 }
1140
1141 pub fn with_policy(mut self, policy: IamPolicy) -> Self {
1143 self.policy = policy;
1144 self
1145 }
1146
1147 pub fn strategy(&self) -> &AuthStrategy {
1149 &self.strategy
1150 }
1151
1152 pub fn iam_policy(&self) -> &IamPolicy {
1154 &self.policy
1155 }
1156}
1157
1158impl<T> Clone for AuthGuard<T> {
1159 fn clone(&self) -> Self {
1160 Self {
1161 strategy: self.strategy.clone(),
1162 policy: self.policy.clone(),
1163 _marker: PhantomData,
1164 }
1165 }
1166}
1167
1168impl<T> std::fmt::Debug for AuthGuard<T> {
1169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1170 f.debug_struct("AuthGuard")
1171 .field("strategy", &self.strategy)
1172 .field("policy", &self.policy)
1173 .finish()
1174 }
1175}
1176
1177fn ct_eq(a: &[u8], b: &[u8]) -> bool {
1179 a.len() == b.len() && a.ct_eq(b).into()
1180}
1181
1182#[async_trait]
1183impl<T> Transition<T, T> for AuthGuard<T>
1184where
1185 T: Send + Sync + 'static,
1186{
1187 type Error = String;
1188 type Resources = ();
1189
1190 async fn run(
1191 &self,
1192 input: T,
1193 _resources: &Self::Resources,
1194 bus: &mut Bus,
1195 ) -> Outcome<T, Self::Error> {
1196 let auth_value = bus.read::<AuthorizationHeader>().map(|h| h.0.clone());
1197
1198 let identity = match &self.strategy {
1199 AuthStrategy::Bearer { tokens } => {
1200 let Some(auth) = auth_value else {
1201 return Outcome::fault(
1202 "401 Unauthorized: missing Authorization header".to_string(),
1203 );
1204 };
1205 let Some(token) = auth.strip_prefix("Bearer ") else {
1206 return Outcome::fault(
1207 "401 Unauthorized: expected Bearer scheme".to_string(),
1208 );
1209 };
1210 let token = token.trim();
1211 let matched = tokens
1212 .iter()
1213 .any(|valid| ct_eq(token.as_bytes(), valid.as_bytes()));
1214 if !matched {
1215 return Outcome::fault(
1216 "401 Unauthorized: invalid bearer token".to_string(),
1217 );
1218 }
1219 IamIdentity::new("bearer-authenticated")
1220 }
1221 AuthStrategy::ApiKey { valid_keys, .. } => {
1222 let Some(key) = auth_value else {
1223 return Outcome::fault("401 Unauthorized: missing API key".to_string());
1224 };
1225 let matched = valid_keys
1226 .iter()
1227 .any(|valid| ct_eq(key.as_bytes(), valid.as_bytes()));
1228 if !matched {
1229 return Outcome::fault("401 Unauthorized: invalid API key".to_string());
1230 }
1231 IamIdentity::new("apikey-authenticated")
1232 }
1233 AuthStrategy::Custom { validator } => {
1234 let raw = auth_value.unwrap_or_default();
1235 match validator(&raw) {
1236 Ok(identity) => identity,
1237 Err(msg) => {
1238 return Outcome::fault(format!("401 Unauthorized: {}", msg));
1239 }
1240 }
1241 }
1242 };
1243
1244 if let Err(e) = enforce_policy(&self.policy, &identity) {
1246 return Outcome::fault(format!("403 Forbidden: {}", e));
1247 }
1248
1249 bus.insert(identity);
1250 Outcome::next(input)
1251 }
1252}
1253
1254#[derive(Debug, Clone)]
1260pub struct RequestContentType(pub String);
1261
1262#[derive(Debug, Clone)]
1279pub struct ContentTypeGuard<T> {
1280 allowed_types: Vec<String>,
1281 _marker: PhantomData<T>,
1282}
1283
1284impl<T> ContentTypeGuard<T> {
1285 pub fn new(allowed_types: Vec<String>) -> Self {
1287 Self {
1288 allowed_types,
1289 _marker: PhantomData,
1290 }
1291 }
1292
1293 pub fn json() -> Self {
1295 Self::new(vec!["application/json".into()])
1296 }
1297
1298 pub fn form() -> Self {
1300 Self::new(vec!["application/x-www-form-urlencoded".into()])
1301 }
1302
1303 pub fn accept(types: impl IntoIterator<Item = impl Into<String>>) -> Self {
1305 Self::new(types.into_iter().map(|t| t.into()).collect())
1306 }
1307
1308 pub fn allowed_types(&self) -> &[String] {
1310 &self.allowed_types
1311 }
1312}
1313
1314#[async_trait]
1315impl<T> Transition<T, T> for ContentTypeGuard<T>
1316where
1317 T: Send + Sync + 'static,
1318{
1319 type Error = String;
1320 type Resources = ();
1321
1322 async fn run(
1323 &self,
1324 input: T,
1325 _resources: &Self::Resources,
1326 bus: &mut Bus,
1327 ) -> Outcome<T, Self::Error> {
1328 let content_type = bus.read::<RequestContentType>().map(|ct| ct.0.clone());
1329
1330 let Some(ct) = content_type else {
1332 return Outcome::next(input);
1333 };
1334
1335 let media_type = ct.split(';').next().unwrap_or("").trim().to_lowercase();
1337 let matched = self
1338 .allowed_types
1339 .iter()
1340 .any(|allowed| allowed.to_lowercase() == media_type);
1341
1342 if matched {
1343 Outcome::next(input)
1344 } else {
1345 Outcome::fault(format!(
1346 "415 Unsupported Media Type: expected one of [{}], got '{}'",
1347 self.allowed_types.join(", "),
1348 media_type,
1349 ))
1350 }
1351 }
1352}
1353
1354#[derive(Debug, Clone)]
1367pub struct TimeoutDeadline {
1368 created_at: std::time::Instant,
1369 timeout: std::time::Duration,
1370}
1371
1372impl TimeoutDeadline {
1373 pub fn new(timeout: std::time::Duration) -> Self {
1375 Self {
1376 created_at: std::time::Instant::now(),
1377 timeout,
1378 }
1379 }
1380
1381 pub fn remaining(&self) -> std::time::Duration {
1383 self.timeout.saturating_sub(self.created_at.elapsed())
1384 }
1385
1386 pub fn is_expired(&self) -> bool {
1388 self.created_at.elapsed() >= self.timeout
1389 }
1390
1391 pub fn duration(&self) -> std::time::Duration {
1393 self.timeout
1394 }
1395}
1396
1397#[derive(Debug, Clone)]
1418pub struct TimeoutGuard<T> {
1419 timeout: std::time::Duration,
1420 _marker: PhantomData<T>,
1421}
1422
1423impl<T> TimeoutGuard<T> {
1424 pub fn new(timeout: std::time::Duration) -> Self {
1426 Self {
1427 timeout,
1428 _marker: PhantomData,
1429 }
1430 }
1431
1432 pub fn secs_5() -> Self {
1434 Self::new(std::time::Duration::from_secs(5))
1435 }
1436
1437 pub fn secs_30() -> Self {
1439 Self::new(std::time::Duration::from_secs(30))
1440 }
1441
1442 pub fn secs_60() -> Self {
1444 Self::new(std::time::Duration::from_secs(60))
1445 }
1446
1447 pub fn timeout(&self) -> std::time::Duration {
1449 self.timeout
1450 }
1451}
1452
1453#[async_trait]
1454impl<T> Transition<T, T> for TimeoutGuard<T>
1455where
1456 T: Send + Sync + 'static,
1457{
1458 type Error = String;
1459 type Resources = ();
1460
1461 async fn run(
1462 &self,
1463 input: T,
1464 _resources: &Self::Resources,
1465 bus: &mut Bus,
1466 ) -> Outcome<T, Self::Error> {
1467 bus.insert(TimeoutDeadline::new(self.timeout));
1468 Outcome::next(input)
1469 }
1470}
1471
1472#[derive(Debug, Clone, Serialize, Deserialize)]
1478pub struct IdempotencyKey(pub String);
1479
1480#[derive(Debug, Clone)]
1486pub struct IdempotencyCachedResponse {
1487 pub body: Vec<u8>,
1488}
1489
1490struct IdempotencyCacheEntry {
1492 body: Vec<u8>,
1493 expires_at: std::time::Instant,
1494}
1495
1496#[derive(Clone)]
1498pub struct IdempotencyCache {
1499 inner: Arc<std::sync::Mutex<std::collections::HashMap<String, IdempotencyCacheEntry>>>,
1500 ttl: std::time::Duration,
1501}
1502
1503impl IdempotencyCache {
1504 pub fn new(ttl: std::time::Duration) -> Self {
1506 Self {
1507 inner: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
1508 ttl,
1509 }
1510 }
1511
1512 pub fn get(&self, key: &str) -> Option<Vec<u8>> {
1514 let mut cache = self.inner.lock().ok()?;
1515 let now = std::time::Instant::now();
1516 if let Some(entry) = cache.get(key) {
1517 if entry.expires_at > now {
1518 return Some(entry.body.clone());
1519 }
1520 cache.remove(key);
1521 }
1522 None
1523 }
1524
1525 pub fn insert(&self, key: String, body: Vec<u8>) {
1527 if let Ok(mut cache) = self.inner.lock() {
1528 let now = std::time::Instant::now();
1529 let expired: Vec<String> = cache
1531 .iter()
1532 .filter(|(_, e)| e.expires_at <= now)
1533 .take(5)
1534 .map(|(k, _)| k.clone())
1535 .collect();
1536 for k in expired {
1537 cache.remove(&k);
1538 }
1539 cache.insert(
1540 key,
1541 IdempotencyCacheEntry {
1542 body,
1543 expires_at: now + self.ttl,
1544 },
1545 );
1546 }
1547 }
1548
1549 pub fn ttl(&self) -> std::time::Duration {
1551 self.ttl
1552 }
1553}
1554
1555impl std::fmt::Debug for IdempotencyCache {
1556 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1557 f.debug_struct("IdempotencyCache")
1558 .field("ttl", &self.ttl)
1559 .finish()
1560 }
1561}
1562
1563pub struct IdempotencyGuard<T> {
1585 cache: IdempotencyCache,
1586 _marker: PhantomData<T>,
1587}
1588
1589impl<T> IdempotencyGuard<T> {
1590 pub fn new(ttl: std::time::Duration) -> Self {
1592 Self {
1593 cache: IdempotencyCache::new(ttl),
1594 _marker: PhantomData,
1595 }
1596 }
1597
1598 pub fn ttl_5min() -> Self {
1600 Self::new(std::time::Duration::from_secs(300))
1601 }
1602
1603 pub fn ttl(&self) -> std::time::Duration {
1605 self.cache.ttl()
1606 }
1607
1608 pub fn cache(&self) -> &IdempotencyCache {
1610 &self.cache
1611 }
1612
1613 pub fn clone_as_unit(&self) -> IdempotencyGuard<()> {
1615 IdempotencyGuard {
1616 cache: self.cache.clone(),
1617 _marker: PhantomData,
1618 }
1619 }
1620}
1621
1622impl<T> Clone for IdempotencyGuard<T> {
1623 fn clone(&self) -> Self {
1624 Self {
1625 cache: self.cache.clone(),
1626 _marker: PhantomData,
1627 }
1628 }
1629}
1630
1631impl<T> std::fmt::Debug for IdempotencyGuard<T> {
1632 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1633 f.debug_struct("IdempotencyGuard")
1634 .field("ttl", &self.cache.ttl())
1635 .finish()
1636 }
1637}
1638
1639#[async_trait]
1640impl<T> Transition<T, T> for IdempotencyGuard<T>
1641where
1642 T: Send + Sync + 'static,
1643{
1644 type Error = String;
1645 type Resources = ();
1646
1647 async fn run(
1648 &self,
1649 input: T,
1650 _resources: &Self::Resources,
1651 bus: &mut Bus,
1652 ) -> Outcome<T, Self::Error> {
1653 let Some(key) = bus.read::<IdempotencyKey>().map(|k| k.0.clone()) else {
1654 return Outcome::next(input);
1655 };
1656
1657 if let Some(body) = self.cache.get(&key) {
1658 bus.insert(IdempotencyCachedResponse { body });
1659 tracing::debug!(idempotency_key = %key, "idempotency cache hit");
1660 } else {
1661 tracing::debug!(idempotency_key = %key, "idempotency cache miss");
1662 }
1663
1664 Outcome::next(input)
1665 }
1666}
1667
1668#[cfg(feature = "advanced")]
1673mod advanced_guards;
1674
1675#[cfg(feature = "advanced")]
1676pub use advanced_guards::*;
1677
1678#[cfg(feature = "distributed")]
1679mod distributed;
1680
1681#[cfg(feature = "distributed")]
1682pub use distributed::DistributedRateLimitGuard;
1683
1684pub mod prelude {
1689 pub use crate::{
1690 AcceptEncoding, AccessLogEntry, AccessLogGuard, AccessLogRequest, AuthGuard,
1691 AuthStrategy, AuthorizationHeader, ClientIdentity, ClientIp, CompressionConfig,
1692 CompressionEncoding, CompressionGuard, ContentLength, ContentTypeGuard, CorsConfig,
1693 CorsGuard, CorsHeaders, IdempotencyCache, IdempotencyCachedResponse, IdempotencyGuard,
1694 IdempotencyKey, IpFilterGuard, RateLimitGuard, RequestContentType, RequestId,
1695 RequestIdGuard, RequestOrigin, RequestSizeLimitGuard, SecurityHeaders,
1696 SecurityHeadersGuard, SecurityPolicy, TimeoutDeadline, TimeoutGuard,
1697 };
1698
1699 #[cfg(feature = "advanced")]
1700 pub use crate::advanced_guards::{
1701 ConditionalRequestGuard, DecompressionGuard, ETag, IfModifiedSince, IfNoneMatch,
1702 LastModified, RedirectGuard, RedirectRule, RequestBody,
1703 };
1704
1705 #[cfg(feature = "distributed")]
1706 pub use crate::distributed::DistributedRateLimitGuard;
1707}
1708
1709#[cfg(test)]
1714mod tests {
1715 use super::*;
1716
1717 #[tokio::test]
1718 async fn cors_guard_allows_wildcard() {
1719 let guard = CorsGuard::<String>::new(CorsConfig::default());
1720 let mut bus = Bus::new();
1721 bus.insert(RequestOrigin("https://example.com".into()));
1722 let result = guard.run("hello".into(), &(), &mut bus).await;
1723 assert!(matches!(result, Outcome::Next(_)));
1724 assert!(bus.read::<CorsHeaders>().is_some());
1725 }
1726
1727 #[tokio::test]
1728 async fn cors_guard_rejects_disallowed_origin() {
1729 let config = CorsConfig {
1730 allowed_origins: vec!["https://trusted.com".into()],
1731 ..Default::default()
1732 };
1733 let guard = CorsGuard::<String>::new(config);
1734 let mut bus = Bus::new();
1735 bus.insert(RequestOrigin("https://evil.com".into()));
1736 let result = guard.run("hello".into(), &(), &mut bus).await;
1737 assert!(matches!(result, Outcome::Fault(_)));
1738 }
1739
1740 #[tokio::test]
1741 async fn rate_limit_allows_within_budget() {
1742 let guard = RateLimitGuard::<String>::new(10, 1000);
1743 let mut bus = Bus::new();
1744 bus.insert(ClientIdentity("user1".into()));
1745 let result = guard.run("ok".into(), &(), &mut bus).await;
1746 assert!(matches!(result, Outcome::Next(_)));
1747 }
1748
1749 #[tokio::test]
1750 async fn rate_limit_exhausts_budget() {
1751 let guard = RateLimitGuard::<String>::new(2, 60000);
1752 let mut bus = Bus::new();
1753 bus.insert(ClientIdentity("user1".into()));
1754
1755 let _ = guard.run("1".into(), &(), &mut bus).await;
1757 let _ = guard.run("2".into(), &(), &mut bus).await;
1758 let result = guard.run("3".into(), &(), &mut bus).await;
1759 assert!(matches!(result, Outcome::Fault(_)));
1760 }
1761
1762 #[tokio::test]
1763 async fn security_headers_injects_policy() {
1764 let guard = SecurityHeadersGuard::<String>::new(SecurityPolicy::default());
1765 let mut bus = Bus::new();
1766 let result = guard.run("ok".into(), &(), &mut bus).await;
1767 assert!(matches!(result, Outcome::Next(_)));
1768 let headers = bus.read::<SecurityHeaders>().unwrap();
1769 assert_eq!(headers.0.x_frame_options, "DENY");
1770 }
1771
1772 #[tokio::test]
1773 async fn ip_filter_allow_list_permits() {
1774 let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
1775 let mut bus = Bus::new();
1776 bus.insert(ClientIp("10.0.0.1".into()));
1777 let result = guard.run("ok".into(), &(), &mut bus).await;
1778 assert!(matches!(result, Outcome::Next(_)));
1779 }
1780
1781 #[tokio::test]
1782 async fn ip_filter_allow_list_denies() {
1783 let guard = IpFilterGuard::<String>::allow_list(["10.0.0.1"]);
1784 let mut bus = Bus::new();
1785 bus.insert(ClientIp("192.168.1.1".into()));
1786 let result = guard.run("ok".into(), &(), &mut bus).await;
1787 assert!(matches!(result, Outcome::Fault(_)));
1788 }
1789
1790 #[tokio::test]
1791 async fn ip_filter_deny_list_blocks() {
1792 let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
1793 let mut bus = Bus::new();
1794 bus.insert(ClientIp("10.0.0.1".into()));
1795 let result = guard.run("ok".into(), &(), &mut bus).await;
1796 assert!(matches!(result, Outcome::Fault(_)));
1797 }
1798
1799 #[tokio::test]
1800 async fn ip_filter_deny_list_allows() {
1801 let guard = IpFilterGuard::<String>::deny_list(["10.0.0.1"]);
1802 let mut bus = Bus::new();
1803 bus.insert(ClientIp("192.168.1.1".into()));
1804 let result = guard.run("ok".into(), &(), &mut bus).await;
1805 assert!(matches!(result, Outcome::Next(_)));
1806 }
1807
1808 #[tokio::test]
1811 async fn access_log_guard_passes_input_through() {
1812 let guard = AccessLogGuard::<String>::new();
1813 let mut bus = Bus::new();
1814 bus.insert(AccessLogRequest {
1815 method: "GET".into(),
1816 path: "/users".into(),
1817 });
1818 let result = guard.run("payload".into(), &(), &mut bus).await;
1819 assert!(matches!(result, Outcome::Next(ref v) if v == "payload"));
1820 }
1821
1822 #[tokio::test]
1823 async fn access_log_guard_writes_entry_to_bus() {
1824 let guard = AccessLogGuard::<String>::new();
1825 let mut bus = Bus::new();
1826 bus.insert(AccessLogRequest {
1827 method: "POST".into(),
1828 path: "/api/orders".into(),
1829 });
1830 let _result = guard.run("ok".into(), &(), &mut bus).await;
1831 let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
1832 assert_eq!(entry.method, "POST");
1833 assert_eq!(entry.path, "/api/orders");
1834 }
1835
1836 #[tokio::test]
1837 async fn access_log_guard_redacts_paths() {
1838 let guard = AccessLogGuard::<String>::new().redact_paths(vec!["/auth/login".into()]);
1839 let mut bus = Bus::new();
1840 bus.insert(AccessLogRequest {
1841 method: "POST".into(),
1842 path: "/auth/login".into(),
1843 });
1844 let _result = guard.run("ok".into(), &(), &mut bus).await;
1845 let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
1846 assert_eq!(entry.path, "[redacted]");
1847 }
1848
1849 #[tokio::test]
1850 async fn access_log_guard_works_without_request_in_bus() {
1851 let guard = AccessLogGuard::<String>::new();
1852 let mut bus = Bus::new();
1853 let result = guard.run("ok".into(), &(), &mut bus).await;
1854 assert!(matches!(result, Outcome::Next(_)));
1855 let entry = bus.read::<AccessLogEntry>().expect("entry should be in bus");
1856 assert_eq!(entry.method, "");
1857 assert_eq!(entry.path, "");
1858 }
1859
1860 #[tokio::test]
1861 async fn access_log_guard_default_works() {
1862 let guard = AccessLogGuard::<String>::default();
1863 let mut bus = Bus::new();
1864 bus.insert(AccessLogRequest {
1865 method: "DELETE".into(),
1866 path: "/api/v1/users/42".into(),
1867 });
1868 let result = guard.run("ok".into(), &(), &mut bus).await;
1869 assert!(matches!(result, Outcome::Next(_)));
1870 }
1871
1872 #[tokio::test]
1873 async fn access_log_guard_entry_has_timestamp() {
1874 let guard = AccessLogGuard::<String>::new();
1875 let mut bus = Bus::new();
1876 bus.insert(AccessLogRequest {
1877 method: "GET".into(),
1878 path: "/".into(),
1879 });
1880 let _result = guard.run("ok".into(), &(), &mut bus).await;
1881 let entry = bus.read::<AccessLogEntry>().unwrap();
1882 assert!(entry.timestamp_ms > 1_700_000_000_000);
1884 }
1885
1886 #[tokio::test]
1887 async fn access_log_guard_works_with_integer_type() {
1888 let guard = AccessLogGuard::<i32>::new();
1889 let mut bus = Bus::new();
1890 bus.insert(AccessLogRequest {
1891 method: "PUT".into(),
1892 path: "/count".into(),
1893 });
1894 let result = guard.run(42, &(), &mut bus).await;
1895 assert!(matches!(result, Outcome::Next(42)));
1896 }
1897
1898 #[tokio::test]
1899 async fn access_log_guard_non_redacted_path_preserved() {
1900 let guard = AccessLogGuard::<String>::new()
1901 .redact_paths(vec!["/auth/login".into()]);
1902 let mut bus = Bus::new();
1903 bus.insert(AccessLogRequest {
1904 method: "GET".into(),
1905 path: "/api/public".into(),
1906 });
1907 let _result = guard.run("ok".into(), &(), &mut bus).await;
1908 let entry = bus.read::<AccessLogEntry>().unwrap();
1909 assert_eq!(entry.path, "/api/public");
1910 }
1911
1912 #[tokio::test]
1915 async fn compression_guard_negotiates_gzip() {
1916 let guard = CompressionGuard::<String>::new();
1917 let mut bus = Bus::new();
1918 bus.insert(AcceptEncoding("gzip, deflate".into()));
1919 let result = guard.run("ok".into(), &(), &mut bus).await;
1920 assert!(matches!(result, Outcome::Next(_)));
1921 let config = bus.read::<CompressionConfig>().unwrap();
1922 assert_eq!(config.encoding, CompressionEncoding::Gzip);
1923 }
1924
1925 #[tokio::test]
1926 async fn compression_guard_prefer_brotli() {
1927 let guard = CompressionGuard::<String>::new().prefer_brotli();
1928 let mut bus = Bus::new();
1929 bus.insert(AcceptEncoding("gzip, br, zstd".into()));
1930 let result = guard.run("ok".into(), &(), &mut bus).await;
1931 assert!(matches!(result, Outcome::Next(_)));
1932 let config = bus.read::<CompressionConfig>().unwrap();
1933 assert_eq!(config.encoding, CompressionEncoding::Brotli);
1934 }
1935
1936 #[tokio::test]
1937 async fn compression_guard_falls_back_to_identity() {
1938 let guard = CompressionGuard::<String>::new();
1939 let mut bus = Bus::new();
1940 bus.insert(AcceptEncoding("deflate".into()));
1941 let result = guard.run("ok".into(), &(), &mut bus).await;
1942 assert!(matches!(result, Outcome::Next(_)));
1943 let config = bus.read::<CompressionConfig>().unwrap();
1944 assert_eq!(config.encoding, CompressionEncoding::Identity);
1945 }
1946
1947 #[tokio::test]
1948 async fn compression_guard_wildcard_accept() {
1949 let guard = CompressionGuard::<String>::new();
1950 let mut bus = Bus::new();
1951 bus.insert(AcceptEncoding("*".into()));
1952 let result = guard.run("ok".into(), &(), &mut bus).await;
1953 assert!(matches!(result, Outcome::Next(_)));
1954 let config = bus.read::<CompressionConfig>().unwrap();
1955 assert_eq!(config.encoding, CompressionEncoding::Gzip);
1956 }
1957
1958 #[tokio::test]
1959 async fn compression_guard_min_body_size() {
1960 let guard = CompressionGuard::<String>::new().with_min_body_size(1024);
1961 let mut bus = Bus::new();
1962 bus.insert(AcceptEncoding("gzip".into()));
1963 let _ = guard.run("ok".into(), &(), &mut bus).await;
1964 let config = bus.read::<CompressionConfig>().unwrap();
1965 assert_eq!(config.min_body_size, 1024);
1966 }
1967
1968 #[tokio::test]
1971 async fn size_limit_allows_within_limit() {
1972 let guard = RequestSizeLimitGuard::<String>::max_2mb();
1973 let mut bus = Bus::new();
1974 bus.insert(ContentLength(1024));
1975 let result = guard.run("ok".into(), &(), &mut bus).await;
1976 assert!(matches!(result, Outcome::Next(_)));
1977 }
1978
1979 #[tokio::test]
1980 async fn size_limit_rejects_over_limit() {
1981 let guard = RequestSizeLimitGuard::<String>::new(1000);
1982 let mut bus = Bus::new();
1983 bus.insert(ContentLength(2000));
1984 let result = guard.run("ok".into(), &(), &mut bus).await;
1985 assert!(matches!(result, Outcome::Fault(ref e) if e.contains("413")));
1986 }
1987
1988 #[tokio::test]
1989 async fn size_limit_passes_without_content_length() {
1990 let guard = RequestSizeLimitGuard::<String>::new(100);
1991 let mut bus = Bus::new();
1992 let result = guard.run("ok".into(), &(), &mut bus).await;
1993 assert!(matches!(result, Outcome::Next(_)));
1994 }
1995
1996 #[tokio::test]
1997 async fn size_limit_convenience_constructors() {
1998 let guard_2mb = RequestSizeLimitGuard::<()>::max_2mb();
1999 assert_eq!(guard_2mb.max_bytes(), 2 * 1024 * 1024);
2000
2001 let guard_10mb = RequestSizeLimitGuard::<()>::max_10mb();
2002 assert_eq!(guard_10mb.max_bytes(), 10 * 1024 * 1024);
2003 }
2004
2005 #[tokio::test]
2008 async fn request_id_generates_uuid() {
2009 let guard = RequestIdGuard::<String>::new();
2010 let mut bus = Bus::new();
2011 let result = guard.run("ok".into(), &(), &mut bus).await;
2012 assert!(matches!(result, Outcome::Next(_)));
2013 let rid = bus.read::<RequestId>().expect("request id should be in bus");
2014 assert_eq!(rid.0.len(), 36); }
2016
2017 #[tokio::test]
2018 async fn request_id_preserves_existing() {
2019 let guard = RequestIdGuard::<String>::new();
2020 let mut bus = Bus::new();
2021 bus.insert(RequestId("custom-id-123".into()));
2022 let _ = guard.run("ok".into(), &(), &mut bus).await;
2023 let rid = bus.read::<RequestId>().unwrap();
2024 assert_eq!(rid.0, "custom-id-123");
2025 }
2026
2027 #[tokio::test]
2030 async fn auth_bearer_success() {
2031 let guard = AuthGuard::<String>::bearer(vec!["secret-token".into()]);
2032 let mut bus = Bus::new();
2033 bus.insert(AuthorizationHeader("Bearer secret-token".into()));
2034 let result = guard.run("ok".into(), &(), &mut bus).await;
2035 assert!(matches!(result, Outcome::Next(_)));
2036 let identity = bus.read::<IamIdentity>().expect("identity should be in bus");
2037 assert_eq!(identity.subject, "bearer-authenticated");
2038 }
2039
2040 #[tokio::test]
2041 async fn auth_bearer_invalid_token() {
2042 let guard = AuthGuard::<String>::bearer(vec!["secret-token".into()]);
2043 let mut bus = Bus::new();
2044 bus.insert(AuthorizationHeader("Bearer wrong-token".into()));
2045 let result = guard.run("ok".into(), &(), &mut bus).await;
2046 assert!(matches!(result, Outcome::Fault(ref e) if e.contains("401")));
2047 }
2048
2049 #[tokio::test]
2050 async fn auth_bearer_missing_header() {
2051 let guard = AuthGuard::<String>::bearer(vec!["token".into()]);
2052 let mut bus = Bus::new();
2053 let result = guard.run("ok".into(), &(), &mut bus).await;
2054 assert!(matches!(result, Outcome::Fault(ref e) if e.contains("401")));
2055 }
2056
2057 #[tokio::test]
2058 async fn auth_apikey_success() {
2059 let guard = AuthGuard::<String>::api_key("X-Api-Key", vec!["my-api-key".into()]);
2060 let mut bus = Bus::new();
2061 bus.insert(AuthorizationHeader("my-api-key".into()));
2062 let result = guard.run("ok".into(), &(), &mut bus).await;
2063 assert!(matches!(result, Outcome::Next(_)));
2064 }
2065
2066 #[tokio::test]
2067 async fn auth_apikey_invalid() {
2068 let guard = AuthGuard::<String>::api_key("X-Api-Key", vec!["valid-key".into()]);
2069 let mut bus = Bus::new();
2070 bus.insert(AuthorizationHeader("invalid-key".into()));
2071 let result = guard.run("ok".into(), &(), &mut bus).await;
2072 assert!(matches!(result, Outcome::Fault(ref e) if e.contains("401")));
2073 }
2074
2075 #[tokio::test]
2076 async fn auth_custom_validator() {
2077 let guard = AuthGuard::<String>::custom(|token| {
2078 if token == "Bearer magic" {
2079 Ok(IamIdentity::new("custom-user").with_role("admin"))
2080 } else {
2081 Err("bad token".into())
2082 }
2083 });
2084 let mut bus = Bus::new();
2085 bus.insert(AuthorizationHeader("Bearer magic".into()));
2086 let result = guard.run("ok".into(), &(), &mut bus).await;
2087 assert!(matches!(result, Outcome::Next(_)));
2088 let id = bus.read::<IamIdentity>().unwrap();
2089 assert!(id.has_role("admin"));
2090 }
2091
2092 #[tokio::test]
2093 async fn auth_policy_enforcement_role() {
2094 let guard = AuthGuard::<String>::bearer(vec!["token".into()])
2095 .with_policy(IamPolicy::RequireRole("admin".into()));
2096 let mut bus = Bus::new();
2097 bus.insert(AuthorizationHeader("Bearer token".into()));
2098 let result = guard.run("ok".into(), &(), &mut bus).await;
2099 assert!(matches!(result, Outcome::Fault(ref e) if e.contains("403")));
2101 }
2102
2103 #[tokio::test]
2104 async fn auth_timing_safe_comparison() {
2105 let guard = AuthGuard::<String>::bearer(vec!["short".into()]);
2107 let mut bus = Bus::new();
2108 bus.insert(AuthorizationHeader("Bearer a-very-long-different-token".into()));
2109 let result = guard.run("ok".into(), &(), &mut bus).await;
2110 assert!(matches!(result, Outcome::Fault(_)));
2111 }
2112
2113 #[tokio::test]
2116 async fn content_type_json_match() {
2117 let guard = ContentTypeGuard::<String>::json();
2118 let mut bus = Bus::new();
2119 bus.insert(RequestContentType("application/json".into()));
2120 let result = guard.run("ok".into(), &(), &mut bus).await;
2121 assert!(matches!(result, Outcome::Next(_)));
2122 }
2123
2124 #[tokio::test]
2125 async fn content_type_json_with_charset() {
2126 let guard = ContentTypeGuard::<String>::json();
2127 let mut bus = Bus::new();
2128 bus.insert(RequestContentType("application/json; charset=utf-8".into()));
2129 let result = guard.run("ok".into(), &(), &mut bus).await;
2130 assert!(matches!(result, Outcome::Next(_)));
2131 }
2132
2133 #[tokio::test]
2134 async fn content_type_mismatch() {
2135 let guard = ContentTypeGuard::<String>::json();
2136 let mut bus = Bus::new();
2137 bus.insert(RequestContentType("text/plain".into()));
2138 let result = guard.run("ok".into(), &(), &mut bus).await;
2139 assert!(matches!(result, Outcome::Fault(ref e) if e.contains("415")));
2140 }
2141
2142 #[tokio::test]
2143 async fn content_type_no_header_allows() {
2144 let guard = ContentTypeGuard::<String>::json();
2145 let mut bus = Bus::new();
2146 let result = guard.run("ok".into(), &(), &mut bus).await;
2147 assert!(matches!(result, Outcome::Next(_)));
2148 }
2149
2150 #[tokio::test]
2151 async fn content_type_form() {
2152 let guard = ContentTypeGuard::<String>::form();
2153 let mut bus = Bus::new();
2154 bus.insert(RequestContentType("application/x-www-form-urlencoded".into()));
2155 let result = guard.run("ok".into(), &(), &mut bus).await;
2156 assert!(matches!(result, Outcome::Next(_)));
2157 }
2158
2159 #[tokio::test]
2160 async fn content_type_accept_multiple() {
2161 let guard = ContentTypeGuard::<String>::accept(["application/json", "text/xml"]);
2162 let mut bus = Bus::new();
2163 bus.insert(RequestContentType("text/xml".into()));
2164 let result = guard.run("ok".into(), &(), &mut bus).await;
2165 assert!(matches!(result, Outcome::Next(_)));
2166 }
2167
2168 #[tokio::test]
2171 async fn timeout_sets_deadline() {
2172 let guard = TimeoutGuard::<String>::secs_30();
2173 let mut bus = Bus::new();
2174 let result = guard.run("ok".into(), &(), &mut bus).await;
2175 assert!(matches!(result, Outcome::Next(_)));
2176 let deadline = bus.read::<TimeoutDeadline>().expect("deadline should be in bus");
2177 assert!(!deadline.is_expired());
2178 assert!(deadline.remaining().as_secs() >= 29);
2179 }
2180
2181 #[tokio::test]
2182 async fn timeout_convenience_constructors() {
2183 assert_eq!(TimeoutGuard::<()>::secs_5().timeout().as_secs(), 5);
2184 assert_eq!(TimeoutGuard::<()>::secs_30().timeout().as_secs(), 30);
2185 assert_eq!(TimeoutGuard::<()>::secs_60().timeout().as_secs(), 60);
2186 }
2187
2188 #[tokio::test]
2191 async fn idempotency_no_key_passes_through() {
2192 let guard = IdempotencyGuard::<String>::ttl_5min();
2193 let mut bus = Bus::new();
2194 let result = guard.run("ok".into(), &(), &mut bus).await;
2195 assert!(matches!(result, Outcome::Next(_)));
2196 assert!(bus.read::<IdempotencyCachedResponse>().is_none());
2197 }
2198
2199 #[tokio::test]
2200 async fn idempotency_cache_miss() {
2201 let guard = IdempotencyGuard::<String>::ttl_5min();
2202 let mut bus = Bus::new();
2203 bus.insert(IdempotencyKey("key-1".into()));
2204 let result = guard.run("ok".into(), &(), &mut bus).await;
2205 assert!(matches!(result, Outcome::Next(_)));
2206 assert!(bus.read::<IdempotencyCachedResponse>().is_none());
2207 }
2208
2209 #[tokio::test]
2210 async fn idempotency_cache_hit() {
2211 let guard = IdempotencyGuard::<String>::ttl_5min();
2212 guard.cache().insert("key-1".into(), b"cached-body".to_vec());
2214
2215 let mut bus = Bus::new();
2216 bus.insert(IdempotencyKey("key-1".into()));
2217 let result = guard.run("ok".into(), &(), &mut bus).await;
2218 assert!(matches!(result, Outcome::Next(_)));
2219 let cached = bus.read::<IdempotencyCachedResponse>().expect("cached response");
2220 assert_eq!(cached.body, b"cached-body");
2221 }
2222
2223 #[tokio::test]
2224 async fn idempotency_cache_shared_across_clones() {
2225 let guard1 = IdempotencyGuard::<String>::ttl_5min();
2226 let guard2 = guard1.clone();
2227 guard1.cache().insert("shared-key".into(), b"data".to_vec());
2228 assert!(guard2.cache().get("shared-key").is_some());
2229 }
2230
2231 #[tokio::test]
2232 async fn idempotency_expired_entry_treated_as_miss() {
2233 let guard = IdempotencyGuard::<String>::new(std::time::Duration::from_millis(1));
2234 guard.cache().insert("key-1".into(), b"old".to_vec());
2235 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
2237
2238 let mut bus = Bus::new();
2239 bus.insert(IdempotencyKey("key-1".into()));
2240 let result = guard.run("ok".into(), &(), &mut bus).await;
2241 assert!(matches!(result, Outcome::Next(_)));
2242 assert!(bus.read::<IdempotencyCachedResponse>().is_none());
2243 }
2244
2245 #[tokio::test]
2248 async fn cors_guard_specific_origin_reflected() {
2249 let config = CorsConfig {
2250 allowed_origins: vec!["https://app.example.com".into()],
2251 ..Default::default()
2252 };
2253 let guard = CorsGuard::<String>::new(config);
2254 let mut bus = Bus::new();
2255 bus.insert(RequestOrigin("https://app.example.com".into()));
2256 let result = guard.run("ok".into(), &(), &mut bus).await;
2257 assert!(matches!(result, Outcome::Next(_)));
2258 let headers = bus.read::<CorsHeaders>().unwrap();
2259 assert_eq!(headers.access_control_allow_origin, "https://app.example.com");
2260 }
2261
2262 #[tokio::test]
2263 async fn cors_guard_no_origin_passes() {
2264 let config = CorsConfig {
2265 allowed_origins: vec!["https://trusted.com".into()],
2266 ..Default::default()
2267 };
2268 let guard = CorsGuard::<String>::new(config);
2269 let mut bus = Bus::new();
2270 let result = guard.run("ok".into(), &(), &mut bus).await;
2272 assert!(matches!(result, Outcome::Next(_)));
2273 }
2274
2275 #[tokio::test]
2278 async fn security_headers_custom_csp() {
2279 let policy = SecurityPolicy::default()
2280 .with_csp("default-src 'self'; script-src 'none'");
2281 let guard = SecurityHeadersGuard::<String>::new(policy);
2282 let mut bus = Bus::new();
2283 let _ = guard.run("ok".into(), &(), &mut bus).await;
2284 let headers = bus.read::<SecurityHeaders>().unwrap();
2285 assert_eq!(
2286 headers.0.content_security_policy.as_deref(),
2287 Some("default-src 'self'; script-src 'none'")
2288 );
2289 }
2290
2291 #[tokio::test]
2292 async fn security_headers_default_no_csp() {
2293 let guard = SecurityHeadersGuard::<String>::new(SecurityPolicy::default());
2294 let mut bus = Bus::new();
2295 let _ = guard.run("ok".into(), &(), &mut bus).await;
2296 let headers = bus.read::<SecurityHeaders>().unwrap();
2297 assert!(headers.0.content_security_policy.is_none());
2298 assert_eq!(headers.0.referrer_policy, "strict-origin-when-cross-origin");
2299 }
2300
2301 #[tokio::test]
2304 async fn timeout_custom_duration() {
2305 let guard = TimeoutGuard::<String>::new(std::time::Duration::from_millis(100));
2306 let mut bus = Bus::new();
2307 let _ = guard.run("ok".into(), &(), &mut bus).await;
2308 let deadline = bus.read::<TimeoutDeadline>().unwrap();
2309 assert!(!deadline.is_expired());
2310 tokio::time::sleep(std::time::Duration::from_millis(150)).await;
2312 assert!(deadline.is_expired());
2313 }
2314
2315 #[tokio::test]
2318 async fn rate_limit_bucket_ttl_prunes_stale_buckets() {
2319 let guard = RateLimitGuard::<String>::new(100, 60000)
2321 .with_bucket_ttl(std::time::Duration::from_millis(50));
2322
2323 let mut bus = Bus::new();
2325 bus.insert(ClientIdentity("stale-user".into()));
2326 let _ = guard.run("ok".into(), &(), &mut bus).await;
2327
2328 tokio::time::sleep(std::time::Duration::from_millis(80)).await;
2330
2331 let mut bus2 = Bus::new();
2333 bus2.insert(ClientIdentity("fresh-user".into()));
2334 let _ = guard.run("ok".into(), &(), &mut bus2).await;
2335
2336 let guard2 = RateLimitGuard::<String>::new(2, 60000)
2339 .with_bucket_ttl(std::time::Duration::from_millis(50));
2340
2341 let mut bus3 = Bus::new();
2342 bus3.insert(ClientIdentity("user-a".into()));
2343 let _ = guard2.run("1".into(), &(), &mut bus3).await;
2344 let _ = guard2.run("2".into(), &(), &mut bus3).await;
2345 let result = guard2.run("3".into(), &(), &mut bus3).await;
2347 assert!(matches!(result, Outcome::Fault(_)));
2348
2349 tokio::time::sleep(std::time::Duration::from_millis(80)).await;
2351
2352 let mut bus4 = Bus::new();
2354 bus4.insert(ClientIdentity("user-b".into()));
2355 let _ = guard2.run("ok".into(), &(), &mut bus4).await;
2356
2357 let mut bus5 = Bus::new();
2359 bus5.insert(ClientIdentity("user-a".into()));
2360 let result = guard2.run("retry".into(), &(), &mut bus5).await;
2361 assert!(matches!(result, Outcome::Next(_)));
2362 }
2363
2364 #[tokio::test]
2365 async fn rate_limit_bucket_ttl_zero_disables_pruning() {
2366 let guard = RateLimitGuard::<String>::new(2, 60000);
2368 assert_eq!(guard.bucket_ttl_ms(), 0);
2369
2370 let mut bus = Bus::new();
2371 bus.insert(ClientIdentity("user".into()));
2372 let _ = guard.run("1".into(), &(), &mut bus).await;
2373 let _ = guard.run("2".into(), &(), &mut bus).await;
2374
2375 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2377
2378 let result = guard.run("3".into(), &(), &mut bus).await;
2379 assert!(matches!(result, Outcome::Fault(_)));
2380 }
2381
2382 #[tokio::test]
2383 async fn rate_limit_with_bucket_ttl_builder() {
2384 let guard = RateLimitGuard::<String>::new(10, 1000)
2385 .with_bucket_ttl(std::time::Duration::from_secs(300));
2386 assert_eq!(guard.bucket_ttl_ms(), 300_000);
2387 }
2388
2389 #[test]
2392 fn trusted_proxies_ignores_xff_from_untrusted_direct() {
2393 let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2394 let result = proxies.extract("1.2.3.4, 10.0.0.1", "192.168.1.100");
2396 assert_eq!(result, "192.168.1.100");
2397 }
2398
2399 #[test]
2400 fn trusted_proxies_extracts_rightmost_non_trusted() {
2401 let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2402 let result = proxies.extract("203.0.113.5, 10.0.0.2", "10.0.0.1");
2405 assert_eq!(result, "203.0.113.5");
2406 }
2407
2408 #[test]
2409 fn trusted_proxies_multi_hop_chain() {
2410 let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2", "10.0.0.3"]);
2411 let result = proxies.extract("8.8.8.8, 10.0.0.3, 10.0.0.2", "10.0.0.1");
2413 assert_eq!(result, "8.8.8.8");
2414 }
2415
2416 #[test]
2417 fn trusted_proxies_all_xff_trusted_falls_back_to_direct() {
2418 let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2419 let result = proxies.extract("10.0.0.2, 10.0.0.1", "10.0.0.1");
2421 assert_eq!(result, "10.0.0.1");
2422 }
2423
2424 #[test]
2425 fn trusted_proxies_empty_xff() {
2426 let proxies = TrustedProxies::new(["10.0.0.1"]);
2427 let result = proxies.extract("", "10.0.0.1");
2428 assert_eq!(result, "10.0.0.1");
2429 }
2430
2431 #[test]
2432 fn trusted_proxies_is_trusted() {
2433 let proxies = TrustedProxies::new(["10.0.0.1", "10.0.0.2"]);
2434 assert!(proxies.is_trusted("10.0.0.1"));
2435 assert!(proxies.is_trusted("10.0.0.2"));
2436 assert!(!proxies.is_trusted("192.168.1.1"));
2437 }
2438}