Skip to main content

actix_web/guard/
mod.rs

1//! Route guards.
2//!
3//! Guards are used during routing to help select a matching service or handler using some aspect of
4//! the request; though guards should not be used for path matching since it is a built-in function
5//! of the Actix Web router.
6//!
7//! Guards can be used on [`Scope`]s, [`Resource`]s, [`Route`]s, and other custom services.
8//!
9//! Fundamentally, a guard is a predicate function that receives a reference to a request context
10//! object and returns a boolean; true if the request _should_ be handled by the guarded service
11//! or handler. This interface is defined by the [`Guard`] trait.
12//!
13//! Commonly-used guards are provided in this module as well as a way of creating a guard from a
14//! closure ([`fn_guard`]). The [`Not`], [`Any()`], and [`All()`] guards are noteworthy, as they can be
15//! used to compose other guards in a more flexible and semantic way than calling `.guard(...)` on
16//! services multiple times (which might have different combining behavior than you want).
17//!
18//! There are shortcuts for routes with method guards in the [`web`](crate::web) module:
19//! [`web::get()`](crate::web::get), [`web::post()`](crate::web::post), etc. The routes created by
20//! the following calls are equivalent:
21//!
22//! - `web::get()` (recommended form)
23//! - `web::route().guard(guard::Get())`
24//!
25//! Guards can not modify anything about the request. However, it is possible to store extra
26//! attributes in the request-local data container obtained with [`GuardContext::req_data_mut`].
27//!
28//! Guards can prevent resource definitions from overlapping which, when only considering paths,
29//! would result in inaccessible routes. See the [`Host`] guard for an example of virtual hosting.
30//!
31//! # Examples
32//!
33//! In the following code, the `/guarded` resource has one defined route whose handler will only be
34//! called if the request method is GET or POST and there is a `x-guarded` request header with value
35//! equal to `secret`.
36//!
37//! ```
38//! use actix_web::{web, http::Method, guard, HttpResponse};
39//!
40//! web::resource("/guarded").route(
41//!     web::route()
42//!         .guard(guard::Any(guard::Get()).or(guard::Post()))
43//!         .guard(guard::Header("x-guarded", "secret"))
44//!         .to(|| HttpResponse::Ok())
45//! );
46//! ```
47//!
48//! [`Scope`]: crate::Scope::guard()
49//! [`Resource`]: crate::Resource::guard()
50//! [`Route`]: crate::Route::guard()
51
52use std::{
53    cell::{Ref, RefMut},
54    rc::Rc,
55};
56
57use actix_http::{header, Extensions, Method as HttpMethod, RequestHead};
58
59use crate::{http::header::Header, service::ServiceRequest, HttpMessage as _};
60
61mod acceptable;
62mod host;
63
64pub use self::{
65    acceptable::Acceptable,
66    host::{Host, HostGuard},
67};
68
69/// Enum to encapsulate various introspection details of a guard.
70#[cfg(feature = "experimental-introspection")]
71#[non_exhaustive]
72#[derive(Debug, Clone)]
73pub enum GuardDetail {
74    /// Detail associated with explicit HTTP method guards.
75    HttpMethods(Vec<String>),
76    /// Detail associated with headers (header, value).
77    Headers(Vec<(String, String)>),
78    /// Generic detail, typically used for compound guard representations.
79    Generic(String),
80}
81
82/// Provides access to request parts that are useful during routing.
83#[derive(Debug)]
84pub struct GuardContext<'a> {
85    pub(crate) req: &'a ServiceRequest,
86}
87
88impl<'a> GuardContext<'a> {
89    /// Returns reference to the request head.
90    #[inline]
91    pub fn head(&self) -> &RequestHead {
92        self.req.head()
93    }
94
95    /// Returns reference to the request-local data/extensions container.
96    #[inline]
97    pub fn req_data(&self) -> Ref<'a, Extensions> {
98        self.req.extensions()
99    }
100
101    /// Returns mutable reference to the request-local data/extensions container.
102    #[inline]
103    pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
104        self.req.extensions_mut()
105    }
106
107    /// Extracts a typed header from the request.
108    ///
109    /// Returns `None` if parsing `H` fails.
110    ///
111    /// # Examples
112    /// ```
113    /// use actix_web::{guard::fn_guard, http::header};
114    ///
115    /// let image_accept_guard = fn_guard(|ctx| {
116    ///     match ctx.header::<header::Accept>() {
117    ///         Some(hdr) => hdr.preference() == "image/*",
118    ///         None => false,
119    ///     }
120    /// });
121    /// ```
122    #[inline]
123    pub fn header<H: Header>(&self) -> Option<H> {
124        H::parse(self.req).ok()
125    }
126
127    /// Counterpart to [HttpRequest::app_data](crate::HttpRequest::app_data).
128    #[inline]
129    pub fn app_data<T: 'static>(&self) -> Option<&T> {
130        self.req.app_data()
131    }
132}
133
134/// Interface for routing guards.
135///
136/// See [module level documentation](self) for more.
137pub trait Guard {
138    /// Returns true if predicate condition is met for a given request.
139    fn check(&self, ctx: &GuardContext<'_>) -> bool;
140
141    /// Returns a nominal representation of the guard.
142    #[cfg(feature = "experimental-introspection")]
143    fn name(&self) -> String {
144        std::any::type_name::<Self>().to_string()
145    }
146
147    /// Returns detailed introspection information, when available.
148    ///
149    /// This is best-effort and may omit complex guard logic.
150    #[cfg(feature = "experimental-introspection")]
151    fn details(&self) -> Option<Vec<GuardDetail>> {
152        None
153    }
154}
155
156impl Guard for Rc<dyn Guard> {
157    fn check(&self, ctx: &GuardContext<'_>) -> bool {
158        (**self).check(ctx)
159    }
160
161    #[cfg(feature = "experimental-introspection")]
162    fn name(&self) -> String {
163        (**self).name()
164    }
165
166    #[cfg(feature = "experimental-introspection")]
167    fn details(&self) -> Option<Vec<GuardDetail>> {
168        (**self).details()
169    }
170}
171
172/// Creates a guard using the given function.
173///
174/// # Examples
175/// ```
176/// use actix_web::{guard, web, HttpResponse};
177///
178/// web::route()
179///     .guard(guard::fn_guard(|ctx| {
180///         ctx.head().headers().contains_key("content-type")
181///     }))
182///     .to(|| HttpResponse::Ok());
183/// ```
184pub fn fn_guard<F>(f: F) -> impl Guard
185where
186    F: Fn(&GuardContext<'_>) -> bool,
187{
188    FnGuard(f)
189}
190
191struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
192
193impl<F> Guard for FnGuard<F>
194where
195    F: Fn(&GuardContext<'_>) -> bool,
196{
197    fn check(&self, ctx: &GuardContext<'_>) -> bool {
198        (self.0)(ctx)
199    }
200}
201
202impl<F> Guard for F
203where
204    F: Fn(&GuardContext<'_>) -> bool,
205{
206    fn check(&self, ctx: &GuardContext<'_>) -> bool {
207        (self)(ctx)
208    }
209}
210
211/// Creates a guard that matches if any added guards match.
212///
213/// # Examples
214/// The handler below will be called for either request method `GET` or `POST`.
215/// ```
216/// use actix_web::{web, guard, HttpResponse};
217///
218/// web::route()
219///     .guard(
220///         guard::Any(guard::Get())
221///             .or(guard::Post()))
222///     .to(|| HttpResponse::Ok());
223/// ```
224#[allow(non_snake_case)]
225pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
226    AnyGuard {
227        guards: vec![Box::new(guard)],
228    }
229}
230
231/// A collection of guards that match if the disjunction of their `check` outcomes is true.
232///
233/// That is, only one contained guard needs to match in order for the aggregate guard to match.
234///
235/// Construct an `AnyGuard` using [`Any()`].
236pub struct AnyGuard {
237    guards: Vec<Box<dyn Guard>>,
238}
239
240impl AnyGuard {
241    /// Adds new guard to the collection of guards to check.
242    pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
243        self.guards.push(Box::new(guard));
244        self
245    }
246}
247
248impl Guard for AnyGuard {
249    #[inline]
250    fn check(&self, ctx: &GuardContext<'_>) -> bool {
251        for guard in &self.guards {
252            if guard.check(ctx) {
253                return true;
254            }
255        }
256
257        false
258    }
259
260    #[cfg(feature = "experimental-introspection")]
261    fn name(&self) -> String {
262        format!(
263            "AnyGuard({})",
264            self.guards
265                .iter()
266                .map(|g| g.name())
267                .collect::<Vec<_>>()
268                .join(", ")
269        )
270    }
271
272    #[cfg(feature = "experimental-introspection")]
273    fn details(&self) -> Option<Vec<GuardDetail>> {
274        Some(
275            self.guards
276                .iter()
277                .flat_map(|g| g.details().unwrap_or_default())
278                .collect(),
279        )
280    }
281}
282
283/// Creates a guard that matches if all added guards match.
284///
285/// # Examples
286/// The handler below will only be called if the request method is `GET` **and** the specified
287/// header name and value match exactly.
288/// ```
289/// use actix_web::{guard, web, HttpResponse};
290///
291/// web::route()
292///     .guard(
293///         guard::All(guard::Get())
294///             .and(guard::Header("accept", "text/plain"))
295///     )
296///     .to(|| HttpResponse::Ok());
297/// ```
298#[allow(non_snake_case)]
299pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
300    AllGuard {
301        guards: vec![Box::new(guard)],
302    }
303}
304
305/// A collection of guards that match if the conjunction of their `check` outcomes is true.
306///
307/// That is, **all** contained guard needs to match in order for the aggregate guard to match.
308///
309/// Construct an `AllGuard` using [`All()`].
310pub struct AllGuard {
311    guards: Vec<Box<dyn Guard>>,
312}
313
314impl AllGuard {
315    /// Adds new guard to the collection of guards to check.
316    pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
317        self.guards.push(Box::new(guard));
318        self
319    }
320}
321
322impl Guard for AllGuard {
323    #[inline]
324    fn check(&self, ctx: &GuardContext<'_>) -> bool {
325        for guard in &self.guards {
326            if !guard.check(ctx) {
327                return false;
328            }
329        }
330
331        true
332    }
333
334    #[cfg(feature = "experimental-introspection")]
335    fn name(&self) -> String {
336        format!(
337            "AllGuard({})",
338            self.guards
339                .iter()
340                .map(|g| g.name())
341                .collect::<Vec<_>>()
342                .join(", ")
343        )
344    }
345
346    #[cfg(feature = "experimental-introspection")]
347    fn details(&self) -> Option<Vec<GuardDetail>> {
348        Some(
349            self.guards
350                .iter()
351                .flat_map(|g| g.details().unwrap_or_default())
352                .collect(),
353        )
354    }
355}
356
357/// Wraps a guard and inverts the outcome of its `Guard` implementation.
358///
359/// # Examples
360/// The handler below will be called for any request method apart from `GET`.
361/// ```
362/// use actix_web::{guard, web, HttpResponse};
363///
364/// web::route()
365///     .guard(guard::Not(guard::Get()))
366///     .to(|| HttpResponse::Ok());
367/// ```
368pub struct Not<G>(pub G);
369
370impl<G: Guard> Guard for Not<G> {
371    #[inline]
372    fn check(&self, ctx: &GuardContext<'_>) -> bool {
373        !self.0.check(ctx)
374    }
375
376    #[cfg(feature = "experimental-introspection")]
377    fn name(&self) -> String {
378        format!("Not({})", self.0.name())
379    }
380
381    #[cfg(feature = "experimental-introspection")]
382    fn details(&self) -> Option<Vec<GuardDetail>> {
383        Some(vec![GuardDetail::Generic(self.name())])
384    }
385}
386
387/// Creates a guard that matches a specified HTTP method.
388#[allow(non_snake_case)]
389pub fn Method(method: HttpMethod) -> impl Guard {
390    MethodGuard(method)
391}
392
393#[derive(Debug, Clone)]
394pub(crate) struct RegisteredMethods(pub(crate) Vec<HttpMethod>);
395
396/// HTTP method guard.
397#[derive(Debug)]
398pub(crate) struct MethodGuard(HttpMethod);
399
400impl Guard for MethodGuard {
401    fn check(&self, ctx: &GuardContext<'_>) -> bool {
402        let registered = ctx.req_data_mut().remove::<RegisteredMethods>();
403
404        if let Some(mut methods) = registered {
405            methods.0.push(self.0.clone());
406            ctx.req_data_mut().insert(methods);
407        } else {
408            ctx.req_data_mut()
409                .insert(RegisteredMethods(vec![self.0.clone()]));
410        }
411
412        ctx.head().method == self.0
413    }
414
415    #[cfg(feature = "experimental-introspection")]
416    fn name(&self) -> String {
417        self.0.to_string()
418    }
419
420    #[cfg(feature = "experimental-introspection")]
421    fn details(&self) -> Option<Vec<GuardDetail>> {
422        Some(vec![GuardDetail::HttpMethods(vec![self.0.to_string()])])
423    }
424}
425
426macro_rules! method_guard {
427    ($method_fn:ident, $method_const:ident) => {
428        #[doc = concat!("Creates a guard that matches the `", stringify!($method_const), "` request method.")]
429        ///
430        /// # Examples
431        #[doc = concat!("The route in this example will only respond to `", stringify!($method_const), "` requests.")]
432        /// ```
433        /// use actix_web::{guard, web, HttpResponse};
434        ///
435        /// web::route()
436        #[doc = concat!("    .guard(guard::", stringify!($method_fn), "())")]
437        ///     .to(|| HttpResponse::Ok());
438        /// ```
439        #[allow(non_snake_case)]
440        pub fn $method_fn() -> impl Guard {
441            MethodGuard(HttpMethod::$method_const)
442        }
443    };
444}
445
446method_guard!(Get, GET);
447method_guard!(Post, POST);
448method_guard!(Put, PUT);
449method_guard!(Delete, DELETE);
450method_guard!(Head, HEAD);
451method_guard!(Options, OPTIONS);
452method_guard!(Connect, CONNECT);
453method_guard!(Patch, PATCH);
454method_guard!(Trace, TRACE);
455
456/// Creates a guard that matches if request contains given header name and value.
457///
458/// # Examples
459/// The handler below will be called when the request contains an `x-guarded` header with value
460/// equal to `secret`.
461/// ```
462/// use actix_web::{guard, web, HttpResponse};
463///
464/// web::route()
465///     .guard(guard::Header("x-guarded", "secret"))
466///     .to(|| HttpResponse::Ok());
467/// ```
468#[allow(non_snake_case)]
469pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
470    HeaderGuard(
471        header::HeaderName::try_from(name).unwrap(),
472        header::HeaderValue::from_static(value),
473    )
474}
475
476struct HeaderGuard(header::HeaderName, header::HeaderValue);
477
478impl Guard for HeaderGuard {
479    fn check(&self, ctx: &GuardContext<'_>) -> bool {
480        if let Some(val) = ctx.head().headers.get(&self.0) {
481            return val == self.1;
482        }
483
484        false
485    }
486
487    #[cfg(feature = "experimental-introspection")]
488    fn name(&self) -> String {
489        format!("Header({}, {})", self.0, self.1.to_str().unwrap_or(""))
490    }
491
492    #[cfg(feature = "experimental-introspection")]
493    fn details(&self) -> Option<Vec<GuardDetail>> {
494        Some(vec![GuardDetail::Headers(vec![(
495            self.0.to_string(),
496            self.1.to_str().unwrap_or("").to_string(),
497        )])])
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use actix_http::Method;
504
505    use super::*;
506    use crate::test::TestRequest;
507
508    #[test]
509    fn header_match() {
510        let req = TestRequest::default()
511            .insert_header((header::TRANSFER_ENCODING, "chunked"))
512            .to_srv_request();
513
514        let hdr = Header("transfer-encoding", "chunked");
515        assert!(hdr.check(&req.guard_ctx()));
516
517        let hdr = Header("transfer-encoding", "other");
518        assert!(!hdr.check(&req.guard_ctx()));
519
520        let hdr = Header("content-type", "chunked");
521        assert!(!hdr.check(&req.guard_ctx()));
522
523        let hdr = Header("content-type", "other");
524        assert!(!hdr.check(&req.guard_ctx()));
525    }
526
527    #[test]
528    fn method_guards() {
529        let get_req = TestRequest::get().to_srv_request();
530        let post_req = TestRequest::post().to_srv_request();
531
532        assert!(Get().check(&get_req.guard_ctx()));
533        assert!(!Get().check(&post_req.guard_ctx()));
534
535        assert!(Post().check(&post_req.guard_ctx()));
536        assert!(!Post().check(&get_req.guard_ctx()));
537
538        let req = TestRequest::put().to_srv_request();
539        assert!(Put().check(&req.guard_ctx()));
540        assert!(!Put().check(&get_req.guard_ctx()));
541
542        let req = TestRequest::patch().to_srv_request();
543        assert!(Patch().check(&req.guard_ctx()));
544        assert!(!Patch().check(&get_req.guard_ctx()));
545
546        let r = TestRequest::delete().to_srv_request();
547        assert!(Delete().check(&r.guard_ctx()));
548        assert!(!Delete().check(&get_req.guard_ctx()));
549
550        let req = TestRequest::default().method(Method::HEAD).to_srv_request();
551        assert!(Head().check(&req.guard_ctx()));
552        assert!(!Head().check(&get_req.guard_ctx()));
553
554        let req = TestRequest::default()
555            .method(Method::OPTIONS)
556            .to_srv_request();
557        assert!(Options().check(&req.guard_ctx()));
558        assert!(!Options().check(&get_req.guard_ctx()));
559
560        let req = TestRequest::default()
561            .method(Method::CONNECT)
562            .to_srv_request();
563        assert!(Connect().check(&req.guard_ctx()));
564        assert!(!Connect().check(&get_req.guard_ctx()));
565
566        let req = TestRequest::default()
567            .method(Method::TRACE)
568            .to_srv_request();
569        assert!(Trace().check(&req.guard_ctx()));
570        assert!(!Trace().check(&get_req.guard_ctx()));
571    }
572
573    #[test]
574    fn aggregate_any() {
575        let req = TestRequest::default()
576            .method(Method::TRACE)
577            .to_srv_request();
578
579        assert!(Any(Trace()).check(&req.guard_ctx()));
580        assert!(Any(Trace()).or(Get()).check(&req.guard_ctx()));
581        assert!(!Any(Get()).or(Get()).check(&req.guard_ctx()));
582    }
583
584    #[test]
585    fn aggregate_all() {
586        let req = TestRequest::default()
587            .method(Method::TRACE)
588            .to_srv_request();
589
590        assert!(All(Trace()).check(&req.guard_ctx()));
591        assert!(All(Trace()).and(Trace()).check(&req.guard_ctx()));
592        assert!(!All(Trace()).and(Get()).check(&req.guard_ctx()));
593    }
594
595    #[test]
596    fn nested_not() {
597        let req = TestRequest::default().to_srv_request();
598
599        let get = Get();
600        assert!(get.check(&req.guard_ctx()));
601
602        let not_get = Not(get);
603        assert!(!not_get.check(&req.guard_ctx()));
604
605        let not_not_get = Not(not_get);
606        assert!(not_not_get.check(&req.guard_ctx()));
607    }
608
609    #[test]
610    fn function_guard() {
611        let domain = "rust-lang.org".to_owned();
612        let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain));
613
614        let req = TestRequest::default()
615            .uri("blog.rust-lang.org")
616            .to_srv_request();
617        assert!(guard.check(&req.guard_ctx()));
618
619        let req = TestRequest::default().uri("crates.io").to_srv_request();
620        assert!(!guard.check(&req.guard_ctx()));
621    }
622
623    #[test]
624    fn mega_nesting() {
625        let guard = fn_guard(|ctx| All(Not(Any(Not(Trace())))).check(ctx));
626
627        let req = TestRequest::default().to_srv_request();
628        assert!(!guard.check(&req.guard_ctx()));
629
630        let req = TestRequest::default()
631            .method(Method::TRACE)
632            .to_srv_request();
633        assert!(guard.check(&req.guard_ctx()));
634    }
635
636    #[test]
637    fn app_data() {
638        const TEST_VALUE: u32 = 42;
639        let guard = fn_guard(|ctx| ctx.app_data::<u32>() == Some(&TEST_VALUE));
640
641        let req = TestRequest::default().app_data(TEST_VALUE).to_srv_request();
642        assert!(guard.check(&req.guard_ctx()));
643
644        let req = TestRequest::default()
645            .app_data(TEST_VALUE * 2)
646            .to_srv_request();
647        assert!(!guard.check(&req.guard_ctx()));
648    }
649}