Skip to main content

tower_http/csrf/
service.rs

1use std::convert::TryFrom;
2use std::fmt::{self, Debug, Formatter};
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use http::{Method, Request, Response, Uri};
7use tower_service::Service;
8
9use super::future::ResponseFuture;
10use super::{
11    BypassFn, DebugFn, DefaultResponseForProtectionError, Origins, ProtectionError,
12    ProtectionErrorKind, ResponseForProtectionError,
13};
14
15/// Middleware that enforces cross-origin request forgery (CSRF) protection.
16///
17/// See the [module docs](crate::csrf) for an example.
18#[derive(Clone)]
19#[must_use]
20pub struct Csrf<S, T = DefaultResponseForProtectionError> {
21    inner: S,
22    insecure_bypass: Option<Arc<BypassFn>>,
23    rejection_response: T,
24    trusted_origins: Origins,
25}
26
27impl<S, T> Csrf<S, T> {
28    pub(super) fn new(
29        inner: S,
30        insecure_bypass: Option<Arc<BypassFn>>,
31        rejection_response: T,
32        trusted_origins: Origins,
33    ) -> Self {
34        Self {
35            inner,
36            insecure_bypass,
37            rejection_response,
38            trusted_origins,
39        }
40    }
41
42    pub(super) fn verify<Body>(&self, req: &Request<Body>) -> Result<(), ProtectionError> {
43        // Deliberately not Method::is_safe: it also treats TRACE as safe, but the
44        // reference implementation only exempts GET/HEAD/OPTIONS, so we match it here.
45        if matches!(
46            req.method(),
47            &Method::GET | &Method::HEAD | &Method::OPTIONS
48        ) {
49            #[cfg(feature = "tracing")]
50            tracing::trace!(uri = %req.uri().path(), "request passed: safe method");
51            return Ok(());
52        }
53
54        let origin = req.headers().get("origin").map(|h| h.as_bytes());
55
56        let origin_uri = origin
57            .filter(|b| !b.is_empty())
58            .and_then(|b| Uri::try_from(b).ok())
59            .filter(|u| matches!(u.scheme_str(), Some("http" | "https")));
60
61        let sec_fetch_site = req.headers().get("sec-fetch-site").map(|h| h.as_bytes());
62
63        let is_exempt = || -> bool {
64            let bypass = self
65                .insecure_bypass
66                .as_ref()
67                .map_or(false, |bypass| bypass(req.method(), req.uri()));
68
69            if bypass {
70                #[cfg(feature = "tracing")]
71                tracing::trace!(uri = %req.uri().path(), "request passed: bypassed");
72                return true;
73            }
74
75            // Strict byte match of the raw Origin header against the registered
76            // set, mirroring the Go reference's `trustedOrigins[Origin]`.
77            let trusted = origin.map_or(false, |b| self.trusted_origins.contains(b));
78
79            if trusted {
80                #[cfg(feature = "tracing")]
81                tracing::trace!(uri = %req.uri().path(), "request passed: trusted origin");
82                return true;
83            }
84
85            false
86        };
87
88        // Fetch spec mandates lowercase here; exact byte match is intentional.
89        match sec_fetch_site {
90            Some(b"same-origin" | b"none") => {
91                #[cfg(feature = "tracing")]
92                tracing::trace!(uri = %req.uri().path(), "request passed: sec-fetch-site is same-origin or none");
93                return Ok(());
94            }
95            None | Some(b"") => {} // fall through to Origin check
96            Some(_) if is_exempt() => return Ok(()),
97            Some(_) => {
98                return Err(ProtectionError::new(
99                    ProtectionErrorKind::CrossOriginRequest,
100                ))
101            }
102        }
103
104        if matches!(origin, None | Some(b"")) {
105            #[cfg(feature = "tracing")]
106            tracing::trace!(uri = %req.uri().path(), "request passed: neither sec-fetch-site nor origin header (same-origin or not a browser request)");
107            return Ok(());
108        }
109
110        let host = req.headers().get("host").map(|h| h.as_bytes());
111
112        // Mirrors the reference's `url.Parse(origin).Host == req.Host`. Per RFC 7230
113        // §5.3, req.Host is the request-target authority (absolute-form URI / HTTP/2
114        // `:authority`) if present, else the Host header. Byte-exact and scheme-blind,
115        // so an http→https mismatch can't be caught here — we fail open (HSTS helps).
116        let effective_host = req
117            .uri()
118            .authority()
119            .map(|a| a.as_str().as_bytes())
120            .or(host);
121
122        if let (Some(uri), Some(effective_host)) = (&origin_uri, effective_host) {
123            if uri.authority().map(|a| a.as_str().as_bytes()) == Some(effective_host) {
124                #[cfg(feature = "tracing")]
125                tracing::trace!(uri = %req.uri().path(), "request passed: origin is same as host");
126                return Ok(());
127            }
128        }
129
130        if is_exempt() {
131            return Ok(());
132        }
133
134        Err(ProtectionError::new(
135            ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
136        ))
137    }
138}
139
140impl<S, T> Default for Csrf<S, T>
141where
142    S: Default,
143    T: Default,
144{
145    fn default() -> Self {
146        Self {
147            inner: S::default(),
148            insecure_bypass: None,
149            rejection_response: T::default(),
150            trusted_origins: Origins::default(),
151        }
152    }
153}
154
155impl<S: Debug, T> Debug for Csrf<S, T> {
156    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
157        f.debug_struct("Csrf")
158            .field("inner", &self.inner)
159            .field(
160                "insecure_bypass",
161                &self.insecure_bypass.as_ref().map(|_| DebugFn),
162            )
163            .field("trusted_origins", &self.trusted_origins)
164            .field("rejection_response", &DebugFn)
165            .finish()
166    }
167}
168
169impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for Csrf<S, T>
170where
171    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
172    T: ResponseForProtectionError<ResBody>,
173{
174    type Error = S::Error;
175    type Future = ResponseFuture<S::Future>;
176    type Response = Response<ResBody>;
177
178    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
179        match self.verify(&req) {
180            Ok(_) => ResponseFuture::future(self.inner.call(req)),
181            Err(err) => {
182                #[cfg(feature = "tracing")]
183                tracing::trace!(uri = %req.uri().path(), error = %err, "request rejected");
184
185                let mut response = self
186                    .rejection_response
187                    .response_for_protection_error(err.clone());
188
189                response.extensions_mut().insert(err);
190
191                ResponseFuture::rejected(Ok(response))
192            }
193        }
194    }
195
196    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197        self.inner.poll_ready(cx)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    // Guards the comment in `verify`: `Method::is_safe` exempts more than the
206    // GET/HEAD/OPTIONS set the reference implementation uses, so we can't rely on it.
207    #[test]
208    fn method_is_safe_covers_more_than_get_head_options() {
209        for method in [&Method::GET, &Method::HEAD, &Method::OPTIONS] {
210            assert!(method.is_safe());
211        }
212
213        // TRACE is "safe" per RFC 7231 but is not in the reference implementation's set.
214        assert!(Method::TRACE.is_safe());
215    }
216}