Skip to main content

silent/middleware/middlewares/
cors.rs

1use crate::{Handler, MiddleWareHandler, Next, Request, Response, Result, SilentError};
2use async_trait::async_trait;
3use http::{HeaderMap, Method, header};
4use std::sync::OnceLock;
5
6#[derive(Debug)]
7pub enum CorsType {
8    Any,
9    AllowSome(Vec<String>),
10}
11
12impl CorsType {
13    fn get_value(&self) -> String {
14        match self {
15            CorsType::Any => "*".to_string(),
16            CorsType::AllowSome(value) => value.join(","),
17        }
18    }
19}
20
21impl From<Vec<&str>> for CorsType {
22    fn from(value: Vec<&str>) -> Self {
23        CorsType::AllowSome(value.iter().map(|s| s.to_string()).collect())
24    }
25}
26
27impl From<Vec<Method>> for CorsType {
28    fn from(value: Vec<Method>) -> Self {
29        CorsType::AllowSome(value.iter().map(|s| s.to_string()).collect())
30    }
31}
32
33impl From<Vec<header::HeaderName>> for CorsType {
34    fn from(value: Vec<header::HeaderName>) -> Self {
35        CorsType::AllowSome(value.iter().map(|s| s.to_string()).collect())
36    }
37}
38
39#[derive(Debug)]
40enum CorsOriginType {
41    Any,
42    AllowSome(Vec<String>),
43}
44
45impl CorsOriginType {
46    fn get_value(&self, origin: &str) -> String {
47        match self {
48            CorsOriginType::Any => origin.to_string(),
49            CorsOriginType::AllowSome(value) => {
50                if let Some(v) = value.iter().find(|&v| v == origin) {
51                    v.to_string()
52                } else {
53                    "".to_string()
54                }
55            }
56        }
57    }
58}
59
60impl From<CorsType> for CorsOriginType {
61    fn from(value: CorsType) -> Self {
62        match value {
63            CorsType::Any => CorsOriginType::Any,
64            CorsType::AllowSome(value) => CorsOriginType::AllowSome(value),
65        }
66    }
67}
68
69impl From<&str> for CorsType {
70    fn from(value: &str) -> Self {
71        if value == "*" {
72            CorsType::Any
73        } else {
74            CorsType::AllowSome(value.split(',').map(|s| s.to_string()).collect())
75        }
76    }
77}
78
79/// cors 中间件
80/// ```rust
81/// use silent::prelude::*;
82/// use silent::middlewares::{Cors, CorsType};
83/// // set with CorsType
84/// let _ = Cors::new()
85///                .origin(CorsType::Any)
86///                .methods(CorsType::AllowSome(vec![Method::POST.to_string()]))
87///                .headers(CorsType::AllowSome(vec![header::AUTHORIZATION.to_string(), header::ACCEPT.to_string()]))
88///                .credentials(true);
89/// // set with Method or header
90/// let _ = Cors::new()
91///                .origin(CorsType::Any)
92///                .methods(vec![Method::POST])
93///                .headers(vec![header::AUTHORIZATION, header::ACCEPT])
94///                .credentials(true);
95/// // set with str
96/// let _ = Cors::new()
97///                .origin("*")
98///                .methods("POST")
99///                .headers("authorization,accept")
100///                .credentials(true);
101#[derive(Debug)]
102pub struct Cors {
103    origin: Option<CorsOriginType>,
104    methods: Option<CorsType>,
105    headers: Option<CorsType>,
106    credentials: Option<bool>,
107    max_age: Option<u32>,
108    expose: Option<CorsType>,
109    // 优化:延迟初始化的响应头缓存
110    cached_headers: OnceLock<HeaderMap>,
111}
112
113impl Default for Cors {
114    fn default() -> Self {
115        Self {
116            origin: None,
117            methods: None,
118            headers: None,
119            credentials: None,
120            max_age: None,
121            expose: None,
122            cached_headers: OnceLock::new(),
123        }
124    }
125}
126
127impl Cors {
128    pub fn new() -> Self {
129        Self::default()
130    }
131    pub fn origin<T>(mut self, origin: T) -> Self
132    where
133        T: Into<CorsType>,
134    {
135        self.origin = Some(origin.into().into());
136        self
137    }
138    pub fn methods<T>(mut self, methods: T) -> Self
139    where
140        T: Into<CorsType>,
141    {
142        self.methods = Some(methods.into());
143        self
144    }
145    pub fn headers<T>(mut self, headers: T) -> Self
146    where
147        T: Into<CorsType>,
148    {
149        self.headers = Some(headers.into());
150        self
151    }
152    pub fn credentials(mut self, credentials: bool) -> Self {
153        self.credentials = Some(credentials);
154        self
155    }
156    pub fn max_age(mut self, max_age: u32) -> Self {
157        self.max_age = Some(max_age);
158        self
159    }
160    pub fn expose<T>(mut self, expose: T) -> Self
161    where
162        T: Into<CorsType>,
163    {
164        self.expose = Some(expose.into());
165        self
166    }
167
168    // 优化:获取或构建静态响应头缓存
169    fn get_cached_headers(&self) -> &HeaderMap {
170        self.cached_headers.get_or_init(|| {
171            let mut headers = HeaderMap::new();
172
173            if let Some(ref methods) = self.methods
174                && let Ok(value) = methods.get_value().parse()
175            {
176                headers.insert("Access-Control-Allow-Methods", value);
177            }
178            if let Some(ref cors_headers) = self.headers
179                && let Ok(value) = cors_headers.get_value().parse()
180            {
181                headers.insert("Access-Control-Allow-Headers", value);
182            }
183            if let Some(ref credentials) = self.credentials
184                && let Ok(value) = credentials.to_string().parse()
185            {
186                headers.insert("Access-Control-Allow-Credentials", value);
187            }
188            if let Some(ref max_age) = self.max_age
189                && let Ok(value) = max_age.to_string().parse()
190            {
191                headers.insert("Access-Control-Max-Age", value);
192            }
193            if let Some(ref expose) = self.expose
194                && let Ok(value) = expose.get_value().parse()
195            {
196                headers.insert("Access-Control-Expose-Headers", value);
197            }
198
199            headers
200        })
201    }
202}
203
204#[async_trait]
205impl MiddleWareHandler for Cors {
206    async fn handle(&self, req: Request, next: &Next) -> Result<Response> {
207        let req_origin = req
208            .headers()
209            .get("origin")
210            .map_or("", |v| v.to_str().unwrap_or(""))
211            .to_string();
212
213        // 如果没有 origin 一般为同源请求,直接返回
214        if req_origin.is_empty() {
215            return next.call(req).await;
216        }
217
218        // 优化:复用预构建的响应头模板
219        let mut res = Response::empty();
220
221        // 复制缓存的静态头部 (优化:避免重复构建)
222        let cached_headers = self.get_cached_headers();
223        res.headers_mut().extend(cached_headers.clone());
224
225        // 只处理动态的 Origin 头部
226        if let Some(ref origin) = self.origin {
227            let origin = origin.get_value(&req_origin);
228            if origin.is_empty() {
229                return Err(SilentError::business_error(
230                    http::StatusCode::FORBIDDEN,
231                    format!("Cors: Origin \"{req_origin}\" is not allowed"),
232                ));
233            }
234            res.headers_mut().insert(
235                "Access-Control-Allow-Origin",
236                origin.parse().map_err(|e| {
237                    SilentError::business_error(
238                        http::StatusCode::INTERNAL_SERVER_ERROR,
239                        format!("Cors: Failed to parse cors allow origin: {e}"),
240                    )
241                })?,
242            );
243        }
244
245        if req.method() == Method::OPTIONS {
246            return Ok(res);
247        }
248        match next.call(req).await {
249            Ok(result) => {
250                res.copy_from_response(result);
251                Ok(res)
252            }
253            Err(e) => {
254                res.copy_from_response(e.into());
255                Ok(res)
256            }
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::prelude::Route;
265
266    // ==================== CorsType 测试 ====================
267
268    #[test]
269    fn test_cors_type_any_get_value() {
270        let cors_type = CorsType::Any;
271        assert_eq!(cors_type.get_value(), "*");
272    }
273
274    #[test]
275    fn test_cors_type_allow_some_get_value() {
276        let cors_type = CorsType::AllowSome(vec!["GET".to_string(), "POST".to_string()]);
277        assert_eq!(cors_type.get_value(), "GET,POST");
278    }
279
280    #[test]
281    fn test_cors_type_allow_some_empty() {
282        let cors_type = CorsType::AllowSome(vec![]);
283        assert_eq!(cors_type.get_value(), "");
284    }
285
286    #[test]
287    fn test_cors_type_from_vec_str() {
288        let cors_type: CorsType = vec!["GET", "POST"].into();
289        match cors_type {
290            CorsType::AllowSome(ref v) => {
291                assert_eq!(v, &["GET".to_string(), "POST".to_string()]);
292            }
293            _ => panic!("Expected AllowSome"),
294        }
295    }
296
297    #[test]
298    fn test_cors_type_from_vec_method() {
299        let methods = vec![Method::GET, Method::POST];
300        let cors_type: CorsType = methods.into();
301        match cors_type {
302            CorsType::AllowSome(ref v) => {
303                assert_eq!(v, &["GET".to_string(), "POST".to_string()]);
304            }
305            _ => panic!("Expected AllowSome"),
306        }
307    }
308
309    #[test]
310    fn test_cors_type_from_vec_header_name() {
311        let headers = vec![header::AUTHORIZATION, header::ACCEPT];
312        let cors_type: CorsType = headers.into();
313        match cors_type {
314            CorsType::AllowSome(ref v) => {
315                assert_eq!(v, &["authorization".to_string(), "accept".to_string()]);
316            }
317            _ => panic!("Expected AllowSome"),
318        }
319    }
320
321    #[test]
322    fn test_cors_type_from_str_any() {
323        let cors_type: CorsType = "*".into();
324        assert!(matches!(cors_type, CorsType::Any));
325    }
326
327    #[test]
328    fn test_cors_type_from_str_multiple() {
329        let cors_type: CorsType = "GET,POST,PUT".into();
330        match cors_type {
331            CorsType::AllowSome(ref v) => {
332                assert_eq!(
333                    v,
334                    &["GET".to_string(), "POST".to_string(), "PUT".to_string()]
335                );
336            }
337            _ => panic!("Expected AllowSome"),
338        }
339    }
340
341    // ==================== CorsOriginType 测试 ====================
342
343    #[test]
344    fn test_cors_origin_type_any_get_value() {
345        let origin_type = CorsOriginType::Any;
346        assert_eq!(
347            origin_type.get_value("http://example.com"),
348            "http://example.com"
349        );
350    }
351
352    #[test]
353    fn test_cors_origin_type_allow_some_match() {
354        let origin_type = CorsOriginType::AllowSome(vec![
355            "http://example.com".to_string(),
356            "http://localhost:8080".to_string(),
357        ]);
358        assert_eq!(
359            origin_type.get_value("http://example.com"),
360            "http://example.com"
361        );
362    }
363
364    #[test]
365    fn test_cors_origin_type_allow_some_no_match() {
366        let origin_type = CorsOriginType::AllowSome(vec!["http://example.com".to_string()]);
367        assert_eq!(origin_type.get_value("http://evil.com"), "");
368    }
369
370    #[test]
371    fn test_cors_origin_type_from_cors_type_any() {
372        let cors_type = CorsType::Any;
373        let origin_type: CorsOriginType = cors_type.into();
374        assert!(matches!(origin_type, CorsOriginType::Any));
375    }
376
377    #[test]
378    fn test_cors_origin_type_from_cors_type_allow_some() {
379        let cors_type = CorsType::AllowSome(vec!["http://example.com".to_string()]);
380        let origin_type: CorsOriginType = cors_type.into();
381        match origin_type {
382            CorsOriginType::AllowSome(ref v) => {
383                assert_eq!(v, &["http://example.com".to_string()]);
384            }
385            _ => panic!("Expected AllowSome"),
386        }
387    }
388
389    // ==================== Cors 结构体构造测试 ====================
390
391    #[test]
392    fn test_cors_new() {
393        let cors = Cors::new();
394        assert!(cors.origin.is_none());
395        assert!(cors.methods.is_none());
396        assert!(cors.headers.is_none());
397        assert!(cors.credentials.is_none());
398        assert!(cors.max_age.is_none());
399        assert!(cors.expose.is_none());
400    }
401
402    #[test]
403    fn test_cors_default() {
404        let cors = Cors::default();
405        assert!(cors.origin.is_none());
406        assert!(cors.methods.is_none());
407        assert!(cors.headers.is_none());
408    }
409
410    #[test]
411    fn test_cors_origin_any() {
412        let cors = Cors::new().origin(CorsType::Any);
413        assert!(matches!(cors.origin, Some(CorsOriginType::Any)));
414    }
415
416    #[test]
417    fn test_cors_origin_str() {
418        let cors = Cors::new().origin("http://example.com");
419        match cors.origin {
420            Some(CorsOriginType::AllowSome(ref v)) => {
421                assert_eq!(v, &["http://example.com".to_string()]);
422            }
423            _ => panic!("Expected AllowSome"),
424        }
425    }
426
427    #[test]
428    fn test_cors_methods() {
429        let cors = Cors::new().methods(vec![Method::GET, Method::POST]);
430        match cors.methods {
431            Some(CorsType::AllowSome(ref v)) => {
432                assert_eq!(v, &["GET".to_string(), "POST".to_string()]);
433            }
434            _ => panic!("Expected AllowSome"),
435        }
436    }
437
438    #[test]
439    fn test_cors_headers() {
440        let cors = Cors::new().headers(vec![header::AUTHORIZATION, header::ACCEPT]);
441        match cors.headers {
442            Some(CorsType::AllowSome(ref v)) => {
443                assert_eq!(v, &["authorization".to_string(), "accept".to_string()]);
444            }
445            _ => panic!("Expected AllowSome"),
446        }
447    }
448
449    #[test]
450    fn test_cors_credentials() {
451        let cors = Cors::new().credentials(true);
452        assert_eq!(cors.credentials, Some(true));
453
454        let cors = Cors::new().credentials(false);
455        assert_eq!(cors.credentials, Some(false));
456    }
457
458    #[test]
459    fn test_cors_max_age() {
460        let cors = Cors::new().max_age(3600);
461        assert_eq!(cors.max_age, Some(3600));
462    }
463
464    #[test]
465    fn test_cors_expose() {
466        let cors = Cors::new().expose("Content-Length,X-Custom-Header");
467        match cors.expose {
468            Some(CorsType::AllowSome(ref v)) => {
469                assert_eq!(
470                    v,
471                    &["Content-Length".to_string(), "X-Custom-Header".to_string()]
472                );
473            }
474            _ => panic!("Expected AllowSome"),
475        }
476    }
477
478    #[test]
479    fn test_cors_builder_chain() {
480        let cors = Cors::new()
481            .origin(CorsType::Any)
482            .methods(vec![Method::GET])
483            .headers(vec![header::ACCEPT])
484            .credentials(true)
485            .max_age(3600)
486            .expose("Content-Length");
487
488        assert!(matches!(cors.origin, Some(CorsOriginType::Any)));
489        assert!(cors.methods.is_some());
490        assert!(cors.headers.is_some());
491        assert_eq!(cors.credentials, Some(true));
492        assert_eq!(cors.max_age, Some(3600));
493        assert!(cors.expose.is_some());
494    }
495
496    // ==================== get_cached_headers 测试 ====================
497
498    #[test]
499    fn test_get_cached_headers_with_methods() {
500        let cors = Cors::new().methods(vec![Method::GET, Method::POST]);
501        let headers = cors.get_cached_headers();
502
503        assert_eq!(
504            headers.get("Access-Control-Allow-Methods"),
505            Some(&"GET,POST".parse().unwrap())
506        );
507    }
508
509    #[test]
510    fn test_get_cached_headers_with_headers() {
511        let cors = Cors::new().headers("authorization,accept");
512        let headers = cors.get_cached_headers();
513
514        assert_eq!(
515            headers.get("Access-Control-Allow-Headers"),
516            Some(&"authorization,accept".parse().unwrap())
517        );
518    }
519
520    #[test]
521    fn test_get_cached_headers_with_credentials() {
522        let cors = Cors::new().credentials(true);
523        let headers = cors.get_cached_headers();
524
525        assert_eq!(
526            headers.get("Access-Control-Allow-Credentials"),
527            Some(&"true".parse().unwrap())
528        );
529    }
530
531    #[test]
532    fn test_get_cached_headers_with_max_age() {
533        let cors = Cors::new().max_age(3600);
534        let headers = cors.get_cached_headers();
535
536        assert_eq!(
537            headers.get("Access-Control-Max-Age"),
538            Some(&"3600".parse().unwrap())
539        );
540    }
541
542    #[test]
543    fn test_get_cached_headers_with_expose() {
544        let cors = Cors::new().expose("Content-Length");
545        let headers = cors.get_cached_headers();
546
547        assert_eq!(
548            headers.get("Access-Control-Expose-Headers"),
549            Some(&"Content-Length".parse().unwrap())
550        );
551    }
552
553    #[test]
554    fn test_get_cached_headers_combined() {
555        let cors = Cors::new()
556            .methods("GET,POST")
557            .headers("authorization")
558            .credentials(true)
559            .max_age(3600);
560        let headers = cors.get_cached_headers();
561
562        assert!(headers.contains_key("Access-Control-Allow-Methods"));
563        assert!(headers.contains_key("Access-Control-Allow-Headers"));
564        assert!(headers.contains_key("Access-Control-Allow-Credentials"));
565        assert!(headers.contains_key("Access-Control-Max-Age"));
566    }
567
568    // ==================== 集成测试 ====================
569
570    #[tokio::test]
571    async fn test_cors_integration() {
572        let route = Route::new("/")
573            .hook(Cors::new().origin(CorsType::Any))
574            .get(|_req: Request| async { Ok("hello world") });
575        let route = Route::new_root().append(route);
576        let mut req = Request::empty();
577        *req.method_mut() = Method::OPTIONS;
578        *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
579        req.headers_mut()
580            .insert("origin", "http://localhost:8080".parse().unwrap());
581        req.headers_mut()
582            .insert("access-control-request-method", "GET".parse().unwrap());
583        req.headers_mut().insert(
584            "access-control-request-headers",
585            "content-type".parse().unwrap(),
586        );
587        let res = route.call(req).await.unwrap();
588        assert_eq!(res.status, http::StatusCode::OK);
589    }
590
591    #[tokio::test]
592    async fn test_cors_with_post_request() {
593        let route = Route::new("/")
594            .hook(
595                Cors::new()
596                    .origin("http://localhost:8080")
597                    .methods(vec![Method::GET, Method::POST])
598                    .credentials(true),
599            )
600            .post(|_req: Request| async { Ok("posted") });
601        let route = Route::new_root().append(route);
602
603        let mut req = Request::empty();
604        *req.method_mut() = Method::POST;
605        *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
606        req.headers_mut()
607            .insert("origin", "http://localhost:8080".parse().unwrap());
608
609        let res = route.call(req).await.unwrap();
610        assert_eq!(res.status, http::StatusCode::OK);
611        assert!(res.headers().contains_key("Access-Control-Allow-Origin"));
612        assert!(
613            res.headers()
614                .contains_key("Access-Control-Allow-Credentials")
615        );
616    }
617
618    // ==================== 边界条件测试 ====================
619
620    #[test]
621    fn test_cors_type_empty_methods() {
622        let cors_type = CorsType::AllowSome(vec![]);
623        assert_eq!(cors_type.get_value(), "");
624    }
625
626    #[test]
627    fn test_cors_origin_empty_list() {
628        let origin_type = CorsOriginType::AllowSome(vec![]);
629        assert_eq!(origin_type.get_value("http://example.com"), "");
630    }
631
632    #[tokio::test]
633    async fn test_handle_without_origin_header() {
634        // 测试同源请求(没有 origin header)
635        let route = Route::new("/")
636            .hook(Cors::new().origin(CorsType::Any))
637            .get(|_req: Request| async { Ok("hello") });
638        let route = Route::new_root().append(route);
639
640        let mut req = Request::empty();
641        *req.method_mut() = Method::GET;
642        *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
643        // 不添加 origin header
644
645        let res = route.call(req).await.unwrap();
646        // 没有 origin 应该正常返回(同源请求)
647        assert_eq!(res.status, http::StatusCode::OK);
648    }
649
650    #[tokio::test]
651    async fn test_handle_empty_string_origin() {
652        // 测试空字符串 origin
653        let route = Route::new("/")
654            .hook(Cors::new().origin("http://example.com"))
655            .get(|_req: Request| async { Ok("hello") });
656        let route = Route::new_root().append(route);
657
658        let mut req = Request::empty();
659        *req.method_mut() = Method::GET;
660        *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
661        req.headers_mut().insert("origin", "".parse().unwrap());
662
663        let res = route.call(req).await.unwrap();
664        // 空 origin 应该被视为同源请求
665        assert_eq!(res.status, http::StatusCode::OK);
666    }
667}