use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
#[derive(Debug, Clone)]
pub struct DnsPolicy {
pub block_private: bool,
}
impl Default for DnsPolicy {
fn default() -> Self {
Self::block_private_ips()
}
}
impl DnsPolicy {
pub fn block_private_ips() -> Self {
Self {
block_private: true,
}
}
pub fn allow_all() -> Self {
Self {
block_private: false,
}
}
pub fn is_blocked_ip(&self, ip: IpAddr) -> bool {
if !self.block_private {
return false;
}
let canonical = ip.to_canonical();
match canonical {
IpAddr::V4(ipv4) => is_blocked_ipv4(ipv4),
IpAddr::V6(ipv6) => is_blocked_ipv6(ipv6),
}
}
pub fn resolve_and_validate(
&self,
host: &str,
port: u16,
) -> Result<SocketAddr, DnsPolicyError> {
let addrs: Vec<SocketAddr> = format!("{}:{}", host, port)
.to_socket_addrs()
.map_err(|e| DnsPolicyError::ResolutionFailed(host.to_string(), e.to_string()))?
.collect();
if addrs.is_empty() {
return Err(DnsPolicyError::NoAddresses(host.to_string()));
}
if !self.block_private {
return Ok(addrs[0]);
}
let mut all_blocked = true;
let mut first_valid = None;
for addr in &addrs {
if self.is_blocked_ip(addr.ip()) {
tracing::debug!(
ip = %addr.ip(),
host = host,
"Blocked private/reserved IP"
);
} else {
all_blocked = false;
if first_valid.is_none() {
first_valid = Some(*addr);
}
}
}
if all_blocked {
return Err(DnsPolicyError::AllAddressesBlocked(host.to_string()));
}
Ok(first_valid.unwrap())
}
}
#[derive(Debug, thiserror::Error)]
pub enum DnsPolicyError {
#[error("DNS resolution failed for {0}: {1}")]
ResolutionFailed(String, String),
#[error("No addresses resolved for {0}")]
NoAddresses(String),
#[error("All resolved addresses for {0} are in blocked IP ranges (private/reserved)")]
AllAddressesBlocked(String),
}
fn is_blocked_ipv4(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
if ip.is_unspecified() {
return true;
}
if ip.is_loopback() {
return true;
}
if octets[0] == 10 {
return true;
}
if octets[0] == 172 && (16..=31).contains(&octets[1]) {
return true;
}
if octets[0] == 192 && octets[1] == 168 {
return true;
}
if ip.is_link_local() {
return true;
}
if octets[0] == 100 && (64..=127).contains(&octets[1]) {
return true;
}
if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 {
return true;
}
if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 {
return true;
}
if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 {
return true;
}
if octets[0] == 198 && (18..=19).contains(&octets[1]) {
return true;
}
if ip.is_multicast() {
return true;
}
if ip.is_broadcast() {
return true;
}
false
}
fn is_blocked_ipv6(ip: Ipv6Addr) -> bool {
if ip.is_unspecified() {
return true;
}
if ip.is_loopback() {
return true;
}
if ip.is_multicast() {
return true;
}
let segments = ip.segments();
if segments[0] & 0xffc0 == 0xfe80 {
return true;
}
if segments[0] & 0xfe00 == 0xfc00 {
return true;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
fn policy() -> DnsPolicy {
DnsPolicy::block_private_ips()
}
#[test]
fn test_loopback_blocked() {
let p = policy();
assert!(p.is_blocked_ip("127.0.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("127.0.0.2".parse().unwrap()));
assert!(p.is_blocked_ip("127.255.255.255".parse().unwrap()));
}
#[test]
fn test_private_10_blocked() {
let p = policy();
assert!(p.is_blocked_ip("10.0.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("10.255.255.255".parse().unwrap()));
}
#[test]
fn test_private_172_blocked() {
let p = policy();
assert!(p.is_blocked_ip("172.16.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("172.31.255.255".parse().unwrap()));
assert!(!p.is_blocked_ip("172.15.0.1".parse().unwrap()));
assert!(!p.is_blocked_ip("172.32.0.1".parse().unwrap()));
}
#[test]
fn test_private_192_168_blocked() {
let p = policy();
assert!(p.is_blocked_ip("192.168.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("192.168.255.255".parse().unwrap()));
}
#[test]
fn test_link_local_blocked() {
let p = policy();
assert!(p.is_blocked_ip("169.254.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("169.254.169.254".parse().unwrap())); assert!(p.is_blocked_ip("169.254.255.255".parse().unwrap()));
}
#[test]
fn test_carrier_grade_nat_blocked() {
let p = policy();
assert!(p.is_blocked_ip("100.64.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("100.127.255.255".parse().unwrap()));
assert!(!p.is_blocked_ip("100.63.0.1".parse().unwrap()));
assert!(!p.is_blocked_ip("100.128.0.1".parse().unwrap()));
}
#[test]
fn test_documentation_ranges_blocked() {
let p = policy();
assert!(p.is_blocked_ip("192.0.2.1".parse().unwrap()));
assert!(p.is_blocked_ip("198.51.100.1".parse().unwrap()));
assert!(p.is_blocked_ip("203.0.113.1".parse().unwrap()));
}
#[test]
fn test_benchmarking_blocked() {
let p = policy();
assert!(p.is_blocked_ip("198.18.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("198.19.255.255".parse().unwrap()));
assert!(!p.is_blocked_ip("198.17.0.1".parse().unwrap()));
assert!(!p.is_blocked_ip("198.20.0.1".parse().unwrap()));
}
#[test]
fn test_multicast_blocked() {
let p = policy();
assert!(p.is_blocked_ip("224.0.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("239.255.255.255".parse().unwrap()));
}
#[test]
fn test_broadcast_blocked() {
let p = policy();
assert!(p.is_blocked_ip("255.255.255.255".parse().unwrap()));
}
#[test]
fn test_unspecified_blocked() {
let p = policy();
assert!(p.is_blocked_ip("0.0.0.0".parse().unwrap()));
}
#[test]
fn test_ipv6_loopback_blocked() {
let p = policy();
assert!(p.is_blocked_ip("::1".parse().unwrap()));
}
#[test]
fn test_ipv6_unspecified_blocked() {
let p = policy();
assert!(p.is_blocked_ip("::".parse().unwrap()));
}
#[test]
fn test_ipv6_link_local_blocked() {
let p = policy();
assert!(p.is_blocked_ip("fe80::1".parse().unwrap()));
assert!(p.is_blocked_ip("fe80::ffff:ffff:ffff:ffff".parse().unwrap()));
}
#[test]
fn test_ipv6_unique_local_blocked() {
let p = policy();
assert!(p.is_blocked_ip("fc00::1".parse().unwrap()));
assert!(p.is_blocked_ip("fd00::1".parse().unwrap()));
}
#[test]
fn test_ipv6_multicast_blocked() {
let p = policy();
assert!(p.is_blocked_ip("ff02::1".parse().unwrap()));
}
#[test]
fn test_ipv6_mapped_ipv4_blocked() {
let p = policy();
assert!(p.is_blocked_ip("::ffff:127.0.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("::ffff:10.0.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("::ffff:169.254.169.254".parse().unwrap()));
assert!(!p.is_blocked_ip("::ffff:8.8.8.8".parse().unwrap()));
}
#[test]
fn test_public_ipv4_allowed() {
let p = policy();
assert!(!p.is_blocked_ip("8.8.8.8".parse().unwrap()));
assert!(!p.is_blocked_ip("1.1.1.1".parse().unwrap()));
assert!(!p.is_blocked_ip("93.184.216.34".parse().unwrap()));
assert!(!p.is_blocked_ip("140.82.121.3".parse().unwrap()));
}
#[test]
fn test_public_ipv6_allowed() {
let p = policy();
assert!(!p.is_blocked_ip("2001:4860:4860::8888".parse().unwrap()));
assert!(!p.is_blocked_ip("2606:4700:4700::1111".parse().unwrap()));
}
#[test]
fn test_allow_all_permits_private() {
let p = DnsPolicy::allow_all();
assert!(!p.is_blocked_ip("127.0.0.1".parse().unwrap()));
assert!(!p.is_blocked_ip("10.0.0.1".parse().unwrap()));
assert!(!p.is_blocked_ip("169.254.169.254".parse().unwrap()));
}
#[test]
fn test_default_blocks_private() {
let p = DnsPolicy::default();
assert!(p.is_blocked_ip("127.0.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("10.0.0.1".parse().unwrap()));
assert!(p.is_blocked_ip("169.254.169.254".parse().unwrap()));
assert!(!p.is_blocked_ip("8.8.8.8".parse().unwrap()));
}
#[test]
fn test_resolve_loopback_blocked() {
let p = policy();
let result = p.resolve_and_validate("127.0.0.1", 80);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("blocked"), "Error was: {}", err);
}
#[test]
fn test_resolve_private_blocked() {
let p = policy();
let result = p.resolve_and_validate("10.0.0.1", 80);
assert!(result.is_err());
}
#[test]
fn test_resolve_nonexistent_fails() {
let p = policy();
let result = p.resolve_and_validate("this-host-definitely-does-not-exist.invalid", 80);
assert!(result.is_err());
}
#[test]
fn test_resolve_allow_all_permits_loopback() {
let p = DnsPolicy::allow_all();
let result = p.resolve_and_validate("127.0.0.1", 80);
assert!(result.is_ok());
assert_eq!(result.unwrap().ip(), "127.0.0.1".parse::<IpAddr>().unwrap());
}
}