use super::{PolicyError, Unrestricted};
pub trait HttpPolicy: Send + Sync + 'static {
fn policy_name(&self) -> &'static str {
std::any::type_name::<Self>()
}
fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError>;
}
impl HttpPolicy for Unrestricted {
fn check_url(&self, _url: &str, _method: &str) -> Result<(), PolicyError> {
Ok(())
}
}
#[derive(Debug)]
pub struct HttpAllowList {
allowed_hosts: Vec<String>,
}
impl HttpAllowList {
pub fn new<I, S>(hosts: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
allowed_hosts: hosts.into_iter().map(Into::into).collect(),
}
}
}
impl HttpPolicy for HttpAllowList {
fn check_url(&self, url: &str, method: &str) -> Result<(), PolicyError> {
let host = extract_url_host(url).unwrap_or("");
if self
.allowed_hosts
.iter()
.any(|pattern| host_matches(host, pattern))
{
Ok(())
} else {
Err(PolicyError::new(format!(
"{method} denied: URL '{url}' does not match any allowed host"
)))
}
}
}
fn host_matches(host: &str, pattern: &str) -> bool {
host == pattern
|| (host.len() > pattern.len()
&& host.as_bytes()[host.len() - pattern.len() - 1] == b'.'
&& host.ends_with(pattern))
}
pub(super) fn extract_url_host(url: &str) -> Option<&str> {
let after_scheme = url.find("://").map(|i| i + 3)?;
let rest = &url[after_scheme..];
let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
let authority = &rest[..authority_end];
let host_start = authority.rfind('@').map(|i| i + 1).unwrap_or(0);
let host_part = &authority[host_start..];
if host_part.starts_with('[') {
host_part.find(']').map(|i| &host_part[1..i])
} else {
Some(host_part.split(':').next().unwrap_or(host_part))
}
}