#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
use anyhow::Result;
use http::HeaderName;
use http::HeaderValue;
use http::Method;
use http::StatusCode;
use http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS;
use http::header::ACCESS_CONTROL_ALLOW_HEADERS;
use http::header::ACCESS_CONTROL_ALLOW_METHODS;
use http::header::ACCESS_CONTROL_ALLOW_ORIGIN;
use http::header::ACCESS_CONTROL_MAX_AGE;
use http::header::ORIGIN;
use crate::body::TakoBody;
use crate::middleware::Next;
use crate::plugins::TakoPlugin;
use crate::responder::Responder;
use crate::router::Router;
use crate::types::Request;
use crate::types::Response;
#[derive(Clone)]
pub struct Config {
pub origins: Vec<String>,
pub methods: Vec<Method>,
pub headers: Vec<HeaderName>,
pub allow_credentials: bool,
pub max_age_secs: Option<u32>,
}
impl Default for Config {
fn default() -> Self {
Self {
origins: Vec::new(),
methods: vec![
Method::GET,
Method::POST,
Method::PUT,
Method::PATCH,
Method::DELETE,
Method::OPTIONS,
],
headers: Vec::new(),
allow_credentials: false,
max_age_secs: Some(3600),
}
}
}
#[must_use]
pub struct CorsBuilder(Config);
impl Default for CorsBuilder {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl CorsBuilder {
#[inline]
#[must_use]
pub fn new() -> Self {
Self(Config::default())
}
#[inline]
#[must_use]
pub fn allow_origin(mut self, o: impl Into<String>) -> Self {
self.0.origins.push(o.into());
self
}
#[inline]
#[must_use]
pub fn allow_methods(mut self, m: &[Method]) -> Self {
self.0.methods = m.to_vec();
self
}
#[inline]
#[must_use]
pub fn allow_headers(mut self, h: &[HeaderName]) -> Self {
self.0.headers = h.to_vec();
self
}
#[inline]
#[must_use]
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.0.allow_credentials = allow;
self
}
#[inline]
#[must_use]
pub fn max_age_secs(mut self, secs: u32) -> Self {
self.0.max_age_secs = Some(secs);
self
}
#[inline]
#[must_use]
pub fn build(self) -> CorsPlugin {
CorsPlugin { cfg: self.0 }
}
}
#[derive(Clone)]
#[doc(alias = "cors")]
pub struct CorsPlugin {
cfg: Config,
}
impl Default for CorsPlugin {
fn default() -> Self {
Self {
cfg: Config::default(),
}
}
}
impl TakoPlugin for CorsPlugin {
fn name(&self) -> &'static str {
"CorsPlugin"
}
fn setup(&self, router: &Router) -> Result<()> {
let cfg = self.cfg.clone();
router.middleware(move |req, next| {
let cfg = cfg.clone();
async move { handle_cors(req, next, cfg).await }
});
Ok(())
}
}
async fn handle_cors(req: Request, next: Next, cfg: Config) -> impl Responder {
let origin = req.headers().get(ORIGIN).cloned();
if req.method() == Method::OPTIONS {
let mut resp = http::Response::builder()
.status(StatusCode::NO_CONTENT)
.body(TakoBody::empty())
.expect("valid CORS preflight response");
add_cors_headers(&cfg, origin, &mut resp);
return resp.into_response();
}
let mut resp = next.run(req).await;
add_cors_headers(&cfg, origin, &mut resp);
resp.into_response()
}
fn add_cors_headers(cfg: &Config, origin: Option<HeaderValue>, resp: &mut Response) {
let allow_origin = if cfg.origins.is_empty() {
"*".to_string()
} else if let Some(o) = &origin {
let s = o.to_str().unwrap_or_default();
if cfg.origins.iter().any(|p| p == s) {
s.to_string()
} else {
return; }
} else {
return; };
resp.headers_mut().insert(
ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str(&allow_origin).expect("valid origin header value"),
);
let methods = if cfg.methods.is_empty() {
None
} else {
Some(
cfg
.methods
.iter()
.map(|m| m.as_str())
.collect::<Vec<_>>()
.join(","),
)
};
if let Some(v) = methods {
resp.headers_mut().insert(
ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_str(&v).expect("valid methods header value"),
);
}
if cfg.headers.is_empty() {
resp
.headers_mut()
.insert(ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_static("*"));
} else {
let h = cfg
.headers
.iter()
.map(|h| h.as_str())
.collect::<Vec<_>>()
.join(",");
resp.headers_mut().insert(
ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_str(&h).expect("valid headers header value"),
);
}
if cfg.allow_credentials {
resp.headers_mut().insert(
ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if let Some(secs) = cfg.max_age_secs {
resp.headers_mut().insert(
ACCESS_CONTROL_MAX_AGE,
HeaderValue::from_str(&secs.to_string()).expect("valid max-age header value"),
);
}
}