use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use ahash::HashSet;
use rama_net::uri::Uri;
use super::ConfigError;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub(crate) struct CsrfOrigin {
secure: bool,
host: Box<str>,
port: u16,
}
impl CsrfOrigin {
pub(crate) fn from_parts(scheme: &str, host: &str, port: Option<u16>) -> Option<Self> {
let secure = if scheme.eq_ignore_ascii_case("https") {
true
} else if scheme.eq_ignore_ascii_case("http") {
false
} else {
return None;
};
Some(Self {
secure,
host: host.to_ascii_lowercase().into_boxed_str(),
port: port.unwrap_or(default_port(secure)),
})
}
pub(crate) fn matches_host(&self, host: &str, port: Option<u16>) -> bool {
self.host.as_ref().eq_ignore_ascii_case(host)
&& self.port == port.unwrap_or(default_port(self.secure))
}
}
const fn default_port(secure: bool) -> u16 {
if secure { 443 } else { 80 }
}
pub(crate) fn parse_trusted_origin(input: &str) -> Result<CsrfOrigin, ConfigError> {
let uri = Uri::parse_canonical(input).map_err(|err| ConfigError::InvalidOrigin {
origin: input.into(),
message: err.to_string().into_boxed_str(),
})?;
let has_path = uri
.path()
.is_some_and(|path| !path.is_empty() && path != "/");
if uri.userinfo().is_some() || uri.query().is_some() || uri.fragment().is_some() || has_path {
return Err(ConfigError::InvalidOriginComponents {
origin: input.into(),
});
}
let scheme = uri
.scheme()
.map(|scheme| scheme.as_str())
.unwrap_or_default();
let host = uri.host().map(|host| host.to_string());
match host {
Some(host) => CsrfOrigin::from_parts(scheme, &host, uri.port_u16()).ok_or_else(|| {
ConfigError::OpaqueOrigin {
origin: input.into(),
}
}),
None => Err(ConfigError::OpaqueOrigin {
origin: input.into(),
}),
}
}
#[derive(Clone, Default)]
pub(crate) struct Origins(Arc<HashSet<CsrfOrigin>>);
impl Origins {
pub(crate) fn contains(&self, origin: &CsrfOrigin) -> bool {
self.0.contains(origin)
}
pub(crate) fn insert(&mut self, origin: CsrfOrigin) {
Arc::make_mut(&mut self.0).insert(origin);
}
}
impl Debug for Origins {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_set().entries(self.0.iter()).finish()
}
}