use std::fmt;
#[derive(Debug, Clone)]
pub struct DomainPattern {
pattern: String,
is_wildcard: bool,
base_domain: String,
}
impl DomainPattern {
pub fn new(pattern: &str) -> Self {
let is_wildcard = pattern.starts_with("*.");
let base_domain = if is_wildcard {
pattern[2..].to_lowercase()
} else {
pattern.to_lowercase()
};
Self {
pattern: pattern.to_string(),
is_wildcard,
base_domain,
}
}
pub fn matches(&self, host: &str) -> bool {
let host_lower = host.to_lowercase();
if self.is_wildcard {
host_lower == self.base_domain
|| host_lower.ends_with(&format!(".{}", self.base_domain))
} else {
host_lower == self.base_domain
}
}
pub fn pattern(&self) -> &str {
&self.pattern
}
}
impl fmt::Display for DomainPattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.pattern)
}
}
#[derive(Debug, Clone)]
pub enum DomainValidationResult {
Allowed,
Denied(String),
}
impl DomainValidationResult {
pub fn is_allowed(&self) -> bool {
matches!(self, DomainValidationResult::Allowed)
}
}
#[derive(Debug, Clone)]
pub struct DomainAllowlist {
patterns: Vec<DomainPattern>,
}
impl DomainAllowlist {
pub fn new(domains: &[String]) -> Self {
Self {
patterns: domains.iter().map(|d| DomainPattern::new(d)).collect(),
}
}
pub fn empty() -> Self {
Self { patterns: vec![] }
}
pub fn add(&mut self, pattern: &str) {
self.patterns.push(DomainPattern::new(pattern));
}
pub fn is_allowed(&self, host: &str) -> DomainValidationResult {
if self.patterns.is_empty() {
return DomainValidationResult::Denied("empty allowlist".to_string());
}
for pattern in &self.patterns {
if pattern.matches(host) {
return DomainValidationResult::Allowed;
}
}
DomainValidationResult::Denied(format!(
"host '{}' not in allowlist: [{}]",
host,
self.patterns
.iter()
.map(|p| p.pattern())
.collect::<Vec<_>>()
.join(", ")
))
}
pub fn patterns(&self) -> &[DomainPattern] {
&self.patterns
}
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
pub fn len(&self) -> usize {
self.patterns.len()
}
}
impl Default for DomainAllowlist {
fn default() -> Self {
Self::new(&crate::sandbox::config::default_allowlist())
}
}
pub fn extract_host(url: &str) -> Option<String> {
let rest = if let Some(stripped) = url.strip_prefix("https://") {
stripped
} else if let Some(stripped) = url.strip_prefix("http://") {
stripped
} else {
return None;
};
let host_end = rest.find('/').unwrap_or(rest.len());
let host_and_port = &rest[..host_end];
let host = if let Some(bracket_idx) = host_and_port.find('[') {
let close_bracket = host_and_port.find(']')?;
&host_and_port[bracket_idx + 1..close_bracket]
} else if let Some(colon_idx) = host_and_port.rfind(':') {
let after_colon = &host_and_port[colon_idx + 1..];
if after_colon.chars().all(|c| c.is_ascii_digit()) {
&host_and_port[..colon_idx]
} else {
host_and_port
}
} else {
host_and_port
};
if host.is_empty() {
None
} else {
Some(host.to_lowercase())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match() {
let pattern = DomainPattern::new("api.example.com");
assert!(pattern.matches("api.example.com"));
assert!(pattern.matches("API.EXAMPLE.COM"));
assert!(!pattern.matches("foo.api.example.com"));
assert!(!pattern.matches("example.com"));
}
#[test]
fn test_wildcard_match() {
let pattern = DomainPattern::new("*.example.com");
assert!(pattern.matches("api.example.com"));
assert!(pattern.matches("foo.bar.example.com"));
assert!(pattern.matches("example.com")); assert!(!pattern.matches("exampleXcom"));
assert!(!pattern.matches("other.com"));
}
#[test]
fn test_allowlist_allows() {
let allowlist =
DomainAllowlist::new(&["crates.io".to_string(), "*.github.com".to_string()]);
assert!(allowlist.is_allowed("crates.io").is_allowed());
assert!(allowlist.is_allowed("api.github.com").is_allowed());
assert!(
!allowlist
.is_allowed("raw.githubusercontent.com")
.is_allowed()
);
}
#[test]
fn test_allowlist_denies() {
let allowlist = DomainAllowlist::new(&["crates.io".to_string()]);
let result = allowlist.is_allowed("evil.com");
assert!(!result.is_allowed());
}
#[test]
fn test_empty_allowlist() {
let allowlist = DomainAllowlist::empty();
assert!(!allowlist.is_allowed("anything.com").is_allowed());
}
#[test]
fn test_extract_host() {
assert_eq!(
extract_host("https://api.example.com/v1/endpoint"),
Some("api.example.com".to_string())
);
assert_eq!(
extract_host("http://localhost:8080/api"),
Some("localhost".to_string())
);
assert_eq!(
extract_host("https://EXAMPLE.COM"),
Some("example.com".to_string())
);
assert_eq!(extract_host("not-a-url"), None);
}
}