1use async_trait::async_trait;
23use ranvier_core::bus::Bus;
24use ranvier_core::outcome::Outcome;
25use ranvier_core::transition::Transition;
26use std::sync::Arc;
27
28pub type BusInjectorFn = Arc<dyn Fn(&http::request::Parts, &mut Bus) + Send + Sync + 'static>;
30
31pub type ResponseExtractorFn =
33 Arc<dyn Fn(&Bus, &mut http::HeaderMap) + Send + Sync + 'static>;
34
35pub type ResponseBodyTransformFn =
40 Arc<dyn Fn(&Bus, bytes::Bytes) -> bytes::Bytes + Send + Sync + 'static>;
41
42#[derive(Debug, Clone)]
47pub struct GuardRejection {
48 pub status: http::StatusCode,
49 pub message: String,
50}
51
52impl GuardRejection {
53 pub fn new(status: http::StatusCode, message: impl Into<String>) -> Self {
54 Self {
55 status,
56 message: message.into(),
57 }
58 }
59
60 pub fn forbidden(message: impl Into<String>) -> Self {
61 Self::new(http::StatusCode::FORBIDDEN, message)
62 }
63
64 pub fn unauthorized(message: impl Into<String>) -> Self {
65 Self::new(http::StatusCode::UNAUTHORIZED, message)
66 }
67
68 pub fn too_many_requests(message: impl Into<String>) -> Self {
69 Self::new(http::StatusCode::TOO_MANY_REQUESTS, message)
70 }
71
72 pub fn payload_too_large(message: impl Into<String>) -> Self {
73 Self::new(http::StatusCode::PAYLOAD_TOO_LARGE, message)
74 }
75}
76
77#[async_trait]
83pub trait GuardExec: Send + Sync {
84 async fn exec_guard(&self, bus: &mut Bus) -> Result<(), GuardRejection>;
88}
89
90struct TransitionGuardExec<G> {
96 guard: G,
97 default_status: http::StatusCode,
98}
99
100fn parse_status_prefix(msg: &str, default: http::StatusCode) -> (http::StatusCode, String) {
105 if msg.len() >= 4 && msg.as_bytes()[3] == b' ' {
106 if let Ok(code) = msg[..3].parse::<u16>() {
107 if let Ok(status) = http::StatusCode::from_u16(code) {
108 return (status, msg[4..].to_string());
109 }
110 }
111 }
112 (default, msg.to_string())
113}
114
115#[async_trait]
116impl<G> GuardExec for TransitionGuardExec<G>
117where
118 G: Transition<(), (), Error = String, Resources = ()> + Send + Sync + 'static,
119{
120 async fn exec_guard(&self, bus: &mut Bus) -> Result<(), GuardRejection> {
121 match self.guard.run((), &(), bus).await {
122 Outcome::Next(_) => Ok(()),
123 Outcome::Fault(e) => {
124 let (status, message) = parse_status_prefix(&e, self.default_status);
125 Err(GuardRejection { status, message })
126 }
127 _ => Ok(()),
128 }
129 }
130}
131
132pub struct RegisteredGuard {
134 pub bus_injectors: Vec<BusInjectorFn>,
136 pub response_extractor: Option<ResponseExtractorFn>,
138 pub response_body_transform: Option<ResponseBodyTransformFn>,
140 pub exec: Arc<dyn GuardExec>,
142 pub handles_preflight: bool,
144 pub preflight_config: Option<PreflightConfig>,
146}
147
148#[derive(Clone)]
150pub struct PreflightConfig {
151 pub allowed_origins: Vec<String>,
152 pub allowed_methods: String,
153 pub allowed_headers: String,
154 pub max_age: String,
155 pub allow_credentials: bool,
156}
157
158pub trait GuardIntegration: Send + Sync + 'static {
176 fn register(self) -> RegisteredGuard;
178}
179
180impl<T> GuardIntegration for ranvier_guard::CorsGuard<T>
185where
186 T: Send + Sync + 'static,
187{
188 fn register(self) -> RegisteredGuard {
189 let config = self.cors_config().clone();
190 let preflight = PreflightConfig {
191 allowed_origins: config.allowed_origins.clone(),
192 allowed_methods: config.allowed_methods.join(", "),
193 allowed_headers: config.allowed_headers.join(", "),
194 max_age: config.max_age_seconds.to_string(),
195 allow_credentials: config.allow_credentials,
196 };
197
198 let exec_guard = ranvier_guard::CorsGuard::<()>::new(config);
200
201 RegisteredGuard {
202 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
203 if let Some(origin) = parts.headers.get("origin") {
204 if let Ok(origin_str) = origin.to_str() {
205 bus.insert(ranvier_guard::RequestOrigin(origin_str.to_string()));
206 }
207 }
208 })],
209 response_extractor: Some(Arc::new(|bus: &Bus, headers: &mut http::HeaderMap| {
210 if let Some(cors) = bus.read::<ranvier_guard::CorsHeaders>() {
211 if let Ok(v) = cors.access_control_allow_origin.parse() {
212 headers.insert("access-control-allow-origin", v);
213 }
214 if let Ok(v) = cors.access_control_allow_methods.parse() {
215 headers.insert("access-control-allow-methods", v);
216 }
217 if let Ok(v) = cors.access_control_allow_headers.parse() {
218 headers.insert("access-control-allow-headers", v);
219 }
220 if let Ok(v) = cors.access_control_max_age.parse() {
221 headers.insert("access-control-max-age", v);
222 }
223 }
224 })),
225 response_body_transform: None,
226 exec: Arc::new(TransitionGuardExec {
227 guard: exec_guard,
228 default_status: http::StatusCode::FORBIDDEN,
229 }),
230 handles_preflight: true,
231 preflight_config: Some(preflight),
232 }
233 }
234}
235
236impl<T> GuardIntegration for ranvier_guard::RateLimitGuard<T>
237where
238 T: Send + Sync + 'static,
239{
240 fn register(self) -> RegisteredGuard {
241 let exec_guard = ranvier_guard::RateLimitGuard::<()>::new(
242 self.max_requests(),
243 self.window_ms(),
244 );
245
246 RegisteredGuard {
247 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
248 let identity = parts
250 .headers
251 .get("x-forwarded-for")
252 .and_then(|v| v.to_str().ok())
253 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
254 .unwrap_or_else(|| "unknown".to_string());
255 bus.insert(ranvier_guard::ClientIdentity(identity));
256 })],
257 response_extractor: None,
258 response_body_transform: None,
259 exec: Arc::new(TransitionGuardExec {
260 guard: exec_guard,
261 default_status: http::StatusCode::TOO_MANY_REQUESTS,
262 }),
263 handles_preflight: false,
264 preflight_config: None,
265 }
266 }
267}
268
269impl<T> GuardIntegration for ranvier_guard::SecurityHeadersGuard<T>
270where
271 T: Send + Sync + 'static,
272{
273 fn register(self) -> RegisteredGuard {
274 let policy = self.policy().clone();
275 let exec_guard = ranvier_guard::SecurityHeadersGuard::<()>::new(policy);
276
277 RegisteredGuard {
278 bus_injectors: vec![],
279 response_extractor: Some(Arc::new(|bus: &Bus, headers: &mut http::HeaderMap| {
280 if let Some(sec) = bus.read::<ranvier_guard::SecurityHeaders>() {
281 if let Ok(v) = sec.0.x_frame_options.parse() {
282 headers.insert("x-frame-options", v);
283 }
284 if let Ok(v) = sec.0.x_content_type_options.parse() {
285 headers.insert("x-content-type-options", v);
286 }
287 if let Ok(v) = sec.0.strict_transport_security.parse() {
288 headers.insert("strict-transport-security", v);
289 }
290 if let Some(ref csp) = sec.0.content_security_policy {
291 if let Ok(v) = csp.parse() {
292 headers.insert("content-security-policy", v);
293 }
294 }
295 if let Ok(v) = sec.0.x_xss_protection.parse() {
296 headers.insert("x-xss-protection", v);
297 }
298 if let Ok(v) = sec.0.referrer_policy.parse() {
299 headers.insert("referrer-policy", v);
300 }
301 }
302 })),
303 response_body_transform: None,
304 exec: Arc::new(TransitionGuardExec {
305 guard: exec_guard,
306 default_status: http::StatusCode::INTERNAL_SERVER_ERROR,
307 }),
308 handles_preflight: false,
309 preflight_config: None,
310 }
311 }
312}
313
314impl<T> GuardIntegration for ranvier_guard::IpFilterGuard<T>
315where
316 T: Send + Sync + 'static,
317{
318 fn register(self) -> RegisteredGuard {
319 let exec_guard = self.clone_as_unit();
320
321 RegisteredGuard {
322 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
323 let ip = parts
324 .headers
325 .get("x-forwarded-for")
326 .and_then(|v| v.to_str().ok())
327 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
328 .unwrap_or_else(|| "unknown".to_string());
329 bus.insert(ranvier_guard::ClientIp(ip));
330 })],
331 response_extractor: None,
332 response_body_transform: None,
333 exec: Arc::new(TransitionGuardExec {
334 guard: exec_guard,
335 default_status: http::StatusCode::FORBIDDEN,
336 }),
337 handles_preflight: false,
338 preflight_config: None,
339 }
340 }
341}
342
343impl<T> GuardIntegration for ranvier_guard::AccessLogGuard<T>
344where
345 T: Send + Sync + 'static,
346{
347 fn register(self) -> RegisteredGuard {
348 let exec_guard = self.clone_as_unit();
349
350 RegisteredGuard {
351 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
352 bus.insert(ranvier_guard::AccessLogRequest {
353 method: parts.method.to_string(),
354 path: parts.uri.path().to_string(),
355 });
356 })],
357 response_extractor: None,
358 response_body_transform: None,
359 exec: Arc::new(TransitionGuardExec {
360 guard: exec_guard,
361 default_status: http::StatusCode::INTERNAL_SERVER_ERROR,
362 }),
363 handles_preflight: false,
364 preflight_config: None,
365 }
366 }
367}
368
369impl<T> GuardIntegration for ranvier_guard::CompressionGuard<T>
374where
375 T: Send + Sync + 'static,
376{
377 fn register(self) -> RegisteredGuard {
378 let min_body_size = self.min_body_size();
379 let preferred = self.preferred_encodings().to_vec();
380
381 let mut exec_guard = ranvier_guard::CompressionGuard::<()>::new()
382 .with_min_body_size(min_body_size);
383 if preferred.first() == Some(&ranvier_guard::CompressionEncoding::Brotli) {
384 exec_guard = exec_guard.prefer_brotli();
385 }
386
387 RegisteredGuard {
388 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
389 if let Some(accept) = parts.headers.get("accept-encoding") {
390 if let Ok(s) = accept.to_str() {
391 bus.insert(ranvier_guard::AcceptEncoding(s.to_string()));
392 }
393 }
394 })],
395 response_extractor: Some(Arc::new(|bus: &Bus, headers: &mut http::HeaderMap| {
396 if let Some(config) = bus.read::<ranvier_guard::CompressionConfig>() {
397 if config.encoding != ranvier_guard::CompressionEncoding::Identity {
398 if let Ok(v) = config.encoding.as_str().parse() {
399 headers.insert("content-encoding", v);
400 }
401 }
402 if let Ok(v) = "accept-encoding".parse() {
403 headers.insert("vary", v);
404 }
405 }
406 })),
407 response_body_transform: Some(Arc::new(move |bus: &Bus, body: bytes::Bytes| {
408 let Some(config) = bus.read::<ranvier_guard::CompressionConfig>() else {
409 return body;
410 };
411 if body.len() < config.min_body_size {
412 return body;
413 }
414 match config.encoding {
415 ranvier_guard::CompressionEncoding::Gzip => {
416 use flate2::write::GzEncoder;
417 use flate2::Compression;
418 use std::io::Write;
419 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
420 if encoder.write_all(&body).is_err() {
421 return body;
422 }
423 match encoder.finish() {
424 Ok(compressed) => bytes::Bytes::from(compressed),
425 Err(_) => body,
426 }
427 }
428 ranvier_guard::CompressionEncoding::Identity => body,
429 _ => body,
431 }
432 })),
433 exec: Arc::new(TransitionGuardExec {
434 guard: exec_guard,
435 default_status: http::StatusCode::INTERNAL_SERVER_ERROR,
436 }),
437 handles_preflight: false,
438 preflight_config: None,
439 }
440 }
441}
442
443impl<T> GuardIntegration for ranvier_guard::RequestSizeLimitGuard<T>
444where
445 T: Send + Sync + 'static,
446{
447 fn register(self) -> RegisteredGuard {
448 let exec_guard = ranvier_guard::RequestSizeLimitGuard::<()>::new(self.max_bytes());
449
450 RegisteredGuard {
451 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
452 if let Some(len) = parts.headers.get("content-length") {
453 if let Ok(s) = len.to_str() {
454 if let Ok(n) = s.parse::<u64>() {
455 bus.insert(ranvier_guard::ContentLength(n));
456 }
457 }
458 }
459 })],
460 response_extractor: None,
461 response_body_transform: None,
462 exec: Arc::new(TransitionGuardExec {
463 guard: exec_guard,
464 default_status: http::StatusCode::PAYLOAD_TOO_LARGE,
465 }),
466 handles_preflight: false,
467 preflight_config: None,
468 }
469 }
470}
471
472impl<T> GuardIntegration for ranvier_guard::RequestIdGuard<T>
473where
474 T: Send + Sync + 'static,
475{
476 fn register(self) -> RegisteredGuard {
477 let exec_guard = ranvier_guard::RequestIdGuard::<()>::new();
478
479 RegisteredGuard {
480 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
481 if let Some(rid) = parts.headers.get("x-request-id") {
483 if let Ok(s) = rid.to_str() {
484 bus.insert(ranvier_guard::RequestId(s.to_string()));
485 }
486 }
487 })],
489 response_extractor: Some(Arc::new(|bus: &Bus, headers: &mut http::HeaderMap| {
490 if let Some(rid) = bus.read::<ranvier_guard::RequestId>() {
491 if let Ok(v) = rid.0.parse() {
492 headers.insert("x-request-id", v);
493 }
494 }
495 })),
496 response_body_transform: None,
497 exec: Arc::new(TransitionGuardExec {
498 guard: exec_guard,
499 default_status: http::StatusCode::INTERNAL_SERVER_ERROR,
500 }),
501 handles_preflight: false,
502 preflight_config: None,
503 }
504 }
505}
506
507impl<T> GuardIntegration for ranvier_guard::AuthGuard<T>
508where
509 T: Send + Sync + 'static,
510{
511 fn register(self) -> RegisteredGuard {
512 let header_name: &'static str = match self.strategy() {
514 ranvier_guard::AuthStrategy::ApiKey { header_name, .. } => {
515 Box::leak(header_name.clone().into_boxed_str())
518 }
519 _ => "authorization",
520 };
521
522 let exec_guard = ranvier_guard::AuthGuard::<()>::new(self.strategy().clone())
523 .with_policy(self.iam_policy().clone());
524
525 RegisteredGuard {
526 bus_injectors: vec![Arc::new(move |parts: &http::request::Parts, bus: &mut Bus| {
527 if let Some(value) = parts.headers.get(header_name) {
528 if let Ok(s) = value.to_str() {
529 bus.insert(ranvier_guard::AuthorizationHeader(s.to_string()));
530 }
531 }
532 })],
533 response_extractor: Some(Arc::new(|bus: &Bus, headers: &mut http::HeaderMap| {
534 if bus.read::<ranvier_core::iam::IamIdentity>().is_none() {
536 if let Ok(v) = "Bearer".parse() {
537 headers.insert("www-authenticate", v);
538 }
539 }
540 })),
541 response_body_transform: None,
542 exec: Arc::new(TransitionGuardExec {
543 guard: exec_guard,
544 default_status: http::StatusCode::UNAUTHORIZED,
545 }),
546 handles_preflight: false,
547 preflight_config: None,
548 }
549 }
550}
551
552impl<T> GuardIntegration for ranvier_guard::ContentTypeGuard<T>
557where
558 T: Send + Sync + 'static,
559{
560 fn register(self) -> RegisteredGuard {
561 let allowed_types = self.allowed_types().to_vec();
562 let exec_guard = ranvier_guard::ContentTypeGuard::<()>::new(allowed_types);
563
564 RegisteredGuard {
565 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
566 if let Some(ct) = parts.headers.get("content-type") {
567 if let Ok(s) = ct.to_str() {
568 bus.insert(ranvier_guard::RequestContentType(s.to_string()));
569 }
570 }
571 })],
572 response_extractor: None,
573 response_body_transform: None,
574 exec: Arc::new(TransitionGuardExec {
575 guard: exec_guard,
576 default_status: http::StatusCode::UNSUPPORTED_MEDIA_TYPE,
577 }),
578 handles_preflight: false,
579 preflight_config: None,
580 }
581 }
582}
583
584impl<T> GuardIntegration for ranvier_guard::TimeoutGuard<T>
585where
586 T: Send + Sync + 'static,
587{
588 fn register(self) -> RegisteredGuard {
589 let exec_guard = ranvier_guard::TimeoutGuard::<()>::new(self.timeout());
590
591 RegisteredGuard {
592 bus_injectors: vec![],
593 response_extractor: None,
594 response_body_transform: None,
595 exec: Arc::new(TransitionGuardExec {
596 guard: exec_guard,
597 default_status: http::StatusCode::REQUEST_TIMEOUT,
598 }),
599 handles_preflight: false,
600 preflight_config: None,
601 }
602 }
603}
604
605impl<T> GuardIntegration for ranvier_guard::IdempotencyGuard<T>
606where
607 T: Send + Sync + 'static,
608{
609 fn register(self) -> RegisteredGuard {
610 let cache = self.cache().clone();
611 let exec_guard = self.clone_as_unit();
612
613 RegisteredGuard {
614 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
615 if let Some(key) = parts.headers.get("idempotency-key") {
616 if let Ok(s) = key.to_str() {
617 bus.insert(ranvier_guard::IdempotencyKey(s.to_string()));
618 }
619 }
620 })],
621 response_extractor: Some(Arc::new(|bus: &Bus, headers: &mut http::HeaderMap| {
622 if bus.read::<ranvier_guard::IdempotencyCachedResponse>().is_some() {
623 if let Ok(v) = "true".parse() {
624 headers.insert("idempotency-replayed", v);
625 }
626 }
627 })),
628 response_body_transform: Some(Arc::new(move |bus: &Bus, body: bytes::Bytes| {
629 if let Some(key) = bus.read::<ranvier_guard::IdempotencyKey>() {
631 if bus.read::<ranvier_guard::IdempotencyCachedResponse>().is_none() {
632 cache.insert(key.0.clone(), body.to_vec());
633 }
634 }
635 body
636 })),
637 exec: Arc::new(TransitionGuardExec {
638 guard: exec_guard,
639 default_status: http::StatusCode::INTERNAL_SERVER_ERROR,
640 }),
641 handles_preflight: false,
642 preflight_config: None,
643 }
644 }
645}
646
647#[cfg(feature = "advanced")]
652impl<T> GuardIntegration for ranvier_guard::DecompressionGuard<T>
653where
654 T: Send + Sync + 'static,
655{
656 fn register(self) -> RegisteredGuard {
657 let exec_guard = ranvier_guard::DecompressionGuard::<()>::new();
658
659 RegisteredGuard {
660 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
661 if let Some(ce) = parts.headers.get("content-encoding") {
662 if let Ok(s) = ce.to_str() {
663 bus.insert(ranvier_guard::RequestContentEncoding(s.to_string()));
664 }
665 }
666 })],
667 response_extractor: None,
668 response_body_transform: None,
669 exec: Arc::new(TransitionGuardExec {
670 guard: exec_guard,
671 default_status: http::StatusCode::BAD_REQUEST,
672 }),
673 handles_preflight: false,
674 preflight_config: None,
675 }
676 }
677}
678
679#[cfg(feature = "advanced")]
680impl<T> GuardIntegration for ranvier_guard::ConditionalRequestGuard<T>
681where
682 T: Send + Sync + 'static,
683{
684 fn register(self) -> RegisteredGuard {
685 let exec_guard = ranvier_guard::ConditionalRequestGuard::<()>::new();
686
687 RegisteredGuard {
688 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
689 if let Some(inm) = parts.headers.get("if-none-match") {
690 if let Ok(s) = inm.to_str() {
691 bus.insert(ranvier_guard::IfNoneMatch(s.to_string()));
692 }
693 }
694 if let Some(ims) = parts.headers.get("if-modified-since") {
695 if let Ok(s) = ims.to_str() {
696 bus.insert(ranvier_guard::IfModifiedSince(s.to_string()));
697 }
698 }
699 })],
700 response_extractor: Some(Arc::new(|bus: &Bus, headers: &mut http::HeaderMap| {
701 if let Some(etag) = bus.read::<ranvier_guard::ETag>() {
702 if let Ok(v) = etag.0.parse() {
703 headers.insert("etag", v);
704 }
705 }
706 if let Some(lm) = bus.read::<ranvier_guard::LastModified>() {
707 if let Ok(v) = lm.0.parse() {
708 headers.insert("last-modified", v);
709 }
710 }
711 })),
712 response_body_transform: None,
713 exec: Arc::new(TransitionGuardExec {
714 guard: exec_guard,
715 default_status: http::StatusCode::NOT_MODIFIED,
716 }),
717 handles_preflight: false,
718 preflight_config: None,
719 }
720 }
721}
722
723#[cfg(feature = "advanced")]
724impl<T> GuardIntegration for ranvier_guard::RedirectGuard<T>
725where
726 T: Send + Sync + 'static,
727{
728 fn register(self) -> RegisteredGuard {
729 let rules: Vec<_> = self.rules().to_vec();
730 let exec_guard = ranvier_guard::RedirectGuard::<()>::new(rules);
731
732 RegisteredGuard {
733 bus_injectors: vec![Arc::new(|parts: &http::request::Parts, bus: &mut Bus| {
734 bus.insert(ranvier_guard::RedirectRequestPath(
735 parts.uri.path().to_string(),
736 ));
737 })],
738 response_extractor: None,
739 response_body_transform: None,
740 exec: Arc::new(TransitionGuardExec {
741 guard: exec_guard,
742 default_status: http::StatusCode::MOVED_PERMANENTLY,
743 }),
744 handles_preflight: false,
745 preflight_config: None,
746 }
747 }
748}