use std::io;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
const METADATA_HOST_BLOCKLIST: &[&str] = &[
"metadata.google.internal",
"metadata",
"metadata.goog",
"169.254.169.254",
];
#[derive(Debug, Clone, Default)]
pub struct SsrfResolver {
pub allow_hosts: Option<Vec<String>>,
}
impl SsrfResolver {
pub fn new() -> Self {
Self { allow_hosts: None }
}
pub fn with_allow_hosts(hosts: Vec<String>) -> Self {
Self {
allow_hosts: Some(hosts.into_iter().map(|h| h.to_lowercase()).collect()),
}
}
pub fn validate_addrs(host: &str, addrs: &[SocketAddr]) -> io::Result<()> {
if addrs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("no addresses for {}", host),
));
}
for sa in addrs {
let ip = sa.ip();
if ip.is_loopback() || ip.is_unspecified() || ip.is_multicast() {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("rejected IP {} for {}", ip, host),
));
}
match ip {
IpAddr::V4(v4) => {
if v4.is_private() || v4.is_link_local() || v4.is_broadcast() {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("rejected IPv4 {} for {}", v4, host),
));
}
if v4.octets() == [169, 254, 169, 254] {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("rejected AWS metadata IP for {}", host),
));
}
}
IpAddr::V6(v6) => {
let seg0 = v6.segments()[0];
if (seg0 & 0xfe00) == 0xfc00 || (seg0 & 0xffc0) == 0xfe80 {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("rejected IPv6 {} for {}", v6, host),
));
}
}
}
}
Ok(())
}
}
impl ureq::Resolver for SsrfResolver {
fn resolve(&self, netloc: &str) -> io::Result<Vec<SocketAddr>> {
let (raw_host, _) = netloc.rsplit_once(':').ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, format!("bad netloc: {}", netloc))
})?;
let host = raw_host
.trim_start_matches('[')
.trim_end_matches(']')
.to_lowercase();
if METADATA_HOST_BLOCKLIST.iter().any(|&h| h == host) {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("rejected metadata host: {}", host),
));
}
if let Some(allow) = &self.allow_hosts {
if !allow.iter().any(|h| h == &host) {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("host {} not in allowlist", host),
));
}
}
let addrs: Vec<SocketAddr> = netloc.to_socket_addrs()?.collect();
Self::validate_addrs(&host, &addrs)?;
Ok(addrs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ureq::Resolver;
#[test]
fn rejects_metadata_host_before_dns() {
let r = SsrfResolver::new();
let err = r.resolve("metadata.google.internal:80").unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
assert!(err.to_string().contains("metadata"));
}
#[test]
fn rejects_aws_metadata_ip_as_hostname() {
let r = SsrfResolver::new();
let err = r.resolve("169.254.169.254:80").unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
}
#[test]
fn rejects_non_allowlist_host_before_dns() {
let r = SsrfResolver::with_allow_hosts(vec!["example.com".into()]);
let err = r.resolve("not-on-list.invalid:80").unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
assert!(err.to_string().contains("allowlist"));
}
#[test]
fn validate_addrs_rejects_loopback() {
let sa: SocketAddr = "127.0.0.1:80".parse().unwrap();
let err = SsrfResolver::validate_addrs("localhost", &[sa]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
}
#[test]
fn validate_addrs_rejects_private_ipv4() {
let sa: SocketAddr = "10.0.0.5:80".parse().unwrap();
let err = SsrfResolver::validate_addrs("internal.corp", &[sa]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
assert!(err.to_string().contains("IPv4"));
}
#[test]
fn validate_addrs_rejects_link_local_ipv4() {
let sa: SocketAddr = "169.254.169.254:80".parse().unwrap();
let err = SsrfResolver::validate_addrs("anywhere", &[sa]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
}
#[test]
fn validate_addrs_rejects_ipv6_ula() {
let sa: SocketAddr = "[fc00::1]:80".parse().unwrap();
let err = SsrfResolver::validate_addrs("anywhere", &[sa]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
}
#[test]
fn validate_addrs_rejects_ipv6_link_local() {
let sa: SocketAddr = "[fe80::1]:80".parse().unwrap();
let err = SsrfResolver::validate_addrs("anywhere", &[sa]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
}
#[test]
fn validate_addrs_rejects_empty_list() {
let err = SsrfResolver::validate_addrs("anywhere", &[]).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
}
#[test]
fn validate_addrs_accepts_public_ipv4() {
let sa: SocketAddr = "93.184.216.34:80".parse().unwrap(); SsrfResolver::validate_addrs("example.com", &[sa]).unwrap();
}
#[test]
fn validate_addrs_rejects_batch_if_any_private() {
let public: SocketAddr = "93.184.216.34:80".parse().unwrap();
let private: SocketAddr = "10.0.0.1:80".parse().unwrap();
let err = SsrfResolver::validate_addrs("dual.example", &[public, private])
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
}
#[test]
fn ipv6_literal_netloc_strips_brackets() {
let r = SsrfResolver::new();
let err = r.resolve("[::1]:80").unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
}
}