1#![allow(
2 clippy::borrow_interior_mutable_const,
3 clippy::type_complexity,
4 clippy::mutable_key_type
5)]
6use std::{
49 collections::HashSet, convert::TryFrom, iter::FromIterator, marker::PhantomData, rc::Rc,
50};
51
52use derive_more::Display;
53use ntex::http::header::{self, HeaderName, HeaderValue};
54use ntex::http::{error::HttpError, HeaderMap, Method, RequestHead, StatusCode, Uri};
55use ntex::service::{Middleware, Service, ServiceCtx};
56use ntex::util::Either;
57use ntex::web::{
58 DefaultError, ErrorRenderer, HttpResponse, WebRequest, WebResponse, WebResponseError,
59};
60
61#[derive(Debug, Display)]
63pub enum CorsError {
64 #[display(fmt = "The HTTP request header `Origin` is required but was not provided")]
66 MissingOrigin,
67 #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")]
69 BadOrigin,
70 #[display(
73 fmt = "The request header `Access-Control-Request-Method` is required but is missing"
74 )]
75 MissingRequestMethod,
76 #[display(fmt = "The request header `Access-Control-Request-Method` has an invalid value")]
78 BadRequestMethod,
79 #[display(
82 fmt = "The request header `Access-Control-Request-Headers` has an invalid value"
83 )]
84 BadRequestHeaders,
85 #[display(fmt = "Origin is not allowed to make this request")]
87 OriginNotAllowed,
88 #[display(fmt = "Requested method is not allowed")]
90 MethodNotAllowed,
91 #[display(fmt = "One or more headers requested are not allowed")]
93 HeadersNotAllowed,
94}
95
96impl WebResponseError<DefaultError> for CorsError {
98 fn status_code(&self) -> StatusCode {
99 StatusCode::BAD_REQUEST
100 }
101}
102
103#[derive(Default, Clone, Debug, Eq, PartialEq)]
108pub enum AllOrSome<T> {
109 #[default]
111 All,
112 Some(T),
114}
115
116impl<T> AllOrSome<T> {
117 pub fn is_all(&self) -> bool {
119 match *self {
120 AllOrSome::All => true,
121 AllOrSome::Some(_) => false,
122 }
123 }
124
125 pub fn is_some(&self) -> bool {
127 !self.is_all()
128 }
129
130 pub fn as_ref(&self) -> Option<&T> {
132 match *self {
133 AllOrSome::All => None,
134 AllOrSome::Some(ref t) => Some(t),
135 }
136 }
137}
138
139#[derive(Default)]
163pub struct Cors {
164 cors: Option<Inner>,
165 methods: bool,
166 expose_hdrs: HashSet<HeaderName>,
167 error: Option<HttpError>,
168}
169
170impl Cors {
171 pub fn new() -> Self {
173 Cors {
174 cors: Some(Inner {
175 origins: AllOrSome::All,
176 origins_str: None,
177 methods: HashSet::new(),
178 headers: AllOrSome::All,
179 expose_hdrs: None,
180 max_age: None,
181 preflight: true,
182 send_wildcard: false,
183 supports_credentials: false,
184 vary_header: true,
185 }),
186 methods: false,
187 error: None,
188 expose_hdrs: HashSet::new(),
189 }
190 }
191
192 pub fn default<Err>() -> CorsFactory<Err> {
194 let inner = Inner {
195 origins: AllOrSome::default(),
196 origins_str: None,
197 methods: HashSet::from_iter(
198 vec![
199 Method::GET,
200 Method::HEAD,
201 Method::POST,
202 Method::OPTIONS,
203 Method::PUT,
204 Method::PATCH,
205 Method::DELETE,
206 ]
207 .into_iter(),
208 ),
209 headers: AllOrSome::All,
210 expose_hdrs: None,
211 max_age: None,
212 preflight: true,
213 send_wildcard: false,
214 supports_credentials: false,
215 vary_header: true,
216 };
217 CorsFactory { inner: Rc::new(inner), _t: PhantomData }
218 }
219
220 pub fn allowed_origin(mut self, origin: &str) -> Self {
238 if let Some(cors) = cors(&mut self.cors, &self.error) {
239 match Uri::try_from(origin) {
240 Ok(_) => {
241 if cors.origins.is_all() {
242 cors.origins = AllOrSome::Some(HashSet::new());
243 }
244 if let AllOrSome::Some(ref mut origins) = cors.origins {
245 origins.insert(origin.to_owned());
246 }
247 }
248 Err(e) => {
249 self.error = Some(e.into());
250 }
251 }
252 }
253 self
254 }
255
256 pub fn allowed_methods<U, M>(mut self, methods: U) -> Self
264 where
265 U: IntoIterator<Item = M>,
266 Method: TryFrom<M>,
267 <Method as TryFrom<M>>::Error: Into<HttpError>,
268 {
269 self.methods = true;
270 if let Some(cors) = cors(&mut self.cors, &self.error) {
271 for m in methods {
272 match Method::try_from(m) {
273 Ok(method) => {
274 cors.methods.insert(method);
275 }
276 Err(e) => {
277 self.error = Some(e.into());
278 break;
279 }
280 }
281 }
282 }
283 self
284 }
285
286 pub fn allowed_header<H>(mut self, header: H) -> Self
288 where
289 HeaderName: TryFrom<H>,
290 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
291 {
292 if let Some(cors) = cors(&mut self.cors, &self.error) {
293 match HeaderName::try_from(header) {
294 Ok(method) => {
295 if cors.headers.is_all() {
296 cors.headers = AllOrSome::Some(HashSet::new());
297 }
298 if let AllOrSome::Some(ref mut headers) = cors.headers {
299 headers.insert(method);
300 }
301 }
302 Err(e) => self.error = Some(e.into()),
303 }
304 }
305 self
306 }
307
308 pub fn allowed_headers<U, H>(mut self, headers: U) -> Self
320 where
321 U: IntoIterator<Item = H>,
322 HeaderName: TryFrom<H>,
323 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
324 {
325 if let Some(cors) = cors(&mut self.cors, &self.error) {
326 for h in headers {
327 match HeaderName::try_from(h) {
328 Ok(method) => {
329 if cors.headers.is_all() {
330 cors.headers = AllOrSome::Some(HashSet::new());
331 }
332 if let AllOrSome::Some(ref mut headers) = cors.headers {
333 headers.insert(method);
334 }
335 }
336 Err(e) => {
337 self.error = Some(e.into());
338 break;
339 }
340 }
341 }
342 }
343 self
344 }
345
346 pub fn expose_headers<U, H>(mut self, headers: U) -> Self
355 where
356 U: IntoIterator<Item = H>,
357 HeaderName: TryFrom<H>,
358 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
359 {
360 for h in headers {
361 match HeaderName::try_from(h) {
362 Ok(method) => {
363 self.expose_hdrs.insert(method);
364 }
365 Err(e) => {
366 self.error = Some(e.into());
367 break;
368 }
369 }
370 }
371 self
372 }
373
374 pub fn max_age(mut self, max_age: usize) -> Self {
379 if let Some(cors) = cors(&mut self.cors, &self.error) {
380 cors.max_age = Some(max_age)
381 }
382 self
383 }
384
385 pub fn send_wildcard(mut self) -> Self {
401 if let Some(cors) = cors(&mut self.cors, &self.error) {
402 cors.send_wildcard = true
403 }
404 self
405 }
406
407 pub fn supports_credentials(mut self) -> Self {
421 if let Some(cors) = cors(&mut self.cors, &self.error) {
422 cors.supports_credentials = true
423 }
424 self
425 }
426
427 pub fn disable_vary_header(mut self) -> Self {
439 if let Some(cors) = cors(&mut self.cors, &self.error) {
440 cors.vary_header = false
441 }
442 self
443 }
444
445 pub fn disable_preflight(mut self) -> Self {
452 if let Some(cors) = cors(&mut self.cors, &self.error) {
453 cors.preflight = false
454 }
455 self
456 }
457
458 pub fn finish<Err>(self) -> CorsFactory<Err> {
460 let mut slf = if !self.methods {
461 self.allowed_methods(vec![
462 Method::GET,
463 Method::HEAD,
464 Method::POST,
465 Method::OPTIONS,
466 Method::PUT,
467 Method::PATCH,
468 Method::DELETE,
469 ])
470 } else {
471 self
472 };
473
474 if let Some(e) = slf.error.take() {
475 panic!("{}", e);
476 }
477
478 let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder");
479
480 if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
481 panic!("Credentials are allowed, but the Origin is set to \"*\"");
482 }
483
484 if let AllOrSome::Some(ref origins) = cors.origins {
485 let s = origins.iter().fold(String::new(), |s, v| format!("{}, {}", s, v));
486 cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap());
487 }
488
489 if !slf.expose_hdrs.is_empty() {
490 cors.expose_hdrs = Some(
491 HeaderValue::try_from(
492 &slf.expose_hdrs
493 .iter()
494 .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..],
495 )
496 .unwrap(),
497 );
498 }
499
500 CorsFactory { inner: Rc::new(cors), _t: PhantomData }
501 }
502}
503
504fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<HttpError>) -> Option<&'a mut Inner> {
505 if err.is_some() {
506 return None;
507 }
508 parts.as_mut()
509}
510
511struct Inner {
512 methods: HashSet<Method>,
513 origins: AllOrSome<HashSet<String>>,
514 origins_str: Option<HeaderValue>,
515 headers: AllOrSome<HashSet<HeaderName>>,
516 expose_hdrs: Option<HeaderValue>,
517 max_age: Option<usize>,
518 preflight: bool,
519 send_wildcard: bool,
520 supports_credentials: bool,
521 vary_header: bool,
522}
523
524impl Inner {
525 fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
526 if let Some(hdr) = req.headers().get(&header::ORIGIN) {
527 if let Ok(origin) = hdr.to_str() {
528 return match self.origins {
529 AllOrSome::All => Ok(()),
530 AllOrSome::Some(ref allowed_origins) => allowed_origins
531 .get(origin)
532 .map(|_| ())
533 .ok_or(CorsError::OriginNotAllowed),
534 };
535 }
536 Err(CorsError::BadOrigin)
537 } else {
538 match self.origins {
539 AllOrSome::All => Ok(()),
540 _ => Err(CorsError::MissingOrigin),
541 }
542 }
543 }
544
545 fn access_control_allow_origin(&self, headers: &HeaderMap) -> Option<HeaderValue> {
546 match self.origins {
547 AllOrSome::All => {
548 if self.send_wildcard {
549 Some(HeaderValue::from_static("*"))
550 } else {
551 headers.get(&header::ORIGIN).cloned()
552 }
553 }
554 AllOrSome::Some(ref origins) => {
555 if let Some(origin) =
556 headers.get(&header::ORIGIN).filter(|o| match o.to_str() {
557 Ok(os) => origins.contains(os),
558 _ => false,
559 })
560 {
561 Some(origin.clone())
562 } else {
563 Some(self.origins_str.as_ref().unwrap().clone())
564 }
565 }
566 }
567 }
568
569 fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
570 if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) {
571 if let Ok(meth) = hdr.to_str() {
572 if let Ok(method) = Method::try_from(meth) {
573 return self
574 .methods
575 .get(&method)
576 .map(|_| ())
577 .ok_or(CorsError::MethodNotAllowed);
578 }
579 }
580 Err(CorsError::BadRequestMethod)
581 } else {
582 Err(CorsError::MissingRequestMethod)
583 }
584 }
585
586 fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
587 match self.headers {
588 AllOrSome::All => Ok(()),
589 AllOrSome::Some(ref allowed_headers) => {
590 if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) {
591 if let Ok(headers) = hdr.to_str() {
592 let mut hdrs = HashSet::new();
593 for hdr in headers.split(',') {
594 match HeaderName::try_from(hdr.trim()) {
595 Ok(hdr) => hdrs.insert(hdr),
596 Err(_) => return Err(CorsError::BadRequestHeaders),
597 };
598 }
599 if !hdrs.is_empty() {
602 if !hdrs.is_subset(allowed_headers) {
603 return Err(CorsError::HeadersNotAllowed);
604 }
605 return Ok(());
606 }
607 }
608 Err(CorsError::BadRequestHeaders)
609 } else {
610 Ok(())
611 }
612 }
613 }
614 }
615
616 fn preflight_check(
617 &self,
618 req: &RequestHead,
619 ) -> Result<Either<HttpResponse, ()>, CorsError> {
620 if self.preflight && Method::OPTIONS == req.method {
621 self.validate_origin(req)
622 .and_then(|_| self.validate_allowed_method(req))
623 .and_then(|_| self.validate_allowed_headers(req))?;
624
625 let headers = if let Some(headers) = self.headers.as_ref() {
627 Some(
628 HeaderValue::try_from(
629 &headers
630 .iter()
631 .fold(String::new(), |s, v| s + "," + v.as_str())
632 .as_str()[1..],
633 )
634 .unwrap(),
635 )
636 } else {
637 req.headers.get(&header::ACCESS_CONTROL_REQUEST_HEADERS).cloned()
638 };
639
640 let res = HttpResponse::Ok()
641 .if_some(self.max_age.as_ref(), |max_age, resp| {
642 let _ = resp.header(
643 header::ACCESS_CONTROL_MAX_AGE,
644 format!("{}", max_age).as_str(),
645 );
646 })
647 .if_some(headers, |headers, resp| {
648 let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
649 })
650 .if_some(self.access_control_allow_origin(req.headers()), |origin, resp| {
651 let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
652 })
653 .if_true(self.supports_credentials, |resp| {
654 resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
655 })
656 .header(
657 header::ACCESS_CONTROL_ALLOW_METHODS,
658 &self
659 .methods
660 .iter()
661 .fold(String::new(), |s, v| s + "," + v.as_str())
662 .as_str()[1..],
663 )
664 .finish()
665 .into_body();
666
667 Ok(Either::Left(res))
668 } else {
669 if req.headers.contains_key(&header::ORIGIN) {
670 self.validate_origin(req)?;
672 }
673 Ok(Either::Right(()))
674 }
675 }
676
677 fn handle_response(&self, headers: &mut HeaderMap, allowed_origin: Option<HeaderValue>) {
678 if let Some(origin) = allowed_origin {
679 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
680 };
681
682 if let Some(ref expose) = self.expose_hdrs {
683 headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
684 }
685 if self.supports_credentials {
686 headers.insert(
687 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
688 HeaderValue::from_static("true"),
689 );
690 }
691 if self.vary_header {
692 let value = if let Some(hdr) = headers.get(&header::VARY) {
693 let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
694 val.extend(hdr.as_bytes());
695 val.extend(b", Origin");
696 HeaderValue::try_from(&val[..]).unwrap()
697 } else {
698 HeaderValue::from_static("Origin")
699 };
700 headers.insert(header::VARY, value);
701 }
702 }
703}
704
705pub struct CorsFactory<Err> {
710 inner: Rc<Inner>,
711 _t: PhantomData<Err>,
712}
713
714impl<S, Err> Middleware<S> for CorsFactory<Err>
715where
716 S: Service<WebRequest<Err>, Response = WebResponse>,
717{
718 type Service = CorsMiddleware<S>;
719
720 fn create(&self, service: S) -> Self::Service {
721 CorsMiddleware { service, inner: self.inner.clone() }
722 }
723}
724
725#[derive(Clone)]
730pub struct CorsMiddleware<S> {
731 service: S,
732 inner: Rc<Inner>,
733}
734
735impl<S, Err> Service<WebRequest<Err>> for CorsMiddleware<S>
736where
737 S: Service<WebRequest<Err>, Response = WebResponse>,
738 Err: ErrorRenderer,
739 Err::Container: From<S::Error>,
740 CorsError: WebResponseError<Err>,
741{
742 type Response = WebResponse;
743 type Error = S::Error;
744
745 ntex::forward_ready!(service);
746 ntex::forward_shutdown!(service);
747
748 async fn call(
749 &self,
750 req: WebRequest<Err>,
751 ctx: ServiceCtx<'_, Self>,
752 ) -> Result<Self::Response, S::Error> {
753 match self.inner.preflight_check(req.head()) {
754 Ok(Either::Left(res)) => Ok(req.into_response(res)),
755 Ok(Either::Right(_)) => {
756 let inner = self.inner.clone();
757 let has_origin = req.headers().contains_key(&header::ORIGIN);
758 let allowed_origin = inner.access_control_allow_origin(req.headers());
759
760 let mut res = ctx.call(&self.service, req).await?;
761
762 if has_origin {
763 inner.handle_response(res.headers_mut(), allowed_origin);
764 }
765 Ok(res)
766 }
767 Err(e) => Ok(req.render_error(e)),
768 }
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use ntex::service::{fn_service, Middleware, Pipeline};
775 use ntex::web::{self, test, test::TestRequest};
776
777 use super::*;
778
779 #[ntex::test]
780 #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
781 async fn cors_validates_illegal_allow_credentials() {
782 let _cors =
783 Cors::new().supports_credentials().send_wildcard().finish::<web::DefaultError>();
784 }
785
786 #[ntex::test]
787 async fn validate_origin_allows_all_origins() {
788 let cors = Cors::new().finish().create(test::ok_service()).into();
789 let req =
790 TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
791
792 let resp = test::call_service(&cors, req).await;
793 assert_eq!(resp.status(), StatusCode::OK);
794 }
795
796 #[ntex::test]
797 async fn default() {
798 let cors = Cors::default().create(test::ok_service()).into();
799 let req =
800 TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
801
802 let resp = test::call_service(&cors, req).await;
803 assert_eq!(resp.status(), StatusCode::OK);
804 }
805
806 #[ntex::test]
807 async fn test_preflight() {
808 let cors: Pipeline<_> = Cors::new()
809 .send_wildcard()
810 .max_age(3600)
811 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
812 .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
813 .allowed_header(header::CONTENT_TYPE)
814 .finish()
815 .create(test::ok_service())
816 .into();
817
818 let req = TestRequest::with_header("Origin", "https://www.example.com")
819 .method(Method::OPTIONS)
820 .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
821 .to_srv_request();
822
823 assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
824 assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_err());
825 let resp = test::call_service(&cors, req).await;
826 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
827
828 let req = TestRequest::with_header("Origin", "https://www.example.com")
829 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
830 .method(Method::OPTIONS)
831 .to_srv_request();
832
833 assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
834 assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_ok());
835
836 let req = TestRequest::with_header("Origin", "https://www.example.com")
837 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
838 .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
839 .method(Method::OPTIONS)
840 .to_srv_request();
841
842 let resp = test::call_service(&cors, req).await;
843 assert_eq!(
844 &b"*"[..],
845 resp.headers().get(&header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
846 );
847 assert_eq!(
848 &b"3600"[..],
849 resp.headers().get(&header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()
850 );
851 let hdr = resp
852 .headers()
853 .get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
854 .unwrap()
855 .to_str()
856 .unwrap();
857 assert!(hdr.contains("authorization"));
858 assert!(hdr.contains("accept"));
859 assert!(hdr.contains("content-type"));
860
861 let methods =
862 resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().to_str().unwrap();
863 assert!(methods.contains("POST"));
864 assert!(methods.contains("GET"));
865 assert!(methods.contains("OPTIONS"));
866
867 }
878
879 #[ntex::test]
890 #[should_panic(expected = "OriginNotAllowed")]
891 async fn test_validate_not_allowed_origin() {
892 let cors: Pipeline<_> = Cors::new()
893 .allowed_origin("https://www.example.com")
894 .finish()
895 .create(test::ok_service::<web::DefaultError>())
896 .into();
897
898 let req = TestRequest::with_header("Origin", "https://www.unknown.com")
899 .method(Method::GET)
900 .to_srv_request();
901 cors.get_ref().inner.validate_origin(req.head()).unwrap();
902 cors.get_ref().inner.validate_allowed_method(req.head()).unwrap();
903 cors.get_ref().inner.validate_allowed_headers(req.head()).unwrap();
904 }
905
906 #[ntex::test]
907 async fn test_validate_origin() {
908 let cors = Cors::new()
909 .allowed_origin("https://www.example.com")
910 .finish()
911 .create(test::ok_service())
912 .into();
913
914 let req = TestRequest::with_header("Origin", "https://www.example.com")
915 .method(Method::GET)
916 .to_srv_request();
917
918 let resp = test::call_service(&cors, req).await;
919 assert_eq!(resp.status(), StatusCode::OK);
920 }
921
922 #[ntex::test]
923 async fn test_no_origin_response() {
924 let cors = Cors::new().disable_preflight().finish().create(test::ok_service()).into();
925
926 let req = TestRequest::default().method(Method::GET).to_srv_request();
927 let resp = test::call_service(&cors, req).await;
928 assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
929
930 let req = TestRequest::with_header("Origin", "https://www.example.com")
931 .method(Method::OPTIONS)
932 .to_srv_request();
933 let resp = test::call_service(&cors, req).await;
934 assert_eq!(
935 &b"https://www.example.com"[..],
936 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
937 );
938 }
939
940 #[ntex::test]
941 async fn test_response() {
942 let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
943 let cors = Cors::new()
944 .send_wildcard()
945 .disable_preflight()
946 .max_age(3600)
947 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
948 .allowed_headers(exposed_headers.clone())
949 .expose_headers(exposed_headers.clone())
950 .allowed_header(header::CONTENT_TYPE)
951 .finish()
952 .create(test::ok_service())
953 .into();
954
955 let req = TestRequest::with_header("Origin", "https://www.example.com")
956 .method(Method::OPTIONS)
957 .to_srv_request();
958
959 let resp = test::call_service(&cors, req).await;
960 assert_eq!(
961 &b"*"[..],
962 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
963 );
964 assert_eq!(&b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes());
965
966 {
967 let headers = resp
968 .headers()
969 .get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
970 .unwrap()
971 .to_str()
972 .unwrap()
973 .split(',')
974 .map(|s| s.trim())
975 .collect::<Vec<&str>>();
976
977 for h in exposed_headers {
978 assert!(headers.contains(&h.as_str()));
979 }
980 }
981
982 let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
983 let cors =
984 Cors::new()
985 .send_wildcard()
986 .disable_preflight()
987 .max_age(3600)
988 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
989 .allowed_headers(exposed_headers.clone())
990 .expose_headers(exposed_headers.clone())
991 .allowed_header(header::CONTENT_TYPE)
992 .finish()
993 .create(fn_service(|req: WebRequest<DefaultError>| async move {
994 Ok::<_, std::convert::Infallible>(req.into_response(
995 HttpResponse::Ok().header(header::VARY, "Accept").finish(),
996 ))
997 }))
998 .into();
999 let req = TestRequest::with_header("Origin", "https://www.example.com")
1000 .method(Method::OPTIONS)
1001 .to_srv_request();
1002 let resp = test::call_service(&cors, req).await;
1003 assert_eq!(
1004 &b"Accept, Origin"[..],
1005 resp.headers().get(header::VARY).unwrap().as_bytes()
1006 );
1007
1008 let cors = Cors::new()
1009 .disable_vary_header()
1010 .allowed_origin("https://www.example.com")
1011 .allowed_origin("https://www.google.com")
1012 .finish()
1013 .create(test::ok_service())
1014 .into();
1015
1016 let req = TestRequest::with_header("Origin", "https://www.example.com")
1017 .method(Method::OPTIONS)
1018 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
1019 .to_srv_request();
1020 let resp = test::call_service(&cors, req).await;
1021
1022 let origins_str =
1023 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().to_str().unwrap();
1024
1025 assert_eq!("https://www.example.com", origins_str);
1026 }
1027
1028 #[ntex::test]
1029 async fn test_multiple_origins() {
1030 let cors = Cors::new()
1031 .allowed_origin("https://example.com")
1032 .allowed_origin("https://example.org")
1033 .allowed_methods(vec![Method::GET])
1034 .finish()
1035 .create(test::ok_service())
1036 .into();
1037
1038 let req = TestRequest::with_header("Origin", "https://example.com")
1039 .method(Method::GET)
1040 .to_srv_request();
1041
1042 let resp = test::call_service(&cors, req).await;
1043 assert_eq!(
1044 &b"https://example.com"[..],
1045 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1046 );
1047
1048 let req = TestRequest::with_header("Origin", "https://example.org")
1049 .method(Method::GET)
1050 .to_srv_request();
1051
1052 let resp = test::call_service(&cors, req).await;
1053 assert_eq!(
1054 &b"https://example.org"[..],
1055 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1056 );
1057 }
1058
1059 #[ntex::test]
1060 async fn test_multiple_origins_preflight() {
1061 let cors = Cors::new()
1062 .allowed_origin("https://example.com")
1063 .allowed_origin("https://example.org")
1064 .allowed_methods(vec![Method::GET])
1065 .finish()
1066 .create(test::ok_service())
1067 .into();
1068
1069 let req = TestRequest::with_header("Origin", "https://example.com")
1070 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1071 .method(Method::OPTIONS)
1072 .to_srv_request();
1073
1074 let resp = test::call_service(&cors, req).await;
1075 assert_eq!(
1076 &b"https://example.com"[..],
1077 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1078 );
1079
1080 let req = TestRequest::with_header("Origin", "https://example.org")
1081 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1082 .method(Method::OPTIONS)
1083 .to_srv_request();
1084
1085 let resp = test::call_service(&cors, req).await;
1086 assert_eq!(
1087 &b"https://example.org"[..],
1088 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1089 );
1090 }
1091}