tower_http/csrf/layer.rs
1use std::fmt::{self, Debug, Formatter};
2use std::sync::Arc;
3
4use http::{Method, Uri};
5use tower_layer::Layer;
6
7use super::service::Csrf;
8use super::url::UriExt;
9use super::{BypassFn, ConfigError, DebugFn, DefaultResponseForProtectionError, Origins};
10
11/// Layer that applies the [`Csrf`] middleware.
12///
13/// See the [module docs](crate::csrf) for an example.
14#[derive(Clone)]
15#[must_use]
16pub struct CsrfLayer<T = DefaultResponseForProtectionError> {
17 insecure_bypass: Option<Arc<BypassFn>>,
18 rejection_response: T,
19 trusted_origins: Origins,
20}
21
22impl Default for CsrfLayer {
23 fn default() -> Self {
24 Self {
25 insecure_bypass: None,
26 rejection_response: DefaultResponseForProtectionError,
27 trusted_origins: Origins::default(),
28 }
29 }
30}
31
32impl<T> Debug for CsrfLayer<T> {
33 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
34 f.debug_struct("CsrfLayer")
35 .field(
36 "insecure_bypass",
37 &self.insecure_bypass.as_ref().map(|_| DebugFn),
38 )
39 .field("trusted_origins", &self.trusted_origins)
40 .field("rejection_response", &DebugFn)
41 .finish()
42 }
43}
44
45impl CsrfLayer {
46 /// Creates a new `CsrfLayer` with no trusted origins, no bypass, and the
47 /// default rejection response.
48 pub fn new() -> Self {
49 Self::default()
50 }
51}
52
53impl<T> CsrfLayer<T> {
54 /// Adds a trusted origin that allows all requests whose `Origin` header
55 /// matches the given value.
56 ///
57 /// The value is matched **byte-for-byte** against the request's `Origin`
58 /// header — there is no normalization (this mirrors the Go reference). It
59 /// must therefore be written exactly as a browser sends it:
60 ///
61 /// - form `scheme://host[:port]`, where `scheme` is `http` or `https`;
62 /// - the host lowercased (browsers lowercase it; IDN hosts must be given in
63 /// punycode, e.g. `xn--exmple-cua.com`);
64 /// - **default ports omitted** — browsers drop `:80`/`:443`, so an explicit
65 /// default port (e.g. `https://example.com:443`) will never match;
66 /// - **no trailing slash**, path, query, or fragment.
67 ///
68 /// Inputs that can't represent a browser `Origin` are rejected with a
69 /// [`ConfigError`]; inputs that parse but aren't in the canonical browser
70 /// form above are accepted but will silently never match.
71 ///
72 /// ```
73 /// # use tower_http::csrf::CsrfLayer;
74 /// // Matches `Origin: https://example.com`:
75 /// let layer = CsrfLayer::new().add_trusted_origin("https://example.com")?;
76 ///
77 /// // Accepted, but never matches a browser Origin (explicit default port):
78 /// let layer = CsrfLayer::new().add_trusted_origin("https://example.com:443")?;
79 /// # Ok::<_, tower_http::csrf::ConfigError>(())
80 /// ```
81 pub fn add_trusted_origin<S: AsRef<str>>(mut self, origin: S) -> Result<Self, ConfigError> {
82 let origin = origin.as_ref();
83
84 // validate the form; the origin is stored and matched verbatim.
85 Uri::parse_origin(origin)?;
86
87 #[cfg(feature = "tracing")]
88 tracing::debug!(origin = %origin, "added trusted origin");
89
90 self.trusted_origins.insert(origin.to_owned());
91
92 Ok(self)
93 }
94
95 /// Adds a bypass predicate that returns `true` for requests which should
96 /// skip CSRF protection.
97 ///
98 /// This is an escape hatch for endpoints that legitimately need to accept
99 /// cross-origin POSTs (e.g. webhook receivers). Bypassed endpoints must
100 /// have their own protection (signed payloads, authentication tokens,
101 /// etc.) — otherwise they are CSRF-vulnerable.
102 pub fn with_insecure_bypass<F>(mut self, predicate: F) -> Self
103 where
104 F: Fn(&Method, &Uri) -> bool + Send + Sync + 'static,
105 {
106 #[cfg(feature = "tracing")]
107 tracing::debug!("added insecure bypass");
108
109 self.insecure_bypass = Some(Arc::new(predicate));
110 self
111 }
112
113 /// Replaces the response builder used when a request is rejected.
114 ///
115 /// Accepts any type that implements [`ResponseForProtectionError`](super::ResponseForProtectionError),
116 /// including a `FnMut(ProtectionError) -> Response<B> + Clone` closure.
117 /// The default builder returns a `403 Forbidden` with an empty body.
118 /// Regardless of the builder, [`Csrf`](super::Csrf) attaches the
119 /// [`ProtectionError`](super::ProtectionError) to the response's extensions,
120 /// so a custom builder need not re-attach it.
121 pub fn with_rejection_response<R>(self, rejection_response: R) -> CsrfLayer<R>
122 where
123 R: Clone,
124 {
125 CsrfLayer {
126 insecure_bypass: self.insecure_bypass,
127 trusted_origins: self.trusted_origins,
128 rejection_response,
129 }
130 }
131}
132
133impl<S, T> Layer<S> for CsrfLayer<T>
134where
135 T: Clone,
136{
137 type Service = Csrf<S, T>;
138
139 fn layer(&self, inner: S) -> Self::Service {
140 Csrf::new(
141 inner,
142 self.insecure_bypass.clone(),
143 self.rejection_response.clone(),
144 self.trusted_origins.clone(),
145 )
146 }
147}