better_auth_core/middleware/
csrf.rs1use super::Middleware;
2use crate::error::AuthResult;
3use crate::types::{AuthRequest, AuthResponse, HttpMethod};
4use async_trait::async_trait;
5
6#[derive(Debug, Clone)]
8pub struct CsrfConfig {
9 pub trusted_origins: Vec<String>,
11
12 pub enabled: bool,
14}
15
16impl Default for CsrfConfig {
17 fn default() -> Self {
18 Self {
19 trusted_origins: Vec::new(),
20 enabled: true,
21 }
22 }
23}
24
25impl CsrfConfig {
26 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn trusted_origin(mut self, origin: impl Into<String>) -> Self {
31 self.trusted_origins.push(origin.into());
32 self
33 }
34
35 pub fn enabled(mut self, enabled: bool) -> Self {
36 self.enabled = enabled;
37 self
38 }
39}
40
41pub struct CsrfMiddleware {
47 config: CsrfConfig,
48 base_origin: String,
50}
51
52impl CsrfMiddleware {
53 pub fn new(config: CsrfConfig, base_url: &str) -> Self {
54 let base_origin = extract_origin(base_url).unwrap_or_default();
55 Self {
56 config,
57 base_origin,
58 }
59 }
60
61 fn is_state_changing(method: &HttpMethod) -> bool {
62 matches!(
63 method,
64 HttpMethod::Post | HttpMethod::Put | HttpMethod::Delete | HttpMethod::Patch
65 )
66 }
67
68 fn is_origin_trusted(&self, origin: &str) -> bool {
69 if origin == self.base_origin {
70 return true;
71 }
72 self.config.trusted_origins.iter().any(|trusted| {
73 let trusted_origin = extract_origin(trusted).unwrap_or_default();
74 origin == trusted_origin
75 })
76 }
77}
78
79#[async_trait]
80impl Middleware for CsrfMiddleware {
81 fn name(&self) -> &'static str {
82 "csrf"
83 }
84
85 async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
86 if !self.config.enabled {
87 return Ok(None);
88 }
89
90 if !Self::is_state_changing(&req.method) {
92 return Ok(None);
93 }
94
95 let request_origin = req
97 .headers
98 .get("origin")
99 .cloned()
100 .or_else(|| req.headers.get("referer").and_then(|r| extract_origin(r)));
101
102 match request_origin {
103 Some(origin) if self.is_origin_trusted(&origin) => Ok(None),
104 Some(_origin) => Ok(Some(AuthResponse::json(
105 403,
106 &serde_json::json!({
107 "code": "CSRF_ERROR",
108 "message": "Cross-site request blocked"
109 }),
110 )?)),
111 None => Ok(None),
115 }
116 }
117}
118
119fn extract_origin(url: &str) -> Option<String> {
121 let scheme_end = url.find("://")?;
123 let rest = &url[scheme_end + 3..];
124 let host_end = rest.find('/').unwrap_or(rest.len());
125 let origin = format!("{}{}", &url[..scheme_end + 3], &rest[..host_end]);
126 Some(origin)
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use std::collections::HashMap;
133
134 fn make_post(origin: Option<&str>) -> AuthRequest {
135 let mut headers = HashMap::new();
136 headers.insert("content-type".to_string(), "application/json".to_string());
137 if let Some(o) = origin {
138 headers.insert("origin".to_string(), o.to_string());
139 }
140 AuthRequest {
141 method: HttpMethod::Post,
142 path: "/sign-in/email".to_string(),
143 headers,
144 body: None,
145 query: HashMap::new(),
146 }
147 }
148
149 #[tokio::test]
150 async fn test_csrf_allows_same_origin() {
151 let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
152 let req = make_post(Some("http://localhost:3000"));
153 assert!(mw.before_request(&req).await.unwrap().is_none());
154 }
155
156 #[tokio::test]
157 async fn test_csrf_blocks_cross_origin() {
158 let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
159 let req = make_post(Some("http://evil.com"));
160 let resp = mw.before_request(&req).await.unwrap();
161 assert!(resp.is_some());
162 assert_eq!(resp.unwrap().status, 403);
163 }
164
165 #[tokio::test]
166 async fn test_csrf_allows_trusted_origin() {
167 let config = CsrfConfig::new().trusted_origin("https://myapp.com");
168 let mw = CsrfMiddleware::new(config, "http://localhost:3000");
169 let req = make_post(Some("https://myapp.com"));
170 assert!(mw.before_request(&req).await.unwrap().is_none());
171 }
172
173 #[tokio::test]
174 async fn test_csrf_skips_get_requests() {
175 let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
176 let req = AuthRequest {
177 method: HttpMethod::Get,
178 path: "/get-session".to_string(),
179 headers: {
180 let mut h = HashMap::new();
181 h.insert("origin".to_string(), "http://evil.com".to_string());
182 h
183 },
184 body: None,
185 query: HashMap::new(),
186 };
187 assert!(mw.before_request(&req).await.unwrap().is_none());
188 }
189
190 #[tokio::test]
191 async fn test_csrf_allows_no_origin_header() {
192 let mw = CsrfMiddleware::new(CsrfConfig::new(), "http://localhost:3000");
193 let req = make_post(None);
194 assert!(mw.before_request(&req).await.unwrap().is_none());
195 }
196
197 #[tokio::test]
198 async fn test_csrf_disabled() {
199 let config = CsrfConfig::new().enabled(false);
200 let mw = CsrfMiddleware::new(config, "http://localhost:3000");
201 let req = make_post(Some("http://evil.com"));
202 assert!(mw.before_request(&req).await.unwrap().is_none());
203 }
204}