Skip to main content

better_auth_core/middleware/
csrf.rs

1use super::Middleware;
2use crate::error::AuthResult;
3use crate::types::{AuthRequest, AuthResponse, HttpMethod};
4use async_trait::async_trait;
5
6/// Configuration for CSRF protection middleware.
7#[derive(Debug, Clone)]
8pub struct CsrfConfig {
9    /// Origins that are trusted and allowed to make state-changing requests.
10    pub trusted_origins: Vec<String>,
11
12    /// Whether CSRF protection is enabled. Defaults to `true`.
13    pub enabled: bool,
14}
15
16impl Default for CsrfConfig {
17    fn default() -> Self {
18        Self {
19            trusted_origins: Vec::new(),
20            enabled: true,
21        }
22    }
23}
24
25impl CsrfConfig {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn trusted_origin(mut self, origin: impl Into<String>) -> Self {
31        self.trusted_origins.push(origin.into());
32        self
33    }
34
35    pub fn enabled(mut self, enabled: bool) -> Self {
36        self.enabled = enabled;
37        self
38    }
39}
40
41/// CSRF protection middleware.
42///
43/// Validates `Origin` / `Referer` headers on state-changing requests
44/// (POST, PUT, DELETE, PATCH) against the configured trusted origins
45/// and the service's own base URL.
46pub struct CsrfMiddleware {
47    config: CsrfConfig,
48    /// Base URL of the service (extracted from AuthConfig at construction).
49    base_origin: String,
50}
51
52impl CsrfMiddleware {
53    pub fn new(config: CsrfConfig, base_url: &str) -> Self {
54        let base_origin = extract_origin(base_url).unwrap_or_default();
55        Self {
56            config,
57            base_origin,
58        }
59    }
60
61    fn is_state_changing(method: &HttpMethod) -> bool {
62        matches!(
63            method,
64            HttpMethod::Post | HttpMethod::Put | HttpMethod::Delete | HttpMethod::Patch
65        )
66    }
67
68    fn is_origin_trusted(&self, origin: &str) -> bool {
69        if origin == self.base_origin {
70            return true;
71        }
72        self.config.trusted_origins.iter().any(|trusted| {
73            let trusted_origin = extract_origin(trusted).unwrap_or_default();
74            origin == trusted_origin
75        })
76    }
77}
78
79#[async_trait]
80impl Middleware for CsrfMiddleware {
81    fn name(&self) -> &'static str {
82        "csrf"
83    }
84
85    async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
86        if !self.config.enabled {
87            return Ok(None);
88        }
89
90        // Only check state-changing methods
91        if !Self::is_state_changing(&req.method) {
92            return Ok(None);
93        }
94
95        // Check Origin header first, then Referer
96        let request_origin = req
97            .headers
98            .get("origin")
99            .cloned()
100            .or_else(|| req.headers.get("referer").and_then(|r| extract_origin(r)));
101
102        match request_origin {
103            Some(origin) if self.is_origin_trusted(&origin) => Ok(None),
104            Some(_origin) => Ok(Some(AuthResponse::json(
105                403,
106                &serde_json::json!({
107                    "code": "CSRF_ERROR",
108                    "message": "Cross-site request blocked"
109                }),
110            )?)),
111            // If no Origin/Referer header is present, allow the request.
112            // This handles same-origin requests from older browsers and
113            // non-browser clients (curl, SDKs, etc.).
114            None => Ok(None),
115        }
116    }
117}
118
119/// Extract the origin (scheme + host + port) from a URL string.
120fn extract_origin(url: &str) -> Option<String> {
121    // Simple parser: find "://" then take up to the next "/"
122    let scheme_end = url.find("://")?;
123    let rest = &url[scheme_end + 3..];
124    let host_end = rest.find('/').unwrap_or(rest.len());
125    let origin = format!("{}{}", &url[..scheme_end + 3], &rest[..host_end]);
126    Some(origin)
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use std::collections::HashMap;
133
134    fn make_post(origin: Option<&str>) -> AuthRequest {
135        let mut headers = HashMap::new();
136        headers.insert("content-type".to_string(), "application/json".to_string());
137        if let Some(o) = origin {
138            headers.insert("origin".to_string(), o.to_string());
139        }
140        AuthRequest {
141            method: HttpMethod::Post,
142            path: "/sign-in/email".to_string(),
143            headers,
144            body: None,
145            query: HashMap::new(),
146        }
147    }
148
149    #[tokio::test]
150    async fn test_csrf_allows_same_origin() {
151        let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
152        let req = make_post(Some("http://localhost:3000"));
153        assert!(mw.before_request(&req).await.unwrap().is_none());
154    }
155
156    #[tokio::test]
157    async fn test_csrf_blocks_cross_origin() {
158        let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
159        let req = make_post(Some("http://evil.com"));
160        let resp = mw.before_request(&req).await.unwrap();
161        assert!(resp.is_some());
162        assert_eq!(resp.unwrap().status, 403);
163    }
164
165    #[tokio::test]
166    async fn test_csrf_allows_trusted_origin() {
167        let config = CsrfConfig::new().trusted_origin("https://myapp.com");
168        let mw = CsrfMiddleware::new(config, "http://localhost:3000");
169        let req = make_post(Some("https://myapp.com"));
170        assert!(mw.before_request(&req).await.unwrap().is_none());
171    }
172
173    #[tokio::test]
174    async fn test_csrf_skips_get_requests() {
175        let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
176        let req = AuthRequest {
177            method: HttpMethod::Get,
178            path: "/get-session".to_string(),
179            headers: {
180                let mut h = HashMap::new();
181                h.insert("origin".to_string(), "http://evil.com".to_string());
182                h
183            },
184            body: None,
185            query: HashMap::new(),
186        };
187        assert!(mw.before_request(&req).await.unwrap().is_none());
188    }
189
190    #[tokio::test]
191    async fn test_csrf_allows_no_origin_header() {
192        let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
193        let req = make_post(None);
194        assert!(mw.before_request(&req).await.unwrap().is_none());
195    }
196
197    #[tokio::test]
198    async fn test_csrf_disabled() {
199        let config = CsrfConfig::new().enabled(false);
200        let mw = CsrfMiddleware::new(config, "http://localhost:3000");
201        let req = make_post(Some("http://evil.com"));
202        assert!(mw.before_request(&req).await.unwrap().is_none());
203    }
204}