use thiserror::Error;
use url::Url;
#[derive(Debug, Clone, Error)]
pub enum GuardError {
#[error("invalid URL: {0}")]
InvalidUrl(String),
#[error("URL has no host: {0}")]
NoHost(String),
#[error("host not on allowlist: {host}")]
HostNotAllowed {
host: String,
},
#[error("scheme not allowed: {0}")]
SchemeNotAllowed(String),
}
#[derive(Debug, Clone)]
enum Rule {
Exact(String),
SubdomainsOf(String),
}
#[derive(Debug, Clone, Default)]
pub struct Allowlist {
rules: Vec<Rule>,
allowed_schemes: Vec<String>,
}
impl Allowlist {
pub fn new() -> Self {
Self {
rules: Vec::new(),
allowed_schemes: vec!["http".into(), "https".into()],
}
}
pub fn domain(mut self, host: impl Into<String>) -> Self {
self.rules.push(Rule::Exact(host.into().to_lowercase()));
self
}
pub fn subdomains_of(mut self, apex: impl Into<String>) -> Self {
self.rules
.push(Rule::SubdomainsOf(apex.into().to_lowercase()));
self
}
pub fn allow_schemes<I, S>(mut self, schemes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_schemes = schemes.into_iter().map(|s| s.into().to_lowercase()).collect();
self
}
pub fn check(&self, url: &str) -> Result<(), GuardError> {
let parsed = Url::parse(url).map_err(|_| GuardError::InvalidUrl(url.to_string()))?;
let scheme = parsed.scheme().to_lowercase();
if !self.allowed_schemes.contains(&scheme) {
return Err(GuardError::SchemeNotAllowed(scheme));
}
let host = parsed
.host_str()
.ok_or_else(|| GuardError::NoHost(url.to_string()))?
.to_lowercase();
if self.is_allowed_host(&host) {
Ok(())
} else {
Err(GuardError::HostNotAllowed { host })
}
}
pub fn is_allowed_host(&self, host: &str) -> bool {
let host = host.to_lowercase();
for rule in &self.rules {
match rule {
Rule::Exact(d) => {
if host == *d {
return true;
}
}
Rule::SubdomainsOf(apex) => {
if host == *apex
|| host.ends_with(&format!(".{apex}"))
{
return true;
}
}
}
}
false
}
}