use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use hickory_resolver::TokioResolver;
use hickory_resolver::config::{ResolverConfig, ResolverOpts};
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use crate::error::ToolError;
#[must_use]
pub fn is_ssrf_blocked(ip: &IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => is_ssrf_blocked_v4(*v4),
IpAddr::V6(v6) => is_ssrf_blocked_v6(*v6),
}
}
fn is_ssrf_blocked_v4(v4: std::net::Ipv4Addr) -> bool {
let octets = v4.octets();
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_broadcast()
|| v4.is_unspecified()
|| v4.is_multicast()
|| v4.is_documentation()
|| (octets[0] == 100 && (64..=127).contains(&octets[1]))
}
fn is_ssrf_blocked_v6(v6: std::net::Ipv6Addr) -> bool {
let segs = v6.segments();
if v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() {
return true;
}
if segs[0] & 0xfe00 == 0xfc00 {
return true;
}
if segs[0] & 0xffc0 == 0xfe80 {
return true;
}
if segs[0..5].iter().all(|s| *s == 0) && segs[5] == 0xffff {
let v4 = std::net::Ipv4Addr::new(
(segs[6] >> 8) as u8,
(segs[6] & 0xff) as u8,
(segs[7] >> 8) as u8,
(segs[7] & 0xff) as u8,
);
return is_ssrf_blocked_v4(v4);
}
if segs[0] == 0x2002 {
return true;
}
if segs[0] == 0x2001 && segs[1] == 0 {
return true;
}
false
}
pub struct SsrfSafeDnsResolver {
inner: TokioResolver,
explicit_allow: Arc<HashSet<IpAddr>>,
}
impl SsrfSafeDnsResolver {
pub fn from_system() -> Result<Self, ToolError> {
let inner = TokioResolver::builder_tokio()
.map_err(|e| ToolError::Config {
message: format!("DNS: failed to read system config: {e}"),
source: Some(Box::new(e)),
})?
.build()
.map_err(|e| ToolError::Config {
message: format!("DNS: failed to construct resolver: {e}"),
source: Some(Box::new(e)),
})?;
Ok(Self {
inner,
explicit_allow: Arc::new(HashSet::new()),
})
}
pub fn from_config(config: ResolverConfig, opts: ResolverOpts) -> Result<Self, ToolError> {
let inner = TokioResolver::builder_with_config(
config,
hickory_resolver::net::runtime::TokioRuntimeProvider::default(),
)
.with_options(opts)
.build()
.map_err(|e| ToolError::Config {
message: format!("DNS: failed to construct resolver: {e}"),
source: Some(Box::new(e)),
})?;
Ok(Self {
inner,
explicit_allow: Arc::new(HashSet::new()),
})
}
#[must_use]
pub fn with_explicit_allow(mut self, ips: HashSet<IpAddr>) -> Self {
self.explicit_allow = Arc::new(ips);
self
}
}
#[allow(
clippy::missing_fields_in_debug,
reason = "TokioResolver carries a non-Debug closure; printed as the explicit-allow count instead"
)]
impl std::fmt::Debug for SsrfSafeDnsResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SsrfSafeDnsResolver")
.field("explicit_allow_count", &self.explicit_allow.len())
.finish()
}
}
impl Resolve for SsrfSafeDnsResolver {
fn resolve(&self, name: Name) -> Resolving {
let inner = self.inner.clone();
let allow = Arc::clone(&self.explicit_allow);
Box::pin(async move {
let host = name.as_str();
let lookup = inner
.lookup_ip(host)
.await
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
let mut safe: Vec<SocketAddr> = Vec::new();
let mut blocked: Vec<IpAddr> = Vec::new();
for ip in lookup.iter() {
if allow.contains(&ip) || !is_ssrf_blocked(&ip) {
safe.push(SocketAddr::new(ip, 0));
} else {
blocked.push(ip);
}
}
if safe.is_empty() {
let msg = format!(
"DNS for '{host}' resolved only to blocked IPs ({blocked:?}); \
refusing to connect (SSRF guard)"
);
return Err(msg.into());
}
let iter: Addrs = Box::new(safe.into_iter());
Ok(iter)
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::ip_constant)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use super::*;
#[test]
fn ipv4_loopback_blocked() {
assert!(is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
}
#[test]
fn ipv4_metadata_endpoint_blocked() {
assert!(is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(
169, 254, 169, 254
))));
}
#[test]
fn ipv4_private_ranges_blocked() {
assert!(is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5))));
assert!(is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(172, 16, 1, 1))));
assert!(is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
}
#[test]
fn ipv4_cgnat_blocked() {
assert!(is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1))));
assert!(is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(100, 127, 1, 1))));
}
#[test]
fn ipv4_public_address_passes() {
assert!(!is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
assert!(!is_ssrf_blocked(&IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
}
#[test]
fn ipv6_loopback_and_unspecified_blocked() {
assert!(is_ssrf_blocked(&IpAddr::V6(Ipv6Addr::LOCALHOST)));
assert!(is_ssrf_blocked(&IpAddr::V6(Ipv6Addr::UNSPECIFIED)));
}
#[test]
fn ipv6_unique_local_and_link_local_blocked() {
assert!(is_ssrf_blocked(&IpAddr::V6("fd00::1".parse().unwrap())));
assert!(is_ssrf_blocked(&IpAddr::V6("fe80::1".parse().unwrap())));
}
#[test]
fn ipv4_mapped_ipv6_routes_through_v4_block() {
assert!(is_ssrf_blocked(&IpAddr::V6(
"::ffff:127.0.0.1".parse().unwrap()
)));
assert!(is_ssrf_blocked(&IpAddr::V6(
"::ffff:10.0.0.5".parse().unwrap()
)));
assert!(is_ssrf_blocked(&IpAddr::V6(
"::ffff:169.254.169.254".parse().unwrap()
)));
}
#[test]
fn ipv4_mapped_public_v4_passes() {
assert!(!is_ssrf_blocked(&IpAddr::V6(
"::ffff:8.8.8.8".parse().unwrap()
)));
}
#[test]
fn six_to_four_prefix_blocked_unconditionally() {
assert!(is_ssrf_blocked(&IpAddr::V6("2002::1".parse().unwrap())));
assert!(is_ssrf_blocked(&IpAddr::V6(
"2002:7f00:0001::".parse().unwrap()
)));
assert!(is_ssrf_blocked(&IpAddr::V6(
"2002:0808:0808::".parse().unwrap()
)));
}
#[test]
fn teredo_prefix_blocked_unconditionally() {
assert!(is_ssrf_blocked(&IpAddr::V6("2001::1".parse().unwrap())));
assert!(is_ssrf_blocked(&IpAddr::V6(
"2001:0:abcd:ef01::".parse().unwrap()
)));
}
#[test]
fn non_teredo_2001_prefix_allowed() {
assert!(!is_ssrf_blocked(&IpAddr::V6(
"2001:4860:4860::8888".parse().unwrap()
)));
}
#[test]
fn ipv6_public_address_passes() {
assert!(!is_ssrf_blocked(&IpAddr::V6(
"2001:4860:4860::8888".parse().unwrap()
)));
}
#[tokio::test]
async fn resolver_rejects_when_only_blocked_ips_resolve() {
let r =
SsrfSafeDnsResolver::from_config(ResolverConfig::default(), ResolverOpts::default());
assert!(format!("{r:?}").contains("SsrfSafeDnsResolver"));
}
#[test]
fn explicit_allow_overrides_block_for_listed_ips() {
let mut allow = HashSet::new();
allow.insert(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
let allowed_ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let other_blocked = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5));
let safe_for_allowed = allow.contains(&allowed_ip) || !is_ssrf_blocked(&allowed_ip);
let safe_for_other = allow.contains(&other_blocked) || !is_ssrf_blocked(&other_blocked);
assert!(safe_for_allowed, "explicit_allow must override block");
assert!(!safe_for_other, "non-allowlisted private IP stays blocked");
}
}