#[cfg(any(test, feature = "test-helpers"))]
use std::sync::atomic::{AtomicBool, Ordering};
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use tokio::net::lookup_host;
use crate::ssrf::{CompiledSsrfAllowlist, ip_block_reason};
#[cfg(any(test, feature = "test-helpers"))]
pub(crate) type TestLoopbackBypass = Arc<AtomicBool>;
#[cfg(not(any(test, feature = "test-helpers")))]
pub(crate) type TestLoopbackBypass = ();
#[derive(Clone)]
pub(crate) struct SsrfScreeningResolver {
allowlist: Arc<CompiledSsrfAllowlist>,
#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
test_bypass: TestLoopbackBypass,
}
impl SsrfScreeningResolver {
pub(crate) fn new(
allowlist: Arc<CompiledSsrfAllowlist>,
test_bypass: TestLoopbackBypass,
) -> Self {
Self {
allowlist,
test_bypass,
}
}
}
impl Resolve for SsrfScreeningResolver {
fn resolve(&self, name: Name) -> Resolving {
let allowlist = Arc::clone(&self.allowlist);
#[cfg(any(test, feature = "test-helpers"))]
let test_bypass = Arc::clone(&self.test_bypass);
Box::pin(async move {
let host = name.as_str().to_owned();
let raw: Vec<SocketAddr> = lookup_host((host.as_str(), 0)).await?.collect();
#[cfg(any(test, feature = "test-helpers"))]
let bypass_loopback = test_bypass.load(Ordering::Relaxed);
#[cfg(not(any(test, feature = "test-helpers")))]
let bypass_loopback = false;
match screen_addrs(&raw, &allowlist, &host, bypass_loopback) {
Ok(addrs) => {
let iter: Addrs = Box::new(addrs.into_iter());
Ok(iter)
}
Err(diag) => {
let err: Box<dyn std::error::Error + Send + Sync> =
format!("ssrf: {diag}").into();
Err(err)
}
}
})
}
}
pub(crate) fn screen_addrs(
addrs: &[SocketAddr],
allowlist: &CompiledSsrfAllowlist,
host: &str,
bypass_loopback: bool,
) -> Result<Vec<SocketAddr>, String> {
if addrs.is_empty() {
return Err(format!("DNS resolution for {host:?} returned no addresses"));
}
let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
for addr in addrs {
let ip: IpAddr = addr.ip();
let Some(reason) = ip_block_reason(ip) else {
continue;
};
if reason == "cloud_metadata" {
return Err(format!(
"{host:?} resolved to blocked IP {ip} (cloud_metadata)"
));
}
if bypass_loopback && reason == "loopback" {
continue;
}
if allowlist.is_empty() {
return Err(format!("{host:?} resolved to blocked IP {ip} ({reason})"));
}
if host_allowed || allowlist.ip_allowed(ip) {
continue;
}
return Err(format!("{host:?} resolved to blocked IP {ip} ({reason})"));
}
Ok(addrs.to_vec())
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use super::*;
use crate::ssrf::{CidrEntry, CompiledSsrfAllowlist};
fn sa(ip: IpAddr) -> SocketAddr {
SocketAddr::new(ip, 0)
}
fn empty_allowlist() -> CompiledSsrfAllowlist {
CompiledSsrfAllowlist::default()
}
fn allowlist_with(hosts: &[&str], cidrs: &[&str]) -> CompiledSsrfAllowlist {
let hosts = hosts.iter().map(|h| (*h).to_lowercase()).collect();
let cidrs = cidrs
.iter()
.map(|c| CidrEntry::parse(c).expect("test CIDR parses"))
.collect();
CompiledSsrfAllowlist::new(hosts, cidrs)
}
#[test]
fn rejects_empty_addrs() {
let err = screen_addrs(&[], &empty_allowlist(), "example.com", false)
.expect_err("empty resolution must error");
assert!(err.contains("returned no addresses"), "{err}");
}
#[test]
fn allows_public_ipv4() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))];
let out = screen_addrs(&addrs, &empty_allowlist(), "dns.google", false)
.expect("public IPv4 must pass");
assert_eq!(out, addrs);
}
#[test]
fn rejects_loopback_under_empty_allowlist() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::LOCALHOST))];
let err = screen_addrs(&addrs, &empty_allowlist(), "localhost", false)
.expect_err("loopback must be blocked");
assert!(err.contains("loopback"), "{err}");
}
#[test]
fn rejects_private_under_empty_allowlist() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))];
let err = screen_addrs(&addrs, &empty_allowlist(), "internal", false)
.expect_err("private RFC1918 must be blocked");
assert!(err.contains("private_rfc1918"), "{err}");
}
#[test]
fn rejects_cloud_metadata_even_with_full_allowlist() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)))];
let allowlist = allowlist_with(&["meta.example"], &["169.254.0.0/16"]);
let err = screen_addrs(&addrs, &allowlist, "meta.example", false)
.expect_err("cloud_metadata must be unbypassable");
assert!(err.contains("cloud_metadata"), "{err}");
}
#[test]
fn rejects_cloud_metadata_even_with_loopback_bypass() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)))];
let err = screen_addrs(&addrs, &empty_allowlist(), "meta", true)
.expect_err("cloud_metadata must survive loopback bypass");
assert!(err.contains("cloud_metadata"), "{err}");
}
#[test]
fn fails_any_blocked_when_mixed() {
let addrs = vec![
sa(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))),
sa(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))),
];
let err = screen_addrs(&addrs, &empty_allowlist(), "split-horizon", false)
.expect_err("any blocked address must fail the whole resolution");
assert!(err.contains("private_rfc1918"), "{err}");
}
#[test]
fn host_allowlist_permits_private() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))];
let allowlist = allowlist_with(&["internal.corp"], &[]);
let out = screen_addrs(&addrs, &allowlist, "internal.corp", false)
.expect("host allowlist must permit private IP");
assert_eq!(out, addrs);
}
#[test]
fn cidr_allowlist_permits_private() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3)))];
let allowlist = allowlist_with(&[], &["10.0.0.0/8"]);
let out = screen_addrs(&addrs, &allowlist, "internal", false)
.expect("CIDR allowlist must permit IP in range");
assert_eq!(out, addrs);
}
#[test]
fn cidr_allowlist_rejects_out_of_range() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))];
let allowlist = allowlist_with(&[], &["10.0.0.0/8"]);
let err = screen_addrs(&addrs, &allowlist, "elsewhere", false)
.expect_err("non-allowlisted private IP must fail");
assert!(err.contains("private_rfc1918"), "{err}");
}
#[test]
fn loopback_bypass_permits_only_loopback() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))];
let err = screen_addrs(&addrs, &empty_allowlist(), "internal", true)
.expect_err("loopback bypass must not allow RFC1918");
assert!(err.contains("private_rfc1918"), "{err}");
}
#[test]
fn loopback_bypass_permits_127_0_0_1() {
let addrs = vec![sa(IpAddr::V4(Ipv4Addr::LOCALHOST))];
let out = screen_addrs(&addrs, &empty_allowlist(), "localhost", true)
.expect("loopback bypass must permit 127.0.0.1");
assert_eq!(out, addrs);
}
#[test]
fn ipv6_loopback_blocked_without_bypass() {
let addrs = vec![sa(IpAddr::V6(Ipv6Addr::LOCALHOST))];
let err = screen_addrs(&addrs, &empty_allowlist(), "localhost", false)
.expect_err("IPv6 loopback must be blocked");
assert!(err.contains("loopback"), "{err}");
}
}