use super::{AllowedOrigins, Config, CorsMiddleware};
use http::{
header::{self, HeaderName, HeaderValue},
Method,
};
use std::{collections::HashSet, time::Duration};
#[derive(Debug, Default, Clone)]
pub struct CorsBuilder {
allowed_methods: HashSet<Method>,
allowed_origins: AllowedOrigins,
allowed_headers: HashSet<HeaderName>,
allow_credentials: bool,
exposed_headers: HashSet<HeaderName>,
max_age: Option<Duration>,
prefer_wildcard: bool,
}
impl CorsBuilder {
pub fn new() -> CorsBuilder {
Default::default()
}
pub fn allow_origins(mut self, origins: AllowedOrigins) -> Self {
self.allowed_origins = origins;
self
}
pub fn allow_methods<I>(mut self, methods: I) -> Self
where
I: IntoIterator,
I::Item: Into<Method>,
{
self.allowed_methods
.extend(methods.into_iter().map(Into::into));
self
}
pub fn allow_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator,
I::Item: Into<HeaderName>,
{
self.allowed_headers
.extend(headers.into_iter().map(Into::into));
self
}
pub fn allow_credentials(mut self, allow_credentials: bool) -> Self {
self.allow_credentials = allow_credentials;
self
}
pub fn expose_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator,
I::Item: Into<HeaderName>,
{
self.exposed_headers
.extend(headers.into_iter().map(Into::into));
self
}
pub fn max_age(mut self, max_age: Duration) -> Self {
self.max_age = Some(max_age);
self
}
pub fn prefer_wildcard(mut self, prefer_wildcard: bool) -> Self {
self.prefer_wildcard = prefer_wildcard;
self
}
pub(super) fn into_config(self) -> Config {
let Self {
allow_credentials,
allowed_headers,
allowed_methods,
allowed_origins,
exposed_headers,
max_age,
prefer_wildcard,
} = self;
let allowed_headers_header =
join_header_value(&allowed_headers).expect("Invalid allowed headers");
let allowed_methods_header =
join_header_value(&allowed_methods).expect("Invalid allowed methods");
let exposed_headers_header = if exposed_headers.is_empty() {
None
} else {
Some(join_header_value(&exposed_headers).expect("Invalid exposed headers"))
};
let max_age = max_age.map(|v| HeaderValue::from(v.as_secs()));
let vary_header = join_header_value(&[
header::ORIGIN,
header::ACCESS_CONTROL_REQUEST_METHOD,
header::ACCESS_CONTROL_REQUEST_HEADERS,
]).expect("Invalid vary");
Config {
allow_credentials,
allowed_headers,
allowed_headers_header,
allowed_methods,
allowed_methods_header,
allowed_origins,
exposed_headers_header,
max_age,
prefer_wildcard,
vary_header,
}
}
pub fn build(self) -> CorsMiddleware {
CorsMiddleware::new(self.into_config())
}
}
fn join_header_value<I>(values: I) -> Result<HeaderValue, header::InvalidHeaderValue>
where
I: IntoIterator,
I::Item: AsRef<str>,
{
let mut values = values.into_iter();
let mut value = Vec::new();
if let Some(v) = values.next() {
value.extend(v.as_ref().as_bytes());
}
for v in values {
value.push(b',');
value.extend(v.as_ref().as_bytes());
}
HeaderValue::from_bytes(&value)
}