ntex_cors/
lib.rs

1#![allow(
2    clippy::borrow_interior_mutable_const,
3    clippy::type_complexity,
4    clippy::mutable_key_type
5)]
6//! Cross-origin resource sharing (CORS) for ntex applications
7//!
8//! CORS middleware could be used with application and with resource.
9//! Cors middleware could be used as parameter for `App::wrap()`,
10//! `Resource::wrap()` or `Scope::wrap()` methods.
11//!
12//! # Example
13//!
14//! ```rust,no_run
15//! use ntex_cors::Cors;
16//! use ntex::{http, web};
17//! use ntex::web::{App, HttpRequest, HttpResponse};
18//!
19//! async fn index(req: HttpRequest) -> &'static str {
20//!     "Hello world"
21//! }
22//!
23//! #[ntex::main]
24//! async fn main() -> std::io::Result<()> {
25//!     web::server(|| App::new()
26//!         .wrap(
27//!             Cors::new() // <- Construct CORS middleware builder
28//!               .allowed_origin("https://www.rust-lang.org/")
29//!               .allowed_methods(vec!["GET", "POST"])
30//!               .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
31//!               .allowed_header(http::header::CONTENT_TYPE)
32//!               .max_age(3600)
33//!               .finish())
34//!         .service(
35//!             web::resource("/index.html")
36//!               .route(web::get().to(index))
37//!               .route(web::head().to(|| async { HttpResponse::MethodNotAllowed() }))
38//!         ))
39//!         .bind("127.0.0.1:8080")?
40//!         .run()
41//!         .await
42//! }
43//! ```
44//! In this example custom *CORS* middleware get registered for "/index.html"
45//! endpoint.
46//!
47//! Cors middleware automatically handle *OPTIONS* preflight request.
48use std::{
49    collections::HashSet, convert::TryFrom, iter::FromIterator, marker::PhantomData, rc::Rc,
50};
51
52use derive_more::Display;
53use ntex::http::header::{self, HeaderName, HeaderValue};
54use ntex::http::{error::HttpError, HeaderMap, Method, RequestHead, StatusCode, Uri};
55use ntex::service::{Middleware, Service, ServiceCtx};
56use ntex::util::Either;
57use ntex::web::{
58    DefaultError, ErrorRenderer, HttpResponse, WebRequest, WebResponse, WebResponseError,
59};
60
61/// A set of errors that can occur during processing CORS
62#[derive(Debug, Display)]
63pub enum CorsError {
64    /// The HTTP request header `Origin` is required but was not provided
65    #[display(fmt = "The HTTP request header `Origin` is required but was not provided")]
66    MissingOrigin,
67    /// The HTTP request header `Origin` could not be parsed correctly.
68    #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")]
69    BadOrigin,
70    /// The request header `Access-Control-Request-Method` is required but is
71    /// missing
72    #[display(
73        fmt = "The request header `Access-Control-Request-Method` is required but is missing"
74    )]
75    MissingRequestMethod,
76    /// The request header `Access-Control-Request-Method` has an invalid value
77    #[display(fmt = "The request header `Access-Control-Request-Method` has an invalid value")]
78    BadRequestMethod,
79    /// The request header `Access-Control-Request-Headers`  has an invalid
80    /// value
81    #[display(
82        fmt = "The request header `Access-Control-Request-Headers`  has an invalid value"
83    )]
84    BadRequestHeaders,
85    /// Origin is not allowed to make this request
86    #[display(fmt = "Origin is not allowed to make this request")]
87    OriginNotAllowed,
88    /// Requested method is not allowed
89    #[display(fmt = "Requested method is not allowed")]
90    MethodNotAllowed,
91    /// One or more headers requested are not allowed
92    #[display(fmt = "One or more headers requested are not allowed")]
93    HeadersNotAllowed,
94}
95
96/// DefaultError renderer support
97impl WebResponseError<DefaultError> for CorsError {
98    fn status_code(&self) -> StatusCode {
99        StatusCode::BAD_REQUEST
100    }
101}
102
103/// An enum signifying that some of type T is allowed, or `All` (everything is
104/// allowed).
105///
106/// `Default` is implemented for this enum and is `All`.
107#[derive(Default, Clone, Debug, Eq, PartialEq)]
108pub enum AllOrSome<T> {
109    /// Everything is allowed. Usually equivalent to the "*" value.
110    #[default]
111    All,
112    /// Only some of `T` is allowed
113    Some(T),
114}
115
116impl<T> AllOrSome<T> {
117    /// Returns whether this is an `All` variant
118    pub fn is_all(&self) -> bool {
119        match *self {
120            AllOrSome::All => true,
121            AllOrSome::Some(_) => false,
122        }
123    }
124
125    /// Returns whether this is a `Some` variant
126    pub fn is_some(&self) -> bool {
127        !self.is_all()
128    }
129
130    /// Returns &T
131    pub fn as_ref(&self) -> Option<&T> {
132        match *self {
133            AllOrSome::All => None,
134            AllOrSome::Some(ref t) => Some(t),
135        }
136    }
137}
138
139/// Structure that follows the builder pattern for building `Cors` middleware
140/// structs.
141///
142/// To construct a cors:
143///
144///   1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
145///   2. Use any of the builder methods to set fields in the backend.
146/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
147/// constructed backend.
148///
149/// # Example
150///
151/// ```rust
152/// use ntex_cors::Cors;
153/// use ntex::http::header;
154///
155/// let cors = Cors::new()
156///     .allowed_origin("https://www.rust-lang.org/")
157///     .allowed_methods(vec!["GET", "POST"])
158///     .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
159///     .allowed_header(header::CONTENT_TYPE)
160///     .max_age(3600);
161/// ```
162#[derive(Default)]
163pub struct Cors {
164    cors: Option<Inner>,
165    methods: bool,
166    expose_hdrs: HashSet<HeaderName>,
167    error: Option<HttpError>,
168}
169
170impl Cors {
171    /// Build a new CORS middleware instance
172    pub fn new() -> Self {
173        Cors {
174            cors: Some(Inner {
175                origins: AllOrSome::All,
176                origins_str: None,
177                methods: HashSet::new(),
178                headers: AllOrSome::All,
179                expose_hdrs: None,
180                max_age: None,
181                preflight: true,
182                send_wildcard: false,
183                supports_credentials: false,
184                vary_header: true,
185            }),
186            methods: false,
187            error: None,
188            expose_hdrs: HashSet::new(),
189        }
190    }
191
192    /// Build a new CORS default middleware
193    pub fn default<Err>() -> CorsFactory<Err> {
194        let inner = Inner {
195            origins: AllOrSome::default(),
196            origins_str: None,
197            methods: HashSet::from_iter(
198                vec![
199                    Method::GET,
200                    Method::HEAD,
201                    Method::POST,
202                    Method::OPTIONS,
203                    Method::PUT,
204                    Method::PATCH,
205                    Method::DELETE,
206                ]
207                .into_iter(),
208            ),
209            headers: AllOrSome::All,
210            expose_hdrs: None,
211            max_age: None,
212            preflight: true,
213            send_wildcard: false,
214            supports_credentials: false,
215            vary_header: true,
216        };
217        CorsFactory { inner: Rc::new(inner), _t: PhantomData }
218    }
219
220    /// Add an origin that are allowed to make requests.
221    /// Will be verified against the `Origin` request header.
222    ///
223    /// When `All` is set, and `send_wildcard` is set, "*" will be sent in
224    /// the `Access-Control-Allow-Origin` response header. Otherwise, the
225    /// client's `Origin` request header will be echoed back in the
226    /// `Access-Control-Allow-Origin` response header.
227    ///
228    /// When `Some` is set, the client's `Origin` request header will be
229    /// checked in a case-sensitive manner.
230    ///
231    /// This is the `list of origins` in the
232    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
233    ///
234    /// Defaults to `All`.
235    ///
236    /// Builder panics if supplied origin is not valid uri.
237    pub fn allowed_origin(mut self, origin: &str) -> Self {
238        if let Some(cors) = cors(&mut self.cors, &self.error) {
239            match Uri::try_from(origin) {
240                Ok(_) => {
241                    if cors.origins.is_all() {
242                        cors.origins = AllOrSome::Some(HashSet::new());
243                    }
244                    if let AllOrSome::Some(ref mut origins) = cors.origins {
245                        origins.insert(origin.to_owned());
246                    }
247                }
248                Err(e) => {
249                    self.error = Some(e.into());
250                }
251            }
252        }
253        self
254    }
255
256    /// Set a list of methods which the allowed origins are allowed to access
257    /// for requests.
258    ///
259    /// This is the `list of methods` in the
260    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
261    ///
262    /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
263    pub fn allowed_methods<U, M>(mut self, methods: U) -> Self
264    where
265        U: IntoIterator<Item = M>,
266        Method: TryFrom<M>,
267        <Method as TryFrom<M>>::Error: Into<HttpError>,
268    {
269        self.methods = true;
270        if let Some(cors) = cors(&mut self.cors, &self.error) {
271            for m in methods {
272                match Method::try_from(m) {
273                    Ok(method) => {
274                        cors.methods.insert(method);
275                    }
276                    Err(e) => {
277                        self.error = Some(e.into());
278                        break;
279                    }
280                }
281            }
282        }
283        self
284    }
285
286    /// Set an allowed header
287    pub fn allowed_header<H>(mut self, header: H) -> Self
288    where
289        HeaderName: TryFrom<H>,
290        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
291    {
292        if let Some(cors) = cors(&mut self.cors, &self.error) {
293            match HeaderName::try_from(header) {
294                Ok(method) => {
295                    if cors.headers.is_all() {
296                        cors.headers = AllOrSome::Some(HashSet::new());
297                    }
298                    if let AllOrSome::Some(ref mut headers) = cors.headers {
299                        headers.insert(method);
300                    }
301                }
302                Err(e) => self.error = Some(e.into()),
303            }
304        }
305        self
306    }
307
308    /// Set a list of header field names which can be used when
309    /// this resource is accessed by allowed origins.
310    ///
311    /// If `All` is set, whatever is requested by the client in
312    /// `Access-Control-Request-Headers` will be echoed back in the
313    /// `Access-Control-Allow-Headers` header.
314    ///
315    /// This is the `list of headers` in the
316    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
317    ///
318    /// Defaults to `All`.
319    pub fn allowed_headers<U, H>(mut self, headers: U) -> Self
320    where
321        U: IntoIterator<Item = H>,
322        HeaderName: TryFrom<H>,
323        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
324    {
325        if let Some(cors) = cors(&mut self.cors, &self.error) {
326            for h in headers {
327                match HeaderName::try_from(h) {
328                    Ok(method) => {
329                        if cors.headers.is_all() {
330                            cors.headers = AllOrSome::Some(HashSet::new());
331                        }
332                        if let AllOrSome::Some(ref mut headers) = cors.headers {
333                            headers.insert(method);
334                        }
335                    }
336                    Err(e) => {
337                        self.error = Some(e.into());
338                        break;
339                    }
340                }
341            }
342        }
343        self
344    }
345
346    /// Set a list of headers which are safe to expose to the API of a CORS API
347    /// specification. This corresponds to the
348    /// `Access-Control-Expose-Headers` response header.
349    ///
350    /// This is the `list of exposed headers` in the
351    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
352    ///
353    /// This defaults to an empty set.
354    pub fn expose_headers<U, H>(mut self, headers: U) -> Self
355    where
356        U: IntoIterator<Item = H>,
357        HeaderName: TryFrom<H>,
358        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
359    {
360        for h in headers {
361            match HeaderName::try_from(h) {
362                Ok(method) => {
363                    self.expose_hdrs.insert(method);
364                }
365                Err(e) => {
366                    self.error = Some(e.into());
367                    break;
368                }
369            }
370        }
371        self
372    }
373
374    /// Set a maximum time for which this CORS request maybe cached.
375    /// This value is set as the `Access-Control-Max-Age` header.
376    ///
377    /// This defaults to `None` (unset).
378    pub fn max_age(mut self, max_age: usize) -> Self {
379        if let Some(cors) = cors(&mut self.cors, &self.error) {
380            cors.max_age = Some(max_age)
381        }
382        self
383    }
384
385    /// Set a wildcard origins
386    ///
387    /// If send wildcard is set and the `allowed_origins` parameter is `All`, a
388    /// wildcard `Access-Control-Allow-Origin` response header is sent,
389    /// rather than the request’s `Origin` header.
390    ///
391    /// This is the `supports credentials flag` in the
392    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
393    ///
394    /// This **CANNOT** be used in conjunction with `allowed_origins` set to
395    /// `All` and `allow_credentials` set to `true`. Depending on the mode
396    /// of usage, this will either result in an `Error::
397    /// CredentialsWithWildcardOrigin` error during ntex launch or runtime.
398    ///
399    /// Defaults to `false`.
400    pub fn send_wildcard(mut self) -> Self {
401        if let Some(cors) = cors(&mut self.cors, &self.error) {
402            cors.send_wildcard = true
403        }
404        self
405    }
406
407    /// Allows users to make authenticated requests
408    ///
409    /// If true, injects the `Access-Control-Allow-Credentials` header in
410    /// responses. This allows cookies and credentials to be submitted
411    /// across domains.
412    ///
413    /// This option cannot be used in conjunction with an `allowed_origin` set
414    /// to `All` and `send_wildcards` set to `true`.
415    ///
416    /// Defaults to `false`.
417    ///
418    /// Builder panics if credentials are allowed, but the Origin is set to "*".
419    /// This is not allowed by W3C
420    pub fn supports_credentials(mut self) -> Self {
421        if let Some(cors) = cors(&mut self.cors, &self.error) {
422            cors.supports_credentials = true
423        }
424        self
425    }
426
427    /// Disable `Vary` header support.
428    ///
429    /// When enabled the header `Vary: Origin` will be returned as per the W3
430    /// implementation guidelines.
431    ///
432    /// Setting this header when the `Access-Control-Allow-Origin` is
433    /// dynamically generated (e.g. when there is more than one allowed
434    /// origin, and an Origin than '*' is returned) informs CDNs and other
435    /// caches that the CORS headers are dynamic, and cannot be cached.
436    ///
437    /// By default `vary` header support is enabled.
438    pub fn disable_vary_header(mut self) -> Self {
439        if let Some(cors) = cors(&mut self.cors, &self.error) {
440            cors.vary_header = false
441        }
442        self
443    }
444
445    /// Disable *preflight* request support.
446    ///
447    /// When enabled cors middleware automatically handles *OPTIONS* request.
448    /// This is useful application level middleware.
449    ///
450    /// By default *preflight* support is enabled.
451    pub fn disable_preflight(mut self) -> Self {
452        if let Some(cors) = cors(&mut self.cors, &self.error) {
453            cors.preflight = false
454        }
455        self
456    }
457
458    /// Construct cors middleware
459    pub fn finish<Err>(self) -> CorsFactory<Err> {
460        let mut slf = if !self.methods {
461            self.allowed_methods(vec![
462                Method::GET,
463                Method::HEAD,
464                Method::POST,
465                Method::OPTIONS,
466                Method::PUT,
467                Method::PATCH,
468                Method::DELETE,
469            ])
470        } else {
471            self
472        };
473
474        if let Some(e) = slf.error.take() {
475            panic!("{}", e);
476        }
477
478        let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder");
479
480        if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
481            panic!("Credentials are allowed, but the Origin is set to \"*\"");
482        }
483
484        if let AllOrSome::Some(ref origins) = cors.origins {
485            let s = origins.iter().fold(String::new(), |s, v| format!("{}, {}", s, v));
486            cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap());
487        }
488
489        if !slf.expose_hdrs.is_empty() {
490            cors.expose_hdrs = Some(
491                HeaderValue::try_from(
492                    &slf.expose_hdrs
493                        .iter()
494                        .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..],
495                )
496                .unwrap(),
497            );
498        }
499
500        CorsFactory { inner: Rc::new(cors), _t: PhantomData }
501    }
502}
503
504fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<HttpError>) -> Option<&'a mut Inner> {
505    if err.is_some() {
506        return None;
507    }
508    parts.as_mut()
509}
510
511struct Inner {
512    methods: HashSet<Method>,
513    origins: AllOrSome<HashSet<String>>,
514    origins_str: Option<HeaderValue>,
515    headers: AllOrSome<HashSet<HeaderName>>,
516    expose_hdrs: Option<HeaderValue>,
517    max_age: Option<usize>,
518    preflight: bool,
519    send_wildcard: bool,
520    supports_credentials: bool,
521    vary_header: bool,
522}
523
524impl Inner {
525    fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
526        if let Some(hdr) = req.headers().get(&header::ORIGIN) {
527            if let Ok(origin) = hdr.to_str() {
528                return match self.origins {
529                    AllOrSome::All => Ok(()),
530                    AllOrSome::Some(ref allowed_origins) => allowed_origins
531                        .get(origin)
532                        .map(|_| ())
533                        .ok_or(CorsError::OriginNotAllowed),
534                };
535            }
536            Err(CorsError::BadOrigin)
537        } else {
538            match self.origins {
539                AllOrSome::All => Ok(()),
540                _ => Err(CorsError::MissingOrigin),
541            }
542        }
543    }
544
545    fn access_control_allow_origin(&self, headers: &HeaderMap) -> Option<HeaderValue> {
546        match self.origins {
547            AllOrSome::All => {
548                if self.send_wildcard {
549                    Some(HeaderValue::from_static("*"))
550                } else {
551                    headers.get(&header::ORIGIN).cloned()
552                }
553            }
554            AllOrSome::Some(ref origins) => {
555                if let Some(origin) =
556                    headers.get(&header::ORIGIN).filter(|o| match o.to_str() {
557                        Ok(os) => origins.contains(os),
558                        _ => false,
559                    })
560                {
561                    Some(origin.clone())
562                } else {
563                    Some(self.origins_str.as_ref().unwrap().clone())
564                }
565            }
566        }
567    }
568
569    fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
570        if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) {
571            if let Ok(meth) = hdr.to_str() {
572                if let Ok(method) = Method::try_from(meth) {
573                    return self
574                        .methods
575                        .get(&method)
576                        .map(|_| ())
577                        .ok_or(CorsError::MethodNotAllowed);
578                }
579            }
580            Err(CorsError::BadRequestMethod)
581        } else {
582            Err(CorsError::MissingRequestMethod)
583        }
584    }
585
586    fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
587        match self.headers {
588            AllOrSome::All => Ok(()),
589            AllOrSome::Some(ref allowed_headers) => {
590                if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) {
591                    if let Ok(headers) = hdr.to_str() {
592                        let mut hdrs = HashSet::new();
593                        for hdr in headers.split(',') {
594                            match HeaderName::try_from(hdr.trim()) {
595                                Ok(hdr) => hdrs.insert(hdr),
596                                Err(_) => return Err(CorsError::BadRequestHeaders),
597                            };
598                        }
599                        // `Access-Control-Request-Headers` must contain 1 or more
600                        // `field-name`.
601                        if !hdrs.is_empty() {
602                            if !hdrs.is_subset(allowed_headers) {
603                                return Err(CorsError::HeadersNotAllowed);
604                            }
605                            return Ok(());
606                        }
607                    }
608                    Err(CorsError::BadRequestHeaders)
609                } else {
610                    Ok(())
611                }
612            }
613        }
614    }
615
616    fn preflight_check(
617        &self,
618        req: &RequestHead,
619    ) -> Result<Either<HttpResponse, ()>, CorsError> {
620        if self.preflight && Method::OPTIONS == req.method {
621            self.validate_origin(req)
622                .and_then(|_| self.validate_allowed_method(req))
623                .and_then(|_| self.validate_allowed_headers(req))?;
624
625            // allowed headers
626            let headers = if let Some(headers) = self.headers.as_ref() {
627                Some(
628                    HeaderValue::try_from(
629                        &headers
630                            .iter()
631                            .fold(String::new(), |s, v| s + "," + v.as_str())
632                            .as_str()[1..],
633                    )
634                    .unwrap(),
635                )
636            } else {
637                req.headers.get(&header::ACCESS_CONTROL_REQUEST_HEADERS).cloned()
638            };
639
640            let res = HttpResponse::Ok()
641                .if_some(self.max_age.as_ref(), |max_age, resp| {
642                    let _ = resp.header(
643                        header::ACCESS_CONTROL_MAX_AGE,
644                        format!("{}", max_age).as_str(),
645                    );
646                })
647                .if_some(headers, |headers, resp| {
648                    let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
649                })
650                .if_some(self.access_control_allow_origin(req.headers()), |origin, resp| {
651                    let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
652                })
653                .if_true(self.supports_credentials, |resp| {
654                    resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
655                })
656                .header(
657                    header::ACCESS_CONTROL_ALLOW_METHODS,
658                    &self
659                        .methods
660                        .iter()
661                        .fold(String::new(), |s, v| s + "," + v.as_str())
662                        .as_str()[1..],
663                )
664                .finish()
665                .into_body();
666
667            Ok(Either::Left(res))
668        } else {
669            if req.headers.contains_key(&header::ORIGIN) {
670                // Only check requests with a origin header.
671                self.validate_origin(req)?;
672            }
673            Ok(Either::Right(()))
674        }
675    }
676
677    fn handle_response(&self, headers: &mut HeaderMap, allowed_origin: Option<HeaderValue>) {
678        if let Some(origin) = allowed_origin {
679            headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
680        };
681
682        if let Some(ref expose) = self.expose_hdrs {
683            headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
684        }
685        if self.supports_credentials {
686            headers.insert(
687                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
688                HeaderValue::from_static("true"),
689            );
690        }
691        if self.vary_header {
692            let value = if let Some(hdr) = headers.get(&header::VARY) {
693                let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
694                val.extend(hdr.as_bytes());
695                val.extend(b", Origin");
696                HeaderValue::try_from(&val[..]).unwrap()
697            } else {
698                HeaderValue::from_static("Origin")
699            };
700            headers.insert(header::VARY, value);
701        }
702    }
703}
704
705/// `Middleware` for Cross-origin resource sharing support
706///
707/// The Cors struct contains the settings for CORS requests to be validated and
708/// for responses to be generated.
709pub struct CorsFactory<Err> {
710    inner: Rc<Inner>,
711    _t: PhantomData<Err>,
712}
713
714impl<S, Err> Middleware<S> for CorsFactory<Err>
715where
716    S: Service<WebRequest<Err>, Response = WebResponse>,
717{
718    type Service = CorsMiddleware<S>;
719
720    fn create(&self, service: S) -> Self::Service {
721        CorsMiddleware { service, inner: self.inner.clone() }
722    }
723}
724
725/// `Middleware` for Cross-origin resource sharing support
726///
727/// The Cors struct contains the settings for CORS requests to be validated and
728/// for responses to be generated.
729#[derive(Clone)]
730pub struct CorsMiddleware<S> {
731    service: S,
732    inner: Rc<Inner>,
733}
734
735impl<S, Err> Service<WebRequest<Err>> for CorsMiddleware<S>
736where
737    S: Service<WebRequest<Err>, Response = WebResponse>,
738    Err: ErrorRenderer,
739    Err::Container: From<S::Error>,
740    CorsError: WebResponseError<Err>,
741{
742    type Response = WebResponse;
743    type Error = S::Error;
744
745    ntex::forward_ready!(service);
746    ntex::forward_shutdown!(service);
747
748    async fn call(
749        &self,
750        req: WebRequest<Err>,
751        ctx: ServiceCtx<'_, Self>,
752    ) -> Result<Self::Response, S::Error> {
753        match self.inner.preflight_check(req.head()) {
754            Ok(Either::Left(res)) => Ok(req.into_response(res)),
755            Ok(Either::Right(_)) => {
756                let inner = self.inner.clone();
757                let has_origin = req.headers().contains_key(&header::ORIGIN);
758                let allowed_origin = inner.access_control_allow_origin(req.headers());
759
760                let mut res = ctx.call(&self.service, req).await?;
761
762                if has_origin {
763                    inner.handle_response(res.headers_mut(), allowed_origin);
764                }
765                Ok(res)
766            }
767            Err(e) => Ok(req.render_error(e)),
768        }
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use ntex::service::{fn_service, Middleware, Pipeline};
775    use ntex::web::{self, test, test::TestRequest};
776
777    use super::*;
778
779    #[ntex::test]
780    #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
781    async fn cors_validates_illegal_allow_credentials() {
782        let _cors =
783            Cors::new().supports_credentials().send_wildcard().finish::<web::DefaultError>();
784    }
785
786    #[ntex::test]
787    async fn validate_origin_allows_all_origins() {
788        let cors = Cors::new().finish().create(test::ok_service()).into();
789        let req =
790            TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
791
792        let resp = test::call_service(&cors, req).await;
793        assert_eq!(resp.status(), StatusCode::OK);
794    }
795
796    #[ntex::test]
797    async fn default() {
798        let cors = Cors::default().create(test::ok_service()).into();
799        let req =
800            TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
801
802        let resp = test::call_service(&cors, req).await;
803        assert_eq!(resp.status(), StatusCode::OK);
804    }
805
806    #[ntex::test]
807    async fn test_preflight() {
808        let cors: Pipeline<_> = Cors::new()
809            .send_wildcard()
810            .max_age(3600)
811            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
812            .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
813            .allowed_header(header::CONTENT_TYPE)
814            .finish()
815            .create(test::ok_service())
816            .into();
817
818        let req = TestRequest::with_header("Origin", "https://www.example.com")
819            .method(Method::OPTIONS)
820            .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
821            .to_srv_request();
822
823        assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
824        assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_err());
825        let resp = test::call_service(&cors, req).await;
826        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
827
828        let req = TestRequest::with_header("Origin", "https://www.example.com")
829            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
830            .method(Method::OPTIONS)
831            .to_srv_request();
832
833        assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
834        assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_ok());
835
836        let req = TestRequest::with_header("Origin", "https://www.example.com")
837            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
838            .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
839            .method(Method::OPTIONS)
840            .to_srv_request();
841
842        let resp = test::call_service(&cors, req).await;
843        assert_eq!(
844            &b"*"[..],
845            resp.headers().get(&header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
846        );
847        assert_eq!(
848            &b"3600"[..],
849            resp.headers().get(&header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()
850        );
851        let hdr = resp
852            .headers()
853            .get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
854            .unwrap()
855            .to_str()
856            .unwrap();
857        assert!(hdr.contains("authorization"));
858        assert!(hdr.contains("accept"));
859        assert!(hdr.contains("content-type"));
860
861        let methods =
862            resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().to_str().unwrap();
863        assert!(methods.contains("POST"));
864        assert!(methods.contains("GET"));
865        assert!(methods.contains("OPTIONS"));
866
867        // Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
868
869        // let req = TestRequest::with_header("Origin", "https://www.example.com")
870        //     .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
871        //     .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
872        //     .method(Method::OPTIONS)
873        //     .to_srv_request();
874
875        // let resp = test::call_service(&cors, req).await;
876        // assert_eq!(resp.status(), StatusCode::OK);
877    }
878
879    // #[ntex::test]
880    // #[should_panic(expected = "MissingOrigin")]
881    // async fn test_validate_missing_origin() {
882    //    let cors = Cors::build()
883    //        .allowed_origin("https://www.example.com")
884    //        .finish();
885    //    let mut req = HttpRequest::default();
886    //    cors.start(&req).unwrap();
887    // }
888
889    #[ntex::test]
890    #[should_panic(expected = "OriginNotAllowed")]
891    async fn test_validate_not_allowed_origin() {
892        let cors: Pipeline<_> = Cors::new()
893            .allowed_origin("https://www.example.com")
894            .finish()
895            .create(test::ok_service::<web::DefaultError>())
896            .into();
897
898        let req = TestRequest::with_header("Origin", "https://www.unknown.com")
899            .method(Method::GET)
900            .to_srv_request();
901        cors.get_ref().inner.validate_origin(req.head()).unwrap();
902        cors.get_ref().inner.validate_allowed_method(req.head()).unwrap();
903        cors.get_ref().inner.validate_allowed_headers(req.head()).unwrap();
904    }
905
906    #[ntex::test]
907    async fn test_validate_origin() {
908        let cors = Cors::new()
909            .allowed_origin("https://www.example.com")
910            .finish()
911            .create(test::ok_service())
912            .into();
913
914        let req = TestRequest::with_header("Origin", "https://www.example.com")
915            .method(Method::GET)
916            .to_srv_request();
917
918        let resp = test::call_service(&cors, req).await;
919        assert_eq!(resp.status(), StatusCode::OK);
920    }
921
922    #[ntex::test]
923    async fn test_no_origin_response() {
924        let cors = Cors::new().disable_preflight().finish().create(test::ok_service()).into();
925
926        let req = TestRequest::default().method(Method::GET).to_srv_request();
927        let resp = test::call_service(&cors, req).await;
928        assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
929
930        let req = TestRequest::with_header("Origin", "https://www.example.com")
931            .method(Method::OPTIONS)
932            .to_srv_request();
933        let resp = test::call_service(&cors, req).await;
934        assert_eq!(
935            &b"https://www.example.com"[..],
936            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
937        );
938    }
939
940    #[ntex::test]
941    async fn test_response() {
942        let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
943        let cors = Cors::new()
944            .send_wildcard()
945            .disable_preflight()
946            .max_age(3600)
947            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
948            .allowed_headers(exposed_headers.clone())
949            .expose_headers(exposed_headers.clone())
950            .allowed_header(header::CONTENT_TYPE)
951            .finish()
952            .create(test::ok_service())
953            .into();
954
955        let req = TestRequest::with_header("Origin", "https://www.example.com")
956            .method(Method::OPTIONS)
957            .to_srv_request();
958
959        let resp = test::call_service(&cors, req).await;
960        assert_eq!(
961            &b"*"[..],
962            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
963        );
964        assert_eq!(&b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes());
965
966        {
967            let headers = resp
968                .headers()
969                .get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
970                .unwrap()
971                .to_str()
972                .unwrap()
973                .split(',')
974                .map(|s| s.trim())
975                .collect::<Vec<&str>>();
976
977            for h in exposed_headers {
978                assert!(headers.contains(&h.as_str()));
979            }
980        }
981
982        let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
983        let cors =
984            Cors::new()
985                .send_wildcard()
986                .disable_preflight()
987                .max_age(3600)
988                .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
989                .allowed_headers(exposed_headers.clone())
990                .expose_headers(exposed_headers.clone())
991                .allowed_header(header::CONTENT_TYPE)
992                .finish()
993                .create(fn_service(|req: WebRequest<DefaultError>| async move {
994                    Ok::<_, std::convert::Infallible>(req.into_response(
995                        HttpResponse::Ok().header(header::VARY, "Accept").finish(),
996                    ))
997                }))
998                .into();
999        let req = TestRequest::with_header("Origin", "https://www.example.com")
1000            .method(Method::OPTIONS)
1001            .to_srv_request();
1002        let resp = test::call_service(&cors, req).await;
1003        assert_eq!(
1004            &b"Accept, Origin"[..],
1005            resp.headers().get(header::VARY).unwrap().as_bytes()
1006        );
1007
1008        let cors = Cors::new()
1009            .disable_vary_header()
1010            .allowed_origin("https://www.example.com")
1011            .allowed_origin("https://www.google.com")
1012            .finish()
1013            .create(test::ok_service())
1014            .into();
1015
1016        let req = TestRequest::with_header("Origin", "https://www.example.com")
1017            .method(Method::OPTIONS)
1018            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
1019            .to_srv_request();
1020        let resp = test::call_service(&cors, req).await;
1021
1022        let origins_str =
1023            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().to_str().unwrap();
1024
1025        assert_eq!("https://www.example.com", origins_str);
1026    }
1027
1028    #[ntex::test]
1029    async fn test_multiple_origins() {
1030        let cors = Cors::new()
1031            .allowed_origin("https://example.com")
1032            .allowed_origin("https://example.org")
1033            .allowed_methods(vec![Method::GET])
1034            .finish()
1035            .create(test::ok_service())
1036            .into();
1037
1038        let req = TestRequest::with_header("Origin", "https://example.com")
1039            .method(Method::GET)
1040            .to_srv_request();
1041
1042        let resp = test::call_service(&cors, req).await;
1043        assert_eq!(
1044            &b"https://example.com"[..],
1045            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1046        );
1047
1048        let req = TestRequest::with_header("Origin", "https://example.org")
1049            .method(Method::GET)
1050            .to_srv_request();
1051
1052        let resp = test::call_service(&cors, req).await;
1053        assert_eq!(
1054            &b"https://example.org"[..],
1055            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1056        );
1057    }
1058
1059    #[ntex::test]
1060    async fn test_multiple_origins_preflight() {
1061        let cors = Cors::new()
1062            .allowed_origin("https://example.com")
1063            .allowed_origin("https://example.org")
1064            .allowed_methods(vec![Method::GET])
1065            .finish()
1066            .create(test::ok_service())
1067            .into();
1068
1069        let req = TestRequest::with_header("Origin", "https://example.com")
1070            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1071            .method(Method::OPTIONS)
1072            .to_srv_request();
1073
1074        let resp = test::call_service(&cors, req).await;
1075        assert_eq!(
1076            &b"https://example.com"[..],
1077            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1078        );
1079
1080        let req = TestRequest::with_header("Origin", "https://example.org")
1081            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1082            .method(Method::OPTIONS)
1083            .to_srv_request();
1084
1085        let resp = test::call_service(&cors, req).await;
1086        assert_eq!(
1087            &b"https://example.org"[..],
1088            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1089        );
1090    }
1091}