better_auth_core/middleware/
csrf.rs1use super::Middleware;
2use crate::config::{AuthConfig, extract_origin};
3use crate::error::AuthResult;
4use crate::types::{AuthRequest, AuthResponse, HttpMethod};
5use async_trait::async_trait;
6use std::sync::Arc;
7
8#[derive(Debug, Clone)]
10pub struct CsrfConfig {
11 pub enabled: bool,
13}
14
15impl Default for CsrfConfig {
16 fn default() -> Self {
17 Self { enabled: true }
18 }
19}
20
21impl CsrfConfig {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn enabled(mut self, enabled: bool) -> Self {
27 self.enabled = enabled;
28 self
29 }
30}
31
32pub struct CsrfMiddleware {
41 config: CsrfConfig,
42 auth_config: Arc<AuthConfig>,
44}
45
46impl CsrfMiddleware {
47 pub fn new(config: CsrfConfig, auth_config: Arc<AuthConfig>) -> Self {
51 Self {
52 config,
53 auth_config,
54 }
55 }
56
57 fn is_state_changing(method: &HttpMethod) -> bool {
58 matches!(
59 method,
60 HttpMethod::Post | HttpMethod::Put | HttpMethod::Delete | HttpMethod::Patch
61 )
62 }
63}
64
65#[async_trait]
66impl Middleware for CsrfMiddleware {
67 fn name(&self) -> &'static str {
68 "csrf"
69 }
70
71 async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
72 if !self.config.enabled {
73 return Ok(None);
74 }
75
76 if !Self::is_state_changing(&req.method) {
78 return Ok(None);
79 }
80
81 let request_origin = req
83 .headers
84 .get("origin")
85 .cloned()
86 .or_else(|| req.headers.get("referer").and_then(|r| extract_origin(r)));
87
88 match request_origin {
89 Some(origin) if self.auth_config.is_origin_trusted(&origin) => Ok(None),
90 Some(_origin) => Ok(Some(AuthResponse::json(
91 403,
92 &crate::types::CodeMessageResponse {
93 code: "CSRF_ERROR",
94 message: "Cross-site request blocked".to_string(),
95 },
96 )?)),
97 None => Ok(None),
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use crate::config::extract_origin;
109 use std::collections::HashMap;
110
111 fn make_post(origin: Option<&str>) -> AuthRequest {
112 let mut headers = HashMap::new();
113 headers.insert("content-type".to_string(), "application/json".to_string());
114 if let Some(o) = origin {
115 headers.insert("origin".to_string(), o.to_string());
116 }
117 AuthRequest {
118 method: HttpMethod::Post,
119 path: "/sign-in/email".to_string(),
120 headers,
121 body: None,
122 query: HashMap::new(),
123 virtual_user_id: None,
124 }
125 }
126
127 fn test_auth_config(trusted_origins: Vec<String>) -> Arc<AuthConfig> {
128 Arc::new(
129 AuthConfig::new("test-secret-key-that-is-at-least-32-characters-long")
130 .base_url("http://localhost:3000")
131 .trusted_origins(trusted_origins),
132 )
133 }
134
135 #[tokio::test]
136 async fn test_csrf_allows_same_origin() {
137 let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
138 let req = make_post(Some("http://localhost:3000"));
139 assert!(mw.before_request(&req).await.unwrap().is_none());
140 }
141
142 #[tokio::test]
143 async fn test_csrf_blocks_cross_origin() {
144 let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
145 let req = make_post(Some("http://evil.com"));
146 let resp = mw.before_request(&req).await.unwrap();
147 assert!(resp.is_some());
148 assert_eq!(resp.unwrap().status, 403);
149 }
150
151 #[tokio::test]
152 async fn test_csrf_allows_trusted_origin() {
153 let mw = CsrfMiddleware::new(
154 CsrfConfig::new(),
155 test_auth_config(vec!["https://myapp.com".to_string()]),
156 );
157 let req = make_post(Some("https://myapp.com"));
158 assert!(mw.before_request(&req).await.unwrap().is_none());
159 }
160
161 #[tokio::test]
162 async fn test_csrf_allows_glob_trusted_origin() {
163 let mw = CsrfMiddleware::new(
164 CsrfConfig::new(),
165 test_auth_config(vec!["https://*.example.com".to_string()]),
166 );
167 let req = make_post(Some("https://app.example.com"));
168 assert!(mw.before_request(&req).await.unwrap().is_none());
169 }
170
171 #[tokio::test]
172 async fn test_csrf_skips_get_requests() {
173 let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
174 let req = AuthRequest {
175 method: HttpMethod::Get,
176 path: "/get-session".to_string(),
177 headers: {
178 let mut h = HashMap::new();
179 h.insert("origin".to_string(), "http://evil.com".to_string());
180 h
181 },
182 body: None,
183 query: HashMap::new(),
184 virtual_user_id: None,
185 };
186 assert!(mw.before_request(&req).await.unwrap().is_none());
187 }
188
189 #[tokio::test]
190 async fn test_csrf_allows_no_origin_header() {
191 let mw = CsrfMiddleware::new(CsrfConfig::new(), test_auth_config(vec![]));
192 let req = make_post(None);
193 assert!(mw.before_request(&req).await.unwrap().is_none());
194 }
195
196 #[tokio::test]
197 async fn test_csrf_disabled() {
198 let config = CsrfConfig::new().enabled(false);
199 let mw = CsrfMiddleware::new(config, test_auth_config(vec![]));
200 let req = make_post(Some("http://evil.com"));
201 assert!(mw.before_request(&req).await.unwrap().is_none());
202 }
203
204 #[test]
205 fn test_extract_origin() {
206 assert_eq!(
207 extract_origin("https://example.com/path"),
208 Some("https://example.com".to_string())
209 );
210 assert_eq!(
211 extract_origin("http://localhost:3000"),
212 Some("http://localhost:3000".to_string())
213 );
214 assert_eq!(extract_origin("not-a-url"), None);
215 }
216}