Skip to main content

better_auth_core/middleware/
cors.rs

1use super::Middleware;
2use crate::error::AuthResult;
3use crate::types::{AuthRequest, AuthResponse, HttpMethod};
4use async_trait::async_trait;
5
6/// Configuration for CORS middleware.
7#[derive(Debug, Clone)]
8pub struct CorsConfig {
9    /// Allowed origins. An empty list means no CORS headers are added.
10    /// Use `["*"]` to allow all origins (not recommended for production).
11    pub allowed_origins: Vec<String>,
12
13    /// Allowed HTTP methods. Defaults to common auth methods.
14    pub allowed_methods: Vec<String>,
15
16    /// Allowed request headers.
17    pub allowed_headers: Vec<String>,
18
19    /// Headers exposed to the browser.
20    pub exposed_headers: Vec<String>,
21
22    /// Whether credentials (cookies, authorization) are allowed.
23    pub allow_credentials: bool,
24
25    /// Max age for preflight cache (seconds).
26    pub max_age: u64,
27
28    /// Whether CORS handling is enabled.
29    pub enabled: bool,
30}
31
32impl Default for CorsConfig {
33    fn default() -> Self {
34        Self {
35            allowed_origins: Vec::new(),
36            allowed_methods: vec![
37                "GET".into(),
38                "POST".into(),
39                "PUT".into(),
40                "DELETE".into(),
41                "PATCH".into(),
42                "OPTIONS".into(),
43            ],
44            allowed_headers: vec![
45                "Content-Type".into(),
46                "Authorization".into(),
47                "X-Requested-With".into(),
48            ],
49            exposed_headers: Vec::new(),
50            allow_credentials: true,
51            max_age: 86400,
52            enabled: true,
53        }
54    }
55}
56
57impl CorsConfig {
58    pub fn new() -> Self {
59        Self::default()
60    }
61
62    pub fn allowed_origin(mut self, origin: impl Into<String>) -> Self {
63        self.allowed_origins.push(origin.into());
64        self
65    }
66
67    pub fn allow_credentials(mut self, allow: bool) -> Self {
68        self.allow_credentials = allow;
69        self
70    }
71
72    pub fn max_age(mut self, seconds: u64) -> Self {
73        self.max_age = seconds;
74        self
75    }
76
77    pub fn enabled(mut self, enabled: bool) -> Self {
78        self.enabled = enabled;
79        self
80    }
81}
82
83/// CORS middleware.
84///
85/// Handles preflight OPTIONS requests and adds CORS response headers.
86pub struct CorsMiddleware {
87    config: CorsConfig,
88}
89
90impl CorsMiddleware {
91    pub fn new(config: CorsConfig) -> Self {
92        Self { config }
93    }
94
95    fn is_origin_allowed(&self, origin: &str) -> bool {
96        if self.config.allowed_origins.is_empty() {
97            return false;
98        }
99        self.config
100            .allowed_origins
101            .iter()
102            .any(|o| o == "*" || o == origin)
103    }
104
105    fn cors_headers(&self, origin: &str) -> Vec<(String, String)> {
106        let mut headers = Vec::new();
107
108        // Use the request origin if allowed (not wildcard when credentials are on)
109        let allow_origin = if self.config.allow_credentials {
110            origin.to_string()
111        } else if self.config.allowed_origins.contains(&"*".to_string()) {
112            "*".to_string()
113        } else {
114            origin.to_string()
115        };
116
117        headers.push(("Access-Control-Allow-Origin".into(), allow_origin));
118
119        if self.config.allow_credentials {
120            headers.push(("Access-Control-Allow-Credentials".into(), "true".into()));
121        }
122
123        if !self.config.allowed_methods.is_empty() {
124            headers.push((
125                "Access-Control-Allow-Methods".into(),
126                self.config.allowed_methods.join(", "),
127            ));
128        }
129
130        if !self.config.allowed_headers.is_empty() {
131            headers.push((
132                "Access-Control-Allow-Headers".into(),
133                self.config.allowed_headers.join(", "),
134            ));
135        }
136
137        if !self.config.exposed_headers.is_empty() {
138            headers.push((
139                "Access-Control-Expose-Headers".into(),
140                self.config.exposed_headers.join(", "),
141            ));
142        }
143
144        headers.push((
145            "Access-Control-Max-Age".into(),
146            self.config.max_age.to_string(),
147        ));
148
149        headers
150    }
151}
152
153#[async_trait]
154impl Middleware for CorsMiddleware {
155    fn name(&self) -> &'static str {
156        "cors"
157    }
158
159    async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
160        if !self.config.enabled {
161            return Ok(None);
162        }
163
164        let origin = match req.headers.get("origin") {
165            Some(o) => o.clone(),
166            None => return Ok(None), // No Origin header → not a CORS request
167        };
168
169        if !self.is_origin_allowed(&origin) {
170            return Ok(None); // Origin not allowed → skip CORS headers
171        }
172
173        // Handle preflight
174        if req.method == HttpMethod::Options {
175            let mut response = AuthResponse::new(204);
176            for (key, value) in self.cors_headers(&origin) {
177                response = response.with_header(key, value);
178            }
179            return Ok(Some(response));
180        }
181
182        Ok(None)
183    }
184
185    async fn after_request(
186        &self,
187        req: &AuthRequest,
188        mut response: AuthResponse,
189    ) -> AuthResult<AuthResponse> {
190        if !self.config.enabled {
191            return Ok(response);
192        }
193
194        let origin = match req.headers.get("origin") {
195            Some(o) => o.clone(),
196            None => return Ok(response),
197        };
198
199        if !self.is_origin_allowed(&origin) {
200            return Ok(response);
201        }
202
203        for (key, value) in self.cors_headers(&origin) {
204            response.headers.insert(key, value);
205        }
206
207        Ok(response)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use std::collections::HashMap;
215
216    fn make_options(origin: &str) -> AuthRequest {
217        let mut headers = HashMap::new();
218        headers.insert("origin".to_string(), origin.to_string());
219        AuthRequest {
220            method: HttpMethod::Options,
221            path: "/sign-in/email".to_string(),
222            headers,
223            body: None,
224            query: HashMap::new(),
225            virtual_user_id: None,
226        }
227    }
228
229    fn make_get(origin: &str) -> AuthRequest {
230        let mut headers = HashMap::new();
231        headers.insert("origin".to_string(), origin.to_string());
232        AuthRequest {
233            method: HttpMethod::Get,
234            path: "/get-session".to_string(),
235            headers,
236            body: None,
237            query: HashMap::new(),
238            virtual_user_id: None,
239        }
240    }
241
242    #[tokio::test]
243    async fn test_cors_preflight_allowed() {
244        let config = CorsConfig::new().allowed_origin("http://localhost:5173");
245        let mw = CorsMiddleware::new(config);
246        let req = make_options("http://localhost:5173");
247
248        let resp = mw.before_request(&req).await.unwrap();
249        assert!(resp.is_some());
250        let resp = resp.unwrap();
251        assert_eq!(resp.status, 204);
252        assert_eq!(
253            resp.headers.get("Access-Control-Allow-Origin").unwrap(),
254            "http://localhost:5173"
255        );
256    }
257
258    #[tokio::test]
259    async fn test_cors_preflight_not_allowed() {
260        let config = CorsConfig::new().allowed_origin("http://localhost:5173");
261        let mw = CorsMiddleware::new(config);
262        let req = make_options("http://evil.com");
263
264        let resp = mw.before_request(&req).await.unwrap();
265        assert!(resp.is_none()); // No CORS headers added for disallowed origin
266    }
267
268    #[tokio::test]
269    async fn test_cors_adds_headers_after_request() {
270        let config = CorsConfig::new().allowed_origin("http://localhost:5173");
271        let mw = CorsMiddleware::new(config);
272        let req = make_get("http://localhost:5173");
273
274        let response = AuthResponse::json(200, &serde_json::json!({"ok": true})).unwrap();
275        let response = mw.after_request(&req, response).await.unwrap();
276
277        assert_eq!(
278            response.headers.get("Access-Control-Allow-Origin").unwrap(),
279            "http://localhost:5173"
280        );
281        assert_eq!(
282            response
283                .headers
284                .get("Access-Control-Allow-Credentials")
285                .unwrap(),
286            "true"
287        );
288    }
289
290    #[tokio::test]
291    async fn test_cors_no_origin_header() {
292        let config = CorsConfig::new().allowed_origin("http://localhost:5173");
293        let mw = CorsMiddleware::new(config);
294        let req = AuthRequest {
295            method: HttpMethod::Get,
296            path: "/get-session".to_string(),
297            headers: HashMap::new(),
298            body: None,
299            query: HashMap::new(),
300            virtual_user_id: None,
301        };
302
303        assert!(mw.before_request(&req).await.unwrap().is_none());
304
305        let response = AuthResponse::new(200);
306        let response = mw.after_request(&req, response).await.unwrap();
307        assert!(!response.headers.contains_key("Access-Control-Allow-Origin"));
308    }
309
310    #[tokio::test]
311    async fn test_cors_wildcard() {
312        let config = CorsConfig::new()
313            .allowed_origin("*")
314            .allow_credentials(false);
315        let mw = CorsMiddleware::new(config);
316        let req = make_get("http://any-origin.com");
317
318        let response = AuthResponse::new(200);
319        let response = mw.after_request(&req, response).await.unwrap();
320        assert_eq!(
321            response.headers.get("Access-Control-Allow-Origin").unwrap(),
322            "*"
323        );
324    }
325}