use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use ipnet::IpNet;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use super::ResolvedIpRules;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub fn is_disallowed_ip(ip: IpAddr, allow_loopback: bool) -> bool {
match ip {
IpAddr::V4(v4) => is_disallowed_ipv4(v4, allow_loopback),
IpAddr::V6(v6) => is_disallowed_ipv6(v6, allow_loopback),
}
}
fn is_disallowed_ipv4(ip: Ipv4Addr, allow_loopback: bool) -> bool {
if ip.is_loopback() {
return !allow_loopback;
}
if ip.is_private() || ip.is_link_local() || ip.is_broadcast() {
return true;
}
if ip.is_unspecified() || ip.is_multicast() || ip.is_documentation() {
return true;
}
if ip.octets() == [169, 254, 169, 254] {
return true;
}
let [a, b, _, _] = ip.octets();
if a == 100 && (64..=127).contains(&b) {
return true;
}
if a == 198 && (b == 18 || b == 19) {
return true;
}
false
}
fn is_disallowed_ipv6(ip: Ipv6Addr, allow_loopback: bool) -> bool {
if ip.is_loopback() {
return !allow_loopback;
}
if ip.is_unspecified() || ip.is_multicast() {
return true;
}
let segments = ip.segments();
if (segments[0] & 0xffc0) == 0xfe80 {
return true;
}
if (segments[0] & 0xfe00) == 0xfc00 {
return true;
}
if let Some(v4) = ip.to_ipv4_mapped() {
return is_disallowed_ipv4(v4, allow_loopback);
}
false
}
trait InnerResolver: Send + Sync {
fn lookup(&self, host: &str) -> std::io::Result<Vec<SocketAddr>>;
}
struct GaiInnerResolver;
impl InnerResolver for GaiInnerResolver {
fn lookup(&self, host: &str) -> std::io::Result<Vec<SocketAddr>> {
(host, 0_u16).to_socket_addrs().map(|addrs| addrs.collect())
}
}
pub struct GuardedResolver {
block_private: bool,
allow_loopback: bool,
deny_nets: Vec<IpNet>,
inner: Arc<dyn InnerResolver>,
}
impl std::fmt::Debug for GuardedResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GuardedResolver")
.field("block_private", &self.block_private)
.field("allow_loopback", &self.allow_loopback)
.field("deny_nets", &self.deny_nets)
.finish_non_exhaustive()
}
}
impl GuardedResolver {
pub fn new(allow_loopback: bool) -> Self {
Self {
block_private: true,
allow_loopback,
deny_nets: Vec::new(),
inner: Arc::new(GaiInnerResolver),
}
}
pub fn with_policy(block_private: bool, allow_loopback: bool, rules: &ResolvedIpRules) -> Self {
Self {
block_private,
allow_loopback,
deny_nets: rules.deny.clone(),
inner: Arc::new(GaiInnerResolver),
}
}
#[cfg(test)]
fn with_inner(allow_loopback: bool, inner: Arc<dyn InnerResolver>) -> Self {
Self {
block_private: true,
allow_loopback,
deny_nets: Vec::new(),
inner,
}
}
#[cfg(test)]
fn with_inner_full(
block_private: bool,
allow_loopback: bool,
deny_nets: Vec<IpNet>,
inner: Arc<dyn InnerResolver>,
) -> Self {
Self {
block_private,
allow_loopback,
deny_nets,
inner,
}
}
}
#[derive(Debug)]
struct BlockedHostError {
host: String,
}
impl std::fmt::Display for BlockedHostError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"egress: host `{}` resolves only to disallowed addresses \
(private, loopback, link-local, or metadata IP)",
self.host
)
}
}
impl std::error::Error for BlockedHostError {}
impl Resolve for GuardedResolver {
fn resolve(&self, name: Name) -> Resolving {
let host = name.as_str().to_string();
let block_private = self.block_private;
let allow_loopback = self.allow_loopback;
let deny_nets = self.deny_nets.clone();
let inner = self.inner.clone();
Box::pin(async move {
let host_for_lookup = host.clone();
let lookup = tokio::task::spawn_blocking(move || inner.lookup(&host_for_lookup))
.await
.map_err(|join_err| Box::new(join_err) as BoxError)?;
let resolved = lookup.map_err(|io_err| Box::new(io_err) as BoxError)?;
let filtered: Vec<SocketAddr> = resolved
.into_iter()
.filter(|addr| {
let ip = addr.ip();
let ssrf_blocked = block_private && is_disallowed_ip(ip, allow_loopback);
let deny_blocked = deny_nets.iter().any(|net| net.contains(&ip));
!ssrf_blocked && !deny_blocked
})
.collect();
if filtered.is_empty() {
return Err(Box::new(BlockedHostError { host }) as BoxError);
}
Ok(Box::new(filtered.into_iter()) as Addrs)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn v4(s: &str) -> IpAddr {
s.parse().unwrap()
}
#[test]
fn ipv4_blocks_loopback_unless_allowed() {
assert!(is_disallowed_ip(v4("127.0.0.1"), false));
assert!(!is_disallowed_ip(v4("127.0.0.1"), true));
assert!(is_disallowed_ip(v4("127.255.255.254"), false));
}
#[test]
fn ipv4_blocks_rfc1918() {
for ip in ["10.0.0.1", "172.16.5.5", "172.31.255.1", "192.168.1.1"] {
assert!(is_disallowed_ip(v4(ip), false), "{ip}");
}
}
#[test]
fn ipv4_blocks_link_local_and_metadata() {
assert!(is_disallowed_ip(v4("169.254.10.1"), false));
assert!(is_disallowed_ip(v4("169.254.169.254"), false));
assert!(is_disallowed_ip(v4("169.254.169.254"), true));
}
#[test]
fn ipv4_blocks_broadcast_unspecified_multicast_documentation() {
assert!(is_disallowed_ip(v4("255.255.255.255"), false));
assert!(is_disallowed_ip(v4("0.0.0.0"), false));
assert!(is_disallowed_ip(v4("224.0.0.1"), false));
assert!(is_disallowed_ip(v4("192.0.2.1"), false)); assert!(is_disallowed_ip(v4("198.51.100.1"), false)); assert!(is_disallowed_ip(v4("203.0.113.1"), false)); }
#[test]
fn ipv4_blocks_cgnat_and_benchmark() {
assert!(is_disallowed_ip(v4("100.64.0.1"), false));
assert!(is_disallowed_ip(v4("100.127.255.254"), false));
assert!(!is_disallowed_ip(v4("100.63.255.255"), false));
assert!(!is_disallowed_ip(v4("100.128.0.1"), false));
assert!(is_disallowed_ip(v4("198.18.0.1"), false));
assert!(is_disallowed_ip(v4("198.19.255.254"), false));
}
#[test]
fn ipv4_allows_public() {
for ip in ["8.8.8.8", "1.1.1.1", "93.184.216.34", "151.101.1.140"] {
assert!(!is_disallowed_ip(v4(ip), false), "{ip}");
}
}
#[test]
fn ipv6_blocks_loopback_unless_allowed() {
assert!(is_disallowed_ip(v4("::1"), false));
assert!(!is_disallowed_ip(v4("::1"), true));
}
#[test]
fn ipv6_blocks_unspecified_multicast_link_local_ula() {
assert!(is_disallowed_ip(v4("::"), false));
assert!(is_disallowed_ip(v4("ff02::1"), false)); assert!(is_disallowed_ip(v4("fe80::1"), false)); assert!(is_disallowed_ip(v4("febf::1"), false)); assert!(is_disallowed_ip(v4("fc00::1"), false)); assert!(is_disallowed_ip(v4("fd12:3456::1"), false)); }
#[test]
fn ipv6_mapped_v4_reapplies_v4_rules() {
assert!(is_disallowed_ip(v4("::ffff:127.0.0.1"), false));
assert!(is_disallowed_ip(v4("::ffff:10.0.0.1"), false));
assert!(is_disallowed_ip(v4("::ffff:169.254.169.254"), false));
assert!(!is_disallowed_ip(v4("::ffff:8.8.8.8"), false));
assert!(!is_disallowed_ip(v4("::ffff:127.0.0.1"), true));
}
#[test]
fn ipv6_allows_public() {
assert!(!is_disallowed_ip(v4("2001:4860:4860::8888"), false));
assert!(!is_disallowed_ip(v4("2606:4700:4700::1111"), false));
}
struct StubResolver {
addrs: Vec<SocketAddr>,
}
impl InnerResolver for StubResolver {
fn lookup(&self, _host: &str) -> std::io::Result<Vec<SocketAddr>> {
Ok(self.addrs.clone())
}
}
fn stub(allow_loopback: bool, addrs: &[&str]) -> GuardedResolver {
let addrs = addrs
.iter()
.map(|s| SocketAddr::new(s.parse().unwrap(), 0))
.collect();
GuardedResolver::with_inner(allow_loopback, Arc::new(StubResolver { addrs }))
}
async fn resolve_host(resolver: &GuardedResolver) -> Result<Vec<SocketAddr>, String> {
let name: Name = "example.test".parse().unwrap();
resolver
.resolve(name)
.await
.map(|addrs| addrs.collect())
.map_err(|e| e.to_string())
}
#[tokio::test]
async fn guarded_resolver_passes_public_addr() {
let resolver = stub(false, &["93.184.216.34"]);
let addrs = resolve_host(&resolver).await.expect("public allowed");
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0].ip(), v4("93.184.216.34"));
}
#[tokio::test]
async fn guarded_resolver_blocks_loopback() {
let resolver = stub(false, &["127.0.0.1"]);
let err = resolve_host(&resolver).await.expect_err("loopback blocked");
assert!(err.contains("example.test"), "{err}");
}
#[tokio::test]
async fn guarded_resolver_blocks_metadata() {
let resolver = stub(false, &["169.254.169.254"]);
let err = resolve_host(&resolver).await.expect_err("metadata blocked");
assert!(err.contains("example.test"), "{err}");
}
#[tokio::test]
async fn guarded_resolver_filters_mixed_to_public_only() {
let resolver = stub(false, &["127.0.0.1", "8.8.8.8", "169.254.169.254"]);
let addrs = resolve_host(&resolver).await.expect("public survives");
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0].ip(), v4("8.8.8.8"));
}
#[tokio::test]
async fn guarded_resolver_loopback_hatch_allows_loopback() {
let resolver = stub(true, &["127.0.0.1"]);
let addrs = resolve_host(&resolver)
.await
.expect("hatch allows loopback");
assert_eq!(addrs.len(), 1);
}
#[tokio::test]
async fn guarded_resolver_error_names_host_only() {
let resolver = stub(false, &["10.0.0.1"]);
let err = resolve_host(&resolver).await.expect_err("private blocked");
assert!(err.contains("example.test"), "{err}");
assert!(
!err.contains("10.0.0.1"),
"must not leak the address: {err}"
);
}
fn stub_with_deny(
block_private: bool,
allow_loopback: bool,
deny: &[&str],
addrs: &[&str],
) -> GuardedResolver {
let deny_nets: Vec<IpNet> = deny.iter().map(|s| s.parse().unwrap()).collect();
let addrs = addrs
.iter()
.map(|s| SocketAddr::new(s.parse().unwrap(), 0))
.collect();
GuardedResolver::with_inner_full(
block_private,
allow_loopback,
deny_nets,
Arc::new(StubResolver { addrs }),
)
}
#[tokio::test]
async fn guarded_resolver_blocks_resolved_addr_in_deny_cidr() {
let resolver = stub_with_deny(false, false, &["203.0.113.0/24"], &["203.0.113.7"]);
let err = resolve_host(&resolver)
.await
.expect_err("deny CIDR blocked");
assert!(err.contains("example.test"), "{err}");
assert!(!err.contains("203.0.113.7"), "must not leak address: {err}");
}
#[tokio::test]
async fn guarded_resolver_keeps_addr_outside_deny_cidr() {
let resolver = stub_with_deny(false, false, &["203.0.113.0/24"], &["8.8.8.8"]);
let addrs = resolve_host(&resolver).await.expect("outside deny CIDR");
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0].ip(), v4("8.8.8.8"));
}
#[tokio::test]
async fn guarded_resolver_deny_only_does_not_block_private() {
let resolver = stub_with_deny(false, false, &["203.0.113.0/24"], &["10.0.0.1"]);
let addrs = resolve_host(&resolver)
.await
.expect("private passes when block off");
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0].ip(), v4("10.0.0.1"));
}
#[tokio::test]
async fn guarded_resolver_applies_both_ssrf_and_deny() {
let resolver = stub_with_deny(
true,
false,
&["203.0.113.0/24"],
&["10.0.0.1", "203.0.113.5", "8.8.8.8"],
);
let addrs = resolve_host(&resolver).await.expect("one survivor");
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0].ip(), v4("8.8.8.8"));
}
}