1#![allow(
2 clippy::borrow_interior_mutable_const,
3 clippy::type_complexity,
4 clippy::mutable_key_type
5)]
6use std::{
49 collections::HashSet, convert::TryFrom, iter::FromIterator, marker::PhantomData, rc::Rc,
50};
51
52use derive_more::Display;
53use ntex::http::header::{self, HeaderName, HeaderValue};
54use ntex::http::{HeaderMap, Method, RequestHead, StatusCode, Uri, error::HttpError};
55use ntex::service::{Middleware, Service, ServiceCtx};
56use ntex::util::{ByteString, Either};
57use ntex::web::{
58 DefaultError, ErrorRenderer, HttpResponse, WebRequest, WebResponse, WebResponseError,
59};
60
61#[derive(Debug, Display)]
63pub enum CorsError {
64 #[display("The HTTP request header `Origin` is required but was not provided")]
66 MissingOrigin,
67 #[display("The HTTP request header `Origin` could not be parsed correctly.")]
69 BadOrigin,
70 #[display("The request header `Access-Control-Request-Method` is required but is missing")]
73 MissingRequestMethod,
74 #[display("The request header `Access-Control-Request-Method` has an invalid value")]
76 BadRequestMethod,
77 #[display("The request header `Access-Control-Request-Headers` has an invalid value")]
80 BadRequestHeaders,
81 #[display("Origin is not allowed to make this request")]
83 OriginNotAllowed,
84 #[display("Requested method is not allowed")]
86 MethodNotAllowed,
87 #[display("One or more headers requested are not allowed")]
89 HeadersNotAllowed,
90}
91
92impl WebResponseError<DefaultError> for CorsError {
94 fn status_code(&self) -> StatusCode {
95 StatusCode::BAD_REQUEST
96 }
97}
98
99#[derive(Default, Clone, Debug, Eq, PartialEq)]
104pub enum AllOrSome<T> {
105 #[default]
107 All,
108 Some(T),
110}
111
112impl<T> AllOrSome<T> {
113 pub fn is_all(&self) -> bool {
115 match *self {
116 AllOrSome::All => true,
117 AllOrSome::Some(_) => false,
118 }
119 }
120
121 pub fn is_some(&self) -> bool {
123 !self.is_all()
124 }
125
126 pub fn as_ref(&self) -> Option<&T> {
128 match *self {
129 AllOrSome::All => None,
130 AllOrSome::Some(ref t) => Some(t),
131 }
132 }
133}
134
135#[derive(Default)]
158pub struct Cors {
159 cors: Option<Inner>,
160 methods: bool,
161 expose_hdrs: HashSet<HeaderName>,
162 error: Option<HttpError>,
163}
164
165impl Cors {
166 pub fn new() -> Self {
168 Cors {
169 cors: Some(Inner {
170 origins: AllOrSome::All,
171 origins_str: None,
172 methods: HashSet::new(),
173 headers: AllOrSome::All,
174 expose_hdrs: None,
175 max_age: None,
176 preflight: true,
177 send_wildcard: false,
178 supports_credentials: false,
179 vary_header: true,
180 }),
181 methods: false,
182 error: None,
183 expose_hdrs: HashSet::new(),
184 }
185 }
186
187 #[allow(clippy::should_implement_trait)]
188 pub fn default<Err>() -> CorsFactory<Err> {
190 let inner = Inner {
191 origins: AllOrSome::default(),
192 origins_str: None,
193 methods: HashSet::from_iter(vec![
194 Method::GET,
195 Method::HEAD,
196 Method::POST,
197 Method::OPTIONS,
198 Method::PUT,
199 Method::PATCH,
200 Method::DELETE,
201 ]),
202 headers: AllOrSome::All,
203 expose_hdrs: None,
204 max_age: None,
205 preflight: true,
206 send_wildcard: false,
207 supports_credentials: false,
208 vary_header: true,
209 };
210 CorsFactory { inner: Rc::new(inner), _t: PhantomData }
211 }
212
213 pub fn allowed_origin(mut self, origin: &str) -> Self {
231 if let Some(cors) = cors(&mut self.cors, &self.error) {
232 match Uri::try_from(origin) {
233 Ok(_) => {
234 if origin.trim() == "*" {
236 cors.origins = AllOrSome::All;
237 return self;
238 }
239 if cors.origins.is_all() {
240 cors.origins = AllOrSome::Some(HashSet::new());
241 }
242 if let AllOrSome::Some(ref mut origins) = cors.origins {
243 origins.insert(origin.to_owned());
244 }
245 }
246 Err(e) => {
247 self.error = Some(e.into());
248 }
249 }
250 }
251 self
252 }
253
254 pub fn allowed_methods<U, M>(mut self, methods: U) -> Self
273 where
274 U: IntoIterator<Item = M>,
275 Method: TryFrom<M>,
276 <Method as TryFrom<M>>::Error: Into<HttpError>,
277 {
278 self.methods = true;
279 if let Some(cors) = cors(&mut self.cors, &self.error) {
280 for m in methods {
281 match Method::try_from(m) {
282 Ok(method) => {
283 cors.methods.insert(method);
284 }
285 Err(e) => {
286 self.error = Some(e.into());
287 break;
288 }
289 }
290 }
291 }
292 self
293 }
294
295 pub fn allowed_header<H>(mut self, header: H) -> Self
297 where
298 HeaderName: TryFrom<H>,
299 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
300 {
301 if let Some(cors) = cors(&mut self.cors, &self.error) {
302 match HeaderName::try_from(header) {
303 Ok(method) => {
304 if cors.headers.is_all() {
305 cors.headers = AllOrSome::Some(HashSet::new());
306 }
307 if let AllOrSome::Some(ref mut headers) = cors.headers {
308 headers.insert(method);
309 }
310 }
311 Err(e) => self.error = Some(e.into()),
312 }
313 }
314 self
315 }
316
317 pub fn allowed_headers<U, H>(mut self, headers: U) -> Self
329 where
330 U: IntoIterator<Item = H>,
331 HeaderName: TryFrom<H>,
332 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
333 {
334 if let Some(cors) = cors(&mut self.cors, &self.error) {
335 for h in headers {
336 match HeaderName::try_from(h) {
337 Ok(method) => {
338 if cors.headers.is_all() {
339 cors.headers = AllOrSome::Some(HashSet::new());
340 }
341 if let AllOrSome::Some(ref mut headers) = cors.headers {
342 headers.insert(method);
343 }
344 }
345 Err(e) => {
346 self.error = Some(e.into());
347 break;
348 }
349 }
350 }
351 }
352 self
353 }
354
355 pub fn expose_headers<U, H>(mut self, headers: U) -> Self
364 where
365 U: IntoIterator<Item = H>,
366 HeaderName: TryFrom<H>,
367 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
368 {
369 for h in headers {
370 match HeaderName::try_from(h) {
371 Ok(method) => {
372 self.expose_hdrs.insert(method);
373 }
374 Err(e) => {
375 self.error = Some(e.into());
376 break;
377 }
378 }
379 }
380 self
381 }
382
383 pub fn max_age(mut self, max_age: usize) -> Self {
388 if let Some(cors) = cors(&mut self.cors, &self.error) {
389 cors.max_age = Some(max_age)
390 }
391 self
392 }
393
394 pub fn send_wildcard(mut self) -> Self {
410 if let Some(cors) = cors(&mut self.cors, &self.error) {
411 cors.send_wildcard = true
412 }
413 self
414 }
415
416 pub fn supports_credentials(mut self) -> Self {
430 if let Some(cors) = cors(&mut self.cors, &self.error) {
431 cors.supports_credentials = true
432 }
433 self
434 }
435
436 pub fn disable_vary_header(mut self) -> Self {
448 if let Some(cors) = cors(&mut self.cors, &self.error) {
449 cors.vary_header = false
450 }
451 self
452 }
453
454 pub fn disable_preflight(mut self) -> Self {
461 if let Some(cors) = cors(&mut self.cors, &self.error) {
462 cors.preflight = false
463 }
464 self
465 }
466
467 pub fn finish<Err>(self) -> CorsFactory<Err> {
469 let mut slf = if !self.methods {
470 self.allowed_methods(vec![
471 Method::GET,
472 Method::HEAD,
473 Method::POST,
474 Method::OPTIONS,
475 Method::PUT,
476 Method::PATCH,
477 Method::DELETE,
478 ])
479 } else {
480 self
481 };
482
483 if let Some(e) = slf.error.take() {
484 panic!("{}", e);
485 }
486
487 let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder");
488
489 if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
490 panic!("Credentials are allowed, but the Origin is set to \"*\"");
491 }
492
493 if let AllOrSome::Some(ref origins) = cors.origins {
494 let s = origins.iter().fold(String::new(), |s, v| format!("{}, {}", s, v));
495 cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap());
496 }
497
498 if !slf.expose_hdrs.is_empty() {
499 cors.expose_hdrs = Some(
500 HeaderValue::try_from(
501 &slf.expose_hdrs
502 .iter()
503 .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..],
504 )
505 .unwrap(),
506 );
507 }
508
509 CorsFactory { inner: Rc::new(cors), _t: PhantomData }
510 }
511}
512
513fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<HttpError>) -> Option<&'a mut Inner> {
514 if err.is_some() {
515 return None;
516 }
517 parts.as_mut()
518}
519
520struct Inner {
521 methods: HashSet<Method>,
522 origins: AllOrSome<HashSet<String>>,
523 origins_str: Option<HeaderValue>,
524 headers: AllOrSome<HashSet<HeaderName>>,
525 expose_hdrs: Option<HeaderValue>,
526 max_age: Option<usize>,
527 preflight: bool,
528 send_wildcard: bool,
529 supports_credentials: bool,
530 vary_header: bool,
531}
532
533impl Inner {
534 fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
535 if let Some(hdr) = req.headers().get(&header::ORIGIN) {
536 if let Ok(origin) = hdr.to_str() {
537 return match self.origins {
538 AllOrSome::All => Ok(()),
539 AllOrSome::Some(ref allowed_origins) => allowed_origins
540 .get(origin)
541 .map(|_| ())
542 .ok_or(CorsError::OriginNotAllowed),
543 };
544 }
545 Err(CorsError::BadOrigin)
546 } else {
547 match self.origins {
548 AllOrSome::All => Ok(()),
549 _ => Err(CorsError::MissingOrigin),
550 }
551 }
552 }
553
554 fn access_control_allow_origin(&self, headers: &HeaderMap) -> Option<HeaderValue> {
555 match self.origins {
556 AllOrSome::All => {
557 if self.send_wildcard {
558 Some(HeaderValue::from_static("*"))
559 } else {
560 headers.get(&header::ORIGIN).cloned()
561 }
562 }
563 AllOrSome::Some(ref origins) => {
564 if let Some(origin) =
565 headers.get(&header::ORIGIN).filter(|o| match o.to_str() {
566 Ok(os) => origins.contains(os),
567 _ => false,
568 })
569 {
570 Some(origin.clone())
571 } else {
572 Some(self.origins_str.as_ref().unwrap().clone())
573 }
574 }
575 }
576 }
577
578 fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
579 if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) {
580 if let Ok(meth) = hdr.to_str()
581 && let Ok(method) = Method::try_from(meth)
582 {
583 return self
584 .methods
585 .get(&method)
586 .map(|_| ())
587 .ok_or(CorsError::MethodNotAllowed);
588 }
589 Err(CorsError::BadRequestMethod)
590 } else {
591 Err(CorsError::MissingRequestMethod)
592 }
593 }
594
595 fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
596 match self.headers {
597 AllOrSome::All => Ok(()),
598 AllOrSome::Some(ref allowed_headers) => {
599 if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) {
600 if let Ok(headers) = hdr.to_str() {
601 let mut hdrs = HashSet::new();
602 for hdr in headers.split(',') {
603 match HeaderName::try_from(hdr.trim()) {
604 Ok(hdr) => hdrs.insert(hdr),
605 Err(_) => return Err(CorsError::BadRequestHeaders),
606 };
607 }
608 if !hdrs.is_empty() {
611 if !hdrs.is_subset(allowed_headers) {
612 return Err(CorsError::HeadersNotAllowed);
613 }
614 return Ok(());
615 }
616 }
617 Err(CorsError::BadRequestHeaders)
618 } else {
619 Ok(())
620 }
621 }
622 }
623 }
624
625 fn preflight_check(
626 &self,
627 req: &RequestHead,
628 ) -> Result<Either<HttpResponse, ()>, CorsError> {
629 if self.preflight && Method::OPTIONS == req.method {
630 self.validate_origin(req)
631 .and_then(|_| self.validate_allowed_method(req))
632 .and_then(|_| self.validate_allowed_headers(req))?;
633
634 let headers = if let Some(headers) = self.headers.as_ref() {
636 Some(
637 HeaderValue::try_from(
638 &headers
639 .iter()
640 .fold(String::new(), |s, v| s + "," + v.as_str())
641 .as_str()[1..],
642 )
643 .unwrap(),
644 )
645 } else {
646 req.headers.get(&header::ACCESS_CONTROL_REQUEST_HEADERS).cloned()
647 };
648
649 let mut res = HttpResponse::Ok();
650
651 if let Some(ref max_age) = self.max_age {
652 res.header(
653 header::ACCESS_CONTROL_MAX_AGE,
654 ByteString::from(format!("{}", max_age)),
655 );
656 }
657 if let Some(headers) = headers {
658 res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
659 }
660 if let Some(origin) = self.access_control_allow_origin(req.headers()) {
661 res.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
662 }
663 if self.supports_credentials {
664 res.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
665 }
666 let res = res
667 .header(
668 header::ACCESS_CONTROL_ALLOW_METHODS,
669 &self
670 .methods
671 .iter()
672 .fold(String::new(), |s, v| s + "," + v.as_str())
673 .as_str()[1..],
674 )
675 .finish()
676 .into_body();
677
678 Ok(Either::Left(res))
679 } else {
680 if req.headers.contains_key(&header::ORIGIN) {
681 self.validate_origin(req)?;
683 }
684 Ok(Either::Right(()))
685 }
686 }
687
688 fn handle_response(&self, headers: &mut HeaderMap, allowed_origin: Option<HeaderValue>) {
689 if let Some(origin) = allowed_origin {
690 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
691 };
692
693 if let Some(ref expose) = self.expose_hdrs {
694 headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
695 }
696 if self.supports_credentials {
697 headers.insert(
698 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
699 HeaderValue::from_static("true"),
700 );
701 }
702 if self.vary_header {
703 let value = if let Some(hdr) = headers.get(&header::VARY) {
704 let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
705 val.extend(hdr.as_bytes());
706 val.extend(b", Origin");
707 HeaderValue::try_from(&val[..]).unwrap()
708 } else {
709 HeaderValue::from_static("Origin")
710 };
711 headers.insert(header::VARY, value);
712 }
713 }
714}
715
716pub struct CorsFactory<Err> {
721 inner: Rc<Inner>,
722 _t: PhantomData<Err>,
723}
724
725impl<S, C, Err> Middleware<S, C> for CorsFactory<Err>
726where
727 S: Service<WebRequest<Err>, Response = WebResponse>,
728{
729 type Service = CorsMiddleware<S>;
730
731 fn create(&self, service: S, _: C) -> Self::Service {
732 CorsMiddleware { service, inner: self.inner.clone() }
733 }
734}
735
736#[derive(Clone)]
741pub struct CorsMiddleware<S> {
742 service: S,
743 inner: Rc<Inner>,
744}
745
746impl<S, Err> Service<WebRequest<Err>> for CorsMiddleware<S>
747where
748 S: Service<WebRequest<Err>, Response = WebResponse>,
749 Err: ErrorRenderer,
750 Err::Container: From<S::Error>,
751 CorsError: WebResponseError<Err>,
752{
753 type Response = WebResponse;
754 type Error = S::Error;
755
756 ntex::forward_ready!(service);
757 ntex::forward_shutdown!(service);
758
759 async fn call(
760 &self,
761 req: WebRequest<Err>,
762 ctx: ServiceCtx<'_, Self>,
763 ) -> Result<Self::Response, S::Error> {
764 match self.inner.preflight_check(req.head()) {
765 Ok(Either::Left(res)) => Ok(req.into_response(res)),
766 Ok(Either::Right(_)) => {
767 let inner = self.inner.clone();
768 let has_origin = req.headers().contains_key(&header::ORIGIN);
769 let allowed_origin = inner.access_control_allow_origin(req.headers());
770
771 let mut res = ctx.call(&self.service, req).await?;
772
773 if has_origin {
774 inner.handle_response(res.headers_mut(), allowed_origin);
775 }
776 Ok(res)
777 }
778 Err(e) => Ok(req.render_error(&e)),
779 }
780 }
781}
782
783#[cfg(test)]
784mod tests {
785 use ntex::service::{Pipeline, fn_service};
786 use ntex::web::{self, test, test::TestRequest};
787
788 use super::*;
789
790 #[ntex::test]
791 #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
792 async fn cors_validates_illegal_allow_credentials() {
793 let _cors =
794 Cors::new().supports_credentials().send_wildcard().finish::<web::DefaultError>();
795 }
796
797 #[ntex::test]
798 async fn validate_origin_allows_all_origins() {
799 let cors = Cors::new().finish().create(test::ok_service(), ()).into();
800 let req =
801 TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
802
803 let resp = test::call_service(&cors, req).await;
804 assert_eq!(resp.status(), StatusCode::OK);
805 }
806
807 #[ntex::test]
808 async fn default() {
809 let cors = Cors::default().create(test::ok_service(), ()).into();
810 let req =
811 TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
812
813 let resp = test::call_service(&cors, req).await;
814 assert_eq!(resp.status(), StatusCode::OK);
815 }
816
817 #[ntex::test]
818 async fn test_preflight() {
819 let cors: Pipeline<_> = Cors::new()
820 .send_wildcard()
821 .max_age(3600)
822 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
823 .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
824 .allowed_header(header::CONTENT_TYPE)
825 .finish()
826 .create(test::ok_service(), ())
827 .into();
828
829 let req = TestRequest::with_header("Origin", "https://www.example.com")
830 .method(Method::OPTIONS)
831 .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
832 .to_srv_request();
833
834 assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
835 assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_err());
836 let resp = test::call_service(&cors, req).await;
837 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
838
839 let req = TestRequest::with_header("Origin", "https://www.example.com")
840 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
841 .method(Method::OPTIONS)
842 .to_srv_request();
843
844 assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
845 assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_ok());
846
847 let req = TestRequest::with_header("Origin", "https://www.example.com")
848 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
849 .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
850 .method(Method::OPTIONS)
851 .to_srv_request();
852
853 let resp = test::call_service(&cors, req).await;
854 assert_eq!(
855 &b"*"[..],
856 resp.headers().get(&header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
857 );
858 assert_eq!(
859 &b"3600"[..],
860 resp.headers().get(&header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()
861 );
862 let hdr = resp
863 .headers()
864 .get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
865 .unwrap()
866 .to_str()
867 .unwrap();
868 assert!(hdr.contains("authorization"));
869 assert!(hdr.contains("accept"));
870 assert!(hdr.contains("content-type"));
871
872 let methods =
873 resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().to_str().unwrap();
874 assert!(methods.contains("POST"));
875 assert!(methods.contains("GET"));
876 assert!(methods.contains("OPTIONS"));
877
878 }
889
890 #[ntex::test]
901 #[should_panic(expected = "OriginNotAllowed")]
902 async fn test_validate_not_allowed_origin() {
903 let cors: Pipeline<_> = Cors::new()
904 .allowed_origin("https://www.example.com")
905 .finish()
906 .create(test::ok_service::<web::DefaultError>(), ())
907 .into();
908
909 let req = TestRequest::with_header("Origin", "https://www.unknown.com")
910 .method(Method::GET)
911 .to_srv_request();
912 cors.get_ref().inner.validate_origin(req.head()).unwrap();
913 cors.get_ref().inner.validate_allowed_method(req.head()).unwrap();
914 cors.get_ref().inner.validate_allowed_headers(req.head()).unwrap();
915 }
916
917 #[ntex::test]
918 async fn test_validate_origin() {
919 let cors = Cors::new()
920 .allowed_origin("https://www.example.com")
921 .finish()
922 .create(test::ok_service(), ())
923 .into();
924
925 let req = TestRequest::with_header("Origin", "https://www.example.com")
926 .method(Method::GET)
927 .to_srv_request();
928
929 let resp = test::call_service(&cors, req).await;
930 assert_eq!(resp.status(), StatusCode::OK);
931 }
932
933 #[ntex::test]
934 async fn test_no_origin_response() {
935 let cors =
936 Cors::new().disable_preflight().finish().create(test::ok_service(), ()).into();
937
938 let req = TestRequest::default().method(Method::GET).to_srv_request();
939 let resp = test::call_service(&cors, req).await;
940 assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
941
942 let req = TestRequest::with_header("Origin", "https://www.example.com")
943 .method(Method::OPTIONS)
944 .to_srv_request();
945 let resp = test::call_service(&cors, req).await;
946 assert_eq!(
947 &b"https://www.example.com"[..],
948 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
949 );
950 }
951
952 #[ntex::test]
953 async fn test_response() {
954 let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
955 let cors = Cors::new()
956 .send_wildcard()
957 .disable_preflight()
958 .max_age(3600)
959 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
960 .allowed_headers(exposed_headers.clone())
961 .expose_headers(exposed_headers.clone())
962 .allowed_header(header::CONTENT_TYPE)
963 .finish()
964 .create(test::ok_service(), ())
965 .into();
966
967 let req = TestRequest::with_header("Origin", "https://www.example.com")
968 .method(Method::OPTIONS)
969 .to_srv_request();
970
971 let resp = test::call_service(&cors, req).await;
972 assert_eq!(
973 &b"*"[..],
974 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
975 );
976 assert_eq!(&b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes());
977
978 {
979 let headers = resp
980 .headers()
981 .get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
982 .unwrap()
983 .to_str()
984 .unwrap()
985 .split(',')
986 .map(|s| s.trim())
987 .collect::<Vec<&str>>();
988
989 for h in exposed_headers {
990 assert!(headers.contains(&h.as_str()));
991 }
992 }
993
994 let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
995 let cors = Cors::new()
996 .send_wildcard()
997 .disable_preflight()
998 .max_age(3600)
999 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
1000 .allowed_headers(exposed_headers.clone())
1001 .expose_headers(exposed_headers.clone())
1002 .allowed_header(header::CONTENT_TYPE)
1003 .finish()
1004 .create(
1005 fn_service(|req: WebRequest<DefaultError>| async move {
1006 Ok::<_, std::convert::Infallible>(req.into_response(
1007 HttpResponse::Ok().header(header::VARY, "Accept").finish(),
1008 ))
1009 }),
1010 (),
1011 )
1012 .into();
1013 let req = TestRequest::with_header("Origin", "https://www.example.com")
1014 .method(Method::OPTIONS)
1015 .to_srv_request();
1016 let resp = test::call_service(&cors, req).await;
1017 assert_eq!(
1018 &b"Accept, Origin"[..],
1019 resp.headers().get(header::VARY).unwrap().as_bytes()
1020 );
1021
1022 let cors = Cors::new()
1023 .disable_vary_header()
1024 .allowed_origin("https://www.example.com")
1025 .allowed_origin("https://www.google.com")
1026 .finish()
1027 .create(test::ok_service(), ())
1028 .into();
1029
1030 let req = TestRequest::with_header("Origin", "https://www.example.com")
1031 .method(Method::OPTIONS)
1032 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
1033 .to_srv_request();
1034 let resp = test::call_service(&cors, req).await;
1035
1036 let origins_str =
1037 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().to_str().unwrap();
1038
1039 assert_eq!("https://www.example.com", origins_str);
1040 }
1041
1042 #[ntex::test]
1043 async fn test_multiple_origins() {
1044 let cors = Cors::new()
1045 .allowed_origin("https://example.com")
1046 .allowed_origin("https://example.org")
1047 .allowed_methods(vec![Method::GET])
1048 .finish()
1049 .create(test::ok_service(), ())
1050 .into();
1051
1052 let req = TestRequest::with_header("Origin", "https://example.com")
1053 .method(Method::GET)
1054 .to_srv_request();
1055
1056 let resp = test::call_service(&cors, req).await;
1057 assert_eq!(
1058 &b"https://example.com"[..],
1059 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1060 );
1061
1062 let req = TestRequest::with_header("Origin", "https://example.org")
1063 .method(Method::GET)
1064 .to_srv_request();
1065
1066 let resp = test::call_service(&cors, req).await;
1067 assert_eq!(
1068 &b"https://example.org"[..],
1069 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1070 );
1071 }
1072
1073 #[ntex::test]
1074 async fn test_multiple_origins_preflight() {
1075 let cors = Cors::new()
1076 .allowed_origin("https://example.com")
1077 .allowed_origin("https://example.org")
1078 .allowed_methods(vec![Method::GET])
1079 .finish()
1080 .create(test::ok_service(), ())
1081 .into();
1082
1083 let req = TestRequest::with_header("Origin", "https://example.com")
1084 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1085 .method(Method::OPTIONS)
1086 .to_srv_request();
1087
1088 let resp = test::call_service(&cors, req).await;
1089 assert_eq!(
1090 &b"https://example.com"[..],
1091 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1092 );
1093
1094 let req = TestRequest::with_header("Origin", "https://example.org")
1095 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1096 .method(Method::OPTIONS)
1097 .to_srv_request();
1098
1099 let resp = test::call_service(&cors, req).await;
1100 assert_eq!(
1101 &b"https://example.org"[..],
1102 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1103 );
1104 }
1105
1106 #[ntex::test]
1107 async fn test_set_allowed_origin_to_all() {
1108 let cors =
1109 Cors::new().allowed_origin("*").finish().create(test::ok_service(), ()).into();
1110
1111 let req = TestRequest::with_header("Origin", "https://www.example.com")
1112 .method(Method::GET)
1113 .to_srv_request();
1114
1115 let resp = test::call_service(&cors, req).await;
1116 assert_eq!(resp.status(), StatusCode::OK);
1117 }
1118}