Skip to main content

better_auth_core/middleware/
csrf.rs

1use super::Middleware;
2use crate::config::{AuthConfig, extract_origin};
3use crate::error::AuthResult;
4use crate::types::{AuthRequest, AuthResponse, HttpMethod};
5use async_trait::async_trait;
6use std::sync::Arc;
7
8/// Configuration for CSRF protection middleware.
9#[derive(Debug, Clone)]
10pub struct CsrfConfig {
11    /// Whether CSRF protection is enabled. Defaults to `true`.
12    pub enabled: bool,
13}
14
15impl Default for CsrfConfig {
16    fn default() -> Self {
17        Self { enabled: true }
18    }
19}
20
21impl CsrfConfig {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn enabled(mut self, enabled: bool) -> Self {
27        self.enabled = enabled;
28        self
29    }
30}
31
32/// CSRF protection middleware.
33///
34/// Validates `Origin` / `Referer` headers on state-changing requests
35/// (POST, PUT, DELETE, PATCH) against the configured trusted origins
36/// and the service's own base URL.
37///
38/// Origin checking is delegated to [`AuthConfig::is_origin_trusted`] so
39/// that all origin-validation logic lives in a single place.
40pub struct CsrfMiddleware {
41    config: CsrfConfig,
42    /// Shared auth configuration used for origin trust checks.
43    auth_config: Arc<AuthConfig>,
44}
45
46impl CsrfMiddleware {
47    /// Create a new CSRF middleware.
48    ///
49    /// Origin trust decisions are delegated to `auth_config`.
50    pub fn new(config: CsrfConfig, auth_config: Arc<AuthConfig>) -> Self {
51        Self {
52            config,
53            auth_config,
54        }
55    }
56
57    fn is_state_changing(method: &HttpMethod) -> bool {
58        matches!(
59            method,
60            HttpMethod::Post | HttpMethod::Put | HttpMethod::Delete | HttpMethod::Patch
61        )
62    }
63}
64
65#[async_trait]
66impl Middleware for CsrfMiddleware {
67    fn name(&self) -> &'static str {
68        "csrf"
69    }
70
71    async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
72        if !self.config.enabled {
73            return Ok(None);
74        }
75
76        // Only check state-changing methods
77        if !Self::is_state_changing(&req.method) {
78            return Ok(None);
79        }
80
81        // Check Origin header first, then Referer
82        let request_origin = req
83            .headers
84            .get("origin")
85            .cloned()
86            .or_else(|| req.headers.get("referer").and_then(|r| extract_origin(r)));
87
88        match request_origin {
89            Some(origin) if self.auth_config.is_origin_trusted(&origin) => Ok(None),
90            Some(_origin) => Ok(Some(AuthResponse::json(
91                403,
92                &crate::types::CodeMessageResponse {
93                    code: "CSRF_ERROR",
94                    message: "Cross-site request blocked".to_string(),
95                },
96            )?)),
97            // If no Origin/Referer header is present, allow the request.
98            // This handles same-origin requests from older browsers and
99            // non-browser clients (curl, SDKs, etc.).
100            None => Ok(None),
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::config::extract_origin;
109    use std::collections::HashMap;
110
111    fn make_post(origin: Option<&str>) -> AuthRequest {
112        let mut headers = HashMap::new();
113        headers.insert("content-type".to_string(), "application/json".to_string());
114        if let Some(o) = origin {
115            headers.insert("origin".to_string(), o.to_string());
116        }
117        AuthRequest {
118            method: HttpMethod::Post,
119            path: "/sign-in/email".to_string(),
120            headers,
121            body: None,
122            query: HashMap::new(),
123            virtual_user_id: None,
124        }
125    }
126
127    fn test_auth_config(trusted_origins: Vec<String>) -> Arc<AuthConfig> {
128        Arc::new(
129            AuthConfig::new("test-secret-key-that-is-at-least-32-characters-long")
130                .base_url("http://localhost:3000")
131                .trusted_origins(trusted_origins),
132        )
133    }
134
135    #[tokio::test]
136    async fn test_csrf_allows_same_origin() {
137        let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
138        let req = make_post(Some("http://localhost:3000"));
139        assert!(mw.before_request(&req).await.unwrap().is_none());
140    }
141
142    #[tokio::test]
143    async fn test_csrf_blocks_cross_origin() {
144        let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
145        let req = make_post(Some("http://evil.com"));
146        let resp = mw.before_request(&req).await.unwrap();
147        assert!(resp.is_some());
148        assert_eq!(resp.unwrap().status, 403);
149    }
150
151    #[tokio::test]
152    async fn test_csrf_allows_trusted_origin() {
153        let mw = CsrfMiddleware::new(
154            CsrfConfig::new(),
155            test_auth_config(vec!["https://myapp.com".to_string()]),
156        );
157        let req = make_post(Some("https://myapp.com"));
158        assert!(mw.before_request(&req).await.unwrap().is_none());
159    }
160
161    #[tokio::test]
162    async fn test_csrf_allows_glob_trusted_origin() {
163        let mw = CsrfMiddleware::new(
164            CsrfConfig::new(),
165            test_auth_config(vec!["https://*.example.com".to_string()]),
166        );
167        let req = make_post(Some("https://app.example.com"));
168        assert!(mw.before_request(&req).await.unwrap().is_none());
169    }
170
171    #[tokio::test]
172    async fn test_csrf_skips_get_requests() {
173        let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
174        let req = AuthRequest {
175            method: HttpMethod::Get,
176            path: "/get-session".to_string(),
177            headers: {
178                let mut h = HashMap::new();
179                h.insert("origin".to_string(), "http://evil.com".to_string());
180                h
181            },
182            body: None,
183            query: HashMap::new(),
184            virtual_user_id: None,
185        };
186        assert!(mw.before_request(&req).await.unwrap().is_none());
187    }
188
189    #[tokio::test]
190    async fn test_csrf_allows_no_origin_header() {
191        let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
192        let req = make_post(None);
193        assert!(mw.before_request(&req).await.unwrap().is_none());
194    }
195
196    #[tokio::test]
197    async fn test_csrf_disabled() {
198        let config = CsrfConfig::new().enabled(false);
199        let mw = CsrfMiddleware::new(config, test_auth_config(vec![]));
200        let req = make_post(Some("http://evil.com"));
201        assert!(mw.before_request(&req).await.unwrap().is_none());
202    }
203
204    #[test]
205    fn test_extract_origin() {
206        assert_eq!(
207            extract_origin("https://example.com/path"),
208            Some("https://example.com".to_string())
209        );
210        assert_eq!(
211            extract_origin("http://localhost:3000"),
212            Some("http://localhost:3000".to_string())
213        );
214        assert_eq!(extract_origin("not-a-url"), None);
215    }
216}