silent/middleware/middlewares/
cors.rs1use crate::{Handler, MiddleWareHandler, Next, Request, Response, Result, SilentError};
2use async_trait::async_trait;
3use http::{HeaderMap, Method, header};
4use std::sync::OnceLock;
5
6#[derive(Debug)]
7pub enum CorsType {
8 Any,
9 AllowSome(Vec<String>),
10}
11
12impl CorsType {
13 fn get_value(&self) -> String {
14 match self {
15 CorsType::Any => "*".to_string(),
16 CorsType::AllowSome(value) => value.join(","),
17 }
18 }
19}
20
21impl From<Vec<&str>> for CorsType {
22 fn from(value: Vec<&str>) -> Self {
23 CorsType::AllowSome(value.iter().map(|s| s.to_string()).collect())
24 }
25}
26
27impl From<Vec<Method>> for CorsType {
28 fn from(value: Vec<Method>) -> Self {
29 CorsType::AllowSome(value.iter().map(|s| s.to_string()).collect())
30 }
31}
32
33impl From<Vec<header::HeaderName>> for CorsType {
34 fn from(value: Vec<header::HeaderName>) -> Self {
35 CorsType::AllowSome(value.iter().map(|s| s.to_string()).collect())
36 }
37}
38
39#[derive(Debug)]
40enum CorsOriginType {
41 Any,
42 AllowSome(Vec<String>),
43}
44
45impl CorsOriginType {
46 fn get_value(&self, origin: &str) -> String {
47 match self {
48 CorsOriginType::Any => origin.to_string(),
49 CorsOriginType::AllowSome(value) => {
50 if let Some(v) = value.iter().find(|&v| v == origin) {
51 v.to_string()
52 } else {
53 "".to_string()
54 }
55 }
56 }
57 }
58}
59
60impl From<CorsType> for CorsOriginType {
61 fn from(value: CorsType) -> Self {
62 match value {
63 CorsType::Any => CorsOriginType::Any,
64 CorsType::AllowSome(value) => CorsOriginType::AllowSome(value),
65 }
66 }
67}
68
69impl From<&str> for CorsType {
70 fn from(value: &str) -> Self {
71 if value == "*" {
72 CorsType::Any
73 } else {
74 CorsType::AllowSome(value.split(',').map(|s| s.to_string()).collect())
75 }
76 }
77}
78
79#[derive(Debug)]
102pub struct Cors {
103 origin: Option<CorsOriginType>,
104 methods: Option<CorsType>,
105 headers: Option<CorsType>,
106 credentials: Option<bool>,
107 max_age: Option<u32>,
108 expose: Option<CorsType>,
109 cached_headers: OnceLock<HeaderMap>,
111}
112
113impl Default for Cors {
114 fn default() -> Self {
115 Self {
116 origin: None,
117 methods: None,
118 headers: None,
119 credentials: None,
120 max_age: None,
121 expose: None,
122 cached_headers: OnceLock::new(),
123 }
124 }
125}
126
127impl Cors {
128 pub fn new() -> Self {
129 Self::default()
130 }
131 pub fn origin<T>(mut self, origin: T) -> Self
132 where
133 T: Into<CorsType>,
134 {
135 self.origin = Some(origin.into().into());
136 self
137 }
138 pub fn methods<T>(mut self, methods: T) -> Self
139 where
140 T: Into<CorsType>,
141 {
142 self.methods = Some(methods.into());
143 self
144 }
145 pub fn headers<T>(mut self, headers: T) -> Self
146 where
147 T: Into<CorsType>,
148 {
149 self.headers = Some(headers.into());
150 self
151 }
152 pub fn credentials(mut self, credentials: bool) -> Self {
153 self.credentials = Some(credentials);
154 self
155 }
156 pub fn max_age(mut self, max_age: u32) -> Self {
157 self.max_age = Some(max_age);
158 self
159 }
160 pub fn expose<T>(mut self, expose: T) -> Self
161 where
162 T: Into<CorsType>,
163 {
164 self.expose = Some(expose.into());
165 self
166 }
167
168 fn get_cached_headers(&self) -> &HeaderMap {
170 self.cached_headers.get_or_init(|| {
171 let mut headers = HeaderMap::new();
172
173 if let Some(ref methods) = self.methods
174 && let Ok(value) = methods.get_value().parse()
175 {
176 headers.insert("Access-Control-Allow-Methods", value);
177 }
178 if let Some(ref cors_headers) = self.headers
179 && let Ok(value) = cors_headers.get_value().parse()
180 {
181 headers.insert("Access-Control-Allow-Headers", value);
182 }
183 if let Some(ref credentials) = self.credentials
184 && let Ok(value) = credentials.to_string().parse()
185 {
186 headers.insert("Access-Control-Allow-Credentials", value);
187 }
188 if let Some(ref max_age) = self.max_age
189 && let Ok(value) = max_age.to_string().parse()
190 {
191 headers.insert("Access-Control-Max-Age", value);
192 }
193 if let Some(ref expose) = self.expose
194 && let Ok(value) = expose.get_value().parse()
195 {
196 headers.insert("Access-Control-Expose-Headers", value);
197 }
198
199 headers
200 })
201 }
202}
203
204#[async_trait]
205impl MiddleWareHandler for Cors {
206 async fn handle(&self, req: Request, next: &Next) -> Result<Response> {
207 let req_origin = req
208 .headers()
209 .get("origin")
210 .map_or("", |v| v.to_str().unwrap_or(""))
211 .to_string();
212
213 if req_origin.is_empty() {
215 return next.call(req).await;
216 }
217
218 let mut res = Response::empty();
220
221 let cached_headers = self.get_cached_headers();
223 res.headers_mut().extend(cached_headers.clone());
224
225 if let Some(ref origin) = self.origin {
227 let origin = origin.get_value(&req_origin);
228 if origin.is_empty() {
229 return Err(SilentError::business_error(
230 http::StatusCode::FORBIDDEN,
231 format!("Cors: Origin \"{req_origin}\" is not allowed"),
232 ));
233 }
234 res.headers_mut().insert(
235 "Access-Control-Allow-Origin",
236 origin.parse().map_err(|e| {
237 SilentError::business_error(
238 http::StatusCode::INTERNAL_SERVER_ERROR,
239 format!("Cors: Failed to parse cors allow origin: {e}"),
240 )
241 })?,
242 );
243 }
244
245 if req.method() == Method::OPTIONS {
246 return Ok(res);
247 }
248 match next.call(req).await {
249 Ok(result) => {
250 res.copy_from_response(result);
251 Ok(res)
252 }
253 Err(e) => {
254 res.copy_from_response(e.into());
255 Ok(res)
256 }
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::prelude::Route;
265
266 #[test]
269 fn test_cors_type_any_get_value() {
270 let cors_type = CorsType::Any;
271 assert_eq!(cors_type.get_value(), "*");
272 }
273
274 #[test]
275 fn test_cors_type_allow_some_get_value() {
276 let cors_type = CorsType::AllowSome(vec!["GET".to_string(), "POST".to_string()]);
277 assert_eq!(cors_type.get_value(), "GET,POST");
278 }
279
280 #[test]
281 fn test_cors_type_allow_some_empty() {
282 let cors_type = CorsType::AllowSome(vec![]);
283 assert_eq!(cors_type.get_value(), "");
284 }
285
286 #[test]
287 fn test_cors_type_from_vec_str() {
288 let cors_type: CorsType = vec!["GET", "POST"].into();
289 match cors_type {
290 CorsType::AllowSome(ref v) => {
291 assert_eq!(v, &["GET".to_string(), "POST".to_string()]);
292 }
293 _ => panic!("Expected AllowSome"),
294 }
295 }
296
297 #[test]
298 fn test_cors_type_from_vec_method() {
299 let methods = vec![Method::GET, Method::POST];
300 let cors_type: CorsType = methods.into();
301 match cors_type {
302 CorsType::AllowSome(ref v) => {
303 assert_eq!(v, &["GET".to_string(), "POST".to_string()]);
304 }
305 _ => panic!("Expected AllowSome"),
306 }
307 }
308
309 #[test]
310 fn test_cors_type_from_vec_header_name() {
311 let headers = vec![header::AUTHORIZATION, header::ACCEPT];
312 let cors_type: CorsType = headers.into();
313 match cors_type {
314 CorsType::AllowSome(ref v) => {
315 assert_eq!(v, &["authorization".to_string(), "accept".to_string()]);
316 }
317 _ => panic!("Expected AllowSome"),
318 }
319 }
320
321 #[test]
322 fn test_cors_type_from_str_any() {
323 let cors_type: CorsType = "*".into();
324 assert!(matches!(cors_type, CorsType::Any));
325 }
326
327 #[test]
328 fn test_cors_type_from_str_multiple() {
329 let cors_type: CorsType = "GET,POST,PUT".into();
330 match cors_type {
331 CorsType::AllowSome(ref v) => {
332 assert_eq!(
333 v,
334 &["GET".to_string(), "POST".to_string(), "PUT".to_string()]
335 );
336 }
337 _ => panic!("Expected AllowSome"),
338 }
339 }
340
341 #[test]
344 fn test_cors_origin_type_any_get_value() {
345 let origin_type = CorsOriginType::Any;
346 assert_eq!(
347 origin_type.get_value("http://example.com"),
348 "http://example.com"
349 );
350 }
351
352 #[test]
353 fn test_cors_origin_type_allow_some_match() {
354 let origin_type = CorsOriginType::AllowSome(vec![
355 "http://example.com".to_string(),
356 "http://localhost:8080".to_string(),
357 ]);
358 assert_eq!(
359 origin_type.get_value("http://example.com"),
360 "http://example.com"
361 );
362 }
363
364 #[test]
365 fn test_cors_origin_type_allow_some_no_match() {
366 let origin_type = CorsOriginType::AllowSome(vec!["http://example.com".to_string()]);
367 assert_eq!(origin_type.get_value("http://evil.com"), "");
368 }
369
370 #[test]
371 fn test_cors_origin_type_from_cors_type_any() {
372 let cors_type = CorsType::Any;
373 let origin_type: CorsOriginType = cors_type.into();
374 assert!(matches!(origin_type, CorsOriginType::Any));
375 }
376
377 #[test]
378 fn test_cors_origin_type_from_cors_type_allow_some() {
379 let cors_type = CorsType::AllowSome(vec!["http://example.com".to_string()]);
380 let origin_type: CorsOriginType = cors_type.into();
381 match origin_type {
382 CorsOriginType::AllowSome(ref v) => {
383 assert_eq!(v, &["http://example.com".to_string()]);
384 }
385 _ => panic!("Expected AllowSome"),
386 }
387 }
388
389 #[test]
392 fn test_cors_new() {
393 let cors = Cors::new();
394 assert!(cors.origin.is_none());
395 assert!(cors.methods.is_none());
396 assert!(cors.headers.is_none());
397 assert!(cors.credentials.is_none());
398 assert!(cors.max_age.is_none());
399 assert!(cors.expose.is_none());
400 }
401
402 #[test]
403 fn test_cors_default() {
404 let cors = Cors::default();
405 assert!(cors.origin.is_none());
406 assert!(cors.methods.is_none());
407 assert!(cors.headers.is_none());
408 }
409
410 #[test]
411 fn test_cors_origin_any() {
412 let cors = Cors::new().origin(CorsType::Any);
413 assert!(matches!(cors.origin, Some(CorsOriginType::Any)));
414 }
415
416 #[test]
417 fn test_cors_origin_str() {
418 let cors = Cors::new().origin("http://example.com");
419 match cors.origin {
420 Some(CorsOriginType::AllowSome(ref v)) => {
421 assert_eq!(v, &["http://example.com".to_string()]);
422 }
423 _ => panic!("Expected AllowSome"),
424 }
425 }
426
427 #[test]
428 fn test_cors_methods() {
429 let cors = Cors::new().methods(vec![Method::GET, Method::POST]);
430 match cors.methods {
431 Some(CorsType::AllowSome(ref v)) => {
432 assert_eq!(v, &["GET".to_string(), "POST".to_string()]);
433 }
434 _ => panic!("Expected AllowSome"),
435 }
436 }
437
438 #[test]
439 fn test_cors_headers() {
440 let cors = Cors::new().headers(vec![header::AUTHORIZATION, header::ACCEPT]);
441 match cors.headers {
442 Some(CorsType::AllowSome(ref v)) => {
443 assert_eq!(v, &["authorization".to_string(), "accept".to_string()]);
444 }
445 _ => panic!("Expected AllowSome"),
446 }
447 }
448
449 #[test]
450 fn test_cors_credentials() {
451 let cors = Cors::new().credentials(true);
452 assert_eq!(cors.credentials, Some(true));
453
454 let cors = Cors::new().credentials(false);
455 assert_eq!(cors.credentials, Some(false));
456 }
457
458 #[test]
459 fn test_cors_max_age() {
460 let cors = Cors::new().max_age(3600);
461 assert_eq!(cors.max_age, Some(3600));
462 }
463
464 #[test]
465 fn test_cors_expose() {
466 let cors = Cors::new().expose("Content-Length,X-Custom-Header");
467 match cors.expose {
468 Some(CorsType::AllowSome(ref v)) => {
469 assert_eq!(
470 v,
471 &["Content-Length".to_string(), "X-Custom-Header".to_string()]
472 );
473 }
474 _ => panic!("Expected AllowSome"),
475 }
476 }
477
478 #[test]
479 fn test_cors_builder_chain() {
480 let cors = Cors::new()
481 .origin(CorsType::Any)
482 .methods(vec![Method::GET])
483 .headers(vec![header::ACCEPT])
484 .credentials(true)
485 .max_age(3600)
486 .expose("Content-Length");
487
488 assert!(matches!(cors.origin, Some(CorsOriginType::Any)));
489 assert!(cors.methods.is_some());
490 assert!(cors.headers.is_some());
491 assert_eq!(cors.credentials, Some(true));
492 assert_eq!(cors.max_age, Some(3600));
493 assert!(cors.expose.is_some());
494 }
495
496 #[test]
499 fn test_get_cached_headers_with_methods() {
500 let cors = Cors::new().methods(vec![Method::GET, Method::POST]);
501 let headers = cors.get_cached_headers();
502
503 assert_eq!(
504 headers.get("Access-Control-Allow-Methods"),
505 Some(&"GET,POST".parse().unwrap())
506 );
507 }
508
509 #[test]
510 fn test_get_cached_headers_with_headers() {
511 let cors = Cors::new().headers("authorization,accept");
512 let headers = cors.get_cached_headers();
513
514 assert_eq!(
515 headers.get("Access-Control-Allow-Headers"),
516 Some(&"authorization,accept".parse().unwrap())
517 );
518 }
519
520 #[test]
521 fn test_get_cached_headers_with_credentials() {
522 let cors = Cors::new().credentials(true);
523 let headers = cors.get_cached_headers();
524
525 assert_eq!(
526 headers.get("Access-Control-Allow-Credentials"),
527 Some(&"true".parse().unwrap())
528 );
529 }
530
531 #[test]
532 fn test_get_cached_headers_with_max_age() {
533 let cors = Cors::new().max_age(3600);
534 let headers = cors.get_cached_headers();
535
536 assert_eq!(
537 headers.get("Access-Control-Max-Age"),
538 Some(&"3600".parse().unwrap())
539 );
540 }
541
542 #[test]
543 fn test_get_cached_headers_with_expose() {
544 let cors = Cors::new().expose("Content-Length");
545 let headers = cors.get_cached_headers();
546
547 assert_eq!(
548 headers.get("Access-Control-Expose-Headers"),
549 Some(&"Content-Length".parse().unwrap())
550 );
551 }
552
553 #[test]
554 fn test_get_cached_headers_combined() {
555 let cors = Cors::new()
556 .methods("GET,POST")
557 .headers("authorization")
558 .credentials(true)
559 .max_age(3600);
560 let headers = cors.get_cached_headers();
561
562 assert!(headers.contains_key("Access-Control-Allow-Methods"));
563 assert!(headers.contains_key("Access-Control-Allow-Headers"));
564 assert!(headers.contains_key("Access-Control-Allow-Credentials"));
565 assert!(headers.contains_key("Access-Control-Max-Age"));
566 }
567
568 #[tokio::test]
571 async fn test_cors_integration() {
572 let route = Route::new("/")
573 .hook(Cors::new().origin(CorsType::Any))
574 .get(|_req: Request| async { Ok("hello world") });
575 let route = Route::new_root().append(route);
576 let mut req = Request::empty();
577 *req.method_mut() = Method::OPTIONS;
578 *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
579 req.headers_mut()
580 .insert("origin", "http://localhost:8080".parse().unwrap());
581 req.headers_mut()
582 .insert("access-control-request-method", "GET".parse().unwrap());
583 req.headers_mut().insert(
584 "access-control-request-headers",
585 "content-type".parse().unwrap(),
586 );
587 let res = route.call(req).await.unwrap();
588 assert_eq!(res.status, http::StatusCode::OK);
589 }
590
591 #[tokio::test]
592 async fn test_cors_with_post_request() {
593 let route = Route::new("/")
594 .hook(
595 Cors::new()
596 .origin("http://localhost:8080")
597 .methods(vec![Method::GET, Method::POST])
598 .credentials(true),
599 )
600 .post(|_req: Request| async { Ok("posted") });
601 let route = Route::new_root().append(route);
602
603 let mut req = Request::empty();
604 *req.method_mut() = Method::POST;
605 *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
606 req.headers_mut()
607 .insert("origin", "http://localhost:8080".parse().unwrap());
608
609 let res = route.call(req).await.unwrap();
610 assert_eq!(res.status, http::StatusCode::OK);
611 assert!(res.headers().contains_key("Access-Control-Allow-Origin"));
612 assert!(
613 res.headers()
614 .contains_key("Access-Control-Allow-Credentials")
615 );
616 }
617
618 #[test]
621 fn test_cors_type_empty_methods() {
622 let cors_type = CorsType::AllowSome(vec![]);
623 assert_eq!(cors_type.get_value(), "");
624 }
625
626 #[test]
627 fn test_cors_origin_empty_list() {
628 let origin_type = CorsOriginType::AllowSome(vec![]);
629 assert_eq!(origin_type.get_value("http://example.com"), "");
630 }
631
632 #[tokio::test]
633 async fn test_handle_without_origin_header() {
634 let route = Route::new("/")
636 .hook(Cors::new().origin(CorsType::Any))
637 .get(|_req: Request| async { Ok("hello") });
638 let route = Route::new_root().append(route);
639
640 let mut req = Request::empty();
641 *req.method_mut() = Method::GET;
642 *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
643 let res = route.call(req).await.unwrap();
646 assert_eq!(res.status, http::StatusCode::OK);
648 }
649
650 #[tokio::test]
651 async fn test_handle_empty_string_origin() {
652 let route = Route::new("/")
654 .hook(Cors::new().origin("http://example.com"))
655 .get(|_req: Request| async { Ok("hello") });
656 let route = Route::new_root().append(route);
657
658 let mut req = Request::empty();
659 *req.method_mut() = Method::GET;
660 *req.uri_mut() = "http://localhost:8080/".parse().unwrap();
661 req.headers_mut().insert("origin", "".parse().unwrap());
662
663 let res = route.call(req).await.unwrap();
664 assert_eq!(res.status, http::StatusCode::OK);
666 }
667}