#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
use std::fmt;
use std::sync::Arc;
use anyhow::Result;
use http::HeaderName;
use http::HeaderValue;
use http::Method;
use http::StatusCode;
use http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS;
use http::header::ACCESS_CONTROL_ALLOW_HEADERS;
use http::header::ACCESS_CONTROL_ALLOW_METHODS;
use http::header::ACCESS_CONTROL_ALLOW_ORIGIN;
use http::header::ACCESS_CONTROL_MAX_AGE;
use http::header::ACCESS_CONTROL_REQUEST_HEADERS;
use http::header::ACCESS_CONTROL_REQUEST_METHOD;
use http::header::ORIGIN;
use http::header::VARY;
use tako_rs_core::body::TakoBody;
use tako_rs_core::middleware::Next;
use tako_rs_core::plugins::TakoPlugin;
use tako_rs_core::responder::Responder;
use tako_rs_core::router::Router;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;
#[derive(Clone)]
pub enum OriginMatcher {
Exact(String),
Suffix(String),
Custom(Arc<dyn Fn(&str) -> bool + Send + Sync + 'static>),
}
impl OriginMatcher {
fn matches(&self, origin: &str) -> bool {
match self {
Self::Exact(s) => s == origin,
Self::Suffix(s) => {
let host = url::Url::parse(origin)
.ok()
.and_then(|u| u.host_str().map(str::to_owned))
.unwrap_or_default();
if host.is_empty() {
return false;
}
host == *s.as_str() || host.ends_with(&format!(".{s}"))
}
Self::Custom(f) => f(origin),
}
}
}
impl<S: Into<String>> From<S> for OriginMatcher {
fn from(value: S) -> Self {
Self::Exact(value.into())
}
}
#[derive(Clone)]
pub struct Config {
pub origins: Vec<String>,
pub origin_matchers: Vec<OriginMatcher>,
pub methods: Vec<Method>,
pub headers: Vec<HeaderName>,
pub allow_credentials: bool,
pub max_age_secs: Option<u32>,
pub allow_private_network: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
origins: Vec::new(),
origin_matchers: Vec::new(),
methods: vec![
Method::GET,
Method::POST,
Method::PUT,
Method::PATCH,
Method::DELETE,
Method::OPTIONS,
],
headers: Vec::new(),
allow_credentials: false,
max_age_secs: Some(3600),
allow_private_network: false,
}
}
}
impl Config {
pub fn validate(&self) -> Result<(), CorsConfigError> {
if self.allow_credentials && self.origins.is_empty() && self.origin_matchers.is_empty() {
return Err(CorsConfigError::CredentialsWithWildcardOrigin);
}
Ok(())
}
fn origin_allowed(&self, origin: &str) -> bool {
self.origins.iter().any(|p| p == origin)
|| self.origin_matchers.iter().any(|m| m.matches(origin))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CorsConfigError {
CredentialsWithWildcardOrigin,
}
impl fmt::Display for CorsConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CredentialsWithWildcardOrigin => f.write_str(
"CORS misconfiguration: allow_credentials = true requires at least one explicit \
allowed origin; reflecting `*` together with credentials is rejected by browsers",
),
}
}
}
impl std::error::Error for CorsConfigError {}
#[must_use]
pub struct CorsBuilder(Config);
impl Default for CorsBuilder {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl CorsBuilder {
#[inline]
pub fn new() -> Self {
Self(Config::default())
}
#[inline]
pub fn allow_origin(mut self, o: impl Into<String>) -> Self {
self.0.origins.push(o.into());
self
}
#[inline]
pub fn allow_methods(mut self, m: &[Method]) -> Self {
self.0.methods = m.to_vec();
self
}
#[inline]
pub fn allow_headers(mut self, h: &[HeaderName]) -> Self {
self.0.headers = h.to_vec();
self
}
#[inline]
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.0.allow_credentials = allow;
self
}
#[inline]
pub fn max_age_secs(mut self, secs: u32) -> Self {
self.0.max_age_secs = Some(secs);
self
}
#[inline]
pub fn allow_origin_suffix(mut self, suffix: impl Into<String>) -> Self {
self
.0
.origin_matchers
.push(OriginMatcher::Suffix(suffix.into()));
self
}
#[inline]
pub fn allow_origin_predicate<F>(mut self, f: F) -> Self
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
self
.0
.origin_matchers
.push(OriginMatcher::Custom(Arc::new(f)));
self
}
#[inline]
pub fn allow_private_network(mut self, yes: bool) -> Self {
self.0.allow_private_network = yes;
self
}
#[inline]
pub fn build(self) -> CorsPlugin {
self.try_build().expect("invalid CORS configuration")
}
#[inline]
pub fn try_build(self) -> Result<CorsPlugin, CorsConfigError> {
self.0.validate()?;
Ok(CorsPlugin { cfg: self.0 })
}
}
#[derive(Clone)]
#[doc(alias = "cors")]
pub struct CorsPlugin {
cfg: Config,
}
impl Default for CorsPlugin {
fn default() -> Self {
Self {
cfg: Config::default(),
}
}
}
impl TakoPlugin for CorsPlugin {
fn name(&self) -> &'static str {
"CorsPlugin"
}
fn setup(&self, router: &Router) -> Result<()> {
let cfg = self.cfg.clone();
router.middleware(move |req, next| {
let cfg = cfg.clone();
async move { handle_cors(req, next, cfg).await }
});
Ok(())
}
}
async fn handle_cors(req: Request, next: Next, cfg: Config) -> impl Responder {
let origin = req.headers().get(ORIGIN).cloned();
let request_headers = req.headers().get(ACCESS_CONTROL_REQUEST_HEADERS).cloned();
let pna_request = req
.headers()
.get("access-control-request-private-network")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.eq_ignore_ascii_case("true"));
let is_preflight = req.method() == Method::OPTIONS
&& origin.is_some()
&& req.headers().contains_key(ACCESS_CONTROL_REQUEST_METHOD);
if is_preflight {
let mut resp = http::Response::builder()
.status(StatusCode::NO_CONTENT)
.body(TakoBody::empty())
.expect("valid CORS preflight response");
add_cors_headers(
&cfg,
origin,
request_headers.as_ref(),
pna_request,
&mut resp,
);
return resp.into_response();
}
let mut resp = next.run(req).await;
add_cors_headers(&cfg, origin, request_headers.as_ref(), false, &mut resp);
resp.into_response()
}
fn add_cors_headers(
cfg: &Config,
origin: Option<HeaderValue>,
request_headers: Option<&HeaderValue>,
pna_request: bool,
resp: &mut Response,
) {
let allow_anything = cfg.origins.is_empty() && cfg.origin_matchers.is_empty();
let (allow_origin, mirrored_origin) = if allow_anything {
("*".to_string(), false)
} else if let Some(o) = &origin {
let Ok(s) = o.to_str() else {
return;
};
if cfg.origin_allowed(s) {
(s.to_string(), true)
} else {
return;
}
} else {
return;
};
let Ok(value) = HeaderValue::from_str(&allow_origin) else {
return;
};
resp
.headers_mut()
.insert(ACCESS_CONTROL_ALLOW_ORIGIN, value);
if mirrored_origin {
resp
.headers_mut()
.append(VARY, HeaderValue::from_static("Origin"));
}
let methods = if cfg.methods.is_empty() {
None
} else {
Some(
cfg
.methods
.iter()
.map(http::Method::as_str)
.collect::<Vec<_>>()
.join(","),
)
};
if let Some(v) = methods
&& let Ok(hv) = HeaderValue::from_str(&v)
{
resp.headers_mut().insert(ACCESS_CONTROL_ALLOW_METHODS, hv);
}
if cfg.headers.is_empty() {
if cfg.allow_credentials {
static WARNED: std::sync::OnceLock<()> = std::sync::OnceLock::new();
let () = WARNED.get_or_init(|| {
tracing::warn!(
"CORS reflects `Access-Control-Request-Headers` while `allow_credentials=true` and no explicit `headers(...)` list is configured — set an explicit allow-list to harden the preflight policy",
);
});
if let Some(req_h) = request_headers {
resp
.headers_mut()
.insert(ACCESS_CONTROL_ALLOW_HEADERS, req_h.clone());
resp.headers_mut().append(
VARY,
HeaderValue::from_static("Access-Control-Request-Headers"),
);
}
} else {
resp
.headers_mut()
.insert(ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_static("*"));
}
} else {
let h = cfg
.headers
.iter()
.map(http::HeaderName::as_str)
.collect::<Vec<_>>()
.join(",");
if let Ok(hv) = HeaderValue::from_str(&h) {
resp.headers_mut().insert(ACCESS_CONTROL_ALLOW_HEADERS, hv);
}
}
if cfg.allow_credentials {
resp.headers_mut().insert(
ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if let Some(secs) = cfg.max_age_secs
&& let Ok(hv) = HeaderValue::from_str(&secs.to_string())
{
resp.headers_mut().insert(ACCESS_CONTROL_MAX_AGE, hv);
}
if cfg.allow_private_network && pna_request {
resp.headers_mut().insert(
"access-control-allow-private-network",
HeaderValue::from_static("true"),
);
}
}