use std::net::IpAddr;
pub fn reject_ssrf_url(url: &str) -> Result<(), String> {
let parsed = url::Url::parse(url).map_err(|e| format!("Invalid URL '{}': {}", url, e))?;
if !matches!(parsed.scheme(), "http" | "https") {
return Err(format!(
"SSRF: only http/https schemes allowed, got '{}'",
parsed.scheme()
));
}
if let Some(parsed_host) = parsed.host() {
match parsed_host {
url::Host::Ipv4(v4) => {
if is_non_public_ip(IpAddr::V4(v4)) {
return Err(format!("SSRF: refusing non-public IPv4 {}", v4));
}
return Ok(());
}
url::Host::Ipv6(v6) => {
let ip = IpAddr::V6(v6);
if is_non_public_ip(ip) {
return Err(format!("SSRF: refusing non-public IPv6 {}", v6));
}
let s = v6.segments();
let upper_zero = s[0] == 0 && s[1] == 0 && s[2] == 0 && s[3] == 0 && s[4] == 0;
let looks_mapped = upper_zero && (s[5] == 0 || s[5] == 0xffff);
if looks_mapped {
let a = (s[6] >> 8) as u8;
let b = (s[6] & 0xff) as u8;
let c = (s[7] >> 8) as u8;
let d = (s[7] & 0xff) as u8;
let embedded = std::net::Ipv4Addr::new(a, b, c, d);
if !embedded.is_unspecified() && is_non_public_ip(IpAddr::V4(embedded)) {
return Err(format!(
"SSRF: refusing IPv6 literal embedding non-public IPv4 {} ({})",
v6, embedded
));
}
}
return Ok(());
}
url::Host::Domain(_) => {
}
}
}
if let Some(host) = parsed.host_str() {
let host_lc = host.to_ascii_lowercase();
if matches!(
host_lc.as_str(),
"localhost" | "127.0.0.1" | "::1" | "[::1]" | "metadata.google.internal"
) || host_lc == "localhost."
|| host_lc.starts_with("localhost.")
{
return Err(format!("SSRF: refusing host '{}'", host));
}
if host_lc == "169.254.169.254" {
return Err("SSRF: cloud metadata endpoint refused".to_string());
}
if let Some(normalised) = normalise_legacy_ipv4(&host_lc) {
let bad = is_bad_ipv4(&normalised);
if bad {
return Err(format!(
"SSRF: refusing obfuscated loopback/private IPv4 literal '{}' ({})",
host, normalised
));
}
}
if let Some(embedded_v4) = extract_embedded_ipv4(&host_lc) {
if is_bad_ipv4(&embedded_v4) {
return Err(format!(
"SSRF: refusing IPv6 literal embedding non-public IPv4 '{}'",
host
));
}
}
if let Ok(ip) = host.parse::<IpAddr>() {
let bad = is_non_public_ip(ip)
|| match ip {
IpAddr::V6(v6) => {
let s = v6.segments();
let upper_zero = s[0] == 0 && s[1] == 0 && s[2] == 0 && s[3] == 0 && s[4] == 0;
let looks_mapped = upper_zero && (s[5] == 0 || s[5] == 0xffff);
if looks_mapped {
let a = (s[6] >> 8) as u8;
let b = (s[6] & 0xff) as u8;
let c = (s[7] >> 8) as u8;
let d = (s[7] & 0xff) as u8;
let embedded = std::net::Ipv4Addr::new(a, b, c, d);
!embedded.is_unspecified()
&& is_non_public_ip(IpAddr::V4(embedded))
} else {
false
}
}
IpAddr::V4(_) => false,
};
if bad {
return Err(format!("SSRF: refusing non-public IP {}", ip));
}
}
} else {
return Err("SSRF: URL has no host".to_string());
}
Ok(())
}
pub fn is_non_public_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_broadcast()
|| v4.is_unspecified()
|| v4.is_documentation()
|| v4.octets() == [169, 254, 169, 254]
}
IpAddr::V6(v6) => {
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_unique_local()
|| v6.is_unicast_link_local()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SsrfSafeResolver;
impl reqwest::dns::Resolve for SsrfSafeResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let host = name.as_str().to_string();
Box::pin(async move {
let join_result: Result<
Result<Vec<std::net::SocketAddr>, std::io::Error>,
tokio::task::JoinError,
> = tokio::task::spawn_blocking(move || {
use std::net::ToSocketAddrs;
let iter = (host.as_str(), 0u16).to_socket_addrs()?;
let addrs: Vec<std::net::SocketAddr> = iter.collect();
Ok::<_, std::io::Error>(addrs)
})
.await;
let addrs: Vec<std::net::SocketAddr> = match join_result {
Ok(Ok(v)) => v,
Ok(Err(io_err)) => {
let boxed: Box<dyn std::error::Error + Send + Sync> = Box::new(io_err);
return Err(boxed);
}
Err(join_err) => {
let boxed: Box<dyn std::error::Error + Send + Sync> = Box::new(
std::io::Error::other(format!("dns join failed: {join_err}")),
);
return Err(boxed);
}
};
let filtered: Vec<std::net::SocketAddr> = addrs
.into_iter()
.filter(|sa| !is_non_public_ip(sa.ip()))
.collect();
if filtered.is_empty() {
let err: Box<dyn std::error::Error + Send + Sync> =
Box::<dyn std::error::Error + Send + Sync>::from(
"SSRF: hostname resolves only to non-public IPs — refusing".to_string(),
);
return Err(err);
}
let boxed: Box<dyn Iterator<Item = std::net::SocketAddr> + Send> =
Box::new(filtered.into_iter());
Ok(reqwest::dns::Addrs::from(boxed))
})
}
}
pub fn build_ssrf_safe_client(
timeout: std::time::Duration,
) -> Result<reqwest::Client, reqwest::Error> {
use std::sync::Arc;
reqwest::Client::builder()
.timeout(timeout)
.redirect(reqwest::redirect::Policy::none())
.dns_resolver(Arc::new(SsrfSafeResolver))
.build()
}
pub fn customise_ssrf_safe_client<F>(
timeout: std::time::Duration,
customise: F,
) -> Result<reqwest::Client, reqwest::Error>
where
F: FnOnce(reqwest::ClientBuilder) -> reqwest::ClientBuilder,
{
use std::sync::Arc;
let builder = reqwest::Client::builder().timeout(timeout);
let builder = customise(builder);
builder.dns_resolver(Arc::new(SsrfSafeResolver)).build()
}
fn normalise_legacy_ipv4(host: &str) -> Option<String> {
if host.parse::<std::net::Ipv4Addr>().is_ok() {
return Some(host.to_string());
}
let parse_part = |s: &str| -> Option<u32> {
if let Some(hex) = s.strip_prefix("0x").or_else(|| s.strip_prefix("0X")) {
u32::from_str_radix(hex, 16).ok()
} else if s.starts_with('0') && s.len() > 1 && s.chars().all(|c| c.is_ascii_digit()) {
u32::from_str_radix(s, 8).ok()
} else {
s.parse::<u32>().ok()
}
};
let parts: Vec<&str> = host.split('.').collect();
match parts.len() {
1 => parse_part(parts[0]).map(|n| {
let o = n.to_be_bytes();
format!("{}.{}.{}.{}", o[0], o[1], o[2], o[3])
}),
2 => {
let a = parse_part(parts[0])?;
let rest = parse_part(parts[1])?;
if a > 0xff || rest > 0xff_ffff {
return None;
}
let o = rest.to_be_bytes();
Some(format!("{}.{}.{}.{}", a, o[1], o[2], o[3]))
}
3 => {
let a = parse_part(parts[0])?;
let b = parse_part(parts[1])?;
let rest = parse_part(parts[2])?;
if a > 0xff || b > 0xff || rest > 0xffff {
return None;
}
let o = rest.to_be_bytes();
Some(format!("{}.{}.{}.{}", a, b, o[2], o[3]))
}
4 => {
let parsed: Option<Vec<u32>> = parts.iter().map(|p| parse_part(p)).collect();
let vals = parsed?;
if vals.iter().any(|v| *v > 0xff) {
return None;
}
Some(format!("{}.{}.{}.{}", vals[0], vals[1], vals[2], vals[3]))
}
_ => None,
}
}
fn is_bad_ipv4(dotted: &str) -> bool {
if let Ok(v4) = dotted.parse::<std::net::Ipv4Addr>() {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_broadcast()
|| v4.is_unspecified()
|| v4.is_documentation()
} else {
false
}
}
fn extract_embedded_ipv4(host: &str) -> Option<String> {
let s = host.trim_start_matches('[').trim_end_matches(']');
if !s.contains(':') {
return None;
}
let tail = s.rsplit(':').next()?;
if tail.parse::<std::net::Ipv4Addr>().is_ok() {
Some(tail.to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_loopback() {
assert!(reject_ssrf_url("http://127.0.0.1/pub").is_err());
assert!(reject_ssrf_url("https://localhost").is_err());
assert!(reject_ssrf_url("http://[::1]:8080").is_err());
}
#[test]
fn rejects_metadata() {
assert!(reject_ssrf_url("http://169.254.169.254/latest").is_err());
assert!(reject_ssrf_url("http://metadata.google.internal").is_err());
}
#[test]
fn rejects_private() {
assert!(reject_ssrf_url("http://10.0.0.5/x").is_err());
assert!(reject_ssrf_url("http://192.168.1.1").is_err());
assert!(reject_ssrf_url("http://172.16.0.3").is_err());
}
#[test]
fn rejects_non_http_scheme() {
assert!(reject_ssrf_url("file:///etc/passwd").is_err());
assert!(reject_ssrf_url("gopher://x").is_err());
}
#[test]
fn rejects_url_that_fails_to_parse() {
assert!(reject_ssrf_url("not a url").is_err());
}
#[test]
fn allows_public_https() {
assert!(reject_ssrf_url("https://example.com/pub").is_ok());
assert!(reject_ssrf_url("https://8.8.8.8").is_ok());
}
#[test]
fn rejects_localhost_subdomains() {
assert!(reject_ssrf_url("http://localhost.localdomain/").is_err());
assert!(reject_ssrf_url("http://localhost.example/").is_err());
assert!(reject_ssrf_url("http://LOCALHOST./").is_err());
}
#[test]
fn rejects_shorthand_ipv4() {
assert!(reject_ssrf_url("http://127.1/").is_err());
assert!(reject_ssrf_url("http://0x7f000001/").is_err());
assert!(reject_ssrf_url("http://2130706433/").is_err());
assert!(reject_ssrf_url("http://0xa000001/").is_err());
}
#[test]
fn rejects_ipv6_with_embedded_loopback() {
assert!(reject_ssrf_url("http://[::127.0.0.1]/").is_err());
assert!(reject_ssrf_url("http://[::ffff:10.0.0.1]/").is_err());
}
#[test]
fn normalise_legacy_ipv4_examples() {
assert_eq!(normalise_legacy_ipv4("127.1").as_deref(), Some("127.0.0.1"));
assert_eq!(
normalise_legacy_ipv4("0x7f000001").as_deref(),
Some("127.0.0.1")
);
assert_eq!(
normalise_legacy_ipv4("2130706433").as_deref(),
Some("127.0.0.1")
);
assert_eq!(
normalise_legacy_ipv4("192.168.1").as_deref(),
Some("192.168.0.1")
);
assert!(normalise_legacy_ipv4("example.com").is_none());
}
}