clia_ntex_cors_mod/
lib.rs

1#![allow(
2    clippy::borrow_interior_mutable_const,
3    clippy::type_complexity,
4    clippy::mutable_key_type
5)]
6//! Cross-origin resource sharing (CORS) for ntex applications
7//!
8//! CORS middleware could be used with application and with resource.
9//! Cors middleware could be used as parameter for `App::wrap()`,
10//! `Resource::wrap()` or `Scope::wrap()` methods.
11//!
12//! # Example
13//!
14//! ```rust,no_run
15//! use ntex_cors::Cors;
16//! use ntex::{http, web};
17//! use ntex::web::{App, HttpRequest, HttpResponse};
18//!
19//! async fn index(req: HttpRequest) -> &'static str {
20//!     "Hello world"
21//! }
22//!
23//! #[ntex::main]
24//! async fn main() -> std::io::Result<()> {
25//!     web::server(|| App::new()
26//!         .wrap(
27//!             Cors::new() // <- Construct CORS middleware builder
28//!               .allowed_origin("https://www.rust-lang.org/")
29//!               .allowed_methods(vec!["GET", "POST"])
30//!               .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT])
31//!               .allowed_header(http::header::CONTENT_TYPE)
32//!               .max_age(3600)
33//!               .finish())
34//!         .service(
35//!             web::resource("/index.html")
36//!               .route(web::get().to(index))
37//!               .route(web::head().to(|| async { HttpResponse::MethodNotAllowed() }))
38//!         ))
39//!         .bind("127.0.0.1:8080")?
40//!         .run()
41//!         .await
42//! }
43//! ```
44//! In this example custom *CORS* middleware get registered for "/index.html"
45//! endpoint.
46//!
47//! Cors middleware automatically handle *OPTIONS* preflight request.
48use std::task::{Context, Poll};
49use std::{
50    collections::HashSet, convert::TryFrom, iter::FromIterator, marker::PhantomData, rc::Rc,
51};
52
53use derive_more::Display;
54use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
55use ntex::http::header::{self, HeaderName, HeaderValue};
56use ntex::http::{error::HttpError, HeaderMap, Method, RequestHead, StatusCode, Uri};
57use ntex::service::{Middleware, Service, ServiceCtx};
58use ntex::web::{
59    DefaultError, ErrorRenderer, HttpResponse, WebRequest, WebResponse, WebResponseError,
60};
61
62/// A set of errors that can occur during processing CORS
63#[derive(Debug, Display)]
64pub enum CorsError {
65    /// The HTTP request header `Origin` is required but was not provided
66    #[display(fmt = "The HTTP request header `Origin` is required but was not provided")]
67    MissingOrigin,
68    /// The HTTP request header `Origin` could not be parsed correctly.
69    #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")]
70    BadOrigin,
71    /// The request header `Access-Control-Request-Method` is required but is
72    /// missing
73    #[display(
74        fmt = "The request header `Access-Control-Request-Method` is required but is missing"
75    )]
76    MissingRequestMethod,
77    /// The request header `Access-Control-Request-Method` has an invalid value
78    #[display(fmt = "The request header `Access-Control-Request-Method` has an invalid value")]
79    BadRequestMethod,
80    /// The request header `Access-Control-Request-Headers`  has an invalid
81    /// value
82    #[display(
83        fmt = "The request header `Access-Control-Request-Headers`  has an invalid value"
84    )]
85    BadRequestHeaders,
86    /// Origin is not allowed to make this request
87    #[display(fmt = "Origin is not allowed to make this request")]
88    OriginNotAllowed,
89    /// Requested method is not allowed
90    #[display(fmt = "Requested method is not allowed")]
91    MethodNotAllowed,
92    /// One or more headers requested are not allowed
93    #[display(fmt = "One or more headers requested are not allowed")]
94    HeadersNotAllowed,
95}
96
97/// DefaultError renderer support
98impl WebResponseError<DefaultError> for CorsError {
99    fn status_code(&self) -> StatusCode {
100        StatusCode::BAD_REQUEST
101    }
102}
103
104/// An enum signifying that some of type T is allowed, or `All` (everything is
105/// allowed).
106///
107/// `Default` is implemented for this enum and is `All`.
108#[derive(Clone, Debug, Eq, PartialEq)]
109pub enum AllOrSome<T> {
110    /// Everything is allowed. Usually equivalent to the "*" value.
111    All,
112    /// Only some of `T` is allowed
113    Some(T),
114}
115
116impl<T> Default for AllOrSome<T> {
117    fn default() -> Self {
118        AllOrSome::All
119    }
120}
121
122impl<T> AllOrSome<T> {
123    /// Returns whether this is an `All` variant
124    pub fn is_all(&self) -> bool {
125        match *self {
126            AllOrSome::All => true,
127            AllOrSome::Some(_) => false,
128        }
129    }
130
131    /// Returns whether this is a `Some` variant
132    pub fn is_some(&self) -> bool {
133        !self.is_all()
134    }
135
136    /// Returns &T
137    pub fn as_ref(&self) -> Option<&T> {
138        match *self {
139            AllOrSome::All => None,
140            AllOrSome::Some(ref t) => Some(t),
141        }
142    }
143}
144
145/// Structure that follows the builder pattern for building `Cors` middleware
146/// structs.
147///
148/// To construct a cors:
149///
150///   1. Call [`Cors::build`](struct.Cors.html#method.build) to start building.
151///   2. Use any of the builder methods to set fields in the backend.
152/// 3. Call [finish](struct.Cors.html#method.finish) to retrieve the
153/// constructed backend.
154///
155/// # Example
156///
157/// ```rust
158/// use ntex_cors::Cors;
159/// use ntex::http::header;
160///
161/// let cors = Cors::new()
162///     .allowed_origin("https://www.rust-lang.org/")
163///     .allowed_methods(vec!["GET", "POST"])
164///     .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
165///     .allowed_header(header::CONTENT_TYPE)
166///     .max_age(3600);
167/// ```
168#[derive(Default)]
169pub struct Cors {
170    cors: Option<Inner>,
171    methods: bool,
172    expose_hdrs: HashSet<HeaderName>,
173    error: Option<HttpError>,
174}
175
176impl Cors {
177    /// Build a new CORS middleware instance
178    pub fn new() -> Self {
179        Cors {
180            cors: Some(Inner {
181                origins: AllOrSome::All,
182                origins_str: None,
183                methods: HashSet::new(),
184                headers: AllOrSome::All,
185                expose_hdrs: None,
186                max_age: None,
187                preflight: true,
188                send_wildcard: false,
189                supports_credentials: false,
190                vary_header: true,
191            }),
192            methods: false,
193            error: None,
194            expose_hdrs: HashSet::new(),
195        }
196    }
197
198    /// Build a new CORS default middleware
199    pub fn default<Err>() -> CorsFactory<Err> {
200        let inner = Inner {
201            origins: AllOrSome::default(),
202            origins_str: None,
203            methods: HashSet::from_iter(
204                vec![
205                    Method::GET,
206                    Method::HEAD,
207                    Method::POST,
208                    Method::OPTIONS,
209                    Method::PUT,
210                    Method::PATCH,
211                    Method::DELETE,
212                ]
213                .into_iter(),
214            ),
215            headers: AllOrSome::All,
216            expose_hdrs: None,
217            max_age: None,
218            preflight: true,
219            send_wildcard: false,
220            supports_credentials: false,
221            vary_header: true,
222        };
223        CorsFactory { inner: Rc::new(inner), _t: PhantomData }
224    }
225
226    /// Add an origin that are allowed to make requests.
227    /// Will be verified against the `Origin` request header.
228    ///
229    /// When `All` is set, and `send_wildcard` is set, "*" will be sent in
230    /// the `Access-Control-Allow-Origin` response header. Otherwise, the
231    /// client's `Origin` request header will be echoed back in the
232    /// `Access-Control-Allow-Origin` response header.
233    ///
234    /// When `Some` is set, the client's `Origin` request header will be
235    /// checked in a case-sensitive manner.
236    ///
237    /// This is the `list of origins` in the
238    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
239    ///
240    /// Defaults to `All`.
241    ///
242    /// Builder panics if supplied origin is not valid uri.
243    pub fn allowed_origin(mut self, origin: &str) -> Self {
244        if let Some(cors) = cors(&mut self.cors, &self.error) {
245            match Uri::try_from(origin) {
246                Ok(_) => {
247                    if cors.origins.is_all() {
248                        cors.origins = AllOrSome::Some(HashSet::new());
249                    }
250                    if let AllOrSome::Some(ref mut origins) = cors.origins {
251                        origins.insert(origin.to_owned());
252                    }
253                }
254                Err(e) => {
255                    self.error = Some(e.into());
256                }
257            }
258        }
259        self
260    }
261
262    /// Set a list of methods which the allowed origins are allowed to access
263    /// for requests.
264    ///
265    /// This is the `list of methods` in the
266    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
267    ///
268    /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
269    pub fn allowed_methods<U, M>(mut self, methods: U) -> Self
270    where
271        U: IntoIterator<Item = M>,
272        Method: TryFrom<M>,
273        <Method as TryFrom<M>>::Error: Into<HttpError>,
274    {
275        self.methods = true;
276        if let Some(cors) = cors(&mut self.cors, &self.error) {
277            for m in methods {
278                match Method::try_from(m) {
279                    Ok(method) => {
280                        cors.methods.insert(method);
281                    }
282                    Err(e) => {
283                        self.error = Some(e.into());
284                        break;
285                    }
286                }
287            }
288        }
289        self
290    }
291
292    /// Set an allowed header
293    pub fn allowed_header<H>(mut self, header: H) -> Self
294    where
295        HeaderName: TryFrom<H>,
296        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
297    {
298        if let Some(cors) = cors(&mut self.cors, &self.error) {
299            match HeaderName::try_from(header) {
300                Ok(method) => {
301                    if cors.headers.is_all() {
302                        cors.headers = AllOrSome::Some(HashSet::new());
303                    }
304                    if let AllOrSome::Some(ref mut headers) = cors.headers {
305                        headers.insert(method);
306                    }
307                }
308                Err(e) => self.error = Some(e.into()),
309            }
310        }
311        self
312    }
313
314    /// Set a list of header field names which can be used when
315    /// this resource is accessed by allowed origins.
316    ///
317    /// If `All` is set, whatever is requested by the client in
318    /// `Access-Control-Request-Headers` will be echoed back in the
319    /// `Access-Control-Allow-Headers` header.
320    ///
321    /// This is the `list of headers` in the
322    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
323    ///
324    /// Defaults to `All`.
325    pub fn allowed_headers<U, H>(mut self, headers: U) -> Self
326    where
327        U: IntoIterator<Item = H>,
328        HeaderName: TryFrom<H>,
329        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
330    {
331        if let Some(cors) = cors(&mut self.cors, &self.error) {
332            for h in headers {
333                match HeaderName::try_from(h) {
334                    Ok(method) => {
335                        if cors.headers.is_all() {
336                            cors.headers = AllOrSome::Some(HashSet::new());
337                        }
338                        if let AllOrSome::Some(ref mut headers) = cors.headers {
339                            headers.insert(method);
340                        }
341                    }
342                    Err(e) => {
343                        self.error = Some(e.into());
344                        break;
345                    }
346                }
347            }
348        }
349        self
350    }
351
352    /// Set a list of headers which are safe to expose to the API of a CORS API
353    /// specification. This corresponds to the
354    /// `Access-Control-Expose-Headers` response header.
355    ///
356    /// This is the `list of exposed headers` in the
357    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
358    ///
359    /// This defaults to an empty set.
360    pub fn expose_headers<U, H>(mut self, headers: U) -> Self
361    where
362        U: IntoIterator<Item = H>,
363        HeaderName: TryFrom<H>,
364        <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
365    {
366        for h in headers {
367            match HeaderName::try_from(h) {
368                Ok(method) => {
369                    self.expose_hdrs.insert(method);
370                }
371                Err(e) => {
372                    self.error = Some(e.into());
373                    break;
374                }
375            }
376        }
377        self
378    }
379
380    /// Set a maximum time for which this CORS request maybe cached.
381    /// This value is set as the `Access-Control-Max-Age` header.
382    ///
383    /// This defaults to `None` (unset).
384    pub fn max_age(mut self, max_age: usize) -> Self {
385        if let Some(cors) = cors(&mut self.cors, &self.error) {
386            cors.max_age = Some(max_age)
387        }
388        self
389    }
390
391    /// Set a wildcard origins
392    ///
393    /// If send wildcard is set and the `allowed_origins` parameter is `All`, a
394    /// wildcard `Access-Control-Allow-Origin` response header is sent,
395    /// rather than the request’s `Origin` header.
396    ///
397    /// This is the `supports credentials flag` in the
398    /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
399    ///
400    /// This **CANNOT** be used in conjunction with `allowed_origins` set to
401    /// `All` and `allow_credentials` set to `true`. Depending on the mode
402    /// of usage, this will either result in an `Error::
403    /// CredentialsWithWildcardOrigin` error during ntex launch or runtime.
404    ///
405    /// Defaults to `false`.
406    pub fn send_wildcard(mut self) -> Self {
407        if let Some(cors) = cors(&mut self.cors, &self.error) {
408            cors.send_wildcard = true
409        }
410        self
411    }
412
413    /// Allows users to make authenticated requests
414    ///
415    /// If true, injects the `Access-Control-Allow-Credentials` header in
416    /// responses. This allows cookies and credentials to be submitted
417    /// across domains.
418    ///
419    /// This option cannot be used in conjunction with an `allowed_origin` set
420    /// to `All` and `send_wildcards` set to `true`.
421    ///
422    /// Defaults to `false`.
423    ///
424    /// Builder panics if credentials are allowed, but the Origin is set to "*".
425    /// This is not allowed by W3C
426    pub fn supports_credentials(mut self) -> Self {
427        if let Some(cors) = cors(&mut self.cors, &self.error) {
428            cors.supports_credentials = true
429        }
430        self
431    }
432
433    /// Disable `Vary` header support.
434    ///
435    /// When enabled the header `Vary: Origin` will be returned as per the W3
436    /// implementation guidelines.
437    ///
438    /// Setting this header when the `Access-Control-Allow-Origin` is
439    /// dynamically generated (e.g. when there is more than one allowed
440    /// origin, and an Origin than '*' is returned) informs CDNs and other
441    /// caches that the CORS headers are dynamic, and cannot be cached.
442    ///
443    /// By default `vary` header support is enabled.
444    pub fn disable_vary_header(mut self) -> Self {
445        if let Some(cors) = cors(&mut self.cors, &self.error) {
446            cors.vary_header = false
447        }
448        self
449    }
450
451    /// Disable *preflight* request support.
452    ///
453    /// When enabled cors middleware automatically handles *OPTIONS* request.
454    /// This is useful application level middleware.
455    ///
456    /// By default *preflight* support is enabled.
457    pub fn disable_preflight(mut self) -> Self {
458        if let Some(cors) = cors(&mut self.cors, &self.error) {
459            cors.preflight = false
460        }
461        self
462    }
463
464    /// Construct cors middleware
465    pub fn finish<Err>(self) -> CorsFactory<Err> {
466        let mut slf = if !self.methods {
467            self.allowed_methods(vec![
468                Method::GET,
469                Method::HEAD,
470                Method::POST,
471                Method::OPTIONS,
472                Method::PUT,
473                Method::PATCH,
474                Method::DELETE,
475            ])
476        } else {
477            self
478        };
479
480        if let Some(e) = slf.error.take() {
481            panic!("{}", e);
482        }
483
484        let mut cors = slf.cors.take().expect("cannot reuse CorsBuilder");
485
486        if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() {
487            panic!("Credentials are allowed, but the Origin is set to \"*\"");
488        }
489
490        if let AllOrSome::Some(ref origins) = cors.origins {
491            let s = origins.iter().fold(String::new(), |s, v| format!("{}, {}", s, v));
492            cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap());
493        }
494
495        if !slf.expose_hdrs.is_empty() {
496            cors.expose_hdrs = Some(
497                HeaderValue::try_from(
498                    &slf.expose_hdrs
499                        .iter()
500                        .fold(String::new(), |s, v| format!("{}, {}", s, v.as_str()))[2..],
501                )
502                .unwrap(),
503            );
504        }
505
506        CorsFactory { inner: Rc::new(cors), _t: PhantomData }
507    }
508}
509
510fn cors<'a>(parts: &'a mut Option<Inner>, err: &Option<HttpError>) -> Option<&'a mut Inner> {
511    if err.is_some() {
512        return None;
513    }
514    parts.as_mut()
515}
516
517struct Inner {
518    methods: HashSet<Method>,
519    origins: AllOrSome<HashSet<String>>,
520    origins_str: Option<HeaderValue>,
521    headers: AllOrSome<HashSet<HeaderName>>,
522    expose_hdrs: Option<HeaderValue>,
523    max_age: Option<usize>,
524    preflight: bool,
525    send_wildcard: bool,
526    supports_credentials: bool,
527    vary_header: bool,
528}
529
530impl Inner {
531    fn validate_origin(&self, req: &RequestHead) -> Result<(), CorsError> {
532        if let Some(hdr) = req.headers().get(&header::ORIGIN) {
533            if let Ok(origin) = hdr.to_str() {
534                return match self.origins {
535                    AllOrSome::All => Ok(()),
536                    AllOrSome::Some(ref allowed_origins) => allowed_origins
537                        .get(origin)
538                        .map(|_| ())
539                        .ok_or(CorsError::OriginNotAllowed),
540                };
541            }
542            Err(CorsError::BadOrigin)
543        } else {
544            match self.origins {
545                AllOrSome::All => Ok(()),
546                _ => Err(CorsError::MissingOrigin),
547            }
548        }
549    }
550
551    fn access_control_allow_origin(&self, headers: &HeaderMap) -> Option<HeaderValue> {
552        match self.origins {
553            AllOrSome::All => {
554                if self.send_wildcard {
555                    Some(HeaderValue::from_static("*"))
556                } else {
557                    headers.get(&header::ORIGIN).cloned()
558                }
559            }
560            AllOrSome::Some(ref origins) => {
561                if let Some(origin) =
562                    headers.get(&header::ORIGIN).filter(|o| match o.to_str() {
563                        Ok(os) => origins.contains(os),
564                        _ => false,
565                    })
566                {
567                    Some(origin.clone())
568                } else {
569                    Some(self.origins_str.as_ref().unwrap().clone())
570                }
571            }
572        }
573    }
574
575    fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
576        if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_METHOD) {
577            if let Ok(meth) = hdr.to_str() {
578                if let Ok(method) = Method::try_from(meth) {
579                    return self
580                        .methods
581                        .get(&method)
582                        .map(|_| ())
583                        .ok_or(CorsError::MethodNotAllowed);
584                }
585            }
586            Err(CorsError::BadRequestMethod)
587        } else {
588            Err(CorsError::MissingRequestMethod)
589        }
590    }
591
592    fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
593        match self.headers {
594            AllOrSome::All => Ok(()),
595            AllOrSome::Some(ref allowed_headers) => {
596                if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) {
597                    if let Ok(headers) = hdr.to_str() {
598                        let mut hdrs = HashSet::new();
599                        for hdr in headers.split(',') {
600                            match HeaderName::try_from(hdr.trim()) {
601                                Ok(hdr) => hdrs.insert(hdr),
602                                Err(_) => return Err(CorsError::BadRequestHeaders),
603                            };
604                        }
605                        // `Access-Control-Request-Headers` must contain 1 or more
606                        // `field-name`.
607                        if !hdrs.is_empty() {
608                            if !hdrs.is_subset(allowed_headers) {
609                                return Err(CorsError::HeadersNotAllowed);
610                            }
611                            return Ok(());
612                        }
613                    }
614                    Err(CorsError::BadRequestHeaders)
615                } else {
616                    Ok(())
617                }
618            }
619        }
620    }
621
622    fn preflight_check(
623        &self,
624        req: &RequestHead,
625    ) -> Result<Either<HttpResponse, ()>, CorsError> {
626        if self.preflight && Method::OPTIONS == req.method {
627            self.validate_origin(req)
628                .and_then(|_| self.validate_allowed_method(req))
629                .and_then(|_| self.validate_allowed_headers(req))?;
630
631            // allowed headers
632            let headers = if let Some(headers) = self.headers.as_ref() {
633                Some(
634                    HeaderValue::try_from(
635                        &headers
636                            .iter()
637                            .fold(String::new(), |s, v| s + "," + v.as_str())
638                            .as_str()[1..],
639                    )
640                    .unwrap(),
641                )
642            } else {
643                req.headers.get(&header::ACCESS_CONTROL_REQUEST_HEADERS).cloned()
644            };
645
646            let res = HttpResponse::Ok()
647                .if_some(self.max_age.as_ref(), |max_age, resp| {
648                    let _ = resp.header(
649                        header::ACCESS_CONTROL_MAX_AGE,
650                        format!("{}", max_age).as_str(),
651                    );
652                })
653                .if_some(headers, |headers, resp| {
654                    let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers);
655                })
656                .if_some(self.access_control_allow_origin(req.headers()), |origin, resp| {
657                    let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
658                })
659                .if_true(self.supports_credentials, |resp| {
660                    resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
661                })
662                .header(
663                    header::ACCESS_CONTROL_ALLOW_METHODS,
664                    &self
665                        .methods
666                        .iter()
667                        .fold(String::new(), |s, v| s + "," + v.as_str())
668                        .as_str()[1..],
669                )
670                .finish()
671                .into_body();
672
673            Ok(Either::Left(res))
674        } else {
675            if req.headers.contains_key(&header::ORIGIN) {
676                // Only check requests with a origin header.
677                self.validate_origin(req)?;
678            }
679            Ok(Either::Right(()))
680        }
681    }
682
683    fn handle_response(&self, headers: &mut HeaderMap, allowed_origin: Option<HeaderValue>) {
684        if let Some(origin) = allowed_origin {
685            headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
686        };
687
688        if let Some(ref expose) = self.expose_hdrs {
689            headers.insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
690        }
691        if self.supports_credentials {
692            headers.insert(
693                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
694                HeaderValue::from_static("true"),
695            );
696        }
697        if self.vary_header {
698            let value = if let Some(hdr) = headers.get(&header::VARY) {
699                let mut val: Vec<u8> = Vec::with_capacity(hdr.as_bytes().len() + 8);
700                val.extend(hdr.as_bytes());
701                val.extend(b", Origin");
702                HeaderValue::try_from(&val[..]).unwrap()
703            } else {
704                HeaderValue::from_static("Origin")
705            };
706            headers.insert(header::VARY, value);
707        }
708    }
709}
710
711/// `Middleware` for Cross-origin resource sharing support
712///
713/// The Cors struct contains the settings for CORS requests to be validated and
714/// for responses to be generated.
715pub struct CorsFactory<Err> {
716    inner: Rc<Inner>,
717    _t: PhantomData<Err>,
718}
719
720impl<S, Err> Middleware<S> for CorsFactory<Err>
721where
722    S: Service<WebRequest<Err>, Response = WebResponse>,
723{
724    type Service = CorsMiddleware<S>;
725
726    fn create(&self, service: S) -> Self::Service {
727        CorsMiddleware { service, inner: self.inner.clone() }
728    }
729}
730
731/// `Middleware` for Cross-origin resource sharing support
732///
733/// The Cors struct contains the settings for CORS requests to be validated and
734/// for responses to be generated.
735#[derive(Clone)]
736pub struct CorsMiddleware<S> {
737    service: S,
738    inner: Rc<Inner>,
739}
740
741impl<S, Err> Service<WebRequest<Err>> for CorsMiddleware<S>
742where
743    S: Service<WebRequest<Err>, Response = WebResponse>,
744    Err: ErrorRenderer,
745    Err::Container: From<S::Error>,
746    CorsError: WebResponseError<Err>,
747{
748    type Response = WebResponse;
749    type Error = S::Error;
750    type Future<'f> = Either<
751        Ready<Result<Self::Response, S::Error>>,
752        LocalBoxFuture<'f, Result<Self::Response, S::Error>>,
753    > where Self: 'f;
754
755    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
756        self.service.poll_ready(cx)
757    }
758
759    fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
760        self.service.poll_shutdown(cx)
761    }
762
763    fn call<'a>(&'a self, req: WebRequest<Err>, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> {
764        match self.inner.preflight_check(req.head()) {
765            Ok(Either::Left(res)) => Either::Left(ok(req.into_response(res))),
766            Ok(Either::Right(_)) => {
767                let inner = self.inner.clone();
768                // let has_origin = req.headers().contains_key(&header::ORIGIN);
769                let allowed_origin = inner.access_control_allow_origin(req.headers());
770
771                Either::Right(
772                    async move {
773                        let mut res = ctx.call(&self.service, req).await?;
774
775                        // if has_origin {
776                        inner.handle_response(res.headers_mut(), allowed_origin);
777                        // }
778                        Ok(res)
779                    }
780                    .boxed_local(),
781                )
782            }
783            Err(e) => Either::Left(ok(req.render_error(e))),
784        }
785    }
786}
787
788#[cfg(test)]
789mod tests {
790    use ntex::service::{fn_service, Middleware, Pipeline};
791    use ntex::web::{self, test, test::TestRequest};
792
793    use super::*;
794
795    #[ntex::test]
796    #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
797    async fn cors_validates_illegal_allow_credentials() {
798        let _cors =
799            Cors::new().supports_credentials().send_wildcard().finish::<web::DefaultError>();
800    }
801
802    #[ntex::test]
803    async fn validate_origin_allows_all_origins() {
804        let cors = Cors::new().finish().create(test::ok_service()).into();
805        let req =
806            TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
807
808        let resp = test::call_service(&cors, req).await;
809        assert_eq!(resp.status(), StatusCode::OK);
810    }
811
812    #[ntex::test]
813    async fn default() {
814        let cors = Cors::default().create(test::ok_service()).into();
815        let req =
816            TestRequest::with_header("Origin", "https://www.example.com").to_srv_request();
817
818        let resp = test::call_service(&cors, req).await;
819        assert_eq!(resp.status(), StatusCode::OK);
820    }
821
822    #[ntex::test]
823    async fn test_preflight() {
824        let mut cors: Pipeline<_> = Cors::new()
825            .send_wildcard()
826            .max_age(3600)
827            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
828            .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
829            .allowed_header(header::CONTENT_TYPE)
830            .finish()
831            .create(test::ok_service())
832            .into();
833
834        let req = TestRequest::with_header("Origin", "https://www.example.com")
835            .method(Method::OPTIONS)
836            .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed")
837            .to_srv_request();
838
839        assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
840        assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_err());
841        let resp = test::call_service(&cors, req).await;
842        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
843
844        let req = TestRequest::with_header("Origin", "https://www.example.com")
845            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put")
846            .method(Method::OPTIONS)
847            .to_srv_request();
848
849        assert!(cors.get_ref().inner.validate_allowed_method(req.head()).is_err());
850        assert!(cors.get_ref().inner.validate_allowed_headers(req.head()).is_ok());
851
852        let req = TestRequest::with_header("Origin", "https://www.example.com")
853            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
854            .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
855            .method(Method::OPTIONS)
856            .to_srv_request();
857
858        let resp = test::call_service(&cors, req).await;
859        assert_eq!(
860            &b"*"[..],
861            resp.headers().get(&header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
862        );
863        assert_eq!(
864            &b"3600"[..],
865            resp.headers().get(&header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes()
866        );
867        let hdr = resp
868            .headers()
869            .get(&header::ACCESS_CONTROL_ALLOW_HEADERS)
870            .unwrap()
871            .to_str()
872            .unwrap();
873        assert!(hdr.contains("authorization"));
874        assert!(hdr.contains("accept"));
875        assert!(hdr.contains("content-type"));
876
877        let methods =
878            resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().to_str().unwrap();
879        assert!(methods.contains("POST"));
880        assert!(methods.contains("GET"));
881        assert!(methods.contains("OPTIONS"));
882
883        // Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
884
885        // let req = TestRequest::with_header("Origin", "https://www.example.com")
886        //     .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
887        //     .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT")
888        //     .method(Method::OPTIONS)
889        //     .to_srv_request();
890
891        // let resp = test::call_service(&cors, req).await;
892        // assert_eq!(resp.status(), StatusCode::OK);
893    }
894
895    // #[ntex::test]
896    // #[should_panic(expected = "MissingOrigin")]
897    // async fn test_validate_missing_origin() {
898    //    let cors = Cors::build()
899    //        .allowed_origin("https://www.example.com")
900    //        .finish();
901    //    let mut req = HttpRequest::default();
902    //    cors.start(&req).unwrap();
903    // }
904
905    #[ntex::test]
906    #[should_panic(expected = "OriginNotAllowed")]
907    async fn test_validate_not_allowed_origin() {
908        let cors: Pipeline<_> = Cors::new()
909            .allowed_origin("https://www.example.com")
910            .finish()
911            .create(test::ok_service::<web::DefaultError>())
912            .into();
913
914        let req = TestRequest::with_header("Origin", "https://www.unknown.com")
915            .method(Method::GET)
916            .to_srv_request();
917        cors.get_ref().inner.validate_origin(req.head()).unwrap();
918        cors.get_ref().inner.validate_allowed_method(req.head()).unwrap();
919        cors.get_ref().inner.validate_allowed_headers(req.head()).unwrap();
920    }
921
922    #[ntex::test]
923    async fn test_validate_origin() {
924        let cors = Cors::new()
925            .allowed_origin("https://www.example.com")
926            .finish()
927            .create(test::ok_service())
928            .into();
929
930        let req = TestRequest::with_header("Origin", "https://www.example.com")
931            .method(Method::GET)
932            .to_srv_request();
933
934        let resp = test::call_service(&cors, req).await;
935        assert_eq!(resp.status(), StatusCode::OK);
936    }
937
938    #[ntex::test]
939    async fn test_no_origin_response() {
940        let cors = Cors::new().disable_preflight().finish().create(test::ok_service()).into();
941
942        let req = TestRequest::default().method(Method::GET).to_srv_request();
943        let resp = test::call_service(&cors, req).await;
944        assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
945
946        let req = TestRequest::with_header("Origin", "https://www.example.com")
947            .method(Method::OPTIONS)
948            .to_srv_request();
949        let resp = test::call_service(&cors, req).await;
950        assert_eq!(
951            &b"https://www.example.com"[..],
952            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
953        );
954    }
955
956    #[ntex::test]
957    async fn test_response() {
958        let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
959        let cors = Cors::new()
960            .send_wildcard()
961            .disable_preflight()
962            .max_age(3600)
963            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
964            .allowed_headers(exposed_headers.clone())
965            .expose_headers(exposed_headers.clone())
966            .allowed_header(header::CONTENT_TYPE)
967            .finish()
968            .create(test::ok_service())
969            .into();
970
971        let req = TestRequest::with_header("Origin", "https://www.example.com")
972            .method(Method::OPTIONS)
973            .to_srv_request();
974
975        let resp = test::call_service(&cors, req).await;
976        assert_eq!(
977            &b"*"[..],
978            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
979        );
980        assert_eq!(&b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes());
981
982        {
983            let headers = resp
984                .headers()
985                .get(header::ACCESS_CONTROL_EXPOSE_HEADERS)
986                .unwrap()
987                .to_str()
988                .unwrap()
989                .split(',')
990                .map(|s| s.trim())
991                .collect::<Vec<&str>>();
992
993            for h in exposed_headers {
994                assert!(headers.contains(&h.as_str()));
995            }
996        }
997
998        let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
999        let cors =
1000            Cors::new()
1001                .send_wildcard()
1002                .disable_preflight()
1003                .max_age(3600)
1004                .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
1005                .allowed_headers(exposed_headers.clone())
1006                .expose_headers(exposed_headers.clone())
1007                .allowed_header(header::CONTENT_TYPE)
1008                .finish()
1009                .create(fn_service(|req: WebRequest<DefaultError>| {
1010                    ok::<_, std::convert::Infallible>(req.into_response(
1011                        HttpResponse::Ok().header(header::VARY, "Accept").finish(),
1012                    ))
1013                }))
1014                .into();
1015        let req = TestRequest::with_header("Origin", "https://www.example.com")
1016            .method(Method::OPTIONS)
1017            .to_srv_request();
1018        let resp = test::call_service(&cors, req).await;
1019        assert_eq!(
1020            &b"Accept, Origin"[..],
1021            resp.headers().get(header::VARY).unwrap().as_bytes()
1022        );
1023
1024        let cors = Cors::new()
1025            .disable_vary_header()
1026            .allowed_origin("https://www.example.com")
1027            .allowed_origin("https://www.google.com")
1028            .finish()
1029            .create(test::ok_service())
1030            .into();
1031
1032        let req = TestRequest::with_header("Origin", "https://www.example.com")
1033            .method(Method::OPTIONS)
1034            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
1035            .to_srv_request();
1036        let resp = test::call_service(&cors, req).await;
1037
1038        let origins_str =
1039            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().to_str().unwrap();
1040
1041        assert_eq!("https://www.example.com", origins_str);
1042    }
1043
1044    #[ntex::test]
1045    async fn test_multiple_origins() {
1046        let cors = Cors::new()
1047            .allowed_origin("https://example.com")
1048            .allowed_origin("https://example.org")
1049            .allowed_methods(vec![Method::GET])
1050            .finish()
1051            .create(test::ok_service())
1052            .into();
1053
1054        let req = TestRequest::with_header("Origin", "https://example.com")
1055            .method(Method::GET)
1056            .to_srv_request();
1057
1058        let resp = test::call_service(&cors, req).await;
1059        assert_eq!(
1060            &b"https://example.com"[..],
1061            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1062        );
1063
1064        let req = TestRequest::with_header("Origin", "https://example.org")
1065            .method(Method::GET)
1066            .to_srv_request();
1067
1068        let resp = test::call_service(&cors, req).await;
1069        assert_eq!(
1070            &b"https://example.org"[..],
1071            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1072        );
1073    }
1074
1075    #[ntex::test]
1076    async fn test_multiple_origins_preflight() {
1077        let cors = Cors::new()
1078            .allowed_origin("https://example.com")
1079            .allowed_origin("https://example.org")
1080            .allowed_methods(vec![Method::GET])
1081            .finish()
1082            .create(test::ok_service())
1083            .into();
1084
1085        let req = TestRequest::with_header("Origin", "https://example.com")
1086            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1087            .method(Method::OPTIONS)
1088            .to_srv_request();
1089
1090        let resp = test::call_service(&cors, req).await;
1091        assert_eq!(
1092            &b"https://example.com"[..],
1093            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1094        );
1095
1096        let req = TestRequest::with_header("Origin", "https://example.org")
1097            .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
1098            .method(Method::OPTIONS)
1099            .to_srv_request();
1100
1101        let resp = test::call_service(&cors, req).await;
1102        assert_eq!(
1103            &b"https://example.org"[..],
1104            resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes()
1105        );
1106    }
1107}