Skip to main content

modo/middleware/
csrf.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::body::Body;
6use axum::response::IntoResponse;
7use axum_extra::extract::cookie::Key;
8use cookie::{Cookie, CookieJar, SameSite};
9use http::{HeaderValue, Method, Request, Response};
10use serde::Deserialize;
11use tower::{Layer, Service};
12
13/// Configuration for CSRF protection middleware.
14///
15/// Uses the double-submit cookie pattern: a signed HttpOnly cookie holds the
16/// token, and the client must echo the same token back via the configured
17/// header on state-changing requests.
18#[non_exhaustive]
19#[derive(Debug, Clone, Deserialize)]
20#[serde(default)]
21pub struct CsrfConfig {
22    /// Name of the CSRF cookie.
23    pub cookie_name: String,
24    /// Name of the HTTP header that must carry the CSRF token on unsafe requests.
25    pub header_name: String,
26    /// Intended form-field name for the CSRF token. Not currently read by the
27    /// middleware — token validation is header-only. Retained for configuration
28    /// compatibility.
29    pub field_name: String,
30    /// Cookie time-to-live in seconds.
31    pub ttl_secs: u64,
32    /// HTTP methods exempt from CSRF validation.
33    pub exempt_methods: Vec<String>,
34}
35
36impl Default for CsrfConfig {
37    fn default() -> Self {
38        Self {
39            cookie_name: "_csrf".to_string(),
40            header_name: "X-CSRF-Token".to_string(),
41            field_name: "_csrf_token".to_string(),
42            ttl_secs: 21600,
43            exempt_methods: vec!["GET", "HEAD", "OPTIONS"]
44                .into_iter()
45                .map(String::from)
46                .collect(),
47        }
48    }
49}
50
51/// CSRF token newtype, stored in request and response extensions for
52/// handler/template access.
53#[derive(Clone, Debug)]
54pub struct CsrfToken(pub String);
55
56/// A [`Layer`] that applies CSRF protection using the double-submit cookie
57/// pattern with signed cookies.
58#[derive(Clone)]
59pub struct CsrfLayer {
60    config: CsrfConfig,
61    key: Key,
62}
63
64impl<S> Layer<S> for CsrfLayer {
65    type Service = CsrfService<S>;
66
67    fn layer(&self, inner: S) -> Self::Service {
68        CsrfService {
69            inner,
70            config: self.config.clone(),
71            key: self.key.clone(),
72        }
73    }
74}
75
76/// The [`Service`] produced by `CsrfLayer`.
77///
78/// For exempt methods (GET, HEAD, OPTIONS by default), reuses the token from
79/// the request cookie when its signature verifies — only minting a new token
80/// and emitting `Set-Cookie` on first visit or signature failure. In every
81/// case, [`CsrfToken`] is injected into both request and response extensions.
82/// Keeping the token stable across GETs is required for the double-submit
83/// pattern to survive multi-tab sessions and long-lived pages.
84///
85/// For unsafe methods (POST, PUT, DELETE, PATCH, etc.), reads the signed
86/// cookie, compares the plain token with the value of the configured header,
87/// and rejects with 403 Forbidden on mismatch.
88#[derive(Clone)]
89pub struct CsrfService<S> {
90    inner: S,
91    config: CsrfConfig,
92    key: Key,
93}
94
95impl<S> CsrfService<S> {
96    /// Signs a token and returns the signed cookie value string.
97    fn sign_token(&self, token: &str) -> String {
98        let mut jar = CookieJar::new();
99        jar.signed_mut(&self.key).add(Cookie::new(
100            self.config.cookie_name.clone(),
101            token.to_string(),
102        ));
103        jar.get(&self.config.cookie_name)
104            .expect("cookie was just added")
105            .value()
106            .to_string()
107    }
108
109    /// Verifies a signed cookie value and returns the plain token if valid.
110    fn verify_token(&self, signed_value: &str) -> Option<String> {
111        let mut jar = CookieJar::new();
112        jar.add_original(Cookie::new(
113            self.config.cookie_name.clone(),
114            signed_value.to_string(),
115        ));
116        jar.signed(&self.key)
117            .get(&self.config.cookie_name)
118            .map(|c: Cookie<'_>| c.value().to_string())
119    }
120
121    /// Builds the Set-Cookie header value for the CSRF cookie.
122    fn build_set_cookie(&self, signed_value: &str) -> String {
123        Cookie::build((self.config.cookie_name.clone(), signed_value.to_string()))
124            .http_only(true)
125            .same_site(SameSite::Lax)
126            .path("/")
127            .max_age(cookie::time::Duration::seconds(self.config.ttl_secs as i64))
128            .build()
129            .to_string()
130    }
131
132    /// Returns `true` if the request method is exempt from CSRF checks.
133    fn is_exempt(&self, method: &Method) -> bool {
134        self.config
135            .exempt_methods
136            .iter()
137            .any(|m| m.eq_ignore_ascii_case(method.as_str()))
138    }
139
140    /// Extracts the token submitted by the client from the configured header.
141    fn extract_submitted_token<B>(&self, request: &Request<B>) -> Option<String> {
142        request
143            .headers()
144            .get(&self.config.header_name)
145            .and_then(|v| v.to_str().ok())
146            .map(|s| s.to_string())
147    }
148
149    /// Extracts the cookie value from the request's Cookie header.
150    fn extract_cookie_value<B>(&self, request: &Request<B>) -> Option<String> {
151        let cookie_header = request.headers().get(http::header::COOKIE)?;
152        let cookie_str = cookie_header.to_str().ok()?;
153
154        for pair in cookie_str.split(';') {
155            let pair = pair.trim();
156            if let Some((name, value)) = pair.split_once('=')
157                && name.trim() == self.config.cookie_name
158            {
159                return Some(value.trim().to_string());
160            }
161        }
162
163        None
164    }
165}
166
167impl<S, ReqBody> Service<Request<ReqBody>> for CsrfService<S>
168where
169    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
170    S::Future: Send + 'static,
171    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
172    ReqBody: Send + 'static,
173{
174    type Response = Response<Body>;
175    type Error = S::Error;
176    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
177
178    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
179        self.inner.poll_ready(cx)
180    }
181
182    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
183        // Clone self's inner service for use in the async block (tower pattern)
184        let mut inner = self.inner.clone();
185        std::mem::swap(&mut self.inner, &mut inner);
186
187        let is_exempt = self.is_exempt(request.method());
188
189        if is_exempt {
190            // Reuse the existing token when the cookie verifies; only mint +
191            // Set-Cookie on first visit or signature failure. Rotating on every
192            // GET would invalidate tokens already rendered into open tabs and
193            // long-lived pages, breaking the next state-changing request.
194            let existing = self
195                .extract_cookie_value(&request)
196                .and_then(|signed| self.verify_token(&signed));
197
198            let (token, set_cookie_value) = match existing {
199                Some(t) => (t, None),
200                None => {
201                    let t = crate::id::ulid();
202                    let signed = self.sign_token(&t);
203                    let sc = self.build_set_cookie(&signed);
204                    (t, Some(sc))
205                }
206            };
207
208            request.extensions_mut().insert(CsrfToken(token.clone()));
209
210            Box::pin(async move {
211                let mut response = inner.call(request).await?;
212
213                if let Some(sc) = set_cookie_value
214                    && let Ok(header_value) = HeaderValue::from_str(&sc)
215                {
216                    response
217                        .headers_mut()
218                        .append(http::header::SET_COOKIE, header_value);
219                }
220
221                response.extensions_mut().insert(CsrfToken(token));
222
223                Ok(response)
224            })
225        } else {
226            // Validate: read signed cookie, verify, compare with submitted token
227            let cookie_value = self.extract_cookie_value(&request);
228            let submitted_token = self.extract_submitted_token(&request);
229
230            let verified = cookie_value
231                .and_then(|signed| self.verify_token(&signed))
232                .zip(submitted_token)
233                .is_some_and(|(cookie_token, header_token)| {
234                    use subtle::ConstantTimeEq;
235                    cookie_token
236                        .as_bytes()
237                        .ct_eq(header_token.as_bytes())
238                        .into()
239                });
240
241            if verified {
242                Box::pin(async move { inner.call(request).await })
243            } else {
244                let header_name = self.config.header_name.clone();
245                Box::pin(async move {
246                    let error = crate::error::Error::forbidden(format!(
247                        "CSRF validation failed: missing or invalid {header_name}"
248                    ));
249                    Ok(error.into_response())
250                })
251            }
252        }
253    }
254}
255
256/// Returns a Tower layer that applies CSRF protection using the
257/// double-submit signed cookie pattern.
258///
259/// # Example
260///
261/// ```rust,no_run
262/// use modo::middleware::{csrf, CsrfConfig};
263/// use modo::cookie::Key;
264///
265/// let config = CsrfConfig::default();
266/// let key = Key::generate();
267/// let layer = csrf(&config, &key);
268/// ```
269pub fn csrf(config: &CsrfConfig, key: &Key) -> CsrfLayer {
270    CsrfLayer {
271        config: config.clone(),
272        key: key.clone(),
273    }
274}