actix_cors/
builder.rs

1use std::{collections::HashSet, rc::Rc};
2
3use actix_utils::future::{self, Ready};
4use actix_web::{
5    body::{EitherBody, MessageBody},
6    dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
7    error::HttpError,
8    http::{
9        header::{HeaderName, HeaderValue},
10        Method, Uri,
11    },
12    Either, Error, Result,
13};
14use log::error;
15use once_cell::sync::Lazy;
16use smallvec::smallvec;
17
18use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn};
19
20/// Convenience for getting mut refs to inner. Cleaner than `Rc::get_mut`.
21/// Additionally, always causes first error (if any) to be reported during initialization.
22fn cors<'a>(
23    inner: &'a mut Rc<Inner>,
24    err: &Option<Either<HttpError, CorsError>>,
25) -> Option<&'a mut Inner> {
26    if err.is_some() {
27        return None;
28    }
29
30    Rc::get_mut(inner)
31}
32
33static ALL_METHODS_SET: Lazy<HashSet<Method>> = Lazy::new(|| {
34    HashSet::from_iter(vec![
35        Method::GET,
36        Method::POST,
37        Method::PUT,
38        Method::DELETE,
39        Method::HEAD,
40        Method::OPTIONS,
41        Method::CONNECT,
42        Method::PATCH,
43        Method::TRACE,
44    ])
45});
46
47/// Builder for CORS middleware.
48///
49/// To construct a CORS middleware, call [`Cors::default()`] to create a blank, restrictive builder.
50/// Then use any of the builder methods to customize CORS behavior.
51///
52/// The alternative [`Cors::permissive()`] constructor is available for local development, allowing
53/// all origins and headers, etc. **The permissive constructor should not be used in production.**
54///
55/// # Behavior
56///
57/// In all cases, behavior for this crate follows the [Fetch Standard CORS protocol]. See that
58/// document for information on exact semantics for configuration options and combinations.
59///
60/// # Errors
61///
62/// Errors surface in the middleware initialization phase. This means that, if you have logs enabled
63/// in Actix Web (using `env_logger` or other crate that exposes logs from the `log` crate), error
64/// messages will outline what is wrong with the CORS configuration in the server logs and the
65/// server will fail to start up or serve requests.
66///
67/// # Example
68///
69/// ```
70/// use actix_cors::Cors;
71/// use actix_web::http::header;
72///
73/// let cors = Cors::default()
74///     .allowed_origin("https://www.rust-lang.org")
75///     .allowed_methods(vec!["GET", "POST"])
76///     .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
77///     .allowed_header(header::CONTENT_TYPE)
78///     .max_age(3600);
79///
80/// // `cors` can now be used in `App::wrap`.
81/// ```
82///
83/// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
84#[derive(Debug)]
85#[must_use]
86pub struct Cors {
87    inner: Rc<Inner>,
88    error: Option<Either<HttpError, CorsError>>,
89}
90
91impl Cors {
92    /// Constructs a very permissive set of defaults for quick development. (Not recommended for
93    /// production use.)
94    ///
95    /// *All* origins, methods, request headers and exposed headers allowed. Credentials supported.
96    /// Max age 1 hour. Does not send wildcard.
97    pub fn permissive() -> Self {
98        let inner = Inner {
99            allowed_origins: AllOrSome::All,
100            allowed_origins_fns: smallvec![],
101
102            allowed_methods: ALL_METHODS_SET.clone(),
103            allowed_methods_baked: None,
104
105            allowed_headers: AllOrSome::All,
106            allowed_headers_baked: None,
107
108            expose_headers: AllOrSome::All,
109            expose_headers_baked: None,
110
111            max_age: Some(3600),
112            preflight: true,
113            send_wildcard: false,
114            supports_credentials: true,
115            #[cfg(feature = "draft-private-network-access")]
116            allow_private_network_access: false,
117            vary_header: true,
118            block_on_origin_mismatch: false,
119        };
120
121        Cors {
122            inner: Rc::new(inner),
123            error: None,
124        }
125    }
126
127    /// Resets allowed origin list to a state where any origin is accepted.
128    ///
129    /// See [`Cors::allowed_origin`] for more info on allowed origins.
130    pub fn allow_any_origin(mut self) -> Cors {
131        if let Some(cors) = cors(&mut self.inner, &self.error) {
132            cors.allowed_origins = AllOrSome::All;
133        }
134
135        self
136    }
137
138    /// Adds an origin that is allowed to make requests.
139    ///
140    /// This method allows specifying a finite set of origins to verify the value of the `Origin`
141    /// request header. These are `origin-or-null` types in the [Fetch Standard].
142    ///
143    /// By default, no origins are accepted.
144    ///
145    /// When this list is set, the client's `Origin` request header will be checked in a
146    /// case-sensitive manner.
147    ///
148    /// When all origins are allowed and `send_wildcard` is set, `*` will be sent in the
149    /// `Access-Control-Allow-Origin` response header. If `send_wildcard` is not set, the client's
150    /// `Origin` request header will be echoed back in the `Access-Control-Allow-Origin`
151    /// response header.
152    ///
153    /// If the origin of the request doesn't match any allowed origins and at least one
154    /// `allowed_origin_fn` function is set, these functions will be used to determinate
155    /// allowed origins.
156    ///
157    /// # Initialization Errors
158    /// - If supplied origin is not valid uri
159    /// - If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead.
160    ///
161    /// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header
162    pub fn allowed_origin(mut self, origin: &str) -> Cors {
163        if let Some(cors) = cors(&mut self.inner, &self.error) {
164            match TryInto::<Uri>::try_into(origin) {
165                Ok(_) if origin == "*" => {
166                    error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
167                    self.error = Some(Either::Right(CorsError::WildcardOrigin));
168                }
169
170                Ok(_) => {
171                    if cors.allowed_origins.is_all() {
172                        cors.allowed_origins = AllOrSome::Some(HashSet::with_capacity(8));
173                    }
174
175                    if let Some(origins) = cors.allowed_origins.as_mut() {
176                        // any uri is a valid header value
177                        let hv = origin.try_into().unwrap();
178                        origins.insert(hv);
179                    }
180                }
181
182                Err(err) => {
183                    self.error = Some(Either::Left(err.into()));
184                }
185            }
186        }
187
188        self
189    }
190
191    /// Determinates allowed origins by processing requests which didn't match any origins specified
192    /// in the `allowed_origin`.
193    ///
194    /// The function will receive two parameters, the Origin header value, and the `RequestHead` of
195    /// each request, which can be used to determine whether to allow the request or not.
196    ///
197    /// If the function returns `true`, the client's `Origin` request header will be echoed back
198    /// into the `Access-Control-Allow-Origin` response header.
199    pub fn allowed_origin_fn<F>(mut self, f: F) -> Cors
200    where
201        F: (Fn(&HeaderValue, &RequestHead) -> bool) + 'static,
202    {
203        if let Some(cors) = cors(&mut self.inner, &self.error) {
204            cors.allowed_origins_fns.push(OriginFn {
205                boxed_fn: Rc::new(f),
206            });
207        }
208
209        self
210    }
211
212    /// Resets allowed methods list to all methods.
213    ///
214    /// See [`Cors::allowed_methods`] for more info on allowed methods.
215    pub fn allow_any_method(mut self) -> Cors {
216        if let Some(cors) = cors(&mut self.inner, &self.error) {
217            ALL_METHODS_SET.clone_into(&mut cors.allowed_methods);
218        }
219
220        self
221    }
222
223    /// Sets a list of methods which allowed origins can perform.
224    ///
225    /// These will be sent in the `Access-Control-Allow-Methods` response header.
226    ///
227    /// This defaults to an empty set.
228    pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
229    where
230        U: IntoIterator<Item = M>,
231        M: TryInto<Method>,
232        <M as TryInto<Method>>::Error: Into<HttpError>,
233    {
234        if let Some(cors) = cors(&mut self.inner, &self.error) {
235            for m in methods {
236                match m.try_into() {
237                    Ok(method) => {
238                        cors.allowed_methods.insert(method);
239                    }
240
241                    Err(err) => {
242                        self.error = Some(Either::Left(err.into()));
243                        break;
244                    }
245                }
246            }
247        }
248
249        self
250    }
251
252    /// Resets allowed request header list to a state where any header is accepted.
253    ///
254    /// See [`Cors::allowed_headers`] for more info on allowed request headers.
255    pub fn allow_any_header(mut self) -> Cors {
256        if let Some(cors) = cors(&mut self.inner, &self.error) {
257            cors.allowed_headers = AllOrSome::All;
258        }
259
260        self
261    }
262
263    /// Add an allowed request header.
264    ///
265    /// See [`Cors::allowed_headers`] for more info on allowed request headers.
266    pub fn allowed_header<H>(mut self, header: H) -> Cors
267    where
268        H: TryInto<HeaderName>,
269        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
270    {
271        if let Some(cors) = cors(&mut self.inner, &self.error) {
272            match header.try_into() {
273                Ok(method) => {
274                    if cors.allowed_headers.is_all() {
275                        cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
276                    }
277
278                    if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
279                        headers.insert(method);
280                    }
281                }
282
283                Err(err) => self.error = Some(Either::Left(err.into())),
284            }
285        }
286
287        self
288    }
289
290    /// Sets a list of request header field names which can be used when this resource is accessed
291    /// by allowed origins.
292    ///
293    /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
294    /// will be echoed back in the `Access-Control-Allow-Headers` header.
295    ///
296    /// This defaults to an empty set.
297    pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
298    where
299        U: IntoIterator<Item = H>,
300        H: TryInto<HeaderName>,
301        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
302    {
303        if let Some(cors) = cors(&mut self.inner, &self.error) {
304            for h in headers {
305                match h.try_into() {
306                    Ok(method) => {
307                        if cors.allowed_headers.is_all() {
308                            cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
309                        }
310
311                        if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
312                            headers.insert(method);
313                        }
314                    }
315                    Err(err) => {
316                        self.error = Some(Either::Left(err.into()));
317                        break;
318                    }
319                }
320            }
321        }
322
323        self
324    }
325
326    /// Resets exposed response header list to a state where all headers are exposed.
327    ///
328    /// See [`Cors::expose_headers`] for more info on exposed response headers.
329    pub fn expose_any_header(mut self) -> Cors {
330        if let Some(cors) = cors(&mut self.inner, &self.error) {
331            cors.expose_headers = AllOrSome::All;
332        }
333
334        self
335    }
336
337    /// Sets a list of headers which are safe to expose to the API of a CORS API specification.
338    ///
339    /// This corresponds to the `Access-Control-Expose-Headers` response header.
340    ///
341    /// This defaults to an empty set.
342    pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
343    where
344        U: IntoIterator<Item = H>,
345        H: TryInto<HeaderName>,
346        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
347    {
348        for h in headers {
349            match h.try_into() {
350                Ok(header) => {
351                    if let Some(cors) = cors(&mut self.inner, &self.error) {
352                        if cors.expose_headers.is_all() {
353                            cors.expose_headers = AllOrSome::Some(HashSet::with_capacity(8));
354                        }
355                        if let AllOrSome::Some(ref mut headers) = cors.expose_headers {
356                            headers.insert(header);
357                        }
358                    }
359                }
360                Err(err) => {
361                    self.error = Some(Either::Left(err.into()));
362                    break;
363                }
364            }
365        }
366
367        self
368    }
369
370    /// Sets a maximum time (in seconds) for which this CORS request may be cached.
371    ///
372    /// This value is set as the `Access-Control-Max-Age` header.
373    ///
374    /// Pass a number (of seconds) or use None to disable sending max age header.
375    pub fn max_age(mut self, max_age: impl Into<Option<usize>>) -> Cors {
376        if let Some(cors) = cors(&mut self.inner, &self.error) {
377            cors.max_age = max_age.into();
378        }
379
380        self
381    }
382
383    /// Configures use of wildcard (`*`) origin in responses when appropriate.
384    ///
385    /// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard
386    /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s
387    /// `Origin` header.
388    ///
389    /// This option **CANNOT** be used in conjunction with a [credential
390    /// supported](Self::supports_credentials()) configuration. Doing so will result in an error
391    /// during server startup.
392    ///
393    /// Defaults to disabled.
394    pub fn send_wildcard(mut self) -> Cors {
395        if let Some(cors) = cors(&mut self.inner, &self.error) {
396            cors.send_wildcard = true;
397        }
398
399        self
400    }
401
402    /// Allows users to make authenticated requests.
403    ///
404    /// If true, injects the `Access-Control-Allow-Credentials` header in responses. This allows
405    /// cookies and credentials to be submitted across domains.
406    ///
407    /// This option **CANNOT** be used in conjunction with option cannot be used in conjunction
408    /// with [wildcard origins](Self::send_wildcard()) configured. Doing so will result in an error
409    /// during server startup.
410    ///
411    /// Defaults to disabled.
412    pub fn supports_credentials(mut self) -> Cors {
413        if let Some(cors) = cors(&mut self.inner, &self.error) {
414            cors.supports_credentials = true;
415        }
416
417        self
418    }
419
420    /// Allow private network access.
421    ///
422    /// If true, injects the `Access-Control-Allow-Private-Network: true` header in responses if the
423    /// request contained the `Access-Control-Request-Private-Network: true` header.
424    ///
425    /// For more information on this behavior, see the draft [Private Network Access] spec.
426    ///
427    /// Defaults to `false`.
428    ///
429    /// [Private Network Access]: https://wicg.github.io/private-network-access
430    #[cfg(feature = "draft-private-network-access")]
431    pub fn allow_private_network_access(mut self) -> Cors {
432        if let Some(cors) = cors(&mut self.inner, &self.error) {
433            cors.allow_private_network_access = true;
434        }
435
436        self
437    }
438
439    /// Disables `Vary` header support.
440    ///
441    /// When enabled the header `Vary: Origin` will be returned as per the Fetch Standard
442    /// implementation guidelines.
443    ///
444    /// Setting this header when the `Access-Control-Allow-Origin` is dynamically generated
445    /// (eg. when there is more than one allowed origin, and an Origin other than '*' is returned)
446    /// informs CDNs and other caches that the CORS headers are dynamic, and cannot be cached.
447    ///
448    /// By default, `Vary` header support is enabled.
449    pub fn disable_vary_header(mut self) -> Cors {
450        if let Some(cors) = cors(&mut self.inner, &self.error) {
451            cors.vary_header = false;
452        }
453
454        self
455    }
456
457    /// Disables preflight request handling.
458    ///
459    /// When enabled CORS middleware automatically handles `OPTIONS` requests. This is useful for
460    /// application level middleware.
461    ///
462    /// By default, preflight support is enabled.
463    pub fn disable_preflight(mut self) -> Cors {
464        if let Some(cors) = cors(&mut self.inner, &self.error) {
465            cors.preflight = false;
466        }
467
468        self
469    }
470
471    /// Configures whether requests should be pre-emptively blocked on mismatched origin.
472    ///
473    /// If `true`, a 400 Bad Request is returned immediately when a request fails origin validation.
474    ///
475    /// If `false`, the request will be processed as normal but relevant CORS headers will not be
476    /// appended to the response. In this case, the browser is trusted to validate CORS headers and
477    /// and block requests based on pre-flight requests. Use this setting to allow cURL and other
478    /// non-browser HTTP clients to function as normal, no matter what `Origin` the request has.
479    ///
480    /// Defaults to false.
481    pub fn block_on_origin_mismatch(mut self, block: bool) -> Cors {
482        if let Some(cors) = cors(&mut self.inner, &self.error) {
483            cors.block_on_origin_mismatch = block;
484        }
485
486        self
487    }
488}
489
490impl Default for Cors {
491    /// A restrictive (security paranoid) set of defaults.
492    ///
493    /// *No* allowed origins, methods, request headers or exposed headers. Credentials
494    /// not supported. No max age (will use browser's default).
495    fn default() -> Cors {
496        let inner = Inner {
497            allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)),
498            allowed_origins_fns: smallvec![],
499
500            allowed_methods: HashSet::with_capacity(8),
501            allowed_methods_baked: None,
502
503            allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)),
504            allowed_headers_baked: None,
505
506            expose_headers: AllOrSome::Some(HashSet::with_capacity(8)),
507            expose_headers_baked: None,
508
509            max_age: None,
510            preflight: true,
511            send_wildcard: false,
512            supports_credentials: false,
513            #[cfg(feature = "draft-private-network-access")]
514            allow_private_network_access: false,
515            vary_header: true,
516            block_on_origin_mismatch: false,
517        };
518
519        Cors {
520            inner: Rc::new(inner),
521            error: None,
522        }
523    }
524}
525
526impl<S, B> Transform<S, ServiceRequest> for Cors
527where
528    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
529    S::Future: 'static,
530
531    B: MessageBody + 'static,
532{
533    type Response = ServiceResponse<EitherBody<B>>;
534    type Error = Error;
535    type InitError = ();
536    type Transform = CorsMiddleware<S>;
537    type Future = Ready<Result<Self::Transform, Self::InitError>>;
538
539    fn new_transform(&self, service: S) -> Self::Future {
540        if let Some(ref err) = self.error {
541            match err {
542                Either::Left(err) => error!("{}", err),
543                Either::Right(err) => error!("{}", err),
544            }
545
546            return future::err(());
547        }
548
549        let mut inner = Rc::clone(&self.inner);
550
551        if inner.supports_credentials && inner.send_wildcard && inner.allowed_origins.is_all() {
552            error!(
553                "Illegal combination of CORS options: credentials can not be supported when all \
554                    origins are allowed and `send_wildcard` is enabled."
555            );
556            return future::err(());
557        }
558
559        // bake allowed headers value if Some and not empty
560        match inner.allowed_headers.as_ref() {
561            Some(header_set) if !header_set.is_empty() => {
562                let allowed_headers_str = intersperse_header_values(header_set);
563                Rc::make_mut(&mut inner).allowed_headers_baked = Some(allowed_headers_str);
564            }
565            _ => {}
566        }
567
568        // bake allowed methods value if not empty
569        if !inner.allowed_methods.is_empty() {
570            let allowed_methods_str = intersperse_header_values(&inner.allowed_methods);
571            Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str);
572        }
573
574        // bake exposed headers value if Some and not empty
575        match inner.expose_headers.as_ref() {
576            Some(header_set) if !header_set.is_empty() => {
577                let expose_headers_str = intersperse_header_values(header_set);
578                Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str);
579            }
580            _ => {}
581        }
582
583        future::ok(CorsMiddleware { service, inner })
584    }
585}
586
587/// Only call when values are guaranteed to be valid header values and set is not empty.
588pub(crate) fn intersperse_header_values<T>(val_set: &HashSet<T>) -> HeaderValue
589where
590    T: AsRef<str>,
591{
592    debug_assert!(
593        !val_set.is_empty(),
594        "only call `intersperse_header_values` when set is not empty"
595    );
596
597    val_set
598        .iter()
599        .fold(String::with_capacity(64), |mut acc, val| {
600            acc.push_str(", ");
601            acc.push_str(val.as_ref());
602            acc
603        })
604        // set is not empty so string will always have leading ", " to trim
605        [2..]
606        .try_into()
607        // all method names are valid header values
608        .unwrap()
609}
610
611impl PartialEq for Cors {
612    fn eq(&self, other: &Self) -> bool {
613        self.inner == other.inner
614        // Because of the cors-function, checking if the content is equal implies that the errors are equal
615        //
616        // Proof by contradiction:
617        // Lets assume that the inner values are equal, but the error values are not.
618        // This means there had been an error, which has been fixed.
619        // This cannot happen as the first call to set the invalid value means that further usages of the cors-function will reject other input.
620        // => inner has to be in a different state
621    }
622}
623
624#[cfg(test)]
625mod test {
626    use std::convert::Infallible;
627
628    use actix_web::{
629        body,
630        dev::fn_service,
631        http::StatusCode,
632        test::{self, TestRequest},
633        HttpResponse,
634    };
635
636    use super::*;
637
638    #[test]
639    fn illegal_allow_credentials() {
640        // using the permissive defaults (all origins allowed) and adding send_wildcard
641        // and supports_credentials should error on construction
642
643        assert!(Cors::permissive()
644            .supports_credentials()
645            .send_wildcard()
646            .new_transform(test::ok_service())
647            .into_inner()
648            .is_err());
649    }
650
651    #[actix_web::test]
652    async fn restrictive_defaults() {
653        let cors = Cors::default()
654            .new_transform(test::ok_service())
655            .await
656            .unwrap();
657
658        let req = TestRequest::default()
659            .insert_header(("Origin", "https://www.example.com"))
660            .to_srv_request();
661
662        let res = test::call_service(&cors, req).await;
663        assert_eq!(res.status(), StatusCode::OK);
664        assert!(!res.headers().contains_key("Access-Control-Allow-Origin"));
665    }
666
667    #[actix_web::test]
668    async fn allowed_header_try_from() {
669        let _cors = Cors::default().allowed_header("Content-Type");
670    }
671
672    #[actix_web::test]
673    async fn allowed_header_try_into() {
674        struct ContentType;
675
676        impl TryInto<HeaderName> for ContentType {
677            type Error = Infallible;
678
679            fn try_into(self) -> Result<HeaderName, Self::Error> {
680                Ok(HeaderName::from_static("content-type"))
681            }
682        }
683
684        let _cors = Cors::default().allowed_header(ContentType);
685    }
686
687    #[actix_web::test]
688    async fn middleware_generic_over_body_type() {
689        let srv = fn_service(|req: ServiceRequest| async move {
690            Ok(req.into_response(HttpResponse::with_body(StatusCode::OK, body::None::new())))
691        });
692
693        Cors::default().new_transform(srv).await.unwrap();
694    }
695
696    #[test]
697    fn impl_eq() {
698        assert_eq!(Cors::default(), Cors::default());
699        assert_ne!(Cors::default().send_wildcard(), Cors::default());
700        assert_ne!(Cors::default(), Cors::permissive());
701    }
702}