better_auth_core/middleware/
cors.rs1use super::Middleware;
2use crate::error::AuthResult;
3use crate::types::{AuthRequest, AuthResponse, HttpMethod};
4use async_trait::async_trait;
5
6#[derive(Debug, Clone)]
8pub struct CorsConfig {
9 pub allowed_origins: Vec<String>,
12
13 pub allowed_methods: Vec<String>,
15
16 pub allowed_headers: Vec<String>,
18
19 pub exposed_headers: Vec<String>,
21
22 pub allow_credentials: bool,
24
25 pub max_age: u64,
27
28 pub enabled: bool,
30}
31
32impl Default for CorsConfig {
33 fn default() -> Self {
34 Self {
35 allowed_origins: Vec::new(),
36 allowed_methods: vec![
37 "GET".into(),
38 "POST".into(),
39 "PUT".into(),
40 "DELETE".into(),
41 "PATCH".into(),
42 "OPTIONS".into(),
43 ],
44 allowed_headers: vec![
45 "Content-Type".into(),
46 "Authorization".into(),
47 "X-Requested-With".into(),
48 ],
49 exposed_headers: Vec::new(),
50 allow_credentials: true,
51 max_age: 86400,
52 enabled: true,
53 }
54 }
55}
56
57impl CorsConfig {
58 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn allowed_origin(mut self, origin: impl Into<String>) -> Self {
63 self.allowed_origins.push(origin.into());
64 self
65 }
66
67 pub fn allow_credentials(mut self, allow: bool) -> Self {
68 self.allow_credentials = allow;
69 self
70 }
71
72 pub fn max_age(mut self, seconds: u64) -> Self {
73 self.max_age = seconds;
74 self
75 }
76
77 pub fn enabled(mut self, enabled: bool) -> Self {
78 self.enabled = enabled;
79 self
80 }
81}
82
83pub struct CorsMiddleware {
87 config: CorsConfig,
88}
89
90impl CorsMiddleware {
91 pub fn new(config: CorsConfig) -> Self {
92 Self { config }
93 }
94
95 fn is_origin_allowed(&self, origin: &str) -> bool {
96 if self.config.allowed_origins.is_empty() {
97 return false;
98 }
99 self.config
100 .allowed_origins
101 .iter()
102 .any(|o| o == "*" || o == origin)
103 }
104
105 fn cors_headers(&self, origin: &str) -> Vec<(String, String)> {
106 let mut headers = Vec::new();
107
108 let allow_origin = if self.config.allow_credentials {
110 origin.to_string()
111 } else if self.config.allowed_origins.contains(&"*".to_string()) {
112 "*".to_string()
113 } else {
114 origin.to_string()
115 };
116
117 headers.push(("Access-Control-Allow-Origin".into(), allow_origin));
118
119 if self.config.allow_credentials {
120 headers.push(("Access-Control-Allow-Credentials".into(), "true".into()));
121 }
122
123 if !self.config.allowed_methods.is_empty() {
124 headers.push((
125 "Access-Control-Allow-Methods".into(),
126 self.config.allowed_methods.join(", "),
127 ));
128 }
129
130 if !self.config.allowed_headers.is_empty() {
131 headers.push((
132 "Access-Control-Allow-Headers".into(),
133 self.config.allowed_headers.join(", "),
134 ));
135 }
136
137 if !self.config.exposed_headers.is_empty() {
138 headers.push((
139 "Access-Control-Expose-Headers".into(),
140 self.config.exposed_headers.join(", "),
141 ));
142 }
143
144 headers.push((
145 "Access-Control-Max-Age".into(),
146 self.config.max_age.to_string(),
147 ));
148
149 headers
150 }
151}
152
153#[async_trait]
154impl Middleware for CorsMiddleware {
155 fn name(&self) -> &'static str {
156 "cors"
157 }
158
159 async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
160 if !self.config.enabled {
161 return Ok(None);
162 }
163
164 let origin = match req.headers.get("origin") {
165 Some(o) => o.clone(),
166 None => return Ok(None), };
168
169 if !self.is_origin_allowed(&origin) {
170 return Ok(None); }
172
173 if req.method == HttpMethod::Options {
175 let mut response = AuthResponse::new(204);
176 for (key, value) in self.cors_headers(&origin) {
177 response = response.with_header(key, value);
178 }
179 return Ok(Some(response));
180 }
181
182 Ok(None)
183 }
184
185 async fn after_request(
186 &self,
187 req: &AuthRequest,
188 mut response: AuthResponse,
189 ) -> AuthResult<AuthResponse> {
190 if !self.config.enabled {
191 return Ok(response);
192 }
193
194 let origin = match req.headers.get("origin") {
195 Some(o) => o.clone(),
196 None => return Ok(response),
197 };
198
199 if !self.is_origin_allowed(&origin) {
200 return Ok(response);
201 }
202
203 for (key, value) in self.cors_headers(&origin) {
204 response.headers.insert(key, value);
205 }
206
207 Ok(response)
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use std::collections::HashMap;
215
216 fn make_options(origin: &str) -> AuthRequest {
217 let mut headers = HashMap::new();
218 headers.insert("origin".to_string(), origin.to_string());
219 AuthRequest {
220 method: HttpMethod::Options,
221 path: "/sign-in/email".to_string(),
222 headers,
223 body: None,
224 query: HashMap::new(),
225 virtual_user_id: None,
226 }
227 }
228
229 fn make_get(origin: &str) -> AuthRequest {
230 let mut headers = HashMap::new();
231 headers.insert("origin".to_string(), origin.to_string());
232 AuthRequest {
233 method: HttpMethod::Get,
234 path: "/get-session".to_string(),
235 headers,
236 body: None,
237 query: HashMap::new(),
238 virtual_user_id: None,
239 }
240 }
241
242 #[tokio::test]
243 async fn test_cors_preflight_allowed() {
244 let config = CorsConfig::new().allowed_origin("http://localhost:5173");
245 let mw = CorsMiddleware::new(config);
246 let req = make_options("http://localhost:5173");
247
248 let resp = mw.before_request(&req).await.unwrap();
249 assert!(resp.is_some());
250 let resp = resp.unwrap();
251 assert_eq!(resp.status, 204);
252 assert_eq!(
253 resp.headers.get("Access-Control-Allow-Origin").unwrap(),
254 "http://localhost:5173"
255 );
256 }
257
258 #[tokio::test]
259 async fn test_cors_preflight_not_allowed() {
260 let config = CorsConfig::new().allowed_origin("http://localhost:5173");
261 let mw = CorsMiddleware::new(config);
262 let req = make_options("http://evil.com");
263
264 let resp = mw.before_request(&req).await.unwrap();
265 assert!(resp.is_none()); }
267
268 #[tokio::test]
269 async fn test_cors_adds_headers_after_request() {
270 let config = CorsConfig::new().allowed_origin("http://localhost:5173");
271 let mw = CorsMiddleware::new(config);
272 let req = make_get("http://localhost:5173");
273
274 let response = AuthResponse::json(200, &serde_json::json!({"ok": true})).unwrap();
275 let response = mw.after_request(&req, response).await.unwrap();
276
277 assert_eq!(
278 response.headers.get("Access-Control-Allow-Origin").unwrap(),
279 "http://localhost:5173"
280 );
281 assert_eq!(
282 response
283 .headers
284 .get("Access-Control-Allow-Credentials")
285 .unwrap(),
286 "true"
287 );
288 }
289
290 #[tokio::test]
291 async fn test_cors_no_origin_header() {
292 let config = CorsConfig::new().allowed_origin("http://localhost:5173");
293 let mw = CorsMiddleware::new(config);
294 let req = AuthRequest {
295 method: HttpMethod::Get,
296 path: "/get-session".to_string(),
297 headers: HashMap::new(),
298 body: None,
299 query: HashMap::new(),
300 virtual_user_id: None,
301 };
302
303 assert!(mw.before_request(&req).await.unwrap().is_none());
304
305 let response = AuthResponse::new(200);
306 let response = mw.after_request(&req, response).await.unwrap();
307 assert!(!response.headers.contains_key("Access-Control-Allow-Origin"));
308 }
309
310 #[tokio::test]
311 async fn test_cors_wildcard() {
312 let config = CorsConfig::new()
313 .allowed_origin("*")
314 .allow_credentials(false);
315 let mw = CorsMiddleware::new(config);
316 let req = make_get("http://any-origin.com");
317
318 let response = AuthResponse::new(200);
319 let response = mw.after_request(&req, response).await.unwrap();
320 assert_eq!(
321 response.headers.get("Access-Control-Allow-Origin").unwrap(),
322 "*"
323 );
324 }
325}