Skip to main content

modo/middleware/
csrf.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::body::Body;
6use axum::response::IntoResponse;
7use axum_extra::extract::cookie::Key;
8use cookie::{Cookie, CookieJar, SameSite};
9use http::{HeaderValue, Method, Request, Response};
10use serde::Deserialize;
11use tower::{Layer, Service};
12
13/// Configuration for CSRF protection middleware.
14///
15/// Uses the double-submit cookie pattern: a signed HttpOnly cookie holds the
16/// token, and the client must echo the same token back via the configured
17/// header on state-changing requests.
18#[non_exhaustive]
19#[derive(Debug, Clone, Deserialize)]
20#[serde(default)]
21pub struct CsrfConfig {
22    /// Name of the CSRF cookie.
23    pub cookie_name: String,
24    /// Name of the HTTP header that must carry the CSRF token on unsafe requests.
25    pub header_name: String,
26    /// Intended form-field name for the CSRF token. Not currently read by the
27    /// middleware — token validation is header-only. Retained for configuration
28    /// compatibility.
29    pub field_name: String,
30    /// Cookie time-to-live in seconds.
31    pub ttl_secs: u64,
32    /// HTTP methods exempt from CSRF validation.
33    pub exempt_methods: Vec<String>,
34}
35
36impl Default for CsrfConfig {
37    fn default() -> Self {
38        Self {
39            cookie_name: "_csrf".to_string(),
40            header_name: "X-CSRF-Token".to_string(),
41            field_name: "_csrf_token".to_string(),
42            ttl_secs: 21600,
43            exempt_methods: vec!["GET", "HEAD", "OPTIONS"]
44                .into_iter()
45                .map(String::from)
46                .collect(),
47        }
48    }
49}
50
51/// CSRF token newtype, stored in request and response extensions for
52/// handler/template access.
53#[derive(Clone, Debug)]
54pub struct CsrfToken(pub String);
55
56/// A [`Layer`] that applies CSRF protection using the double-submit cookie
57/// pattern with signed cookies.
58#[derive(Clone)]
59pub struct CsrfLayer {
60    config: CsrfConfig,
61    key: Key,
62}
63
64impl<S> Layer<S> for CsrfLayer {
65    type Service = CsrfService<S>;
66
67    fn layer(&self, inner: S) -> Self::Service {
68        CsrfService {
69            inner,
70            config: self.config.clone(),
71            key: self.key.clone(),
72        }
73    }
74}
75
76/// The [`Service`] produced by `CsrfLayer`.
77///
78/// For exempt methods (GET, HEAD, OPTIONS by default), generates a new CSRF
79/// token, sets a signed cookie, and injects [`CsrfToken`] into both request
80/// and response extensions.
81///
82/// For unsafe methods (POST, PUT, DELETE, PATCH, etc.), reads the signed
83/// cookie, compares the plain token with the value of the configured header,
84/// and rejects with 403 Forbidden on mismatch.
85#[derive(Clone)]
86pub struct CsrfService<S> {
87    inner: S,
88    config: CsrfConfig,
89    key: Key,
90}
91
92impl<S> CsrfService<S> {
93    /// Signs a token and returns the signed cookie value string.
94    fn sign_token(&self, token: &str) -> String {
95        let mut jar = CookieJar::new();
96        jar.signed_mut(&self.key).add(Cookie::new(
97            self.config.cookie_name.clone(),
98            token.to_string(),
99        ));
100        jar.get(&self.config.cookie_name)
101            .expect("cookie was just added")
102            .value()
103            .to_string()
104    }
105
106    /// Verifies a signed cookie value and returns the plain token if valid.
107    fn verify_token(&self, signed_value: &str) -> Option<String> {
108        let mut jar = CookieJar::new();
109        jar.add_original(Cookie::new(
110            self.config.cookie_name.clone(),
111            signed_value.to_string(),
112        ));
113        jar.signed(&self.key)
114            .get(&self.config.cookie_name)
115            .map(|c: Cookie<'_>| c.value().to_string())
116    }
117
118    /// Builds the Set-Cookie header value for the CSRF cookie.
119    fn build_set_cookie(&self, signed_value: &str) -> String {
120        Cookie::build((self.config.cookie_name.clone(), signed_value.to_string()))
121            .http_only(true)
122            .same_site(SameSite::Lax)
123            .path("/")
124            .max_age(cookie::time::Duration::seconds(self.config.ttl_secs as i64))
125            .build()
126            .to_string()
127    }
128
129    /// Returns `true` if the request method is exempt from CSRF checks.
130    fn is_exempt(&self, method: &Method) -> bool {
131        self.config
132            .exempt_methods
133            .iter()
134            .any(|m| m.eq_ignore_ascii_case(method.as_str()))
135    }
136
137    /// Extracts the token submitted by the client from the configured header.
138    fn extract_submitted_token<B>(&self, request: &Request<B>) -> Option<String> {
139        request
140            .headers()
141            .get(&self.config.header_name)
142            .and_then(|v| v.to_str().ok())
143            .map(|s| s.to_string())
144    }
145
146    /// Extracts the cookie value from the request's Cookie header.
147    fn extract_cookie_value<B>(&self, request: &Request<B>) -> Option<String> {
148        let cookie_header = request.headers().get(http::header::COOKIE)?;
149        let cookie_str = cookie_header.to_str().ok()?;
150
151        for pair in cookie_str.split(';') {
152            let pair = pair.trim();
153            if let Some((name, value)) = pair.split_once('=')
154                && name.trim() == self.config.cookie_name
155            {
156                return Some(value.trim().to_string());
157            }
158        }
159
160        None
161    }
162}
163
164impl<S, ReqBody> Service<Request<ReqBody>> for CsrfService<S>
165where
166    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
167    S::Future: Send + 'static,
168    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
169    ReqBody: Send + 'static,
170{
171    type Response = Response<Body>;
172    type Error = S::Error;
173    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
174
175    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176        self.inner.poll_ready(cx)
177    }
178
179    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
180        // Clone self's inner service for use in the async block (tower pattern)
181        let mut inner = self.inner.clone();
182        std::mem::swap(&mut self.inner, &mut inner);
183
184        let is_exempt = self.is_exempt(request.method());
185
186        if is_exempt {
187            // Generate a new token, sign it, set cookie, inject into extensions
188            let token = crate::id::ulid();
189            let signed_value = self.sign_token(&token);
190            let set_cookie_value = self.build_set_cookie(&signed_value);
191
192            request.extensions_mut().insert(CsrfToken(token.clone()));
193
194            Box::pin(async move {
195                let mut response = inner.call(request).await?;
196
197                if let Ok(header_value) = HeaderValue::from_str(&set_cookie_value) {
198                    response
199                        .headers_mut()
200                        .append(http::header::SET_COOKIE, header_value);
201                }
202
203                response.extensions_mut().insert(CsrfToken(token));
204
205                Ok(response)
206            })
207        } else {
208            // Validate: read signed cookie, verify, compare with submitted token
209            let cookie_value = self.extract_cookie_value(&request);
210            let submitted_token = self.extract_submitted_token(&request);
211
212            let verified = cookie_value
213                .and_then(|signed| self.verify_token(&signed))
214                .zip(submitted_token)
215                .is_some_and(|(cookie_token, header_token)| {
216                    use subtle::ConstantTimeEq;
217                    cookie_token
218                        .as_bytes()
219                        .ct_eq(header_token.as_bytes())
220                        .into()
221                });
222
223            if verified {
224                Box::pin(async move { inner.call(request).await })
225            } else {
226                let header_name = self.config.header_name.clone();
227                Box::pin(async move {
228                    let error = crate::error::Error::forbidden(format!(
229                        "CSRF validation failed: missing or invalid {header_name}"
230                    ));
231                    Ok(error.into_response())
232                })
233            }
234        }
235    }
236}
237
238/// Returns a Tower layer that applies CSRF protection using the
239/// double-submit signed cookie pattern.
240///
241/// # Example
242///
243/// ```rust,no_run
244/// use modo::middleware::{csrf, CsrfConfig};
245/// use modo::cookie::Key;
246///
247/// let config = CsrfConfig::default();
248/// let key = Key::generate();
249/// let layer = csrf(&config, &key);
250/// ```
251pub fn csrf(config: &CsrfConfig, key: &Key) -> CsrfLayer {
252    CsrfLayer {
253        config: config.clone(),
254        key: key.clone(),
255    }
256}