use bytes::Bytes;
use http::{header, Method, StatusCode};
use http_body_util::Full;
use rustapi_core::middleware::{BoxedNext, MiddlewareLayer};
use rustapi_core::{Request, Response, ResponseBody};
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
#[derive(Debug, Clone)]
pub enum AllowedOrigins {
Any,
List(Vec<String>),
}
impl Default for AllowedOrigins {
fn default() -> Self {
Self::List(Vec::new())
}
}
#[derive(Debug, Clone)]
pub struct CorsLayer {
origins: AllowedOrigins,
methods: Vec<Method>,
headers: Vec<String>,
credentials: bool,
max_age: Option<Duration>,
}
impl Default for CorsLayer {
fn default() -> Self {
Self::new()
}
}
impl CorsLayer {
pub fn new() -> Self {
Self {
origins: AllowedOrigins::default(),
methods: vec![Method::GET, Method::HEAD, Method::OPTIONS],
headers: Vec::new(),
credentials: false,
max_age: None,
}
}
pub fn permissive() -> Self {
Self {
origins: AllowedOrigins::Any,
methods: vec![
Method::GET,
Method::POST,
Method::PUT,
Method::PATCH,
Method::DELETE,
Method::HEAD,
Method::OPTIONS,
],
headers: vec!["*".to_string()],
credentials: false,
max_age: Some(Duration::from_secs(86400)),
}
}
pub fn restrictive() -> Self {
Self::new()
}
pub fn allow_any_origin(mut self) -> Self {
self.origins = AllowedOrigins::Any;
self
}
pub fn allow_origins<I, S>(mut self, origins: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.origins = AllowedOrigins::List(origins.into_iter().map(Into::into).collect());
self
}
pub fn allow_methods<I>(mut self, methods: I) -> Self
where
I: IntoIterator<Item = Method>,
{
self.methods = methods.into_iter().collect();
self
}
pub fn allow_headers<I, S>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.headers = headers.into_iter().map(Into::into).collect();
self
}
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.credentials = allow;
self
}
pub fn max_age(mut self, duration: Duration) -> Self {
self.max_age = Some(duration);
self
}
pub fn origins(&self) -> &AllowedOrigins {
&self.origins
}
pub fn methods(&self) -> &[Method] {
&self.methods
}
pub fn headers(&self) -> &[String] {
&self.headers
}
pub fn credentials(&self) -> bool {
self.credentials
}
pub fn max_age_duration(&self) -> Option<Duration> {
self.max_age
}
fn methods_header_value(&self) -> String {
self.methods
.iter()
.map(|m| m.as_str())
.collect::<Vec<_>>()
.join(", ")
}
fn headers_header_value(&self) -> String {
if self.headers.is_empty() {
"Content-Type, Authorization".to_string()
} else {
self.headers.join(", ")
}
}
}
impl MiddlewareLayer for CorsLayer {
fn call(
&self,
req: Request,
next: BoxedNext,
) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
let origins = self.origins.clone();
let methods = self.methods_header_value();
let allow_headers = if self.headers.len() == 1
&& self
.headers
.first()
.map(|value| value == "*")
.unwrap_or(false)
{
req.headers()
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
.and_then(|value| value.to_str().ok())
.filter(|value| !value.trim().is_empty())
.map(str::to_string)
.unwrap_or_else(|| "*".to_string())
} else {
self.headers_header_value()
};
let credentials = self.credentials;
let max_age = self.max_age;
let is_any_origin = matches!(origins, AllowedOrigins::Any);
let origin = req
.headers()
.get(header::ORIGIN)
.and_then(|v| v.to_str().ok())
.map(String::from);
let is_preflight = req.method() == Method::OPTIONS
&& req
.headers()
.contains_key(header::ACCESS_CONTROL_REQUEST_METHOD);
let is_origin_allowed = origin
.as_ref()
.map(|o| match &origins {
AllowedOrigins::Any => true,
AllowedOrigins::List(list) => list.iter().any(|allowed| allowed == o),
})
.unwrap_or(false);
Box::pin(async move {
if is_preflight {
let mut response = http::Response::builder()
.status(StatusCode::NO_CONTENT)
.body(ResponseBody::Full(Full::new(Bytes::new())))
.unwrap();
let headers_mut = response.headers_mut();
if let Some(ref origin) = origin {
if is_origin_allowed {
if is_any_origin && !credentials {
headers_mut
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
} else {
headers_mut.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.parse().unwrap(),
);
}
}
}
headers_mut.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
methods.parse().unwrap(),
);
headers_mut.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
allow_headers.parse().unwrap(),
);
if credentials {
headers_mut.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
"true".parse().unwrap(),
);
}
if let Some(max_age) = max_age {
headers_mut.insert(
header::ACCESS_CONTROL_MAX_AGE,
max_age.as_secs().to_string().parse().unwrap(),
);
}
return response;
}
let mut response = next(req).await;
if let Some(ref origin) = origin {
if is_origin_allowed {
let headers_mut = response.headers_mut();
if is_any_origin && !credentials {
headers_mut
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
} else {
headers_mut
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
}
if credentials {
headers_mut.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
"true".parse().unwrap(),
);
}
headers_mut.insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
"Content-Length, Content-Type".parse().unwrap(),
);
}
}
response
})
}
fn clone_box(&self) -> Box<dyn MiddlewareLayer> {
Box::new(self.clone())
}
}