tower_csrf/
lib.rs

1//! Modern protection against cross-site request forgery (CSRF) attacks,
2//!
3//! This is _experimental_ middleware for [Tower](https://crates.io/crates/tower). It provides modern CSRF protection as outlined in a [blogpost](https://words.filippo.io/csrf/) by Filippo Valsorda, discussing the research background for integrating CSRF protection in Go 1.25's `net/http`.
4//!
5//! This boils down to (quoting from the blog):
6//!
7//! 1. Allow all GET, HEAD, or OPTIONS requests
8//! 2. If the Origin header matches an allow-list of trusted origins, allow the request
9//! 3. If the Sec-Fetch-Site header is present and the value is `same-origin` or `none`, allow the request, otherwise reject
10//! 4. If neither the Sec-Fetch-Site nor the Origin headers are present, allow the request
11//! 5. If the Origin header’s host (including the port) matches the Host header, allow the request, otherwise reject it
12//!
13//! The crate uses [tracing](https://docs.rs/tracing/latest/tracing/) to log passed requests and configuration changes. Errors are not logged, just pass through the
14//! chain.
15use std::collections::HashSet;
16use std::future::Future;
17use std::pin::Pin;
18use std::result::Result;
19use std::sync::Arc;
20use std::task::{Context, Poll};
21
22use http::{Method, Request, Response, Uri};
23use tower::{BoxError, Layer, Service};
24use tracing::{debug, instrument, trace};
25use url::Url;
26
27/// Errors that can occur during configuration of the layer.
28#[derive(thiserror::Error, Debug)]
29pub enum ConfigError {
30    /// An invalid origin url was added as a trusted origin.
31    #[error(transparent)]
32    InvalidOriginUrl(#[from] url::ParseError),
33
34    /// A origin url containing a path, query or fragment was added as a trusted origin.
35    #[error("invalid origin {origin:?}: path, query, and fragment are not allowed")]
36    InvalidOriginUrlComponents { origin: String },
37}
38
39/// Errors that can occur during request processing of the middleware.
40///
41/// These errors must be handled when using the middleware in web frameworks (such as axum) to e.g. log errors or
42/// render appropriate responses.
43#[derive(thiserror::Error, Debug, PartialEq)]
44pub enum ProtectionError {
45    /// A cross-origin request was detected.
46    #[error("Cross-Origin request detected")]
47    CrossOriginRequest,
48
49    /// A cross-origin request was detected.
50    #[error("Cross-Origin request from old browser detected")]
51    CrossOriginRequestFromOldBrowser,
52
53    /// The host request header cannot be parsed.
54    #[error("Host header cannot be parsed")]
55    MalformedHost(#[source] url::ParseError),
56
57    /// The origin request header cannot be parsed.
58    #[error("Origin header cannot be parsed")]
59    MalformedOrigin(#[source] url::ParseError),
60}
61
62struct Bypass<T: Fn(&Method, &Uri) -> bool>(T);
63
64impl<T: Fn(&Method, &Uri) -> bool> std::fmt::Debug for Bypass<T> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("<fn>").finish()
67    }
68}
69
70trait Filter: std::fmt::Debug + Send + Sync {
71    fn is_bypassed(&self, method: &Method, uri: &Uri) -> bool;
72}
73
74impl<T: Fn(&Method, &Uri) -> bool> Filter for Option<Bypass<T>>
75where
76    T: Send + Sync,
77{
78    fn is_bypassed(&self, method: &Method, uri: &Uri) -> bool {
79        match self {
80            Some(ref p) => p.0(method, uri),
81            None => false,
82        }
83    }
84}
85
86#[derive(Clone, Debug, Default)]
87struct Origins(Arc<HashSet<String>>);
88
89impl Origins {
90    fn contains(&self, origin: &str) -> bool {
91        self.0.contains(origin)
92    }
93
94    fn insert(&mut self, origin: impl Into<String>) {
95        Arc::make_mut(&mut self.0).insert(origin.into());
96    }
97}
98
99/// Decorates a HTTP service with CSRF protection.
100#[derive(Clone, Debug)]
101pub struct CrossOriginProtectionLayer {
102    insecure_bypass: Arc<dyn Filter>,
103    trusted_origins: Origins,
104}
105
106impl Default for CrossOriginProtectionLayer {
107    fn default() -> Self {
108        CrossOriginProtectionLayer {
109            insecure_bypass: Arc::new(Option::<Bypass<fn(&Method, &Uri) -> bool>>::default()),
110            trusted_origins: Origins::default(),
111        }
112    }
113}
114
115impl CrossOriginProtectionLayer {
116    /// Adds a trusted origin which allows all requests with an `Origin` header which exactly matches
117    /// the given value.
118    ///
119    /// Origin header values are of the form `scheme://host[:port]`.
120    pub fn add_trusted_origin<S: Into<String>>(mut self, origin: S) -> Result<Self, ConfigError> {
121        let origin = origin.into();
122
123        // using url crate here for fragment support (see https://github.com/hyperium/http/issues/127)
124        let url = Url::parse(&origin)?;
125
126        // note that the url crate will always normalize an empty path to "/"
127        if url.path() != "/" || url.query().is_some() || url.fragment().is_some() {
128            return Err(ConfigError::InvalidOriginUrlComponents { origin });
129        }
130
131        debug!(origin = %origin, "added trusted origin");
132
133        self.trusted_origins.insert(origin);
134
135        Ok(self)
136    }
137
138    /// Adds a bypass function that returns `true` if the given request should bypass CSRF protection. Notes that this
139    /// might be insecure.
140    pub fn with_insecure_bypass<F>(self, predicate: F) -> CrossOriginProtectionLayer
141    where
142        F: Fn(&Method, &Uri) -> bool + Send + Sync + 'static,
143    {
144        debug!("added insecure bypass");
145
146        CrossOriginProtectionLayer {
147            insecure_bypass: Arc::new(Some(Bypass(predicate))),
148            trusted_origins: self.trusted_origins,
149        }
150    }
151}
152
153impl<S> Layer<S> for CrossOriginProtectionLayer {
154    type Service = CrossOriginProtectionMiddleware<S>;
155
156    fn layer(&self, inner: S) -> Self::Service {
157        CrossOriginProtectionMiddleware {
158            inner,
159            insecure_bypass: self.insecure_bypass.clone(),
160            trusted_origins: self.trusted_origins.clone(),
161        }
162    }
163}
164
165/// CSRF protection middleware for HTTP requests.
166#[derive(Clone, Debug)]
167pub struct CrossOriginProtectionMiddleware<S> {
168    inner: S,
169    insecure_bypass: Arc<dyn Filter>,
170    trusted_origins: Origins,
171}
172
173impl<S: Default> Default for CrossOriginProtectionMiddleware<S> {
174    fn default() -> Self {
175        Self {
176            inner: S::default(),
177            insecure_bypass: Arc::new(Option::<Bypass<fn(&Method, &Uri) -> bool>>::default()),
178            trusted_origins: Origins::default(),
179        }
180    }
181}
182
183impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CrossOriginProtectionMiddleware<S>
184where
185    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
186    S::Error: Into<BoxError> + Send,
187    S::Future: Future<Output = Result<Response<ResBody>, S::Error>> + Send,
188    ReqBody: Send + 'static,
189    ResBody: Send + 'static,
190{
191    type Error = BoxError;
192    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
193    type Response = Response<ResBody>;
194
195    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
196        let clone = self.inner.clone();
197        let mut inner = std::mem::replace(&mut self.inner, clone);
198
199        match self.verify(&req) {
200            Ok(_) => Box::pin(async move { inner.call(req).await.map_err(Into::into) }),
201            Err(err) => Box::pin(async move { Err(err.into()) }),
202        }
203    }
204
205    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206        self.inner.poll_ready(cx).map_err(Into::into)
207    }
208}
209
210impl<S> CrossOriginProtectionMiddleware<S> {
211    #[instrument(skip(self, req), fields(uri = %req.uri()))]
212    fn is_exempt<Body>(&self, req: &Request<Body>) -> bool {
213        if self.insecure_bypass.is_bypassed(req.method(), req.uri()) {
214            trace!("request passed: bypassed");
215            return true;
216        }
217
218        if let Some(origin) = req.headers().get("origin") {
219            if self
220                .trusted_origins
221                .contains(origin.to_str().unwrap_or_default())
222            {
223                trace!("request passed: trusted origin");
224                return true;
225            }
226        }
227
228        false
229    }
230
231    #[instrument(skip(self, req), fields(uri = %req.uri()))]
232    fn verify<Body>(&self, req: &Request<Body>) -> Result<(), ProtectionError> {
233        if matches!(*req.method(), Method::GET | Method::HEAD | Method::OPTIONS) {
234            trace!("request passed: safe method");
235            return Ok(());
236        }
237
238        if let Some(sec_fetch_site) = req
239            .headers()
240            .get("sec-fetch-site")
241            .and_then(|h| h.to_str().ok())
242        {
243            if matches!(sec_fetch_site, "same-origin" | "none") {
244                trace!("request passed: sec-fetch-site is same-origin or none");
245                return Ok(());
246            } else if self.is_exempt(req) {
247                return Ok(());
248            } else {
249                return Err(ProtectionError::CrossOriginRequest);
250            }
251        }
252
253        match req.headers().get("origin").and_then(|h| h.to_str().ok()) {
254            Some("null") => {}
255            Some(origin) => {
256                let origin = Url::parse(origin).map_err(ProtectionError::MalformedOrigin)?;
257
258                let origin_host = origin.host_str();
259                let host = req.headers().get("host").and_then(|h| h.to_str().ok());
260
261                // the origin header matches the host header. note that the host header
262                // doesn't include the scheme, so we don't know if this might be an
263                // http→https cross-origin request. we fail open, since all modern
264                // browsers support sec-fetch-site since 2023, and running an older
265                // browser makes a clear security trade-off already. sites can mitigate
266                // this with http strict transport security (hsts).
267
268                match (origin_host, host) {
269                    (Some(origin_host), Some(host)) if origin_host == host => {
270                        trace!("request passed: origin is same as host - ");
271                        return Ok(());
272                    }
273                    _ => {}
274                }
275            }
276            None => {
277                trace!("request passed: neither sec-fetch-site nor origin header (same-origin or not a browser request)");
278                return Ok(());
279            }
280        }
281
282        if self.is_exempt(req) {
283            return Ok(());
284        }
285
286        Err(ProtectionError::CrossOriginRequestFromOldBrowser)
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use tracing::Level;
293
294    use super::*;
295    use std::sync::Once;
296
297    static INIT: Once = Once::new();
298
299    fn init() {
300        INIT.call_once(|| {
301            tracing_subscriber::fmt()
302                .with_max_level(Level::TRACE)
303                .init();
304        });
305    }
306
307    #[test]
308    fn test_url_path_normalization() {
309        for url in ["https://example.com/", "https://example.com"] {
310            let url = Url::parse(url).unwrap();
311            assert_eq!(url.path(), "/");
312        }
313    }
314
315    #[test]
316    fn test_layer_add_trusted_origin() {
317        init();
318
319        assert!(matches!(
320            CrossOriginProtectionLayer::default().add_trusted_origin("https://example.com"),
321            Ok(_)
322        ));
323
324        for origin in ["not a valid url", "example.com", "https://"] {
325            assert!(matches!(
326                CrossOriginProtectionLayer::default().add_trusted_origin(origin),
327                Err(ConfigError::InvalidOriginUrl(_))
328            ));
329        }
330
331        for origin in [
332            "https://example.com/path",
333            "https://example.com/path?query=value",
334            "https://example.com/path#fragment",
335        ] {
336            assert!(matches!(
337                CrossOriginProtectionLayer::default().add_trusted_origin(origin),
338                Err(ConfigError::InvalidOriginUrlComponents { origin }) if origin == origin
339            ));
340        }
341    }
342
343    #[test]
344    fn test_middleware_debug_trait() {
345        init();
346
347        let layer = CrossOriginProtectionLayer::default();
348
349        let middleware = layer
350            .clone()
351            .with_insecure_bypass(|method, uri| method == Method::POST && uri.path() == "/bypass")
352            .layer(());
353
354        assert_eq!(
355            format!("{:?}", middleware),
356            "CrossOriginProtectionMiddleware { inner: (), insecure_bypass: Some(<fn>), trusted_origins: Origins({}) }"
357        );
358
359        let middleware = layer.layer(());
360
361        assert_eq!(
362            format!("{:?}", middleware),
363            "CrossOriginProtectionMiddleware { inner: (), insecure_bypass: None, trusted_origins: Origins({}) }"
364        );
365    }
366
367    #[test]
368    fn test_middleware_sec_fetch_site() {
369        init();
370
371        let middleware: CrossOriginProtectionMiddleware<()> = Default::default();
372
373        struct Test {
374            name: &'static str,
375            method: http::Method,
376            sec_fetch_site: Option<&'static str>,
377            origin: Option<&'static str>,
378            result: Result<(), ProtectionError>,
379        }
380
381        let tests = [
382            Test {
383                name: "same-origin allowed",
384                method: Method::GET,
385                sec_fetch_site: Some("same-origin"),
386                origin: None,
387                result: Ok(()),
388            },
389            Test {
390                name: "none allowed",
391                method: Method::POST,
392                sec_fetch_site: Some("none"),
393                origin: None,
394                result: Ok(()),
395            },
396            Test {
397                name: "cross-site blocked",
398                method: Method::POST,
399                sec_fetch_site: Some("cross-site"),
400                origin: None,
401                result: Err(ProtectionError::CrossOriginRequest),
402            },
403            Test {
404                name: "same-site blocked",
405                method: Method::POST,
406                sec_fetch_site: Some("same-site"),
407                origin: None,
408                result: Err(ProtectionError::CrossOriginRequest),
409            },
410            Test {
411                name: "no header with no origin",
412                method: Method::POST,
413                sec_fetch_site: None,
414                origin: None,
415                result: Ok(()),
416            },
417            Test {
418                name: "no header with matching origin",
419                method: Method::POST,
420                sec_fetch_site: None,
421                origin: Some("https://example.com"),
422                result: Ok(()),
423            },
424            Test {
425                name: "no header with mismatched origin",
426                method: Method::POST,
427                sec_fetch_site: None,
428                origin: Some("https://attacker.example"),
429                result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
430            },
431            Test {
432                name: "no header with null origin",
433                method: Method::POST,
434                sec_fetch_site: None,
435                origin: Some("null"),
436                result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
437            },
438            Test {
439                name: "GET allowed",
440                method: Method::GET,
441                sec_fetch_site: Some("cross-site"),
442                origin: None,
443                result: Ok(()),
444            },
445            Test {
446                name: "HEAD allowed",
447                method: Method::HEAD,
448                sec_fetch_site: Some("cross-site"),
449                origin: None,
450                result: Ok(()),
451            },
452            Test {
453                name: "OPTIONS allowed",
454                method: Method::OPTIONS,
455                sec_fetch_site: Some("cross-site"),
456                origin: None,
457                result: Ok(()),
458            },
459            Test {
460                name: "PUT allowed",
461                method: Method::PUT,
462                sec_fetch_site: Some("cross-site"),
463                origin: None,
464                result: Err(ProtectionError::CrossOriginRequest),
465            },
466        ];
467
468        for test in tests {
469            let mut req = Request::builder()
470                .method(test.method)
471                .header("host", "example.com");
472
473            if let Some(sec_fetch_site) = test.sec_fetch_site {
474                req = req.header("sec-fetch-site", sec_fetch_site);
475            }
476
477            if let Some(origin) = test.origin {
478                req = req.header("origin", origin);
479            }
480
481            let req = req.body(()).unwrap();
482
483            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
484        }
485    }
486
487    #[test]
488    fn test_middleware_trusted_origin_bypass() {
489        init();
490
491        let layer = CrossOriginProtectionLayer::default()
492            .add_trusted_origin("https://trusted.example")
493            .unwrap();
494
495        let middleware = layer.layer(());
496
497        struct Test {
498            name: &'static str,
499            sec_fetch_site: Option<&'static str>,
500            origin: Option<&'static str>,
501            result: Result<(), ProtectionError>,
502        }
503
504        let tests = [
505            Test {
506                name: "trusted origin without sec-fetch-site",
507                origin: Some("https://trusted.example"),
508                sec_fetch_site: None,
509                result: Ok(()),
510            },
511            Test {
512                name: "trusted origin with cross-site",
513                origin: Some("https://trusted.example"),
514                sec_fetch_site: Some("cross-site"),
515                result: Ok(()),
516            },
517            Test {
518                name: "untrusted origin without sec-fetch-site",
519                origin: Some("https://attacker.example"),
520                sec_fetch_site: None,
521                result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
522            },
523            Test {
524                name: "untrusted origin with cross-site",
525                origin: Some("https://attacker.example"),
526                sec_fetch_site: Some("cross-site"),
527                result: Err(ProtectionError::CrossOriginRequest),
528            },
529        ];
530
531        for test in tests {
532            let mut req = Request::builder()
533                .method("POST")
534                .header("host", "example.com");
535
536            if let Some(sec_fetch_site) = test.sec_fetch_site {
537                req = req.header("sec-fetch-site", sec_fetch_site);
538            }
539
540            if let Some(origin) = test.origin {
541                req = req.header("origin", origin);
542            }
543
544            let req = req.body(()).unwrap();
545
546            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
547        }
548    }
549
550    #[test]
551    fn test_middleware_bypass() {
552        init();
553
554        let layer = CrossOriginProtectionLayer::default()
555            .with_insecure_bypass(|_method, uri| -> bool { uri.path() == "/bypass" });
556
557        let middleware = layer.layer(());
558
559        struct Test {
560            name: &'static str,
561            path: &'static str,
562            sec_fetch_site: Option<&'static str>,
563            result: Result<(), ProtectionError>,
564        }
565
566        let tests = [
567            Test {
568                name: "bypass path without sec-fetch-site",
569                path: "/bypass",
570                sec_fetch_site: None,
571                result: Ok(()),
572            },
573            Test {
574                name: "bypass path with cross-site",
575                path: "/bypass",
576                sec_fetch_site: Some("cross-site"),
577                result: Ok(()),
578            },
579            Test {
580                name: "non-bypass path without sec-fetch-site",
581                path: "/api",
582                sec_fetch_site: None,
583                result: Err(ProtectionError::CrossOriginRequestFromOldBrowser),
584            },
585            Test {
586                name: "non-bypass path with cross-site",
587                path: "/api",
588                sec_fetch_site: Some("cross-site"),
589                result: Err(ProtectionError::CrossOriginRequest),
590            },
591        ];
592
593        for test in tests {
594            let mut req = Request::builder()
595                .method("POST")
596                .header("host", "example.com")
597                .header("origin", "https://attacker.example")
598                .uri(format!("https://example.com{}", test.path));
599
600            if let Some(sec_fetch_site) = test.sec_fetch_site {
601                req = req.header("sec-fetch-site", sec_fetch_site);
602            }
603
604            let req = req.body(()).unwrap();
605
606            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
607        }
608    }
609}