use std::net::IpAddr;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FilterResult {
Allow,
DenyHost {
host: String,
},
DenyLinkLocal {
ip: IpAddr,
},
DenyNotAllowed {
host: String,
},
}
impl FilterResult {
#[must_use]
pub fn is_allowed(&self) -> bool {
matches!(self, FilterResult::Allow)
}
#[must_use]
pub fn reason(&self) -> String {
match self {
FilterResult::Allow => "allowed by host filter".to_string(),
FilterResult::DenyHost { host } => {
format!("host {} is in the deny list", host)
}
FilterResult::DenyLinkLocal { ip } => {
format!(
"resolved IP {} is in the link-local range (cloud metadata protection)",
ip
)
}
FilterResult::DenyNotAllowed { host } => {
format!("host {} is not in the allowlist", host)
}
}
}
}
fn is_link_local(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => v4.octets()[0] == 169 && v4.octets()[1] == 254,
IpAddr::V6(v6) => {
if (v6.segments()[0] & 0xffc0) == 0xfe80 {
return true;
}
if let Some(v4) = v6.to_ipv4_mapped() {
return v4.octets()[0] == 169 && v4.octets()[1] == 254;
}
false
}
}
}
const DENY_HOSTS: &[&str] = &[
"169.254.169.254",
"metadata.google.internal",
"metadata.azure.internal",
];
#[derive(Debug, Clone)]
pub struct HostFilter {
allowed_hosts: Vec<String>,
allowed_suffixes: Vec<String>,
deny_hosts: Vec<String>,
}
impl HostFilter {
#[must_use]
pub fn new(allowed_hosts: &[String]) -> Self {
let mut exact = Vec::new();
let mut suffixes = Vec::new();
for host in allowed_hosts {
let lower = host.to_lowercase();
if let Some(suffix) = lower.strip_prefix('*') {
suffixes.push(suffix.to_string());
} else {
exact.push(lower);
}
}
Self {
allowed_hosts: exact,
allowed_suffixes: suffixes,
deny_hosts: DENY_HOSTS.iter().map(|s| s.to_lowercase()).collect(),
}
}
#[must_use]
pub fn allow_all() -> Self {
Self {
allowed_hosts: Vec::new(),
allowed_suffixes: Vec::new(),
deny_hosts: DENY_HOSTS.iter().map(|s| s.to_lowercase()).collect(),
}
}
#[must_use]
pub fn check_host(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
let lower_host = host.to_lowercase();
if self.deny_hosts.contains(&lower_host) {
return FilterResult::DenyHost {
host: host.to_string(),
};
}
for ip in resolved_ips {
if is_link_local(ip) {
return FilterResult::DenyLinkLocal { ip: *ip };
}
}
if self.allowed_hosts.is_empty() && self.allowed_suffixes.is_empty() {
return FilterResult::Allow;
}
if self.allowed_hosts.contains(&lower_host) {
return FilterResult::Allow;
}
for suffix in &self.allowed_suffixes {
if lower_host.ends_with(suffix.as_str()) && lower_host.len() > suffix.len() {
return FilterResult::Allow;
}
}
FilterResult::DenyNotAllowed {
host: host.to_string(),
}
}
#[must_use]
pub fn allowed_count(&self) -> usize {
self.allowed_hosts
.len()
.saturating_add(self.allowed_suffixes.len())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
fn public_ip() -> Vec<IpAddr> {
vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))]
}
#[test]
fn test_exact_host_allowed() {
let filter = HostFilter::new(&["api.openai.com".to_string()]);
let result = filter.check_host("api.openai.com", &public_ip());
assert!(result.is_allowed());
}
#[test]
fn test_exact_host_case_insensitive() {
let filter = HostFilter::new(&["API.OpenAI.COM".to_string()]);
let result = filter.check_host("api.openai.com", &public_ip());
assert!(result.is_allowed());
}
#[test]
fn test_host_not_in_allowlist() {
let filter = HostFilter::new(&["api.openai.com".to_string()]);
let result = filter.check_host("evil.com", &public_ip());
assert!(!result.is_allowed());
assert!(matches!(result, FilterResult::DenyNotAllowed { .. }));
}
#[test]
fn test_wildcard_subdomain_match() {
let filter = HostFilter::new(&["*.googleapis.com".to_string()]);
let result = filter.check_host("storage.googleapis.com", &public_ip());
assert!(result.is_allowed());
let result = filter.check_host("us-central1-aiplatform.googleapis.com", &public_ip());
assert!(result.is_allowed());
}
#[test]
fn test_wildcard_does_not_match_bare_domain() {
let filter = HostFilter::new(&["*.googleapis.com".to_string()]);
let result = filter.check_host("googleapis.com", &public_ip());
assert!(!result.is_allowed());
}
#[test]
fn test_deny_cloud_metadata_hostname() {
let filter = HostFilter::new(&["169.254.169.254".to_string()]);
let result = filter.check_host("169.254.169.254", &public_ip());
assert!(!result.is_allowed());
assert!(matches!(result, FilterResult::DenyHost { .. }));
}
#[test]
fn test_deny_google_metadata() {
let filter = HostFilter::new(&["metadata.google.internal".to_string()]);
let result = filter.check_host("metadata.google.internal", &public_ip());
assert!(!result.is_allowed());
}
#[test]
fn test_allow_all_mode() {
let filter = HostFilter::allow_all();
let result = filter.check_host("any-host.example.com", &public_ip());
assert!(result.is_allowed());
}
#[test]
fn test_allow_all_allows_private_networks() {
let filter = HostFilter::allow_all();
let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
let result = filter.check_host("internal.corp.com", &private_ip);
assert!(result.is_allowed());
}
#[test]
fn test_allow_all_allows_192_168() {
let filter = HostFilter::allow_all();
let private_ip = vec![IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))];
let result = filter.check_host("nas.local", &private_ip);
assert!(result.is_allowed());
}
#[test]
fn test_deny_link_local_ipv4() {
let filter = HostFilter::new(&["*.example.com".to_string()]);
let link_local = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1))];
let result = filter.check_host("api.example.com", &link_local);
assert!(!result.is_allowed());
assert!(matches!(result, FilterResult::DenyLinkLocal { .. }));
}
#[test]
fn test_deny_link_local_ipv6() {
let filter = HostFilter::new(&["*.example.com".to_string()]);
let link_local = vec![IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1))];
let result = filter.check_host("api.example.com", &link_local);
assert!(!result.is_allowed());
assert!(matches!(result, FilterResult::DenyLinkLocal { .. }));
}
#[test]
fn test_deny_ipv4_mapped_ipv6_link_local() {
let filter = HostFilter::new(&["attacker.com".to_string()]);
let mapped = vec![IpAddr::V6(Ipv6Addr::new(
0, 0, 0, 0, 0, 0xffff, 0xa9fe, 0xa9fe,
))];
let result = filter.check_host("attacker.com", &mapped);
assert!(!result.is_allowed());
assert!(matches!(result, FilterResult::DenyLinkLocal { .. }));
}
#[test]
fn test_deny_ipv4_mapped_ipv6_other_link_local() {
let filter = HostFilter::allow_all();
let mapped = vec![IpAddr::V6(Ipv6Addr::new(
0, 0, 0, 0, 0, 0xffff, 0xa9fe, 0x0001,
))];
let result = filter.check_host("evil.com", &mapped);
assert!(!result.is_allowed());
}
#[test]
fn test_ipv4_mapped_ipv6_non_link_local_allowed() {
let filter = HostFilter::allow_all();
let mapped = vec![IpAddr::V6(Ipv6Addr::new(
0, 0, 0, 0, 0, 0xffff, 0x6812, 0x0760,
))];
let result = filter.check_host("example.com", &mapped);
assert!(result.is_allowed());
}
#[test]
fn test_dns_rebinding_to_metadata_ip() {
let filter = HostFilter::new(&["attacker.com".to_string()]);
let metadata_ip = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))];
let result = filter.check_host("attacker.com", &metadata_ip);
assert!(!result.is_allowed());
assert!(matches!(result, FilterResult::DenyLinkLocal { .. }));
}
#[test]
fn test_dns_rebinding_allow_all_blocked() {
let filter = HostFilter::allow_all();
let metadata_ip = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))];
let result = filter.check_host("evil.com", &metadata_ip);
assert!(!result.is_allowed());
}
#[test]
fn test_empty_resolved_ips_skips_link_local_check() {
let filter = HostFilter::new(&["api.openai.com".to_string()]);
let result = filter.check_host("api.openai.com", &[]);
assert!(result.is_allowed());
}
#[test]
fn test_multiple_ips_any_link_local_denied() {
let filter = HostFilter::new(&["multi.example.com".to_string()]);
let ips = vec![
IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96)),
IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1)),
];
let result = filter.check_host("multi.example.com", &ips);
assert!(!result.is_allowed());
}
#[test]
fn test_allowed_count() {
let filter = HostFilter::new(&[
"api.openai.com".to_string(),
"*.googleapis.com".to_string(),
"github.com".to_string(),
]);
assert_eq!(filter.allowed_count(), 3);
}
#[test]
fn test_filter_result_reason() {
let allow = FilterResult::Allow;
assert!(allow.reason().contains("allowed"));
let deny = FilterResult::DenyNotAllowed {
host: "evil.com".to_string(),
};
assert!(deny.reason().contains("evil.com"));
let link_local = FilterResult::DenyLinkLocal {
ip: IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)),
};
assert!(link_local.reason().contains("link-local"));
}
}