use std::{collections::HashSet, iter::FromIterator, rc::Rc};
use actix_utils::future::{self, Ready};
use actix_web::{
body::{EitherBody, MessageBody},
dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
error::HttpError,
http::{
header::{HeaderName, HeaderValue},
Method, Uri,
},
Either, Error, Result,
};
use log::error;
use once_cell::sync::Lazy;
use smallvec::smallvec;
use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn};
fn cors<'a>(
inner: &'a mut Rc<Inner>,
err: &Option<Either<HttpError, CorsError>>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
Rc::get_mut(inner)
}
static ALL_METHODS_SET: Lazy<HashSet<Method>> = Lazy::new(|| {
HashSet::from_iter(vec![
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::HEAD,
Method::OPTIONS,
Method::CONNECT,
Method::PATCH,
Method::TRACE,
])
});
#[derive(Debug)]
#[must_use]
pub struct Cors {
inner: Rc<Inner>,
error: Option<Either<HttpError, CorsError>>,
}
impl Cors {
pub fn permissive() -> Self {
let inner = Inner {
allowed_origins: AllOrSome::All,
allowed_origins_fns: smallvec![],
allowed_methods: ALL_METHODS_SET.clone(),
allowed_methods_baked: None,
allowed_headers: AllOrSome::All,
allowed_headers_baked: None,
expose_headers: AllOrSome::All,
expose_headers_baked: None,
max_age: Some(3600),
preflight: true,
send_wildcard: false,
supports_credentials: true,
#[cfg(feature = "draft-private-network-access")]
allow_private_network_access: false,
vary_header: true,
block_on_origin_mismatch: false,
};
Cors {
inner: Rc::new(inner),
error: None,
}
}
pub fn allow_any_origin(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_origins = AllOrSome::All;
}
self
}
pub fn allowed_origin(mut self, origin: &str) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
match TryInto::<Uri>::try_into(origin) {
Ok(_) if origin == "*" => {
error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
self.error = Some(Either::Right(CorsError::WildcardOrigin));
}
Ok(_) => {
if cors.allowed_origins.is_all() {
cors.allowed_origins = AllOrSome::Some(HashSet::with_capacity(8));
}
if let Some(origins) = cors.allowed_origins.as_mut() {
let hv = origin.try_into().unwrap();
origins.insert(hv);
}
}
Err(err) => {
self.error = Some(Either::Left(err.into()));
}
}
}
self
}
pub fn allowed_origin_fn<F>(mut self, f: F) -> Cors
where
F: (Fn(&HeaderValue, &RequestHead) -> bool) + 'static,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_origins_fns.push(OriginFn {
boxed_fn: Rc::new(f),
});
}
self
}
pub fn allow_any_method(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_methods = ALL_METHODS_SET.clone();
}
self
}
pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
where
U: IntoIterator<Item = M>,
M: TryInto<Method>,
<M as TryInto<Method>>::Error: Into<HttpError>,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
for m in methods {
match m.try_into() {
Ok(method) => {
cors.allowed_methods.insert(method);
}
Err(err) => {
self.error = Some(Either::Left(err.into()));
break;
}
}
}
}
self
}
pub fn allow_any_header(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_headers = AllOrSome::All;
}
self
}
pub fn allowed_header<H>(mut self, header: H) -> Cors
where
H: TryInto<HeaderName>,
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
match header.try_into() {
Ok(method) => {
if cors.allowed_headers.is_all() {
cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
}
if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
headers.insert(method);
}
}
Err(err) => self.error = Some(Either::Left(err.into())),
}
}
self
}
pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
where
U: IntoIterator<Item = H>,
H: TryInto<HeaderName>,
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
for h in headers {
match h.try_into() {
Ok(method) => {
if cors.allowed_headers.is_all() {
cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
}
if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
headers.insert(method);
}
}
Err(err) => {
self.error = Some(Either::Left(err.into()));
break;
}
}
}
}
self
}
pub fn expose_any_header(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.expose_headers = AllOrSome::All;
}
self
}
pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
where
U: IntoIterator<Item = H>,
H: TryInto<HeaderName>,
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
{
for h in headers {
match h.try_into() {
Ok(header) => {
if let Some(cors) = cors(&mut self.inner, &self.error) {
if cors.expose_headers.is_all() {
cors.expose_headers = AllOrSome::Some(HashSet::with_capacity(8));
}
if let AllOrSome::Some(ref mut headers) = cors.expose_headers {
headers.insert(header);
}
}
}
Err(err) => {
self.error = Some(Either::Left(err.into()));
break;
}
}
}
self
}
pub fn max_age(mut self, max_age: impl Into<Option<usize>>) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.max_age = max_age.into();
}
self
}
pub fn send_wildcard(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.send_wildcard = true;
}
self
}
pub fn supports_credentials(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.supports_credentials = true;
}
self
}
#[cfg(feature = "draft-private-network-access")]
pub fn allow_private_network_access(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allow_private_network_access = true;
}
self
}
pub fn disable_vary_header(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.vary_header = false;
}
self
}
pub fn disable_preflight(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.preflight = false;
}
self
}
pub fn block_on_origin_mismatch(mut self, block: bool) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.block_on_origin_mismatch = block;
}
self
}
}
impl Default for Cors {
fn default() -> Cors {
let inner = Inner {
allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)),
allowed_origins_fns: smallvec![],
allowed_methods: HashSet::with_capacity(8),
allowed_methods_baked: None,
allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)),
allowed_headers_baked: None,
expose_headers: AllOrSome::Some(HashSet::with_capacity(8)),
expose_headers_baked: None,
max_age: None,
preflight: true,
send_wildcard: false,
supports_credentials: false,
#[cfg(feature = "draft-private-network-access")]
allow_private_network_access: false,
vary_header: true,
block_on_origin_mismatch: false,
};
Cors {
inner: Rc::new(inner),
error: None,
}
}
}
impl<S, B> Transform<S, ServiceRequest> for Cors
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type InitError = ();
type Transform = CorsMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
if let Some(ref err) = self.error {
match err {
Either::Left(err) => error!("{}", err),
Either::Right(err) => error!("{}", err),
}
return future::err(());
}
let mut inner = Rc::clone(&self.inner);
if inner.supports_credentials && inner.send_wildcard && inner.allowed_origins.is_all() {
error!(
"Illegal combination of CORS options: credentials can not be supported when all \
origins are allowed and `send_wildcard` is enabled."
);
return future::err(());
}
match inner.allowed_headers.as_ref() {
Some(header_set) if !header_set.is_empty() => {
let allowed_headers_str = intersperse_header_values(header_set);
Rc::make_mut(&mut inner).allowed_headers_baked = Some(allowed_headers_str);
}
_ => {}
}
if !inner.allowed_methods.is_empty() {
let allowed_methods_str = intersperse_header_values(&inner.allowed_methods);
Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str);
}
match inner.expose_headers.as_ref() {
Some(header_set) if !header_set.is_empty() => {
let expose_headers_str = intersperse_header_values(header_set);
Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str);
}
_ => {}
}
future::ok(CorsMiddleware { service, inner })
}
}
pub(crate) fn intersperse_header_values<T>(val_set: &HashSet<T>) -> HeaderValue
where
T: AsRef<str>,
{
debug_assert!(
!val_set.is_empty(),
"only call `intersperse_header_values` when set is not empty"
);
val_set
.iter()
.fold(String::with_capacity(64), |mut acc, val| {
acc.push_str(", ");
acc.push_str(val.as_ref());
acc
})
[2..]
.try_into()
.unwrap()
}
#[cfg(test)]
mod test {
use std::convert::Infallible;
use actix_web::{
body,
dev::{fn_service, Transform},
http::{header::HeaderName, StatusCode},
test::{self, TestRequest},
HttpResponse,
};
use super::*;
#[test]
fn illegal_allow_credentials() {
assert!(Cors::permissive()
.supports_credentials()
.send_wildcard()
.new_transform(test::ok_service())
.into_inner()
.is_err());
}
#[actix_web::test]
async fn restrictive_defaults() {
let cors = Cors::default()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default()
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
let res = test::call_service(&cors, req).await;
assert_eq!(res.status(), StatusCode::OK);
assert!(!res.headers().contains_key("Access-Control-Allow-Origin"));
}
#[actix_web::test]
async fn allowed_header_try_from() {
let _cors = Cors::default().allowed_header("Content-Type");
}
#[actix_web::test]
async fn allowed_header_try_into() {
struct ContentType;
impl TryInto<HeaderName> for ContentType {
type Error = Infallible;
fn try_into(self) -> Result<HeaderName, Self::Error> {
Ok(HeaderName::from_static("content-type"))
}
}
let _cors = Cors::default().allowed_header(ContentType);
}
#[actix_web::test]
async fn middleware_generic_over_body_type() {
let srv = fn_service(|req: ServiceRequest| async move {
Ok(req.into_response(HttpResponse::with_body(StatusCode::OK, body::None::new())))
});
Cors::default().new_transform(srv).await.unwrap();
}
}