use std::collections::HashSet;
use std::net::IpAddr;
use url::Url;
pub fn is_private_ip(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64 }
IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xfe00) == 0xfc00 || (v6.segments()[0] & 0xffc0) == 0xfe80 }
}
}
fn redact_url(url: &str) -> String {
match Url::parse(url) {
Ok(mut parsed) => {
if !parsed.username().is_empty() || parsed.password().is_some() {
let _ = parsed.set_username("***");
let _ = parsed.set_password(None);
}
parsed.to_string()
}
Err(_) => "[invalid URL]".to_string(),
}
}
#[derive(Debug, Clone)]
pub struct NetworkAllowlist {
patterns: HashSet<String>,
allow_all: bool,
block_private_ips: bool,
}
impl Default for NetworkAllowlist {
fn default() -> Self {
Self {
patterns: HashSet::new(),
allow_all: false,
block_private_ips: true,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum UrlMatch {
Allowed,
Blocked { reason: String },
Invalid { reason: String },
}
impl NetworkAllowlist {
pub fn new() -> Self {
Self::default()
}
pub fn allow_all() -> Self {
Self {
patterns: HashSet::new(),
allow_all: true,
block_private_ips: true,
}
}
pub fn block_private_ips(mut self, block: bool) -> Self {
self.block_private_ips = block;
self
}
pub fn is_blocking_private_ips(&self) -> bool {
self.block_private_ips
}
pub fn allow(mut self, pattern: impl Into<String>) -> Self {
self.patterns.insert(pattern.into());
self
}
pub fn allow_many(mut self, patterns: impl IntoIterator<Item = impl Into<String>>) -> Self {
for pattern in patterns {
self.patterns.insert(pattern.into());
}
self
}
pub fn check(&self, url: &str) -> UrlMatch {
if self.allow_all {
return UrlMatch::Allowed;
}
if self.patterns.is_empty() {
return UrlMatch::Blocked {
reason: "no URLs are allowed (empty allowlist)".to_string(),
};
}
let parsed = match Url::parse(url) {
Ok(u) => u,
Err(e) => {
return UrlMatch::Invalid {
reason: format!("invalid URL: {}", e),
};
}
};
if self.block_private_ips
&& let Some(host) = parsed.host_str()
&& let Ok(ip) = host.parse::<IpAddr>()
&& is_private_ip(&ip)
{
return UrlMatch::Blocked {
reason: format!(
"request to private/reserved IP {} blocked (SSRF protection)",
ip
),
};
}
for pattern in &self.patterns {
if self.matches_pattern(&parsed, pattern) {
return UrlMatch::Allowed;
}
}
UrlMatch::Blocked {
reason: format!("URL not in allowlist: {}", redact_url(url)),
}
}
fn matches_pattern(&self, url: &Url, pattern: &str) -> bool {
let pattern_url = match Url::parse(pattern) {
Ok(u) => u,
Err(_) => return false,
};
if url.scheme() != pattern_url.scheme() {
return false;
}
match (url.host_str(), pattern_url.host_str()) {
(Some(url_host), Some(pattern_host)) => {
if url_host != pattern_host {
return false;
}
}
_ => return false,
}
let url_port = url.port_or_known_default();
let pattern_port = pattern_url.port_or_known_default();
if url_port != pattern_port {
return false;
}
let pattern_path = pattern_url.path();
let url_path = url.path();
if pattern_path == "/" || pattern_path.is_empty() {
return true;
}
if !url_path.starts_with(pattern_path) {
return false;
}
if !pattern_path.ends_with('/') && url_path.len() > pattern_path.len() {
let next_char = url_path
.as_bytes()
.get(pattern_path.len())
.map(|&b| b as char);
if next_char != Some('/') && next_char != Some('?') && next_char != Some('#') {
return false;
}
}
true
}
pub fn is_allowed(&self, url: &str) -> bool {
matches!(self.check(url), UrlMatch::Allowed)
}
pub fn is_enabled(&self) -> bool {
self.allow_all || !self.patterns.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_allowlist_blocks_all() {
let allowlist = NetworkAllowlist::new();
assert!(matches!(
allowlist.check("https://example.com"),
UrlMatch::Blocked { .. }
));
}
#[test]
fn test_allow_all() {
let allowlist = NetworkAllowlist::allow_all();
assert_eq!(allowlist.check("https://example.com"), UrlMatch::Allowed);
assert_eq!(
allowlist.check("http://localhost:8080/anything"),
UrlMatch::Allowed
);
}
#[test]
fn test_exact_host_match() {
let allowlist = NetworkAllowlist::new().allow("https://api.example.com");
assert_eq!(
allowlist.check("https://api.example.com"),
UrlMatch::Allowed
);
assert_eq!(
allowlist.check("https://api.example.com/"),
UrlMatch::Allowed
);
assert_eq!(
allowlist.check("https://api.example.com/v1/users"),
UrlMatch::Allowed
);
assert!(matches!(
allowlist.check("http://api.example.com"),
UrlMatch::Blocked { .. }
));
assert!(matches!(
allowlist.check("https://other.example.com"),
UrlMatch::Blocked { .. }
));
}
#[test]
fn test_path_prefix_match() {
let allowlist = NetworkAllowlist::new().allow("https://api.example.com/v1");
assert_eq!(
allowlist.check("https://api.example.com/v1"),
UrlMatch::Allowed
);
assert_eq!(
allowlist.check("https://api.example.com/v1/"),
UrlMatch::Allowed
);
assert_eq!(
allowlist.check("https://api.example.com/v1/users"),
UrlMatch::Allowed
);
assert!(matches!(
allowlist.check("https://api.example.com/v2"),
UrlMatch::Blocked { .. }
));
assert!(matches!(
allowlist.check("https://api.example.com/v10"),
UrlMatch::Blocked { .. }
));
}
#[test]
fn test_port_matching() {
let allowlist = NetworkAllowlist::new().allow("http://localhost:8080");
assert_eq!(
allowlist.check("http://localhost:8080/api"),
UrlMatch::Allowed
);
assert!(matches!(
allowlist.check("http://localhost:3000"),
UrlMatch::Blocked { .. }
));
assert!(matches!(
allowlist.check("http://localhost"),
UrlMatch::Blocked { .. }
));
}
#[test]
fn test_multiple_patterns() {
let allowlist = NetworkAllowlist::new()
.allow("https://api.example.com")
.allow("https://cdn.example.com")
.allow("http://localhost:3000");
assert_eq!(
allowlist.check("https://api.example.com/v1"),
UrlMatch::Allowed
);
assert_eq!(
allowlist.check("https://cdn.example.com/assets/logo.png"),
UrlMatch::Allowed
);
assert_eq!(
allowlist.check("http://localhost:3000/health"),
UrlMatch::Allowed
);
assert!(matches!(
allowlist.check("https://evil.com"),
UrlMatch::Blocked { .. }
));
}
#[test]
fn test_invalid_url() {
let allowlist = NetworkAllowlist::new().allow("https://example.com");
assert!(matches!(
allowlist.check("not a url"),
UrlMatch::Invalid { .. }
));
}
#[test]
fn test_is_enabled() {
let empty = NetworkAllowlist::new();
assert!(!empty.is_enabled());
let with_pattern = NetworkAllowlist::new().allow("https://example.com");
assert!(with_pattern.is_enabled());
let allow_all = NetworkAllowlist::allow_all();
assert!(allow_all.is_enabled());
}
#[test]
fn test_redact_url_strips_credentials() {
let redacted = redact_url("https://user:secret@example.com/path");
assert!(
!redacted.contains("secret"),
"password leaked: {}",
redacted
);
assert!(!redacted.contains("user"), "username leaked: {}", redacted);
assert!(redacted.contains("example.com/path"));
}
#[test]
fn test_redact_url_preserves_clean_url() {
let clean = "https://example.com/path?q=1";
assert_eq!(redact_url(clean), clean);
}
#[test]
fn test_blocked_message_no_credentials() {
let allowlist = NetworkAllowlist::new().allow("https://allowed.com");
let result = allowlist.check("https://user:pass@blocked.com/api");
match result {
UrlMatch::Blocked { reason } => {
assert!(!reason.contains("pass"), "credentials leaked: {}", reason);
}
_ => panic!("expected Blocked"),
}
}
#[test]
fn test_path_boundary_check_byte_safe() {
let allowlist = NetworkAllowlist::new().allow("https://example.com/api");
assert!(matches!(
allowlist.check("https://example.com/api/v1"),
UrlMatch::Allowed
));
assert!(matches!(
allowlist.check("https://example.com/apix"),
UrlMatch::Blocked { .. }
));
}
#[test]
fn test_is_private_ip_loopback() {
assert!(is_private_ip(&"127.0.0.1".parse().unwrap()));
assert!(is_private_ip(&"127.0.0.2".parse().unwrap()));
assert!(is_private_ip(&"::1".parse().unwrap()));
}
#[test]
fn test_is_private_ip_rfc1918() {
assert!(is_private_ip(&"10.0.0.1".parse().unwrap()));
assert!(is_private_ip(&"10.255.255.255".parse().unwrap()));
assert!(is_private_ip(&"172.16.0.1".parse().unwrap()));
assert!(is_private_ip(&"172.31.255.255".parse().unwrap()));
assert!(is_private_ip(&"192.168.0.1".parse().unwrap()));
assert!(is_private_ip(&"192.168.255.255".parse().unwrap()));
}
#[test]
fn test_is_private_ip_link_local() {
assert!(is_private_ip(&"169.254.0.1".parse().unwrap()));
assert!(is_private_ip(&"169.254.169.254".parse().unwrap())); }
#[test]
fn test_is_private_ip_public() {
assert!(!is_private_ip(&"8.8.8.8".parse().unwrap()));
assert!(!is_private_ip(&"1.1.1.1".parse().unwrap()));
assert!(!is_private_ip(&"203.0.113.1".parse().unwrap()));
}
#[test]
fn test_is_private_ip_v6() {
assert!(is_private_ip(&"::1".parse().unwrap()));
assert!(is_private_ip(&"fd00::1".parse().unwrap()));
assert!(is_private_ip(&"fe80::1".parse().unwrap()));
assert!(!is_private_ip(
&"2001:db8::1".parse::<std::net::IpAddr>().unwrap()
));
}
#[test]
fn test_block_private_ips_default_true() {
let al = NetworkAllowlist::new();
assert!(al.is_blocking_private_ips());
}
#[test]
fn test_block_private_ips_disabled() {
let al = NetworkAllowlist::new().block_private_ips(false);
assert!(!al.is_blocking_private_ips());
}
}