use crate::middleware::{Middleware, Next};
use crate::{Request, Response, Result};
use http::{header, Method, StatusCode};
use regex::Regex;
use std::collections::HashSet;
use std::sync::Arc;
use tracing::debug;
#[derive(Clone)]
pub struct Cors {
inner: Arc<Inner>,
}
#[derive(Clone)]
struct Inner {
allowed_origins: AllowedOrigins,
allowed_methods: HashSet<Method>,
allowed_headers: HashSet<String>,
exposed_headers: Vec<String>,
max_age: Option<u64>,
allow_credentials: bool,
supports_credentials: bool,
allowed_origins_all: bool,
}
#[derive(Clone)]
enum AllowedOrigins {
Any,
Exact(HashSet<String>),
Regex(Regex),
}
impl Cors {
pub fn new() -> Self {
let mut methods = HashSet::new();
methods.extend(vec![
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::PATCH,
Method::OPTIONS,
Method::HEAD,
]);
let mut headers = HashSet::new();
headers.extend(vec![
"content-type".to_string(),
"authorization".to_string(),
"x-requested-with".to_string(),
]);
Self {
inner: Arc::new(Inner {
allowed_origins: AllowedOrigins::Any,
allowed_methods: methods,
allowed_headers: headers,
exposed_headers: Vec::new(),
max_age: Some(86400), allow_credentials: false,
supports_credentials: false,
allowed_origins_all: true,
}),
}
}
pub fn allow_any_origin(mut self) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.allowed_origins = AllowedOrigins::Any;
inner.allowed_origins_all = true;
self
}
pub fn allowed_origin(mut self, origin: &str) -> Self {
let inner = Arc::make_mut(&mut self.inner);
if let AllowedOrigins::Exact(ref mut origins) = inner.allowed_origins {
origins.insert(origin.to_string());
} else {
let mut origins = HashSet::new();
origins.insert(origin.to_string());
inner.allowed_origins = AllowedOrigins::Exact(origins);
inner.allowed_origins_all = false;
}
self
}
pub fn allowed_origins(mut self, origins: &[&str]) -> Self {
let inner = Arc::make_mut(&mut self.inner);
let origin_set: HashSet<String> = origins.iter().map(|&s| s.to_string()).collect();
inner.allowed_origins = AllowedOrigins::Exact(origin_set);
inner.allowed_origins_all = false;
self
}
pub fn allowed_origin_regex(mut self, pattern: &str) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.allowed_origins = AllowedOrigins::Regex(
Regex::new(pattern).expect("Invalid regex pattern for CORS origins"),
);
inner.allowed_origins_all = false;
self
}
pub fn allowed_methods(mut self, methods: &[Method]) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.allowed_methods = methods.iter().cloned().collect();
self
}
pub fn allowed_headers(mut self, headers: &[&str]) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.allowed_headers = headers.iter().map(|&s| s.to_lowercase()).collect();
self
}
pub fn expose_headers(mut self, headers: &[&str]) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.exposed_headers = headers.iter().map(|&s| s.to_string()).collect();
self
}
pub fn max_age(mut self, seconds: u64) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.max_age = Some(seconds);
self
}
pub fn allow_credentials(mut self) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.allow_credentials = true;
inner.supports_credentials = true;
self
}
pub fn disable_credentials(mut self) -> Self {
let inner = Arc::make_mut(&mut self.inner);
inner.allow_credentials = false;
inner.supports_credentials = false;
self
}
pub fn build(self) -> Result<CorsMiddleware> {
self.validate()?;
Ok(CorsMiddleware { cors: self })
}
fn validate(&self) -> Result<()> {
let inner = &self.inner;
if inner.supports_credentials {
match inner.allowed_origins {
AllowedOrigins::Any => {
return Err(crate::Error::Internal(
"Cannot use wildcard origin with credentials".into(),
));
}
AllowedOrigins::Exact(ref origins) if origins.contains("*") => {
return Err(crate::Error::Internal(
"Cannot use wildcard origin with credentials".into(),
));
}
_ => {}
}
}
debug!("CORS configuration validated successfully");
Ok(())
}
}
#[derive(Clone)]
pub struct CorsMiddleware {
cors: Cors,
}
impl CorsMiddleware {
pub fn new() -> Cors {
Cors::new()
}
pub fn permissive() -> Self {
Self {
cors: Cors::new().allow_any_origin().max_age(3600),
}
}
pub fn strict(origins: &[&str]) -> Self {
Self {
cors: Cors::new()
.allowed_origins(origins)
.allowed_methods(&[Method::GET, Method::POST, Method::PUT, Method::DELETE])
.max_age(3600),
}
}
}
#[async_trait::async_trait]
impl Middleware for CorsMiddleware {
async fn handle(&self, req: Request, next: Next) -> Response {
if req.method == Method::OPTIONS {
if self.is_preflight_request(&req) {
return self.handle_preflight(&req);
}
}
let mut res = next.run(req).await;
self.add_cors_headers(&mut res);
res
}
}
impl CorsMiddleware {
fn is_preflight_request(&self, req: &Request) -> bool {
req.header("access-control-request-method").is_some()
}
fn handle_preflight(&self, req: &Request) -> Response {
let origin = match req.header("origin") {
Some(origin) => origin,
None => {
return Response::new(StatusCode::BAD_REQUEST);
}
};
if !self.is_origin_allowed(origin) {
debug!("CORS preflight rejected: origin not allowed: {}", origin);
return Response::new(StatusCode::FORBIDDEN);
}
if let Some(request_method) = req.header("access-control-request-method") {
if let Ok(method) = Method::from_bytes(request_method.as_bytes()) {
if !self.cors.inner.allowed_methods.contains(&method) {
debug!(
"CORS preflight rejected: method not allowed: {}",
request_method
);
return Response::new(StatusCode::METHOD_NOT_ALLOWED);
}
}
}
if let Some(request_headers) = req.header("access-control-request-headers") {
for header in request_headers.split(',') {
let header = header.trim().to_lowercase();
if !self.cors.inner.allowed_headers.contains(&header) {
debug!("CORS preflight rejected: header not allowed: {}", header);
return Response::new(StatusCode::BAD_REQUEST);
}
}
}
let mut response = Response::new(StatusCode::NO_CONTENT);
self.add_preflight_headers(&mut response, origin);
debug!("CORS preflight request handled successfully");
response
}
fn is_origin_allowed(&self, origin: &str) -> bool {
if origin.is_empty() {
return false;
}
match &self.cors.inner.allowed_origins {
AllowedOrigins::Any => true,
AllowedOrigins::Exact(origins) => origins.contains(origin),
AllowedOrigins::Regex(regex) => regex.is_match(origin),
}
}
fn add_cors_headers(&self, response: &mut Response) {
let inner = &self.cors.inner;
if inner.allowed_origins_all {
response
.headers
.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
}
if !inner.exposed_headers.is_empty() {
response.headers.insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
inner.exposed_headers.join(", ").parse().unwrap(),
);
}
if inner.supports_credentials {
response.headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
"true".parse().unwrap(),
);
}
}
fn add_preflight_headers(&self, response: &mut Response, origin: &str) {
let inner = &self.cors.inner;
response.headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
if inner.allowed_origins_all {
"*".parse().unwrap()
} else {
origin.parse().unwrap()
},
);
let methods: Vec<String> = inner
.allowed_methods
.iter()
.map(|m| m.as_str().to_string())
.collect();
response.headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
methods.join(", ").parse().unwrap(),
);
if !inner.allowed_headers.is_empty() {
response.headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
inner
.allowed_headers
.iter()
.cloned()
.collect::<Vec<_>>()
.join(", ")
.parse()
.unwrap(),
);
}
if let Some(max_age) = inner.max_age {
response.headers.insert(
header::ACCESS_CONTROL_MAX_AGE,
max_age.to_string().parse().unwrap(),
);
}
if inner.supports_credentials {
response.headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
"true".parse().unwrap(),
);
}
}
}
impl Default for CorsMiddleware {
fn default() -> Self {
Self::permissive()
}
}