1#![allow(
2 clippy::borrow_interior_mutable_const,
3 clippy::type_complexity,
4 clippy::mutable_key_type
5)]
6use std::task::{Context, Poll};
49use std::{
50 collections::HashSet, convert::TryFrom, iter::FromIterator, marker::PhantomData, rc::Rc,
51};
52
53use derive_more::Display;
54use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
55use ntex::http::header::{self, HeaderName, HeaderValue};
56use ntex::http::{error::HttpError, HeaderMap, Method, RequestHead, StatusCode, Uri};
57use ntex::service::{Middleware, Service, ServiceCtx};
58use ntex::web::{
59 DefaultError, ErrorRenderer, HttpResponse, WebRequest, WebResponse, WebResponseError,
60};
61
62#[derive(Debug, Display)]
64pub enum CorsError {
65 #[display(fmt = "The HTTP request header `Origin` is required but was not provided")]
67 MissingOrigin,
68 #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")]
70 BadOrigin,
71 #[display(
74 fmt = "The request header `Access-Control-Request-Method` is required but is missing"
75 )]
76 MissingRequestMethod,
77 #[display(fmt = "The request header `Access-Control-Request-Method` has an invalid value")]
79 BadRequestMethod,
80 #[display(
83 fmt = "The request header `Access-Control-Request-Headers` has an invalid value"
84 )]
85 BadRequestHeaders,
86 #[display(fmt = "Origin is not allowed to make this request")]
88 OriginNotAllowed,
89 #[display(fmt = "Requested method is not allowed")]
91 MethodNotAllowed,
92 #[display(fmt = "One or more headers requested are not allowed")]
94 HeadersNotAllowed,
95}
96
97impl WebResponseError<DefaultError> for CorsError {
99 fn status_code(&self) -> StatusCode {
100 StatusCode::BAD_REQUEST
101 }
102}
103
104#[derive(Clone, Debug, Eq, PartialEq)]
109pub enum AllOrSome<T> {
110 All,
112 Some(T),
114}
115
116impl<T> Default for AllOrSome<T> {
117 fn default() -> Self {
118 AllOrSome::All
119 }
120}
121
122impl<T> AllOrSome<T> {
123 pub fn is_all(&self) -> bool {
125 match *self {
126 AllOrSome::All => true,
127 AllOrSome::Some(_) => false,
128 }
129 }
130
131 pub fn is_some(&self) -> bool {
133 !self.is_all()
134 }
135
136 pub fn as_ref(&self) -> Option<&T> {
138 match *self {
139 AllOrSome::All => None,
140 AllOrSome::Some(ref t) => Some(t),
141 }
142 }
143}
144
145#[derive(Default)]
169pub struct Cors {
170 cors: Option<Inner>,
171 methods: bool,
172 expose_hdrs: HashSet<HeaderName>,
173 error: Option<HttpError>,
174}
175
176impl Cors {
177 pub fn new() -> Self {
179 Cors {
180 cors: Some(Inner {
181 origins: AllOrSome::All,
182 origins_str: None,
183 methods: HashSet::new(),
184 headers: AllOrSome::All,
185 expose_hdrs: None,
186 max_age: None,
187 preflight: true,
188 send_wildcard: false,
189 supports_credentials: false,
190 vary_header: true,
191 }),
192 methods: false,
193 error: None,
194 expose_hdrs: HashSet::new(),
195 }
196 }
197
198 pub fn default<Err>() -> CorsFactory<Err> {
200 let inner = Inner {
201 origins: AllOrSome::default(),
202 origins_str: None,
203 methods: HashSet::from_iter(
204 vec![
205 Method::GET,
206 Method::HEAD,
207 Method::POST,
208 Method::OPTIONS,
209 Method::PUT,
210 Method::PATCH,
211 Method::DELETE,
212 ]
213 .into_iter(),
214 ),
215 headers: AllOrSome::All,
216 expose_hdrs: None,
217 max_age: None,
218 preflight: true,
219 send_wildcard: false,
220 supports_credentials: false,
221 vary_header: true,
222 };
223 CorsFactory { inner: Rc::new(inner), _t: PhantomData }
224 }
225
226 pub fn allowed_origin(mut self, origin: &str) -> Self {
244 if let Some(cors) = cors(&mut self.cors, &self.error) {
245 match Uri::try_from(origin) {
246 Ok(_) => {
247 if cors.origins.is_all() {
248 cors.origins = AllOrSome::Some(HashSet::new());
249 }
250 if let AllOrSome::Some(ref mut origins) = cors.origins {
251 origins.insert(origin.to_owned());
252 }
253 }
254 Err(e) => {
255 self.error = Some(e.into());
256 }
257 }
258 }
259 self
260 }
261
262 pub fn allowed_methods<U, M>(mut self, methods: U) -> Self
270 where
271 U: IntoIterator<Item = M>,
272 Method: TryFrom<M>,
273 <Method as TryFrom<M>>::Error: Into<HttpError>,
274 {
275 self.methods = true;
276 if let Some(cors) = cors(&mut self.cors, &self.error) {
277 for m in methods {
278 match Method::try_from(m) {
279 Ok(method) => {
280 cors.methods.insert(method);
281 }
282 Err(e) => {
283 self.error = Some(e.into());
284 break;
285 }
286 }
287 }
288 }
289 self
290 }
291
292 pub fn allowed_header<H>(mut self, header: H) -> Self
294 where
295 HeaderName: TryFrom<H>,
296 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
297 {
298 if let Some(cors) = cors(&mut self.cors, &self.error) {
299 match HeaderName::try_from(header) {
300 Ok(method) => {
301 if cors.headers.is_all() {
302 cors.headers = AllOrSome::Some(HashSet::new());
303 }
304 if let AllOrSome::Some(ref mut headers) = cors.headers {
305 headers.insert(method);
306 }
307 }
308 Err(e) => self.error = Some(e.into()),
309 }
310 }
311 self
312 }
313
314 pub fn allowed_headers<U, H>(mut self, headers: U) -> Self
326 where
327 U: IntoIterator<Item = H>,
328 HeaderName: TryFrom<H>,
329 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
330 {
331 if let Some(cors) = cors(&mut self.cors, &self.error) {
332 for h in headers {
333 match HeaderName::try_from(h) {
334 Ok(method) => {
335 if cors.headers.is_all() {
336 cors.headers = AllOrSome::Some(HashSet::new());
337 }
338 if let AllOrSome::Some(ref mut headers) = cors.headers {
339 headers.insert(method);
340 }
341 }
342 Err(e) => {
343 self.error = Some(e.into());
344 break;
345 }
346 }
347 }
348 }
349 self
350 }
351
352 pub fn expose_headers<U, H>(mut self, headers: U) -> Self
361 where
362 U: IntoIterator<Item = H>,
363 HeaderName: TryFrom<H>,
364 <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
365 {
366 for h in headers {
367 match HeaderName::try_from(h) {
368 Ok(method) => {
369 self.expose_hdrs.insert(method);
370 }
371 Err(e) => {
372 self.error = Some(e.into());
373 break;
374 }
375 }
376 }
377 self
378 }
379
380 pub fn max_age(mut self, max_age: usize) -> Self {
385 if let Some(cors) = cors(&mut self.cors, &self.error) {
386 cors.max_age = Some(max_age)
387 }
388 self
389 }
390
391 pub fn send_wildcard(mut self) -> Self {
407 if let Some(cors) = cors(&mut self.cors, &self.error) {
408 cors.send_wildcard = true
409 }
410 self
411 }
412
413 pub fn supports_credentials(mut self) -> Self {
427 if let Some(cors) = cors(&mut self.cors, &self.error) {
428 cors.supports_credentials = true
429 }
430 self
431 }
432
433 pub fn disable_vary_header(mut self) -> Self {
445 if let Some(cors) = cors(&mut self.cors, &self.error) {
446 cors.vary_header = false
447 }
448 self
449 }
450
451 pub fn disable_preflight(mut self) -> Self {
458 if let Some(cors) = cors(&mut self.cors, &self.error) {
459 cors.preflight = false
460 }
461 self
462 }
463
464 pub fn finish<Err>(self) -> CorsFactory<Err> {
466 let mut slf = if !self.methods {
467 self.allowed_methods(vec![
468 Method::GET,
469 Method::HEAD,
470 Method::POST,
471 Method::OPTIONS,
472 Method::PUT,
473 Method::PATCH,
474 Method::DELETE,
475 ])
476 } else {
477 self
478 };
479
480 if let Some(e) = slf.error.take() {
481 panic!("{}", e);
482 }
483
484 let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder");
485
486 if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
487 panic!("Credentials are allowed, but the Origin is set to \"*\"");
488 }
489
490 if let AllOrSome::Some(ref origins) = cors.origins {
491 let s = origins.iter().fold(String::new(), |s, v| format!("{}, {}", s, v));
492 cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap());
493 }
494
495 if !slf.expose_hdrs.is_empty() {
496 cors.expose_hdrs = Some(
497 HeaderValue::try_from(
498 &slf.expose_hdrs
499 .iter()
500 .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..],
501 )
502 .unwrap(),
503 );
504 }
505
506 CorsFactory { inner: Rc::new(cors), _t: PhantomData }
507 }
508}
509
510fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<HttpError>) -> Option<&'a mut Inner> {
511 if err.is_some() {
512 return None;
513 }
514 parts.as_mut()
515}
516
517struct Inner {
518 methods: HashSet<Method>,
519 origins: AllOrSome<HashSet<String>>,
520 origins_str: Option<HeaderValue>,
521 headers: AllOrSome<HashSet<HeaderName>>,
522 expose_hdrs: Option<HeaderValue>,
523 max_age: Option<usize>,
524 preflight: bool,
525 send_wildcard: bool,
526 supports_credentials: bool,
527 vary_header: bool,
528}
529
530impl Inner {
531 fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
532 if let Some(hdr) = req.headers().get(&header::ORIGIN) {
533 if let Ok(origin) = hdr.to_str() {
534 return match self.origins {
535 AllOrSome::All => Ok(()),
536 AllOrSome::Some(ref allowed_origins) => allowed_origins
537 .get(origin)
538 .map(|_| ())
539 .ok_or(CorsError::OriginNotAllowed),
540 };
541 }
542 Err(CorsError::BadOrigin)
543 } else {
544 match self.origins {
545 AllOrSome::All => Ok(()),
546 _ => Err(CorsError::MissingOrigin),
547 }
548 }
549 }
550
551 fn access_control_allow_origin(&self, headers: &HeaderMap) -> Option<HeaderValue> {
552 match self.origins {
553 AllOrSome::All => {
554 if self.send_wildcard {
555 Some(HeaderValue::from_static("*"))
556 } else {
557 headers.get(&header::ORIGIN).cloned()
558 }
559 }
560 AllOrSome::Some(ref origins) => {
561 if let Some(origin) =
562 headers.get(&header::ORIGIN).filter(|o| match o.to_str() {
563 Ok(os) => origins.contains(os),
564 _ => false,
565 })
566 {
567 Some(origin.clone())
568 } else {
569 Some(self.origins_str.as_ref().unwrap().clone())
570 }
571 }
572 }
573 }
574
575 fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
576 if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) {
577 if let Ok(meth) = hdr.to_str() {
578 if let Ok(method) = Method::try_from(meth) {
579 return self
580 .methods
581 .get(&method)
582 .map(|_| ())
583 .ok_or(CorsError::MethodNotAllowed);
584 }
585 }
586 Err(CorsError::BadRequestMethod)
587 } else {
588 Err(CorsError::MissingRequestMethod)
589 }
590 }
591
592 fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
593 match self.headers {
594 AllOrSome::All => Ok(()),
595 AllOrSome::Some(ref allowed_headers) => {
596 if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) {
597 if let Ok(headers) = hdr.to_str() {
598 let mut hdrs = HashSet::new();
599 for hdr in headers.split(',') {
600 match HeaderName::try_from(hdr.trim()) {
601 Ok(hdr) => hdrs.insert(hdr),
602 Err(_) => return Err(CorsError::BadRequestHeaders),
603 };
604 }
605 if !hdrs.is_empty() {
608 if !hdrs.is_subset(allowed_headers) {
609 return Err(CorsError::HeadersNotAllowed);
610 }
611 return Ok(());
612 }
613 }
614 Err(CorsError::BadRequestHeaders)
615 } else {
616 Ok(())
617 }
618 }
619 }
620 }
621
622 fn preflight_check(
623 &self,
624 req: &RequestHead,
625 ) -> Result<Either<HttpResponse, ()>, CorsError> {
626 if self.preflight && Method::OPTIONS == req.method {
627 self.validate_origin(req)
628 .and_then(|_| self.validate_allowed_method(req))
629 .and_then(|_| self.validate_allowed_headers(req))?;
630
631 let headers = if let Some(headers) = self.headers.as_ref() {
633 Some(
634 HeaderValue::try_from(
635 &headers
636 .iter()
637 .fold(String::new(), |s, v| s + "," + v.as_str())
638 .as_str()[1..],
639 )
640 .unwrap(),
641 )
642 } else {
643 req.headers.get(&header::ACCESS_CONTROL_REQUEST_HEADERS).cloned()
644 };
645
646 let res = HttpResponse::Ok()
647 .if_some(self.max_age.as_ref(), |max_age, resp| {
648 let _ = resp.header(
649 header::ACCESS_CONTROL_MAX_AGE,
650 format!("{}", max_age).as_str(),
651 );
652 })
653 .if_some(headers, |headers, resp| {
654 let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
655 })
656 .if_some(self.access_control_allow_origin(req.headers()), |origin, resp| {
657 let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
658 })
659 .if_true(self.supports_credentials, |resp| {
660 resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
661 })
662 .header(
663 header::ACCESS_CONTROL_ALLOW_METHODS,
664 &self
665 .methods
666 .iter()
667 .fold(String::new(), |s, v| s + "," + v.as_str())
668 .as_str()[1..],
669 )
670 .finish()
671 .into_body();
672
673 Ok(Either::Left(res))
674 } else {
675 if req.headers.contains_key(&header::ORIGIN) {
676 self.validate_origin(req)?;
678 }
679 Ok(Either::Right(()))
680 }
681 }
682
683 fn handle_response(&self, headers: &mut HeaderMap, allowed_origin: Option<HeaderValue>) {
684 if let Some(origin) = allowed_origin {
685 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
686 };
687
688 if let Some(ref expose) = self.expose_hdrs {
689 headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
690 }
691 if self.supports_credentials {
692 headers.insert(
693 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
694 HeaderValue::from_static("true"),
695 );
696 }
697 if self.vary_header {
698 let value = if let Some(hdr) = headers.get(&header::VARY) {
699 let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
700 val.extend(hdr.as_bytes());
701 val.extend(b", Origin");
702 HeaderValue::try_from(&val[..]).unwrap()
703 } else {
704 HeaderValue::from_static("Origin")
705 };
706 headers.insert(header::VARY, value);
707 }
708 }
709}
710
711pub struct CorsFactory<Err> {
716 inner: Rc<Inner>,
717 _t: PhantomData<Err>,
718}
719
720impl<S, Err> Middleware<S> for CorsFactory<Err>
721where
722 S: Service<WebRequest<Err>, Response = WebResponse>,
723{
724 type Service = CorsMiddleware<S>;
725
726 fn create(&self, service: S) -> Self::Service {
727 CorsMiddleware { service, inner: self.inner.clone() }
728 }
729}
730
731#[derive(Clone)]
736pub struct CorsMiddleware<S> {
737 service: S,
738 inner: Rc<Inner>,
739}
740
741impl<S, Err> Service<WebRequest<Err>> for CorsMiddleware<S>
742where
743 S: Service<WebRequest<Err>, Response = WebResponse>,
744 Err: ErrorRenderer,
745 Err::Container: From<S::Error>,
746 CorsError: WebResponseError<Err>,
747{
748 type Response = WebResponse;
749 type Error = S::Error;
750 type Future<'f> = Either<
751 Ready<Result<Self::Response, S::Error>>,
752 LocalBoxFuture<'f, Result<Self::Response, S::Error>>,
753 > where Self: 'f;
754
755 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
756 self.service.poll_ready(cx)
757 }
758
759 fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
760 self.service.poll_shutdown(cx)
761 }
762
763 fn call<'a>(&'a self, req: WebRequest<Err>, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> {
764 match self.inner.preflight_check(req.head()) {
765 Ok(Either::Left(res)) => Either::Left(ok(req.into_response(res))),
766 Ok(Either::Right(_)) => {
767 let inner = self.inner.clone();
768 let allowed_origin = inner.access_control_allow_origin(req.headers());
770
771 Either::Right(
772 async move {
773 let mut res = ctx.call(&self.service, req).await?;
774
775 inner.handle_response(res.headers_mut(), allowed_origin);
777 Ok(res)
779 }
780 .boxed_local(),
781 )
782 }
783 Err(e) => Either::Left(ok(req.render_error(e))),
784 }
785 }
786}
787
788#[cfg(test)]
789mod tests {
790 use ntex::service::{fn_service, Middleware, Pipeline};
791 use ntex::web::{self, test, test::TestRequest};
792
793 use super::*;
794
795 #[ntex::test]
796 #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
797 async fn cors_validates_illegal_allow_credentials() {
798 let _cors =
799 Cors::new().supports_credentials().send_wildcard().finish::<web::DefaultError>();
800 }
801
802 #[ntex::test]
803 async fn validate_origin_allows_all_origins() {
804 let cors = Cors::new().finish().create(test::ok_service()).into();
805 let req =
806 TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
807
808 let resp = test::call_service(&cors, req).await;
809 assert_eq!(resp.status(), StatusCode::OK);
810 }
811
812 #[ntex::test]
813 async fn default() {
814 let cors = Cors::default().create(test::ok_service()).into();
815 let req =
816 TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
817
818 let resp = test::call_service(&cors, req).await;
819 assert_eq!(resp.status(), StatusCode::OK);
820 }
821
822 #[ntex::test]
823 async fn test_preflight() {
824 let mut cors: Pipeline<_> = Cors::new()
825 .send_wildcard()
826 .max_age(3600)
827 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
828 .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
829 .allowed_header(header::CONTENT_TYPE)
830 .finish()
831 .create(test::ok_service())
832 .into();
833
834 let req = TestRequest::with_header("Origin", "https://www.example.com")
835 .method(Method::OPTIONS)
836 .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
837 .to_srv_request();
838
839 assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
840 assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_err());
841 let resp = test::call_service(&cors, req).await;
842 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
843
844 let req = TestRequest::with_header("Origin", "https://www.example.com")
845 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
846 .method(Method::OPTIONS)
847 .to_srv_request();
848
849 assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
850 assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_ok());
851
852 let req = TestRequest::with_header("Origin", "https://www.example.com")
853 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
854 .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
855 .method(Method::OPTIONS)
856 .to_srv_request();
857
858 let resp = test::call_service(&cors, req).await;
859 assert_eq!(
860 &b"*"[..],
861 resp.headers().get(&header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
862 );
863 assert_eq!(
864 &b"3600"[..],
865 resp.headers().get(&header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()
866 );
867 let hdr = resp
868 .headers()
869 .get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
870 .unwrap()
871 .to_str()
872 .unwrap();
873 assert!(hdr.contains("authorization"));
874 assert!(hdr.contains("accept"));
875 assert!(hdr.contains("content-type"));
876
877 let methods =
878 resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().to_str().unwrap();
879 assert!(methods.contains("POST"));
880 assert!(methods.contains("GET"));
881 assert!(methods.contains("OPTIONS"));
882
883 }
894
895 #[ntex::test]
906 #[should_panic(expected = "OriginNotAllowed")]
907 async fn test_validate_not_allowed_origin() {
908 let cors: Pipeline<_> = Cors::new()
909 .allowed_origin("https://www.example.com")
910 .finish()
911 .create(test::ok_service::<web::DefaultError>())
912 .into();
913
914 let req = TestRequest::with_header("Origin", "https://www.unknown.com")
915 .method(Method::GET)
916 .to_srv_request();
917 cors.get_ref().inner.validate_origin(req.head()).unwrap();
918 cors.get_ref().inner.validate_allowed_method(req.head()).unwrap();
919 cors.get_ref().inner.validate_allowed_headers(req.head()).unwrap();
920 }
921
922 #[ntex::test]
923 async fn test_validate_origin() {
924 let cors = Cors::new()
925 .allowed_origin("https://www.example.com")
926 .finish()
927 .create(test::ok_service())
928 .into();
929
930 let req = TestRequest::with_header("Origin", "https://www.example.com")
931 .method(Method::GET)
932 .to_srv_request();
933
934 let resp = test::call_service(&cors, req).await;
935 assert_eq!(resp.status(), StatusCode::OK);
936 }
937
938 #[ntex::test]
939 async fn test_no_origin_response() {
940 let cors = Cors::new().disable_preflight().finish().create(test::ok_service()).into();
941
942 let req = TestRequest::default().method(Method::GET).to_srv_request();
943 let resp = test::call_service(&cors, req).await;
944 assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
945
946 let req = TestRequest::with_header("Origin", "https://www.example.com")
947 .method(Method::OPTIONS)
948 .to_srv_request();
949 let resp = test::call_service(&cors, req).await;
950 assert_eq!(
951 &b"https://www.example.com"[..],
952 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
953 );
954 }
955
956 #[ntex::test]
957 async fn test_response() {
958 let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
959 let cors = Cors::new()
960 .send_wildcard()
961 .disable_preflight()
962 .max_age(3600)
963 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
964 .allowed_headers(exposed_headers.clone())
965 .expose_headers(exposed_headers.clone())
966 .allowed_header(header::CONTENT_TYPE)
967 .finish()
968 .create(test::ok_service())
969 .into();
970
971 let req = TestRequest::with_header("Origin", "https://www.example.com")
972 .method(Method::OPTIONS)
973 .to_srv_request();
974
975 let resp = test::call_service(&cors, req).await;
976 assert_eq!(
977 &b"*"[..],
978 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
979 );
980 assert_eq!(&b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes());
981
982 {
983 let headers = resp
984 .headers()
985 .get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
986 .unwrap()
987 .to_str()
988 .unwrap()
989 .split(',')
990 .map(|s| s.trim())
991 .collect::<Vec<&str>>();
992
993 for h in exposed_headers {
994 assert!(headers.contains(&h.as_str()));
995 }
996 }
997
998 let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
999 let cors =
1000 Cors::new()
1001 .send_wildcard()
1002 .disable_preflight()
1003 .max_age(3600)
1004 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
1005 .allowed_headers(exposed_headers.clone())
1006 .expose_headers(exposed_headers.clone())
1007 .allowed_header(header::CONTENT_TYPE)
1008 .finish()
1009 .create(fn_service(|req: WebRequest<DefaultError>| {
1010 ok::<_, std::convert::Infallible>(req.into_response(
1011 HttpResponse::Ok().header(header::VARY, "Accept").finish(),
1012 ))
1013 }))
1014 .into();
1015 let req = TestRequest::with_header("Origin", "https://www.example.com")
1016 .method(Method::OPTIONS)
1017 .to_srv_request();
1018 let resp = test::call_service(&cors, req).await;
1019 assert_eq!(
1020 &b"Accept, Origin"[..],
1021 resp.headers().get(header::VARY).unwrap().as_bytes()
1022 );
1023
1024 let cors = Cors::new()
1025 .disable_vary_header()
1026 .allowed_origin("https://www.example.com")
1027 .allowed_origin("https://www.google.com")
1028 .finish()
1029 .create(test::ok_service())
1030 .into();
1031
1032 let req = TestRequest::with_header("Origin", "https://www.example.com")
1033 .method(Method::OPTIONS)
1034 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
1035 .to_srv_request();
1036 let resp = test::call_service(&cors, req).await;
1037
1038 let origins_str =
1039 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().to_str().unwrap();
1040
1041 assert_eq!("https://www.example.com", origins_str);
1042 }
1043
1044 #[ntex::test]
1045 async fn test_multiple_origins() {
1046 let cors = Cors::new()
1047 .allowed_origin("https://example.com")
1048 .allowed_origin("https://example.org")
1049 .allowed_methods(vec![Method::GET])
1050 .finish()
1051 .create(test::ok_service())
1052 .into();
1053
1054 let req = TestRequest::with_header("Origin", "https://example.com")
1055 .method(Method::GET)
1056 .to_srv_request();
1057
1058 let resp = test::call_service(&cors, req).await;
1059 assert_eq!(
1060 &b"https://example.com"[..],
1061 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1062 );
1063
1064 let req = TestRequest::with_header("Origin", "https://example.org")
1065 .method(Method::GET)
1066 .to_srv_request();
1067
1068 let resp = test::call_service(&cors, req).await;
1069 assert_eq!(
1070 &b"https://example.org"[..],
1071 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1072 );
1073 }
1074
1075 #[ntex::test]
1076 async fn test_multiple_origins_preflight() {
1077 let cors = Cors::new()
1078 .allowed_origin("https://example.com")
1079 .allowed_origin("https://example.org")
1080 .allowed_methods(vec![Method::GET])
1081 .finish()
1082 .create(test::ok_service())
1083 .into();
1084
1085 let req = TestRequest::with_header("Origin", "https://example.com")
1086 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1087 .method(Method::OPTIONS)
1088 .to_srv_request();
1089
1090 let resp = test::call_service(&cors, req).await;
1091 assert_eq!(
1092 &b"https://example.com"[..],
1093 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1094 );
1095
1096 let req = TestRequest::with_header("Origin", "https://example.org")
1097 .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1098 .method(Method::OPTIONS)
1099 .to_srv_request();
1100
1101 let resp = test::call_service(&cors, req).await;
1102 assert_eq!(
1103 &b"https://example.org"[..],
1104 resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1105 );
1106 }
1107}