use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc};
use actix_web::{
dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
error::{Error, Result},
http::{self, header::HeaderName, Error as HttpError, HeaderValue, Method, Uri},
Either,
};
use futures_util::future::{self, Ready};
use log::error;
use once_cell::sync::Lazy;
use tinyvec::tiny_vec;
use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn};
fn cors<'a>(
inner: &'a mut Rc<Inner>,
err: &Option<Either<http::Error, CorsError>>,
) -> Option<&'a mut Inner> {
if err.is_some() {
return None;
}
Rc::get_mut(inner)
}
static ALL_METHODS_SET: Lazy<HashSet<Method>> = Lazy::new(|| {
HashSet::from_iter(vec![
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::HEAD,
Method::OPTIONS,
Method::CONNECT,
Method::PATCH,
Method::TRACE,
])
});
#[derive(Debug)]
pub struct Cors {
inner: Rc<Inner>,
error: Option<Either<http::Error, CorsError>>,
}
impl Cors {
pub fn permissive() -> Self {
let inner = Inner {
allowed_origins: AllOrSome::All,
allowed_origins_fns: tiny_vec![],
allowed_methods: ALL_METHODS_SET.clone(),
allowed_methods_baked: None,
allowed_headers: AllOrSome::All,
allowed_headers_baked: None,
expose_headers: AllOrSome::All,
expose_headers_baked: None,
max_age: Some(3600),
preflight: true,
send_wildcard: false,
supports_credentials: true,
vary_header: true,
};
Cors {
inner: Rc::new(inner),
error: None,
}
}
pub fn allow_any_origin(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_origins = AllOrSome::All;
}
self
}
pub fn allowed_origin(mut self, origin: &str) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
match TryInto::<Uri>::try_into(origin) {
Ok(_) if origin == "*" => {
error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
self.error = Some(Either::B(CorsError::WildcardOrigin));
}
Ok(_) => {
if cors.allowed_origins.is_all() {
cors.allowed_origins =
AllOrSome::Some(HashSet::with_capacity(8));
}
if let Some(origins) = cors.allowed_origins.as_mut() {
let hv = origin.try_into().unwrap();
origins.insert(hv);
}
}
Err(err) => {
self.error = Some(Either::A(err.into()));
}
}
}
self
}
pub fn allowed_origin_fn<F>(mut self, f: F) -> Cors
where
F: (Fn(&HeaderValue, &RequestHead) -> bool) + 'static,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_origins_fns.push(OriginFn {
boxed_fn: Rc::new(f),
});
}
self
}
pub fn allow_any_method(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_methods = ALL_METHODS_SET.clone();
}
self
}
pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
where
U: IntoIterator<Item = M>,
M: TryInto<Method>,
<M as TryInto<Method>>::Error: Into<HttpError>,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
for m in methods {
match m.try_into() {
Ok(method) => {
cors.allowed_methods.insert(method);
}
Err(err) => {
self.error = Some(Either::A(err.into()));
break;
}
}
}
}
self
}
pub fn allow_any_header(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_origins = AllOrSome::All;
}
self
}
pub fn allowed_header<H>(mut self, header: H) -> Cors
where
H: TryInto<HeaderName>,
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
match header.try_into() {
Ok(method) => {
if cors.allowed_headers.is_all() {
cors.allowed_headers =
AllOrSome::Some(HashSet::with_capacity(8));
}
if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
headers.insert(method);
}
}
Err(err) => self.error = Some(Either::A(err.into())),
}
}
self
}
pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
where
U: IntoIterator<Item = H>,
H: TryInto<HeaderName>,
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
{
if let Some(cors) = cors(&mut self.inner, &self.error) {
for h in headers {
match h.try_into() {
Ok(method) => {
if cors.allowed_headers.is_all() {
cors.allowed_headers =
AllOrSome::Some(HashSet::with_capacity(8));
}
if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
headers.insert(method);
}
}
Err(err) => {
self.error = Some(Either::A(err.into()));
break;
}
}
}
}
self
}
pub fn expose_any_header(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.allowed_origins = AllOrSome::All;
}
self
}
pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
where
U: IntoIterator<Item = H>,
H: TryInto<HeaderName>,
<H as TryInto<HeaderName>>::Error: Into<HttpError>,
{
for h in headers {
match h.try_into() {
Ok(header) => {
if let Some(cors) = cors(&mut self.inner, &self.error) {
if cors.expose_headers.is_all() {
cors.expose_headers =
AllOrSome::Some(HashSet::with_capacity(8));
}
if let AllOrSome::Some(ref mut headers) = cors.expose_headers {
headers.insert(header);
}
}
}
Err(err) => {
self.error = Some(Either::A(err.into()));
break;
}
}
}
self
}
pub fn max_age(mut self, max_age: impl Into<Option<usize>>) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.max_age = max_age.into()
}
self
}
pub fn send_wildcard(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.send_wildcard = true
}
self
}
pub fn supports_credentials(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.supports_credentials = true
}
self
}
pub fn disable_vary_header(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.vary_header = false
}
self
}
pub fn disable_preflight(mut self) -> Cors {
if let Some(cors) = cors(&mut self.inner, &self.error) {
cors.preflight = false
}
self
}
}
impl Default for Cors {
fn default() -> Cors {
let inner = Inner {
allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)),
allowed_origins_fns: tiny_vec![],
allowed_methods: HashSet::with_capacity(8),
allowed_methods_baked: None,
allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)),
allowed_headers_baked: None,
expose_headers: AllOrSome::Some(HashSet::with_capacity(8)),
expose_headers_baked: None,
max_age: None,
preflight: true,
send_wildcard: false,
supports_credentials: false,
vary_header: true,
};
Cors {
inner: Rc::new(inner),
error: None,
}
}
}
impl<S, B> Transform<S> for Cors
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = CorsMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
if let Some(ref err) = self.error {
match err {
Either::A(err) => error!("{}", err),
Either::B(err) => error!("{}", err),
}
return future::err(());
}
let mut inner = Rc::clone(&self.inner);
if inner.supports_credentials
&& inner.send_wildcard
&& inner.allowed_origins.is_all()
{
error!("Illegal combination of CORS options: credentials can not be supported when all \
origins are allowed and `send_wildcard` is enabled.");
return future::err(());
}
match inner.allowed_headers.as_ref() {
Some(header_set) if !header_set.is_empty() => {
let allowed_headers_str = intersperse_header_values(header_set);
Rc::make_mut(&mut inner).allowed_headers_baked =
Some(allowed_headers_str);
}
_ => {}
}
if !inner.allowed_methods.is_empty() {
let allowed_methods_str = intersperse_header_values(&inner.allowed_methods);
Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str);
}
match inner.expose_headers.as_ref() {
Some(header_set) if !header_set.is_empty() => {
let expose_headers_str = intersperse_header_values(header_set);
Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str);
}
_ => {}
}
future::ok(CorsMiddleware { service, inner })
}
}
fn intersperse_header_values<T>(val_set: &HashSet<T>) -> HeaderValue
where
T: AsRef<str>,
{
val_set
.iter()
.fold(String::with_capacity(32), |mut acc, val| {
acc.push_str(", ");
acc.push_str(val.as_ref());
acc
})
[2..]
.try_into()
.unwrap()
}
#[cfg(test)]
mod test {
use std::convert::{Infallible, TryInto};
use actix_web::{
dev::Transform,
http::{HeaderName, StatusCode},
test::{self, TestRequest},
};
use super::*;
#[test]
fn illegal_allow_credentials() {
assert!(Cors::permissive()
.supports_credentials()
.send_wildcard()
.new_transform(test::ok_service())
.into_inner()
.is_err());
}
#[actix_rt::test]
async fn restrictive_defaults() {
let mut cors = Cors::default()
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request();
let resp = test::call_service(&mut cors, req).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[actix_rt::test]
async fn allowed_header_try_from() {
let _cors = Cors::default().allowed_header("Content-Type");
}
#[actix_rt::test]
async fn allowed_header_try_into() {
struct ContentType;
impl TryInto<HeaderName> for ContentType {
type Error = Infallible;
fn try_into(self) -> Result<HeaderName, Self::Error> {
Ok(HeaderName::from_static("content-type"))
}
}
let _cors = Cors::default().allowed_header(ContentType);
}
}