use crate::http::header::{HeaderName, HeaderValue, ORIGIN, VARY};
use crate::http::{Method, StatusCode};
use crate::{async_trait, Context, Middleware, Next, Result};
use headers::{
AccessControlAllowCredentials, AccessControlAllowHeaders, AccessControlAllowMethods,
AccessControlAllowOrigin, AccessControlExposeHeaders, AccessControlMaxAge,
AccessControlRequestHeaders, AccessControlRequestMethod, Header, HeaderMapExt,
};
use roa_core::Status;
use std::collections::HashSet;
use std::convert::TryInto;
use std::fmt::Debug;
use std::iter::FromIterator;
use std::time::Duration;
#[derive(Debug, Default)]
pub struct Cors {
allow_origin: Option<AccessControlAllowOrigin>,
allow_methods: Option<AccessControlAllowMethods>,
expose_headers: Option<AccessControlExposeHeaders>,
allow_headers: Option<AccessControlAllowHeaders>,
max_age: Option<AccessControlMaxAge>,
credentials: Option<AccessControlAllowCredentials>,
}
#[derive(Clone, Debug, Default)]
pub struct Builder {
credentials: bool,
allowed_headers: HashSet<HeaderName>,
exposed_headers: HashSet<HeaderName>,
max_age: Option<u64>,
methods: HashSet<Method>,
origins: Option<HeaderValue>,
}
impl Cors {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> Builder {
Builder::default()
}
}
impl Builder {
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.credentials = allow;
self
}
pub fn allow_method(mut self, method: Method) -> Self {
self.methods.insert(method);
self
}
pub fn allow_methods(mut self, methods: impl IntoIterator<Item = Method>) -> Self {
self.methods.extend(methods);
self
}
pub fn allow_header<H>(mut self, header: H) -> Self
where
H: TryInto<HeaderName>,
H::Error: Debug,
{
self.allowed_headers
.insert(header.try_into().expect("invalid header"));
self
}
pub fn allow_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator,
I::Item: TryInto<HeaderName>,
<I::Item as TryInto<HeaderName>>::Error: Debug,
{
let iter = headers
.into_iter()
.map(|h| h.try_into().expect("invalid header"));
self.allowed_headers.extend(iter);
self
}
pub fn expose_header<H>(mut self, header: H) -> Self
where
H: TryInto<HeaderName>,
H::Error: Debug,
{
self.exposed_headers
.insert(header.try_into().expect("illegal Header"));
self
}
pub fn expose_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator,
I::Item: TryInto<HeaderName>,
<I::Item as TryInto<HeaderName>>::Error: Debug,
{
let iter = headers
.into_iter()
.map(|h| h.try_into().expect("illegal Header"));
self.exposed_headers.extend(iter);
self
}
pub fn allow_origin<H>(mut self, origin: H) -> Self
where
H: TryInto<HeaderValue>,
H::Error: Debug,
{
self.origins = Some(origin.try_into().expect("invalid origin"));
self
}
pub fn max_age(mut self, seconds: u64) -> Self {
self.max_age = Some(seconds);
self
}
pub fn build(self) -> Cors {
let Builder {
allowed_headers,
credentials,
exposed_headers,
max_age,
origins,
methods,
} = self;
let mut cors = Cors::default();
if !allowed_headers.is_empty() {
cors.allow_headers =
Some(AccessControlAllowHeaders::from_iter(allowed_headers))
}
if credentials {
cors.credentials = Some(AccessControlAllowCredentials)
}
if !exposed_headers.is_empty() {
cors.expose_headers =
Some(AccessControlExposeHeaders::from_iter(exposed_headers))
}
if let Some(age) = max_age {
cors.max_age = Some(Duration::from_secs(age).into())
}
if origins.is_some() {
cors.allow_origin = Some(
AccessControlAllowOrigin::decode(&mut origins.iter())
.expect("invalid origins"),
);
}
if !methods.is_empty() {
cors.allow_methods = Some(AccessControlAllowMethods::from_iter(methods))
}
cors
}
}
#[async_trait(?Send)]
impl<'a, S> Middleware<'a, S> for Cors {
#[inline]
async fn handle(&'a self, ctx: &'a mut Context<S>, next: Next<'a>) -> Result {
ctx.resp.headers.append(VARY, ORIGIN.into());
let origin = match ctx.req.headers.get(ORIGIN) {
None => return next.await,
Some(origin) => AccessControlAllowOrigin::decode(
&mut Some(origin).into_iter(),
)
.map_err(|err| {
Status::new(
StatusCode::BAD_REQUEST,
format!("invalid origin: {}", err),
true,
)
})?,
};
let allow_origin = self.allow_origin.clone().unwrap_or(origin);
let credentials = self.credentials.clone();
let insert_origin_and_credentials = move |ctx: &mut Context<S>| {
ctx.resp.headers.typed_insert(allow_origin);
if let Some(credentials) = credentials {
ctx.resp.headers.typed_insert(credentials);
}
};
if ctx.method() != Method::OPTIONS {
insert_origin_and_credentials(ctx);
if let Some(ref exposed_headers) = self.expose_headers {
ctx.resp.headers.typed_insert(exposed_headers.clone());
}
next.await
} else {
let request_method =
match ctx.req.headers.typed_get::<AccessControlRequestMethod>() {
None => return next.await,
Some(request_method) => request_method,
};
let allow_methods = match self.allow_methods {
Some(ref origin) => origin.clone(),
None => {
AccessControlAllowMethods::from_iter(Some(request_method.into()))
}
};
ctx.resp.headers.typed_insert(allow_methods);
insert_origin_and_credentials(ctx);
if let Some(ref max_age) = self.max_age {
ctx.resp.headers.typed_insert(max_age.clone());
}
let allow_headers = self.allow_headers.clone().or_else(|| {
ctx.req
.headers
.typed_get::<AccessControlRequestHeaders>()
.map(|headers| AccessControlAllowHeaders::from_iter(headers.iter()))
});
if let Some(headers) = allow_headers {
ctx.resp.headers.typed_insert(headers);
};
ctx.resp.status = StatusCode::NO_CONTENT;
Ok(())
}
}
}
#[cfg(all(test, feature = "tcp"))]
mod tests {
use super::Cors;
use crate::http::header::{
ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS,
ACCESS_CONTROL_REQUEST_METHOD, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_TYPE,
ORIGIN, VARY, WWW_AUTHENTICATE,
};
use crate::http::{HeaderValue, Method, StatusCode};
use crate::preload::*;
use crate::{App, Context};
use async_std::task::spawn;
use headers::{
AccessControlAllowCredentials, AccessControlAllowOrigin,
AccessControlExposeHeaders, HeaderMapExt, HeaderName,
};
async fn end(ctx: &mut Context) -> crate::Result {
ctx.resp.write("Hello, World");
Ok(())
}
#[tokio::test]
async fn default_cors() -> Result<(), Box<dyn std::error::Error>> {
let (addr, server) = App::new().gate(Cors::new()).end(end).run()?;
spawn(server);
let client = reqwest::Client::new();
let resp = client.get(&format!("http://{}", addr)).send().await?;
assert_eq!(StatusCode::OK, resp.status());
assert!(resp
.headers()
.typed_get::<AccessControlAllowOrigin>()
.is_none());
assert_eq!(
HeaderValue::from_name(ORIGIN),
resp.headers().get(VARY).unwrap()
);
assert_eq!("Hello, World", resp.text().await?);
let resp = client
.get(&format!("http://{}", addr))
.header(ORIGIN, "github.com")
.send()
.await?;
assert_eq!(StatusCode::BAD_REQUEST, resp.status());
let resp = client
.get(&format!("http://{}", addr))
.header(ORIGIN, "http://github.com")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
let allow_origin = resp
.headers()
.typed_get::<AccessControlAllowOrigin>()
.unwrap();
let origin = allow_origin.origin().unwrap();
assert_eq!("http", origin.scheme());
assert_eq!("github.com", origin.hostname());
assert!(origin.port().is_none());
assert!(resp
.headers()
.typed_get::<AccessControlAllowCredentials>()
.is_none());
assert!(resp
.headers()
.typed_get::<AccessControlExposeHeaders>()
.is_none());
assert_eq!("Hello, World", resp.text().await?);
let resp = client
.request(Method::OPTIONS, &format!("http://{}", addr))
.header(ORIGIN, "http://github.com")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
assert!(resp.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
assert_eq!(
HeaderValue::from_name(ORIGIN),
resp.headers().get(VARY).unwrap()
);
assert_eq!("Hello, World", resp.text().await?);
let resp = client
.request(Method::OPTIONS, &format!("http://{}", addr))
.header(ORIGIN, "http://github.com")
.header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(
ACCESS_CONTROL_REQUEST_HEADERS,
HeaderValue::from_name(CONTENT_TYPE),
)
.send()
.await?;
assert_eq!(StatusCode::NO_CONTENT, resp.status());
assert_eq!(
"http://github.com",
resp.headers()
.get(ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.to_str()?
);
assert!(resp
.headers()
.get(ACCESS_CONTROL_ALLOW_CREDENTIALS)
.is_none());
assert!(resp.headers().get(ACCESS_CONTROL_MAX_AGE).is_none());
assert_eq!(
"POST",
resp.headers()
.get(ACCESS_CONTROL_ALLOW_METHODS)
.unwrap()
.to_str()?
);
assert_eq!(
HeaderValue::from_name(CONTENT_TYPE),
resp.headers().get(ACCESS_CONTROL_ALLOW_HEADERS).unwrap()
);
assert_eq!("", resp.text().await?);
Ok(())
}
#[tokio::test]
async fn configured_cors() -> Result<(), Box<dyn std::error::Error>> {
let configured_cors = Cors::builder()
.allow_credentials(true)
.max_age(86400)
.allow_origin("https://github.com")
.allow_methods(vec![Method::GET, Method::POST])
.allow_method(Method::PUT)
.expose_headers(vec![CONTENT_DISPOSITION])
.expose_header(WWW_AUTHENTICATE)
.allow_headers(vec![AUTHORIZATION])
.allow_header(CONTENT_TYPE)
.build();
let (addr, server) = App::new().gate(configured_cors).end(end).run()?;
spawn(server);
let client = reqwest::Client::new();
let resp = client.get(&format!("http://{}", addr)).send().await?;
assert_eq!(StatusCode::OK, resp.status());
assert!(resp
.headers()
.typed_get::<AccessControlAllowOrigin>()
.is_none());
assert_eq!(
HeaderValue::from_name(ORIGIN),
resp.headers().get(VARY).unwrap()
);
assert_eq!("Hello, World", resp.text().await?);
let resp = client
.get(&format!("http://{}", addr))
.header(ORIGIN, "github.com")
.send()
.await?;
assert_eq!(StatusCode::BAD_REQUEST, resp.status());
let resp = client
.get(&format!("http://{}", addr))
.header(ORIGIN, "http://github.io")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
let allow_origin = resp
.headers()
.typed_get::<AccessControlAllowOrigin>()
.unwrap();
let origin = allow_origin.origin().unwrap();
assert_eq!("https", origin.scheme());
assert_eq!("github.com", origin.hostname());
assert!(origin.port().is_none());
assert!(resp
.headers()
.typed_get::<AccessControlAllowCredentials>()
.is_some());
let expose_headers = resp
.headers()
.typed_get::<AccessControlExposeHeaders>()
.unwrap();
let headers = expose_headers.iter().collect::<Vec<HeaderName>>();
assert!(headers.contains(&CONTENT_DISPOSITION));
assert!(headers.contains(&WWW_AUTHENTICATE));
assert_eq!("Hello, World", resp.text().await?);
let resp = client
.request(Method::OPTIONS, &format!("http://{}", addr))
.header(ORIGIN, "http://github.com")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
assert!(resp.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
assert_eq!(
HeaderValue::from_name(ORIGIN),
resp.headers().get(VARY).unwrap()
);
assert_eq!("Hello, World", resp.text().await?);
let resp = client
.request(Method::OPTIONS, &format!("http://{}", addr))
.header(ORIGIN, "http://github.io")
.header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(
ACCESS_CONTROL_REQUEST_HEADERS,
HeaderValue::from_name(CONTENT_TYPE),
)
.send()
.await?;
assert_eq!(StatusCode::NO_CONTENT, resp.status());
assert_eq!(
"https://github.com",
resp.headers()
.get(ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap()
.to_str()?
);
assert_eq!(
"true",
resp.headers()
.get(ACCESS_CONTROL_ALLOW_CREDENTIALS)
.unwrap()
.to_str()?
);
assert_eq!("86400", resp.headers().get(ACCESS_CONTROL_MAX_AGE).unwrap());
let allow_methods = resp
.headers()
.get(ACCESS_CONTROL_ALLOW_METHODS)
.unwrap()
.to_str()?;
assert!(allow_methods.contains("POST"));
assert!(allow_methods.contains("GET"));
assert!(allow_methods.contains("PUT"));
let allow_headers = resp
.headers()
.get(ACCESS_CONTROL_ALLOW_HEADERS)
.unwrap()
.to_str()?;
assert!(allow_headers.contains(CONTENT_TYPE.as_str()));
assert!(allow_headers.contains(AUTHORIZATION.as_str()));
assert_eq!("", resp.text().await?);
Ok(())
}
}