1use std::{
53 cell::{Ref, RefMut},
54 rc::Rc,
55};
56
57use actix_http::{header, Extensions, Method as HttpMethod, RequestHead};
58
59use crate::{http::header::Header, service::ServiceRequest, HttpMessage as _};
60
61mod acceptable;
62mod host;
63
64pub use self::{
65 acceptable::Acceptable,
66 host::{Host, HostGuard},
67};
68
69#[cfg(feature = "experimental-introspection")]
71#[non_exhaustive]
72#[derive(Debug, Clone)]
73pub enum GuardDetail {
74 HttpMethods(Vec<String>),
76 Headers(Vec<(String, String)>),
78 Generic(String),
80}
81
82#[derive(Debug)]
84pub struct GuardContext<'a> {
85 pub(crate) req: &'a ServiceRequest,
86}
87
88impl<'a> GuardContext<'a> {
89 #[inline]
91 pub fn head(&self) -> &RequestHead {
92 self.req.head()
93 }
94
95 #[inline]
97 pub fn req_data(&self) -> Ref<'a, Extensions> {
98 self.req.extensions()
99 }
100
101 #[inline]
103 pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
104 self.req.extensions_mut()
105 }
106
107 #[inline]
123 pub fn header<H: Header>(&self) -> Option<H> {
124 H::parse(self.req).ok()
125 }
126
127 #[inline]
129 pub fn app_data<T: 'static>(&self) -> Option<&T> {
130 self.req.app_data()
131 }
132}
133
134pub trait Guard {
138 fn check(&self, ctx: &GuardContext<'_>) -> bool;
140
141 #[cfg(feature = "experimental-introspection")]
143 fn name(&self) -> String {
144 std::any::type_name::<Self>().to_string()
145 }
146
147 #[cfg(feature = "experimental-introspection")]
151 fn details(&self) -> Option<Vec<GuardDetail>> {
152 None
153 }
154}
155
156impl Guard for Rc<dyn Guard> {
157 fn check(&self, ctx: &GuardContext<'_>) -> bool {
158 (**self).check(ctx)
159 }
160
161 #[cfg(feature = "experimental-introspection")]
162 fn name(&self) -> String {
163 (**self).name()
164 }
165
166 #[cfg(feature = "experimental-introspection")]
167 fn details(&self) -> Option<Vec<GuardDetail>> {
168 (**self).details()
169 }
170}
171
172pub fn fn_guard<F>(f: F) -> impl Guard
185where
186 F: Fn(&GuardContext<'_>) -> bool,
187{
188 FnGuard(f)
189}
190
191struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
192
193impl<F> Guard for FnGuard<F>
194where
195 F: Fn(&GuardContext<'_>) -> bool,
196{
197 fn check(&self, ctx: &GuardContext<'_>) -> bool {
198 (self.0)(ctx)
199 }
200}
201
202impl<F> Guard for F
203where
204 F: Fn(&GuardContext<'_>) -> bool,
205{
206 fn check(&self, ctx: &GuardContext<'_>) -> bool {
207 (self)(ctx)
208 }
209}
210
211#[allow(non_snake_case)]
225pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
226 AnyGuard {
227 guards: vec![Box::new(guard)],
228 }
229}
230
231pub struct AnyGuard {
237 guards: Vec<Box<dyn Guard>>,
238}
239
240impl AnyGuard {
241 pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
243 self.guards.push(Box::new(guard));
244 self
245 }
246}
247
248impl Guard for AnyGuard {
249 #[inline]
250 fn check(&self, ctx: &GuardContext<'_>) -> bool {
251 for guard in &self.guards {
252 if guard.check(ctx) {
253 return true;
254 }
255 }
256
257 false
258 }
259
260 #[cfg(feature = "experimental-introspection")]
261 fn name(&self) -> String {
262 format!(
263 "AnyGuard({})",
264 self.guards
265 .iter()
266 .map(|g| g.name())
267 .collect::<Vec<_>>()
268 .join(", ")
269 )
270 }
271
272 #[cfg(feature = "experimental-introspection")]
273 fn details(&self) -> Option<Vec<GuardDetail>> {
274 Some(
275 self.guards
276 .iter()
277 .flat_map(|g| g.details().unwrap_or_default())
278 .collect(),
279 )
280 }
281}
282
283#[allow(non_snake_case)]
299pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
300 AllGuard {
301 guards: vec![Box::new(guard)],
302 }
303}
304
305pub struct AllGuard {
311 guards: Vec<Box<dyn Guard>>,
312}
313
314impl AllGuard {
315 pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
317 self.guards.push(Box::new(guard));
318 self
319 }
320}
321
322impl Guard for AllGuard {
323 #[inline]
324 fn check(&self, ctx: &GuardContext<'_>) -> bool {
325 for guard in &self.guards {
326 if !guard.check(ctx) {
327 return false;
328 }
329 }
330
331 true
332 }
333
334 #[cfg(feature = "experimental-introspection")]
335 fn name(&self) -> String {
336 format!(
337 "AllGuard({})",
338 self.guards
339 .iter()
340 .map(|g| g.name())
341 .collect::<Vec<_>>()
342 .join(", ")
343 )
344 }
345
346 #[cfg(feature = "experimental-introspection")]
347 fn details(&self) -> Option<Vec<GuardDetail>> {
348 Some(
349 self.guards
350 .iter()
351 .flat_map(|g| g.details().unwrap_or_default())
352 .collect(),
353 )
354 }
355}
356
357pub struct Not<G>(pub G);
369
370impl<G: Guard> Guard for Not<G> {
371 #[inline]
372 fn check(&self, ctx: &GuardContext<'_>) -> bool {
373 !self.0.check(ctx)
374 }
375
376 #[cfg(feature = "experimental-introspection")]
377 fn name(&self) -> String {
378 format!("Not({})", self.0.name())
379 }
380
381 #[cfg(feature = "experimental-introspection")]
382 fn details(&self) -> Option<Vec<GuardDetail>> {
383 Some(vec![GuardDetail::Generic(self.name())])
384 }
385}
386
387#[allow(non_snake_case)]
389pub fn Method(method: HttpMethod) -> impl Guard {
390 MethodGuard(method)
391}
392
393#[derive(Debug, Clone)]
394pub(crate) struct RegisteredMethods(pub(crate) Vec<HttpMethod>);
395
396#[derive(Debug)]
398pub(crate) struct MethodGuard(HttpMethod);
399
400impl Guard for MethodGuard {
401 fn check(&self, ctx: &GuardContext<'_>) -> bool {
402 let registered = ctx.req_data_mut().remove::<RegisteredMethods>();
403
404 if let Some(mut methods) = registered {
405 methods.0.push(self.0.clone());
406 ctx.req_data_mut().insert(methods);
407 } else {
408 ctx.req_data_mut()
409 .insert(RegisteredMethods(vec![self.0.clone()]));
410 }
411
412 ctx.head().method == self.0
413 }
414
415 #[cfg(feature = "experimental-introspection")]
416 fn name(&self) -> String {
417 self.0.to_string()
418 }
419
420 #[cfg(feature = "experimental-introspection")]
421 fn details(&self) -> Option<Vec<GuardDetail>> {
422 Some(vec![GuardDetail::HttpMethods(vec![self.0.to_string()])])
423 }
424}
425
426macro_rules! method_guard {
427 ($method_fn:ident, $method_const:ident) => {
428 #[doc = concat!("Creates a guard that matches the `", stringify!($method_const), "` request method.")]
429 #[doc = concat!("The route in this example will only respond to `", stringify!($method_const), "` requests.")]
432 #[doc = concat!(" .guard(guard::", stringify!($method_fn), "())")]
437 #[allow(non_snake_case)]
440 pub fn $method_fn() -> impl Guard {
441 MethodGuard(HttpMethod::$method_const)
442 }
443 };
444}
445
446method_guard!(Get, GET);
447method_guard!(Post, POST);
448method_guard!(Put, PUT);
449method_guard!(Delete, DELETE);
450method_guard!(Head, HEAD);
451method_guard!(Options, OPTIONS);
452method_guard!(Connect, CONNECT);
453method_guard!(Patch, PATCH);
454method_guard!(Trace, TRACE);
455
456#[allow(non_snake_case)]
469pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
470 HeaderGuard(
471 header::HeaderName::try_from(name).unwrap(),
472 header::HeaderValue::from_static(value),
473 )
474}
475
476struct HeaderGuard(header::HeaderName, header::HeaderValue);
477
478impl Guard for HeaderGuard {
479 fn check(&self, ctx: &GuardContext<'_>) -> bool {
480 if let Some(val) = ctx.head().headers.get(&self.0) {
481 return val == self.1;
482 }
483
484 false
485 }
486
487 #[cfg(feature = "experimental-introspection")]
488 fn name(&self) -> String {
489 format!("Header({}, {})", self.0, self.1.to_str().unwrap_or(""))
490 }
491
492 #[cfg(feature = "experimental-introspection")]
493 fn details(&self) -> Option<Vec<GuardDetail>> {
494 Some(vec![GuardDetail::Headers(vec![(
495 self.0.to_string(),
496 self.1.to_str().unwrap_or("").to_string(),
497 )])])
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use actix_http::Method;
504
505 use super::*;
506 use crate::test::TestRequest;
507
508 #[test]
509 fn header_match() {
510 let req = TestRequest::default()
511 .insert_header((header::TRANSFER_ENCODING, "chunked"))
512 .to_srv_request();
513
514 let hdr = Header("transfer-encoding", "chunked");
515 assert!(hdr.check(&req.guard_ctx()));
516
517 let hdr = Header("transfer-encoding", "other");
518 assert!(!hdr.check(&req.guard_ctx()));
519
520 let hdr = Header("content-type", "chunked");
521 assert!(!hdr.check(&req.guard_ctx()));
522
523 let hdr = Header("content-type", "other");
524 assert!(!hdr.check(&req.guard_ctx()));
525 }
526
527 #[test]
528 fn method_guards() {
529 let get_req = TestRequest::get().to_srv_request();
530 let post_req = TestRequest::post().to_srv_request();
531
532 assert!(Get().check(&get_req.guard_ctx()));
533 assert!(!Get().check(&post_req.guard_ctx()));
534
535 assert!(Post().check(&post_req.guard_ctx()));
536 assert!(!Post().check(&get_req.guard_ctx()));
537
538 let req = TestRequest::put().to_srv_request();
539 assert!(Put().check(&req.guard_ctx()));
540 assert!(!Put().check(&get_req.guard_ctx()));
541
542 let req = TestRequest::patch().to_srv_request();
543 assert!(Patch().check(&req.guard_ctx()));
544 assert!(!Patch().check(&get_req.guard_ctx()));
545
546 let r = TestRequest::delete().to_srv_request();
547 assert!(Delete().check(&r.guard_ctx()));
548 assert!(!Delete().check(&get_req.guard_ctx()));
549
550 let req = TestRequest::default().method(Method::HEAD).to_srv_request();
551 assert!(Head().check(&req.guard_ctx()));
552 assert!(!Head().check(&get_req.guard_ctx()));
553
554 let req = TestRequest::default()
555 .method(Method::OPTIONS)
556 .to_srv_request();
557 assert!(Options().check(&req.guard_ctx()));
558 assert!(!Options().check(&get_req.guard_ctx()));
559
560 let req = TestRequest::default()
561 .method(Method::CONNECT)
562 .to_srv_request();
563 assert!(Connect().check(&req.guard_ctx()));
564 assert!(!Connect().check(&get_req.guard_ctx()));
565
566 let req = TestRequest::default()
567 .method(Method::TRACE)
568 .to_srv_request();
569 assert!(Trace().check(&req.guard_ctx()));
570 assert!(!Trace().check(&get_req.guard_ctx()));
571 }
572
573 #[test]
574 fn aggregate_any() {
575 let req = TestRequest::default()
576 .method(Method::TRACE)
577 .to_srv_request();
578
579 assert!(Any(Trace()).check(&req.guard_ctx()));
580 assert!(Any(Trace()).or(Get()).check(&req.guard_ctx()));
581 assert!(!Any(Get()).or(Get()).check(&req.guard_ctx()));
582 }
583
584 #[test]
585 fn aggregate_all() {
586 let req = TestRequest::default()
587 .method(Method::TRACE)
588 .to_srv_request();
589
590 assert!(All(Trace()).check(&req.guard_ctx()));
591 assert!(All(Trace()).and(Trace()).check(&req.guard_ctx()));
592 assert!(!All(Trace()).and(Get()).check(&req.guard_ctx()));
593 }
594
595 #[test]
596 fn nested_not() {
597 let req = TestRequest::default().to_srv_request();
598
599 let get = Get();
600 assert!(get.check(&req.guard_ctx()));
601
602 let not_get = Not(get);
603 assert!(!not_get.check(&req.guard_ctx()));
604
605 let not_not_get = Not(not_get);
606 assert!(not_not_get.check(&req.guard_ctx()));
607 }
608
609 #[test]
610 fn function_guard() {
611 let domain = "rust-lang.org".to_owned();
612 let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain));
613
614 let req = TestRequest::default()
615 .uri("blog.rust-lang.org")
616 .to_srv_request();
617 assert!(guard.check(&req.guard_ctx()));
618
619 let req = TestRequest::default().uri("crates.io").to_srv_request();
620 assert!(!guard.check(&req.guard_ctx()));
621 }
622
623 #[test]
624 fn mega_nesting() {
625 let guard = fn_guard(|ctx| All(Not(Any(Not(Trace())))).check(ctx));
626
627 let req = TestRequest::default().to_srv_request();
628 assert!(!guard.check(&req.guard_ctx()));
629
630 let req = TestRequest::default()
631 .method(Method::TRACE)
632 .to_srv_request();
633 assert!(guard.check(&req.guard_ctx()));
634 }
635
636 #[test]
637 fn app_data() {
638 const TEST_VALUE: u32 = 42;
639 let guard = fn_guard(|ctx| ctx.app_data::<u32>() == Some(&TEST_VALUE));
640
641 let req = TestRequest::default().app_data(TEST_VALUE).to_srv_request();
642 assert!(guard.check(&req.guard_ctx()));
643
644 let req = TestRequest::default()
645 .app_data(TEST_VALUE * 2)
646 .to_srv_request();
647 assert!(!guard.check(&req.guard_ctx()));
648 }
649}