Skip to main content

ntex_cors/
lib.rs

1#![allow(
2    clippy::borrow_interior_mutable_const,
3    clippy::type_complexity,
4    clippy::mutable_key_type
5)]
6//! Cross-origin resource sharing (CORS) for ntex applications
7//!
8//! CORS middleware could be used with application and with resource.
9//! Cors middleware could be used as parameter for `App::wrap()`,
10//! `Resource::wrap()` or `Scope::wrap()` methods.
11//!
12//! # Example
13//!
14//! ```rust,no_run
15//! use ntex_cors::Cors;
16//! use ntex::{http, web};
17//! use ntex::web::{App, HttpRequest, HttpResponse};
18//!
19//! async fn index(req: HttpRequest) -> &'static str {
20//!     "Hello world"
21//! }
22//!
23//! #[ntex::main]
24//! async fn main() -> std::io::Result<()> {
25//!     web::server(async || App::new()
26//!         .wrap(
27//!             Cors::new() // <- Construct CORS middleware builder
28//!               .allowed_origin("https://www.rust-lang.org/")
29//!               .allowed_methods(vec![http::Method::GET, http::Method::POST])
30//!               .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
31//!               .allowed_header(http::header::CONTENT_TYPE)
32//!               .max_age(3600)
33//!               .finish())
34//!         .service(
35//!             web::resource("/index.html")
36//!               .route(web::get().to(index))
37//!               .route(web::head().to(|| async { HttpResponse::MethodNotAllowed() }))
38//!         ))
39//!         .bind("127.0.0.1:8080")?
40//!         .run()
41//!         .await
42//! }
43//! ```
44//! In this example custom *CORS* middleware get registered for "/index.html"
45//! endpoint.
46//!
47//! Cors middleware automatically handle *OPTIONS* preflight request.
48use std::{
49    collections::HashSet, convert::TryFrom, iter::FromIterator, marker::PhantomData, rc::Rc,
50};
51
52use derive_more::Display;
53use ntex::http::header::{self, HeaderName, HeaderValue};
54use ntex::http::{HeaderMap, Method, RequestHead, StatusCode, Uri, error::HttpError};
55use ntex::service::{Middleware, Service, ServiceCtx};
56use ntex::util::{ByteString, Either};
57use ntex::web::{
58    DefaultError, ErrorRenderer, HttpResponse, WebRequest, WebResponse, WebResponseError,
59};
60
61/// A set of errors that can occur during processing CORS
62#[derive(Debug, Display)]
63pub enum CorsError {
64    /// The HTTP request header `Origin` is required but was not provided
65    #[display("The HTTP request header `Origin` is required but was not provided")]
66    MissingOrigin,
67    /// The HTTP request header `Origin` could not be parsed correctly.
68    #[display("The HTTP request header `Origin` could not be parsed correctly.")]
69    BadOrigin,
70    /// The request header `Access-Control-Request-Method` is required but is
71    /// missing
72    #[display("The request header `Access-Control-Request-Method` is required but is missing")]
73    MissingRequestMethod,
74    /// The request header `Access-Control-Request-Method` has an invalid value
75    #[display("The request header `Access-Control-Request-Method` has an invalid value")]
76    BadRequestMethod,
77    /// The request header `Access-Control-Request-Headers`  has an invalid
78    /// value
79    #[display("The request header `Access-Control-Request-Headers`  has an invalid value")]
80    BadRequestHeaders,
81    /// Origin is not allowed to make this request
82    #[display("Origin is not allowed to make this request")]
83    OriginNotAllowed,
84    /// Requested method is not allowed
85    #[display("Requested method is not allowed")]
86    MethodNotAllowed,
87    /// One or more headers requested are not allowed
88    #[display("One or more headers requested are not allowed")]
89    HeadersNotAllowed,
90}
91
92/// DefaultError renderer support
93impl WebResponseError<DefaultError> for CorsError {
94    fn status_code(&self) -> StatusCode {
95        StatusCode::BAD_REQUEST
96    }
97}
98
99/// An enum signifying that some of type T is allowed, or `All` (everything is
100/// allowed).
101///
102/// `Default` is implemented for this enum and is `All`.
103#[derive(Default, Clone, Debug, Eq, PartialEq)]
104pub enum AllOrSome<T> {
105    /// Everything is allowed. Usually equivalent to the "*" value.
106    #[default]
107    All,
108    /// Only some of `T` is allowed
109    Some(T),
110}
111
112impl<T> AllOrSome<T> {
113    /// Returns whether this is an `All` variant
114    pub fn is_all(&self) -> bool {
115        match *self {
116            AllOrSome::All => true,
117            AllOrSome::Some(_) => false,
118        }
119    }
120
121    /// Returns whether this is a `Some` variant
122    pub fn is_some(&self) -> bool {
123        !self.is_all()
124    }
125
126    /// Returns &T
127    pub fn as_ref(&self) -> Option<&T> {
128        match *self {
129            AllOrSome::All => None,
130            AllOrSome::Some(ref t) => Some(t),
131        }
132    }
133}
134
135/// Structure that follows the builder pattern for building `Cors` middleware
136/// structs.
137///
138/// To construct a cors:
139///
140///   1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
141///   2. Use any of the builder methods to set fields in the backend.
142///   3. Call [finish](struct.Cors.html#method.finish) to retrieve the constructed backend.
143///
144/// # Example
145///
146/// ```rust
147/// use ntex_cors::Cors;
148/// use ntex::http::header;
149///
150/// let cors = Cors::new()
151///     .allowed_origin("https://www.rust-lang.org/")
152///     .allowed_methods(vec!["GET", "POST"])
153///     .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
154///     .allowed_header(header::CONTENT_TYPE)
155///     .max_age(3600);
156/// ```
157#[derive(Default)]
158pub struct Cors {
159    cors: Option<Inner>,
160    methods: bool,
161    expose_hdrs: HashSet<HeaderName>,
162    error: Option<HttpError>,
163}
164
165impl Cors {
166    /// Build a new CORS middleware instance
167    pub fn new() -> Self {
168        Cors {
169            cors: Some(Inner {
170                origins: AllOrSome::All,
171                origins_str: None,
172                methods: HashSet::new(),
173                headers: AllOrSome::All,
174                expose_hdrs: None,
175                max_age: None,
176                preflight: true,
177                send_wildcard: false,
178                supports_credentials: false,
179                vary_header: true,
180            }),
181            methods: false,
182            error: None,
183            expose_hdrs: HashSet::new(),
184        }
185    }
186
187    #[allow(clippy::should_implement_trait)]
188    /// Build a new CORS default middleware
189    pub fn default<Err>() -> CorsFactory<Err> {
190        let inner = Inner {
191            origins: AllOrSome::default(),
192            origins_str: None,
193            methods: HashSet::from_iter(vec![
194                Method::GET,
195                Method::HEAD,
196                Method::POST,
197                Method::OPTIONS,
198                Method::PUT,
199                Method::PATCH,
200                Method::DELETE,
201            ]),
202            headers: AllOrSome::All,
203            expose_hdrs: None,
204            max_age: None,
205            preflight: true,
206            send_wildcard: false,
207            supports_credentials: false,
208            vary_header: true,
209        };
210        CorsFactory { inner: Rc::new(inner), _t: PhantomData }
211    }
212
213    /// Add an origin that are allowed to make requests.
214    /// Will be verified against the `Origin` request header.
215    ///
216    /// When `All` is set, and `send_wildcard` is set, "*" will be sent in
217    /// the `Access-Control-Allow-Origin` response header. Otherwise, the
218    /// client's `Origin` request header will be echoed back in the
219    /// `Access-Control-Allow-Origin` response header.
220    ///
221    /// When `Some` is set, the client's `Origin` request header will be
222    /// checked in a case-sensitive manner.
223    ///
224    /// This is the `list of origins` in the
225    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
226    ///
227    /// Defaults to `All`.
228    ///
229    /// Builder panics if supplied origin is not valid uri.
230    pub fn allowed_origin(mut self, origin: &str) -> Self {
231        if let Some(cors) = cors(&mut self.cors, &self.error) {
232            match Uri::try_from(origin) {
233                Ok(_) => {
234                    // If the origin is "*", set the origins to `All`
235                    if origin.trim() == "*" {
236                        cors.origins = AllOrSome::All;
237                        return self;
238                    }
239                    if cors.origins.is_all() {
240                        cors.origins = AllOrSome::Some(HashSet::new());
241                    }
242                    if let AllOrSome::Some(ref mut origins) = cors.origins {
243                        origins.insert(origin.to_owned());
244                    }
245                }
246                Err(e) => {
247                    self.error = Some(e.into());
248                }
249            }
250        }
251        self
252    }
253
254    /// Set a list of methods which the allowed origins are allowed to access
255    /// for requests.
256    ///
257    /// This is the `list of methods` in the
258    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
259    ///
260    /// You can use either `&str` (e.g. `"GET"`) or directly use the [`http::Method`] enum
261    /// (e.g. `http::Method::GET`). Using [`http::Method`] is recommended for better type safety
262    /// and to avoid runtime errors due to typos.
263    ///
264    /// # Example
265    /// ```rust,ignore
266    /// use ntex::http::Method;
267    ///
268    /// cors.allowed_methods(vec![Method::GET, Method::POST]);
269    /// ```
270    ///
271    /// The default is `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`.
272    pub fn allowed_methods<U, M>(mut self, methods: U) -> Self
273    where
274        U: IntoIterator<Item = M>,
275        Method: TryFrom<M>,
276        <Method as TryFrom<M>>::Error: Into<HttpError>,
277    {
278        self.methods = true;
279        if let Some(cors) = cors(&mut self.cors, &self.error) {
280            for m in methods {
281                match Method::try_from(m) {
282                    Ok(method) => {
283                        cors.methods.insert(method);
284                    }
285                    Err(e) => {
286                        self.error = Some(e.into());
287                        break;
288                    }
289                }
290            }
291        }
292        self
293    }
294
295    /// Set an allowed header
296    pub fn allowed_header<H>(mut self, header: H) -> Self
297    where
298        HeaderName: TryFrom<H>,
299        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
300    {
301        if let Some(cors) = cors(&mut self.cors, &self.error) {
302            match HeaderName::try_from(header) {
303                Ok(method) => {
304                    if cors.headers.is_all() {
305                        cors.headers = AllOrSome::Some(HashSet::new());
306                    }
307                    if let AllOrSome::Some(ref mut headers) = cors.headers {
308                        headers.insert(method);
309                    }
310                }
311                Err(e) => self.error = Some(e.into()),
312            }
313        }
314        self
315    }
316
317    /// Set a list of header field names which can be used when
318    /// this resource is accessed by allowed origins.
319    ///
320    /// If `All` is set, whatever is requested by the client in
321    /// `Access-Control-Request-Headers` will be echoed back in the
322    /// `Access-Control-Allow-Headers` header.
323    ///
324    /// This is the `list of headers` in the
325    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
326    ///
327    /// Defaults to `All`.
328    pub fn allowed_headers<U, H>(mut self, headers: U) -> Self
329    where
330        U: IntoIterator<Item = H>,
331        HeaderName: TryFrom<H>,
332        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
333    {
334        if let Some(cors) = cors(&mut self.cors, &self.error) {
335            for h in headers {
336                match HeaderName::try_from(h) {
337                    Ok(method) => {
338                        if cors.headers.is_all() {
339                            cors.headers = AllOrSome::Some(HashSet::new());
340                        }
341                        if let AllOrSome::Some(ref mut headers) = cors.headers {
342                            headers.insert(method);
343                        }
344                    }
345                    Err(e) => {
346                        self.error = Some(e.into());
347                        break;
348                    }
349                }
350            }
351        }
352        self
353    }
354
355    /// Set a list of headers which are safe to expose to the API of a CORS API
356    /// specification. This corresponds to the
357    /// `Access-Control-Expose-Headers` response header.
358    ///
359    /// This is the `list of exposed headers` in the
360    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
361    ///
362    /// This defaults to an empty set.
363    pub fn expose_headers<U, H>(mut self, headers: U) -> Self
364    where
365        U: IntoIterator<Item = H>,
366        HeaderName: TryFrom<H>,
367        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
368    {
369        for h in headers {
370            match HeaderName::try_from(h) {
371                Ok(method) => {
372                    self.expose_hdrs.insert(method);
373                }
374                Err(e) => {
375                    self.error = Some(e.into());
376                    break;
377                }
378            }
379        }
380        self
381    }
382
383    /// Set a maximum time for which this CORS request maybe cached.
384    /// This value is set as the `Access-Control-Max-Age` header.
385    ///
386    /// This defaults to `None` (unset).
387    pub fn max_age(mut self, max_age: usize) -> Self {
388        if let Some(cors) = cors(&mut self.cors, &self.error) {
389            cors.max_age = Some(max_age)
390        }
391        self
392    }
393
394    /// Set a wildcard origins
395    ///
396    /// If send wildcard is set and the `allowed_origins` parameter is `All`, a
397    /// wildcard `Access-Control-Allow-Origin` response header is sent,
398    /// rather than the request’s `Origin` header.
399    ///
400    /// This is the `supports credentials flag` in the
401    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
402    ///
403    /// This **CANNOT** be used in conjunction with `allowed_origins` set to
404    /// `All` and `allow_credentials` set to `true`. Depending on the mode
405    /// of usage, this will either result in an `Error::
406    /// CredentialsWithWildcardOrigin` error during ntex launch or runtime.
407    ///
408    /// Defaults to `false`.
409    pub fn send_wildcard(mut self) -> Self {
410        if let Some(cors) = cors(&mut self.cors, &self.error) {
411            cors.send_wildcard = true
412        }
413        self
414    }
415
416    /// Allows users to make authenticated requests
417    ///
418    /// If true, injects the `Access-Control-Allow-Credentials` header in
419    /// responses. This allows cookies and credentials to be submitted
420    /// across domains.
421    ///
422    /// This option cannot be used in conjunction with an `allowed_origin` set
423    /// to `All` and `send_wildcards` set to `true`.
424    ///
425    /// Defaults to `false`.
426    ///
427    /// Builder panics if credentials are allowed, but the Origin is set to "*".
428    /// This is not allowed by W3C
429    pub fn supports_credentials(mut self) -> Self {
430        if let Some(cors) = cors(&mut self.cors, &self.error) {
431            cors.supports_credentials = true
432        }
433        self
434    }
435
436    /// Disable `Vary` header support.
437    ///
438    /// When enabled the header `Vary: Origin` will be returned as per the W3
439    /// implementation guidelines.
440    ///
441    /// Setting this header when the `Access-Control-Allow-Origin` is
442    /// dynamically generated (e.g. when there is more than one allowed
443    /// origin, and an Origin than '*' is returned) informs CDNs and other
444    /// caches that the CORS headers are dynamic, and cannot be cached.
445    ///
446    /// By default `vary` header support is enabled.
447    pub fn disable_vary_header(mut self) -> Self {
448        if let Some(cors) = cors(&mut self.cors, &self.error) {
449            cors.vary_header = false
450        }
451        self
452    }
453
454    /// Disable *preflight* request support.
455    ///
456    /// When enabled cors middleware automatically handles *OPTIONS* request.
457    /// This is useful application level middleware.
458    ///
459    /// By default *preflight* support is enabled.
460    pub fn disable_preflight(mut self) -> Self {
461        if let Some(cors) = cors(&mut self.cors, &self.error) {
462            cors.preflight = false
463        }
464        self
465    }
466
467    /// Construct cors middleware
468    pub fn finish<Err>(self) -> CorsFactory<Err> {
469        let mut slf = if !self.methods {
470            self.allowed_methods(vec![
471                Method::GET,
472                Method::HEAD,
473                Method::POST,
474                Method::OPTIONS,
475                Method::PUT,
476                Method::PATCH,
477                Method::DELETE,
478            ])
479        } else {
480            self
481        };
482
483        if let Some(e) = slf.error.take() {
484            panic!("{}", e);
485        }
486
487        let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder");
488
489        if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
490            panic!("Credentials are allowed, but the Origin is set to \"*\"");
491        }
492
493        if let AllOrSome::Some(ref origins) = cors.origins {
494            let s = origins.iter().fold(String::new(), |s, v| format!("{}, {}", s, v));
495            cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap());
496        }
497
498        if !slf.expose_hdrs.is_empty() {
499            cors.expose_hdrs = Some(
500                HeaderValue::try_from(
501                    &slf.expose_hdrs
502                        .iter()
503                        .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..],
504                )
505                .unwrap(),
506            );
507        }
508
509        CorsFactory { inner: Rc::new(cors), _t: PhantomData }
510    }
511}
512
513fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<HttpError>) -> Option<&'a mut Inner> {
514    if err.is_some() {
515        return None;
516    }
517    parts.as_mut()
518}
519
520struct Inner {
521    methods: HashSet<Method>,
522    origins: AllOrSome<HashSet<String>>,
523    origins_str: Option<HeaderValue>,
524    headers: AllOrSome<HashSet<HeaderName>>,
525    expose_hdrs: Option<HeaderValue>,
526    max_age: Option<usize>,
527    preflight: bool,
528    send_wildcard: bool,
529    supports_credentials: bool,
530    vary_header: bool,
531}
532
533impl Inner {
534    fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
535        if let Some(hdr) = req.headers().get(&header::ORIGIN) {
536            if let Ok(origin) = hdr.to_str() {
537                return match self.origins {
538                    AllOrSome::All => Ok(()),
539                    AllOrSome::Some(ref allowed_origins) => allowed_origins
540                        .get(origin)
541                        .map(|_| ())
542                        .ok_or(CorsError::OriginNotAllowed),
543                };
544            }
545            Err(CorsError::BadOrigin)
546        } else {
547            match self.origins {
548                AllOrSome::All => Ok(()),
549                _ => Err(CorsError::MissingOrigin),
550            }
551        }
552    }
553
554    fn access_control_allow_origin(&self, headers: &HeaderMap) -> Option<HeaderValue> {
555        match self.origins {
556            AllOrSome::All => {
557                if self.send_wildcard {
558                    Some(HeaderValue::from_static("*"))
559                } else {
560                    headers.get(&header::ORIGIN).cloned()
561                }
562            }
563            AllOrSome::Some(ref origins) => {
564                if let Some(origin) =
565                    headers.get(&header::ORIGIN).filter(|o| match o.to_str() {
566                        Ok(os) => origins.contains(os),
567                        _ => false,
568                    })
569                {
570                    Some(origin.clone())
571                } else {
572                    Some(self.origins_str.as_ref().unwrap().clone())
573                }
574            }
575        }
576    }
577
578    fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
579        if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) {
580            if let Ok(meth) = hdr.to_str()
581                && let Ok(method) = Method::try_from(meth)
582            {
583                return self
584                    .methods
585                    .get(&method)
586                    .map(|_| ())
587                    .ok_or(CorsError::MethodNotAllowed);
588            }
589            Err(CorsError::BadRequestMethod)
590        } else {
591            Err(CorsError::MissingRequestMethod)
592        }
593    }
594
595    fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
596        match self.headers {
597            AllOrSome::All => Ok(()),
598            AllOrSome::Some(ref allowed_headers) => {
599                if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) {
600                    if let Ok(headers) = hdr.to_str() {
601                        let mut hdrs = HashSet::new();
602                        for hdr in headers.split(',') {
603                            match HeaderName::try_from(hdr.trim()) {
604                                Ok(hdr) => hdrs.insert(hdr),
605                                Err(_) => return Err(CorsError::BadRequestHeaders),
606                            };
607                        }
608                        // `Access-Control-Request-Headers` must contain 1 or more
609                        // `field-name`.
610                        if !hdrs.is_empty() {
611                            if !hdrs.is_subset(allowed_headers) {
612                                return Err(CorsError::HeadersNotAllowed);
613                            }
614                            return Ok(());
615                        }
616                    }
617                    Err(CorsError::BadRequestHeaders)
618                } else {
619                    Ok(())
620                }
621            }
622        }
623    }
624
625    fn preflight_check(
626        &self,
627        req: &RequestHead,
628    ) -> Result<Either<HttpResponse, ()>, CorsError> {
629        if self.preflight && Method::OPTIONS == req.method {
630            self.validate_origin(req)
631                .and_then(|_| self.validate_allowed_method(req))
632                .and_then(|_| self.validate_allowed_headers(req))?;
633
634            // allowed headers
635            let headers = if let Some(headers) = self.headers.as_ref() {
636                Some(
637                    HeaderValue::try_from(
638                        &headers
639                            .iter()
640                            .fold(String::new(), |s, v| s + "," + v.as_str())
641                            .as_str()[1..],
642                    )
643                    .unwrap(),
644                )
645            } else {
646                req.headers.get(&header::ACCESS_CONTROL_REQUEST_HEADERS).cloned()
647            };
648
649            let mut res = HttpResponse::Ok();
650
651            if let Some(ref max_age) = self.max_age {
652                res.header(
653                    header::ACCESS_CONTROL_MAX_AGE,
654                    ByteString::from(format!("{}", max_age)),
655                );
656            }
657            if let Some(headers) = headers {
658                res.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
659            }
660            if let Some(origin) = self.access_control_allow_origin(req.headers()) {
661                res.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
662            }
663            if self.supports_credentials {
664                res.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
665            }
666            let res = res
667                .header(
668                    header::ACCESS_CONTROL_ALLOW_METHODS,
669                    &self
670                        .methods
671                        .iter()
672                        .fold(String::new(), |s, v| s + "," + v.as_str())
673                        .as_str()[1..],
674                )
675                .finish()
676                .into_body();
677
678            Ok(Either::Left(res))
679        } else {
680            if req.headers.contains_key(&header::ORIGIN) {
681                // Only check requests with a origin header.
682                self.validate_origin(req)?;
683            }
684            Ok(Either::Right(()))
685        }
686    }
687
688    fn handle_response(&self, headers: &mut HeaderMap, allowed_origin: Option<HeaderValue>) {
689        if let Some(origin) = allowed_origin {
690            headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
691        };
692
693        if let Some(ref expose) = self.expose_hdrs {
694            headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
695        }
696        if self.supports_credentials {
697            headers.insert(
698                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
699                HeaderValue::from_static("true"),
700            );
701        }
702        if self.vary_header {
703            let value = if let Some(hdr) = headers.get(&header::VARY) {
704                let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
705                val.extend(hdr.as_bytes());
706                val.extend(b", Origin");
707                HeaderValue::try_from(&val[..]).unwrap()
708            } else {
709                HeaderValue::from_static("Origin")
710            };
711            headers.insert(header::VARY, value);
712        }
713    }
714}
715
716/// `Middleware` for Cross-origin resource sharing support
717///
718/// The Cors struct contains the settings for CORS requests to be validated and
719/// for responses to be generated.
720pub struct CorsFactory<Err> {
721    inner: Rc<Inner>,
722    _t: PhantomData<Err>,
723}
724
725impl<S, C, Err> Middleware<S, C> for CorsFactory<Err>
726where
727    S: Service<WebRequest<Err>, Response = WebResponse>,
728{
729    type Service = CorsMiddleware<S>;
730
731    fn create(&self, service: S, _: C) -> Self::Service {
732        CorsMiddleware { service, inner: self.inner.clone() }
733    }
734}
735
736/// `Middleware` for Cross-origin resource sharing support
737///
738/// The Cors struct contains the settings for CORS requests to be validated and
739/// for responses to be generated.
740#[derive(Clone)]
741pub struct CorsMiddleware<S> {
742    service: S,
743    inner: Rc<Inner>,
744}
745
746impl<S, Err> Service<WebRequest<Err>> for CorsMiddleware<S>
747where
748    S: Service<WebRequest<Err>, Response = WebResponse>,
749    Err: ErrorRenderer,
750    Err::Container: From<S::Error>,
751    CorsError: WebResponseError<Err>,
752{
753    type Response = WebResponse;
754    type Error = S::Error;
755
756    ntex::forward_ready!(service);
757    ntex::forward_shutdown!(service);
758
759    async fn call(
760        &self,
761        req: WebRequest<Err>,
762        ctx: ServiceCtx<'_, Self>,
763    ) -> Result<Self::Response, S::Error> {
764        match self.inner.preflight_check(req.head()) {
765            Ok(Either::Left(res)) => Ok(req.into_response(res)),
766            Ok(Either::Right(_)) => {
767                let inner = self.inner.clone();
768                let has_origin = req.headers().contains_key(&header::ORIGIN);
769                let allowed_origin = inner.access_control_allow_origin(req.headers());
770
771                let mut res = ctx.call(&self.service, req).await?;
772
773                if has_origin {
774                    inner.handle_response(res.headers_mut(), allowed_origin);
775                }
776                Ok(res)
777            }
778            Err(e) => Ok(req.render_error(&e)),
779        }
780    }
781}
782
783#[cfg(test)]
784mod tests {
785    use ntex::service::{Pipeline, fn_service};
786    use ntex::web::{self, test, test::TestRequest};
787
788    use super::*;
789
790    #[ntex::test]
791    #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
792    async fn cors_validates_illegal_allow_credentials() {
793        let _cors =
794            Cors::new().supports_credentials().send_wildcard().finish::<web::DefaultError>();
795    }
796
797    #[ntex::test]
798    async fn validate_origin_allows_all_origins() {
799        let cors = Cors::new().finish().create(test::ok_service(), ()).into();
800        let req =
801            TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
802
803        let resp = test::call_service(&cors, req).await;
804        assert_eq!(resp.status(), StatusCode::OK);
805    }
806
807    #[ntex::test]
808    async fn default() {
809        let cors = Cors::default().create(test::ok_service(), ()).into();
810        let req =
811            TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
812
813        let resp = test::call_service(&cors, req).await;
814        assert_eq!(resp.status(), StatusCode::OK);
815    }
816
817    #[ntex::test]
818    async fn test_preflight() {
819        let cors: Pipeline<_> = Cors::new()
820            .send_wildcard()
821            .max_age(3600)
822            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
823            .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
824            .allowed_header(header::CONTENT_TYPE)
825            .finish()
826            .create(test::ok_service(), ())
827            .into();
828
829        let req = TestRequest::with_header("Origin", "https://www.example.com")
830            .method(Method::OPTIONS)
831            .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
832            .to_srv_request();
833
834        assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
835        assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_err());
836        let resp = test::call_service(&cors, req).await;
837        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
838
839        let req = TestRequest::with_header("Origin", "https://www.example.com")
840            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
841            .method(Method::OPTIONS)
842            .to_srv_request();
843
844        assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
845        assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_ok());
846
847        let req = TestRequest::with_header("Origin", "https://www.example.com")
848            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
849            .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
850            .method(Method::OPTIONS)
851            .to_srv_request();
852
853        let resp = test::call_service(&cors, req).await;
854        assert_eq!(
855            &b"*"[..],
856            resp.headers().get(&header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
857        );
858        assert_eq!(
859            &b"3600"[..],
860            resp.headers().get(&header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()
861        );
862        let hdr = resp
863            .headers()
864            .get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
865            .unwrap()
866            .to_str()
867            .unwrap();
868        assert!(hdr.contains("authorization"));
869        assert!(hdr.contains("accept"));
870        assert!(hdr.contains("content-type"));
871
872        let methods =
873            resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().to_str().unwrap();
874        assert!(methods.contains("POST"));
875        assert!(methods.contains("GET"));
876        assert!(methods.contains("OPTIONS"));
877
878        // Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
879
880        // let req = TestRequest::with_header("Origin", "https://www.example.com")
881        //     .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
882        //     .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
883        //     .method(Method::OPTIONS)
884        //     .to_srv_request();
885
886        // let resp = test::call_service(&cors, req).await;
887        // assert_eq!(resp.status(), StatusCode::OK);
888    }
889
890    // #[ntex::test]
891    // #[should_panic(expected = "MissingOrigin")]
892    // async fn test_validate_missing_origin() {
893    //    let cors = Cors::build()
894    //        .allowed_origin("https://www.example.com")
895    //        .finish();
896    //    let mut req = HttpRequest::default();
897    //    cors.start(&req).unwrap();
898    // }
899
900    #[ntex::test]
901    #[should_panic(expected = "OriginNotAllowed")]
902    async fn test_validate_not_allowed_origin() {
903        let cors: Pipeline<_> = Cors::new()
904            .allowed_origin("https://www.example.com")
905            .finish()
906            .create(test::ok_service::<web::DefaultError>(), ())
907            .into();
908
909        let req = TestRequest::with_header("Origin", "https://www.unknown.com")
910            .method(Method::GET)
911            .to_srv_request();
912        cors.get_ref().inner.validate_origin(req.head()).unwrap();
913        cors.get_ref().inner.validate_allowed_method(req.head()).unwrap();
914        cors.get_ref().inner.validate_allowed_headers(req.head()).unwrap();
915    }
916
917    #[ntex::test]
918    async fn test_validate_origin() {
919        let cors = Cors::new()
920            .allowed_origin("https://www.example.com")
921            .finish()
922            .create(test::ok_service(), ())
923            .into();
924
925        let req = TestRequest::with_header("Origin", "https://www.example.com")
926            .method(Method::GET)
927            .to_srv_request();
928
929        let resp = test::call_service(&cors, req).await;
930        assert_eq!(resp.status(), StatusCode::OK);
931    }
932
933    #[ntex::test]
934    async fn test_no_origin_response() {
935        let cors =
936            Cors::new().disable_preflight().finish().create(test::ok_service(), ()).into();
937
938        let req = TestRequest::default().method(Method::GET).to_srv_request();
939        let resp = test::call_service(&cors, req).await;
940        assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
941
942        let req = TestRequest::with_header("Origin", "https://www.example.com")
943            .method(Method::OPTIONS)
944            .to_srv_request();
945        let resp = test::call_service(&cors, req).await;
946        assert_eq!(
947            &b"https://www.example.com"[..],
948            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
949        );
950    }
951
952    #[ntex::test]
953    async fn test_response() {
954        let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
955        let cors = Cors::new()
956            .send_wildcard()
957            .disable_preflight()
958            .max_age(3600)
959            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
960            .allowed_headers(exposed_headers.clone())
961            .expose_headers(exposed_headers.clone())
962            .allowed_header(header::CONTENT_TYPE)
963            .finish()
964            .create(test::ok_service(), ())
965            .into();
966
967        let req = TestRequest::with_header("Origin", "https://www.example.com")
968            .method(Method::OPTIONS)
969            .to_srv_request();
970
971        let resp = test::call_service(&cors, req).await;
972        assert_eq!(
973            &b"*"[..],
974            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
975        );
976        assert_eq!(&b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes());
977
978        {
979            let headers = resp
980                .headers()
981                .get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
982                .unwrap()
983                .to_str()
984                .unwrap()
985                .split(',')
986                .map(|s| s.trim())
987                .collect::<Vec<&str>>();
988
989            for h in exposed_headers {
990                assert!(headers.contains(&h.as_str()));
991            }
992        }
993
994        let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
995        let cors = Cors::new()
996            .send_wildcard()
997            .disable_preflight()
998            .max_age(3600)
999            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
1000            .allowed_headers(exposed_headers.clone())
1001            .expose_headers(exposed_headers.clone())
1002            .allowed_header(header::CONTENT_TYPE)
1003            .finish()
1004            .create(
1005                fn_service(|req: WebRequest<DefaultError>| async move {
1006                    Ok::<_, std::convert::Infallible>(req.into_response(
1007                        HttpResponse::Ok().header(header::VARY, "Accept").finish(),
1008                    ))
1009                }),
1010                (),
1011            )
1012            .into();
1013        let req = TestRequest::with_header("Origin", "https://www.example.com")
1014            .method(Method::OPTIONS)
1015            .to_srv_request();
1016        let resp = test::call_service(&cors, req).await;
1017        assert_eq!(
1018            &b"Accept, Origin"[..],
1019            resp.headers().get(header::VARY).unwrap().as_bytes()
1020        );
1021
1022        let cors = Cors::new()
1023            .disable_vary_header()
1024            .allowed_origin("https://www.example.com")
1025            .allowed_origin("https://www.google.com")
1026            .finish()
1027            .create(test::ok_service(), ())
1028            .into();
1029
1030        let req = TestRequest::with_header("Origin", "https://www.example.com")
1031            .method(Method::OPTIONS)
1032            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
1033            .to_srv_request();
1034        let resp = test::call_service(&cors, req).await;
1035
1036        let origins_str =
1037            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().to_str().unwrap();
1038
1039        assert_eq!("https://www.example.com", origins_str);
1040    }
1041
1042    #[ntex::test]
1043    async fn test_multiple_origins() {
1044        let cors = Cors::new()
1045            .allowed_origin("https://example.com")
1046            .allowed_origin("https://example.org")
1047            .allowed_methods(vec![Method::GET])
1048            .finish()
1049            .create(test::ok_service(), ())
1050            .into();
1051
1052        let req = TestRequest::with_header("Origin", "https://example.com")
1053            .method(Method::GET)
1054            .to_srv_request();
1055
1056        let resp = test::call_service(&cors, req).await;
1057        assert_eq!(
1058            &b"https://example.com"[..],
1059            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1060        );
1061
1062        let req = TestRequest::with_header("Origin", "https://example.org")
1063            .method(Method::GET)
1064            .to_srv_request();
1065
1066        let resp = test::call_service(&cors, req).await;
1067        assert_eq!(
1068            &b"https://example.org"[..],
1069            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1070        );
1071    }
1072
1073    #[ntex::test]
1074    async fn test_multiple_origins_preflight() {
1075        let cors = Cors::new()
1076            .allowed_origin("https://example.com")
1077            .allowed_origin("https://example.org")
1078            .allowed_methods(vec![Method::GET])
1079            .finish()
1080            .create(test::ok_service(), ())
1081            .into();
1082
1083        let req = TestRequest::with_header("Origin", "https://example.com")
1084            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1085            .method(Method::OPTIONS)
1086            .to_srv_request();
1087
1088        let resp = test::call_service(&cors, req).await;
1089        assert_eq!(
1090            &b"https://example.com"[..],
1091            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1092        );
1093
1094        let req = TestRequest::with_header("Origin", "https://example.org")
1095            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1096            .method(Method::OPTIONS)
1097            .to_srv_request();
1098
1099        let resp = test::call_service(&cors, req).await;
1100        assert_eq!(
1101            &b"https://example.org"[..],
1102            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1103        );
1104    }
1105
1106    #[ntex::test]
1107    async fn test_set_allowed_origin_to_all() {
1108        let cors =
1109            Cors::new().allowed_origin("*").finish().create(test::ok_service(), ()).into();
1110
1111        let req = TestRequest::with_header("Origin", "https://www.example.com")
1112            .method(Method::GET)
1113            .to_srv_request();
1114
1115        let resp = test::call_service(&cors, req).await;
1116        assert_eq!(resp.status(), StatusCode::OK);
1117    }
1118}