use anyhow::{Result, bail};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct NetworkRestrictions {
pub block_all: bool,
pub allowed_domains: Vec<String>,
pub allowed_ips: Vec<IpAddr>,
pub https_only: bool,
pub block_private_ips: bool,
pub block_localhost: bool,
pub max_requests: usize,
}
impl Default for NetworkRestrictions {
fn default() -> Self {
Self {
block_all: true,
allowed_domains: vec![],
allowed_ips: vec![],
https_only: true,
block_private_ips: true,
block_localhost: true,
max_requests: 0,
}
}
}
impl NetworkRestrictions {
pub fn deny_all() -> Self {
Self {
block_all: true,
..Default::default()
}
}
pub fn allow_domains(domains: Vec<String>) -> Self {
Self {
block_all: false,
allowed_domains: domains,
https_only: true,
block_private_ips: true,
block_localhost: true,
max_requests: 10,
..Default::default()
}
}
pub fn validate_url(&self, url: &str) -> Result<()> {
if self.block_all {
bail!(NetworkSecurityError::NetworkAccessDenied {
reason: "All network access is blocked".to_string()
});
}
let parsed = url::Url::parse(url).map_err(|e| NetworkSecurityError::InvalidUrl {
url: url.to_string(),
reason: e.to_string(),
})?;
self.validate_scheme(&parsed)?;
if let Some(host) = parsed.host_str() {
self.validate_host(host)?;
} else {
bail!(NetworkSecurityError::InvalidUrl {
url: url.to_string(),
reason: "No host specified".to_string()
});
}
debug!("URL validated: {}", url);
Ok(())
}
pub fn validate_domain(&self, domain: &str) -> Result<()> {
if self.block_all {
bail!(NetworkSecurityError::NetworkAccessDenied {
reason: "All network access is blocked".to_string()
});
}
if !self.is_domain_allowed(domain) {
warn!("Domain access denied: {} (not in whitelist)", domain);
bail!(NetworkSecurityError::DomainNotInWhitelist {
domain: domain.to_string(),
allowed_domains: self.allowed_domains.clone()
});
}
if self.block_localhost && is_localhost(domain) {
bail!(NetworkSecurityError::LocalhostAccessDenied {
domain: domain.to_string()
});
}
debug!("Domain validated: {}", domain);
Ok(())
}
fn validate_scheme(&self, url: &url::Url) -> Result<()> {
let scheme = url.scheme();
match scheme {
"https" => Ok(()),
"http" => {
if self.https_only {
bail!(NetworkSecurityError::HttpNotAllowed {
url: url.to_string()
});
}
Ok(())
}
_ => bail!(NetworkSecurityError::UnsupportedProtocol {
protocol: scheme.to_string(),
url: url.to_string()
}),
}
}
fn validate_host(&self, host: &str) -> Result<()> {
if let Ok(ip) = host.parse::<IpAddr>() {
return self.validate_ip(&ip);
}
self.validate_domain(host)
}
fn validate_ip(&self, ip: &IpAddr) -> Result<()> {
if !self.allowed_ips.is_empty() && !self.allowed_ips.contains(ip) {
bail!(NetworkSecurityError::IpNotInWhitelist {
ip: ip.to_string(),
allowed_ips: self.allowed_ips.iter().map(|i| i.to_string()).collect()
});
}
if self.block_localhost && is_localhost_ip(ip) {
bail!(NetworkSecurityError::LocalhostAccessDenied {
domain: ip.to_string()
});
}
if self.block_private_ips && is_private_ip(ip) {
bail!(NetworkSecurityError::PrivateIpAccessDenied { ip: ip.to_string() });
}
Ok(())
}
fn is_domain_allowed(&self, domain: &str) -> bool {
if self.allowed_domains.is_empty() {
return false;
}
if self.allowed_domains.contains(&domain.to_string()) {
return true;
}
for allowed in &self.allowed_domains {
if domain.ends_with(&format!(".{}", allowed)) {
return true;
}
}
false
}
}
fn is_localhost(domain: &str) -> bool {
matches!(
domain.to_lowercase().as_str(),
"localhost" | "localhost.localdomain" | "127.0.0.1" | "::1" | "0.0.0.0"
)
}
fn is_localhost_ip(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(ipv4) => ipv4.is_loopback(),
IpAddr::V6(ipv6) => ipv6.is_loopback(),
}
}
fn is_private_ip(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(ipv4) => is_private_ipv4(ipv4),
IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
}
}
fn is_private_ipv4(ip: &Ipv4Addr) -> bool {
ip.is_private()
|| ip.is_loopback()
|| ip.is_link_local()
|| ip.is_broadcast()
|| ip.is_documentation()
}
fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
if ip.is_loopback() {
return true;
}
let segments = ip.segments();
if (segments[0] & 0xfe00) == 0xfc00 {
return true;
}
if ip.is_multicast() {
return true;
}
false
}
#[derive(Debug, thiserror::Error)]
pub enum NetworkSecurityError {
#[error("Network access denied: {reason}")]
NetworkAccessDenied { reason: String },
#[error("Invalid URL: {url} - {reason}")]
InvalidUrl { url: String, reason: String },
#[error("HTTP not allowed (HTTPS only): {url}")]
HttpNotAllowed { url: String },
#[error("Unsupported protocol: {protocol} in URL: {url}")]
UnsupportedProtocol { protocol: String, url: String },
#[error("Domain not in whitelist: {domain} (allowed: {allowed_domains:?})")]
DomainNotInWhitelist {
domain: String,
allowed_domains: Vec<String>,
},
#[error("IP not in whitelist: {ip} (allowed: {allowed_ips:?})")]
IpNotInWhitelist {
ip: String,
allowed_ips: Vec<String>,
},
#[error("Localhost access denied: {domain}")]
LocalhostAccessDenied { domain: String },
#[error("Private IP access denied: {ip}")]
PrivateIpAccessDenied { ip: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deny_all() {
let restrictions = NetworkRestrictions::deny_all();
let result = restrictions.validate_url("https://example.com");
assert!(result.is_err());
}
#[test]
fn test_https_only() {
let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
assert!(restrictions.validate_url("https://example.com").is_ok());
assert!(restrictions.validate_url("http://example.com").is_err());
}
#[test]
fn test_domain_whitelist() {
let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
assert!(restrictions.validate_url("https://example.com").is_ok());
assert!(restrictions.validate_url("https://api.example.com").is_ok());
assert!(restrictions.validate_url("https://evil.com").is_err());
}
#[test]
fn test_localhost_blocking() {
let mut restrictions = NetworkRestrictions::allow_domains(vec!["localhost".to_string()]);
restrictions.block_localhost = true;
assert!(restrictions.validate_domain("localhost").is_err());
assert!(restrictions.validate_domain("127.0.0.1").is_err());
}
#[test]
fn test_private_ip_blocking() {
let restrictions = NetworkRestrictions {
block_all: false,
block_private_ips: true,
..Default::default()
};
let private_ips = vec![
"10.0.0.1",
"172.16.0.1",
"192.168.1.1",
"127.0.0.1",
"169.254.1.1",
];
for ip in private_ips {
let addr: IpAddr = ip.parse().unwrap();
assert!(restrictions.validate_ip(&addr).is_err());
}
}
#[test]
fn test_public_ip_allowed() {
let restrictions = NetworkRestrictions {
block_all: false,
allowed_ips: vec!["8.8.8.8".parse().unwrap()],
block_private_ips: true,
..Default::default()
};
let ip: IpAddr = "8.8.8.8".parse().unwrap();
assert!(restrictions.validate_ip(&ip).is_ok());
}
#[test]
fn test_is_localhost() {
assert!(is_localhost("localhost"));
assert!(is_localhost("LOCALHOST"));
assert!(is_localhost("127.0.0.1"));
assert!(!is_localhost("example.com"));
}
#[test]
fn test_is_private_ipv4() {
let private = Ipv4Addr::new(192, 168, 1, 1);
assert!(is_private_ipv4(&private));
let public = Ipv4Addr::new(8, 8, 8, 8);
assert!(!is_private_ipv4(&public));
}
#[test]
fn test_subdomain_matching() {
let restrictions = NetworkRestrictions::allow_domains(vec!["example.com".to_string()]);
assert!(restrictions.is_domain_allowed("example.com"));
assert!(restrictions.is_domain_allowed("api.example.com"));
assert!(restrictions.is_domain_allowed("foo.bar.example.com"));
assert!(!restrictions.is_domain_allowed("examplecom"));
assert!(!restrictions.is_domain_allowed("evil.com"));
}
}