use gotham::{
handler::HandlerFuture,
helpers::http::response::create_empty_response,
hyper::{
header::{
HeaderMap, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS,
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS,
ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY
},
Body, Method, Response, StatusCode
},
middleware::Middleware,
pipeline::PipelineHandleChain,
prelude::*,
router::{builder::ExtendRouteMatcher, route::matcher::AccessControlRequestMethodMatcher},
state::State
};
use std::{panic::RefUnwindSafe, pin::Pin};
#[derive(Clone, Debug)]
pub enum Origin {
None,
Star,
Single(String),
Copy
}
impl Default for Origin {
fn default() -> Self {
Self::None
}
}
impl Origin {
fn header_value(&self, state: &State) -> Option<HeaderValue> {
match self {
Self::None => None,
Self::Star => Some("*".parse().unwrap()),
Self::Single(origin) => Some(origin.parse().unwrap()),
Self::Copy => {
let headers = HeaderMap::borrow_from(state);
headers.get(ORIGIN).cloned()
}
}
}
fn varies(&self) -> bool {
matches!(self, Self::Copy)
}
}
#[derive(Clone, Debug)]
pub enum Headers {
None,
List(Vec<HeaderName>),
Copy
}
impl Default for Headers {
fn default() -> Self {
Self::None
}
}
impl Headers {
fn header_value(&self, state: &State) -> Option<HeaderValue> {
match self {
Self::None => None,
Self::List(list) => Some(list.join(",").parse().unwrap()),
Self::Copy => {
let headers = HeaderMap::borrow_from(state);
headers.get(ACCESS_CONTROL_REQUEST_HEADERS).cloned()
}
}
}
fn varies(&self) -> bool {
matches!(self, Self::Copy)
}
}
#[derive(Clone, Debug, Default, NewMiddleware, StateData)]
pub struct CorsConfig {
pub origin: Origin,
pub headers: Headers,
pub max_age: u64,
pub credentials: bool
}
impl Middleware for CorsConfig {
fn call<Chain>(self, mut state: State, chain: Chain) -> Pin<Box<HandlerFuture>>
where
Chain: FnOnce(State) -> Pin<Box<HandlerFuture>>
{
state.put(self);
chain(state)
}
}
pub fn handle_cors(state: &State, res: &mut Response<Body>) {
let config = CorsConfig::try_borrow_from(state);
if let Some(cfg) = config {
let headers = res.headers_mut();
if let Some(header) = cfg.origin.header_value(state) {
headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, header);
}
if cfg.origin.varies() {
let vary = headers
.get(VARY)
.map(|vary| format!("{},origin", vary.to_str().unwrap()));
headers.insert(VARY, vary.as_deref().unwrap_or("origin").parse().unwrap());
}
if cfg.credentials {
headers.insert(
ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true")
);
}
}
}
pub trait CorsRoute<C, P>
where
C: PipelineHandleChain<P> + Copy + Send + Sync + 'static,
P: RefUnwindSafe + Send + Sync + 'static
{
fn cors(&mut self, path: &str, method: Method);
}
pub(crate) fn cors_preflight_handler(state: State) -> (State, Response<Body>) {
let config = CorsConfig::try_borrow_from(&state);
let mut res = create_empty_response(&state, StatusCode::NO_CONTENT);
let headers = res.headers_mut();
let mut vary: Vec<HeaderName> = Vec::new();
let method = HeaderMap::borrow_from(&state)
.get(ACCESS_CONTROL_REQUEST_METHOD)
.unwrap()
.clone();
headers.insert(ACCESS_CONTROL_ALLOW_METHODS, method);
vary.push(ACCESS_CONTROL_REQUEST_METHOD);
if let Some(cfg) = config {
if let Some(header) = cfg.headers.header_value(&state) {
headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, header);
}
if cfg.headers.varies() {
vary.push(ACCESS_CONTROL_REQUEST_HEADERS);
}
if let Some(age) = config.map(|cfg| cfg.max_age) {
headers.insert(ACCESS_CONTROL_MAX_AGE, age.into());
}
}
headers.insert(VARY, vary.join(",").parse().unwrap());
handle_cors(&state, &mut res);
(state, res)
}
impl<D, C, P> CorsRoute<C, P> for D
where
D: DrawRoutes<C, P>,
C: PipelineHandleChain<P> + Copy + Send + Sync + 'static,
P: RefUnwindSafe + Send + Sync + 'static
{
fn cors(&mut self, path: &str, method: Method) {
let matcher = AccessControlRequestMethodMatcher::new(method);
self.options(path)
.extend_route_matcher(matcher)
.to(cors_preflight_handler);
}
}