use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SsrfError {
#[error("invalid URL: {0}")]
InvalidUrl(String),
#[error("URL scheme '{0}' not allowed — only http/https")]
DisallowedScheme(String),
#[error("URL has no host component")]
MissingHost,
#[error("DNS resolution failed for '{host}': {source}")]
DnsResolutionFailed {
host: String,
#[source]
source: std::io::Error,
},
#[error("DNS resolution returned no addresses for '{0}'")]
NoAddressesResolved(String),
#[error(
"target '{host}' resolves to {ip} which is in a blocked range ({reason}); \
pointing the cloud runner at internal addresses is not allowed"
)]
BlockedAddress {
host: String,
ip: IpAddr,
reason: &'static str,
},
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Policy {
pub allow_loopback: bool,
}
impl Policy {
pub const fn strict() -> Self {
Self {
allow_loopback: false,
}
}
pub const fn for_test() -> Self {
Self {
allow_loopback: true,
}
}
}
pub async fn validate_target_url(url: &str, policy: Policy) -> Result<(), SsrfError> {
let parsed = url::Url::parse(url).map_err(|e| SsrfError::InvalidUrl(e.to_string()))?;
let scheme = parsed.scheme();
if scheme != "http" && scheme != "https" {
return Err(SsrfError::DisallowedScheme(scheme.to_string()));
}
let host = parsed.host_str().ok_or(SsrfError::MissingHost)?.to_string();
let port = parsed.port_or_known_default().unwrap_or(80);
if let Ok(ip) = host.parse::<IpAddr>() {
check_ip(&host, ip, policy)?;
return Ok(());
}
let lookup_target = format!("{}:{}", host, port);
let addrs: Vec<std::net::SocketAddr> = tokio::net::lookup_host(&lookup_target)
.await
.map_err(|source| SsrfError::DnsResolutionFailed {
host: host.clone(),
source,
})?
.collect();
if addrs.is_empty() {
return Err(SsrfError::NoAddressesResolved(host));
}
for addr in addrs {
check_ip(&host, addr.ip(), policy)?;
}
Ok(())
}
fn check_ip(host: &str, ip: IpAddr, policy: Policy) -> Result<(), SsrfError> {
if let Some(reason) = blocked_reason(ip, policy) {
return Err(SsrfError::BlockedAddress {
host: host.to_string(),
ip,
reason,
});
}
Ok(())
}
fn blocked_reason(ip: IpAddr, policy: Policy) -> Option<&'static str> {
match ip {
IpAddr::V4(v4) => blocked_reason_v4(v4, policy),
IpAddr::V6(v6) => blocked_reason_v6(v6, policy),
}
}
fn blocked_reason_v4(ip: Ipv4Addr, policy: Policy) -> Option<&'static str> {
if ip.is_loopback() {
if policy.allow_loopback {
return None;
}
return Some("IPv4 loopback (127.0.0.0/8)");
}
if ip.is_unspecified() {
return Some("IPv4 unspecified (0.0.0.0)");
}
if ip.is_broadcast() {
return Some("IPv4 broadcast");
}
if ip.is_link_local() {
return Some("IPv4 link-local (169.254.0.0/16, includes cloud metadata IP)");
}
if ip.is_private() {
return Some("IPv4 RFC1918 private (10/8, 172.16/12, 192.168/16)");
}
if ip.is_documentation() {
return Some("IPv4 documentation range (RFC5737)");
}
let octets = ip.octets();
if octets[0] == 100 && (64..=127).contains(&octets[1]) {
return Some("IPv4 CGNAT (100.64.0.0/10)");
}
if octets[0] == 198 && (octets[1] == 18 || octets[1] == 19) {
return Some("IPv4 benchmark (198.18.0.0/15)");
}
None
}
fn blocked_reason_v6(ip: Ipv6Addr, policy: Policy) -> Option<&'static str> {
if ip.is_loopback() {
if policy.allow_loopback {
return None;
}
return Some("IPv6 loopback (::1)");
}
if ip.is_unspecified() {
return Some("IPv6 unspecified (::)");
}
let segments = ip.segments();
if (segments[0] & 0xffc0) == 0xfe80 {
return Some("IPv6 link-local (fe80::/10)");
}
if (segments[0] & 0xfe00) == 0xfc00 {
return Some("IPv6 unique-local (fc00::/7)");
}
if let Some(v4) = ip.to_ipv4_mapped() {
return blocked_reason_v4(v4, policy);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_blocked(addr: &str, policy: Policy, fragment: &str) {
let ip: IpAddr = addr.parse().unwrap();
let reason =
blocked_reason(ip, policy).unwrap_or_else(|| panic!("expected {addr} to be blocked"));
assert!(
reason.contains(fragment),
"{addr} blocked but reason '{reason}' missing fragment '{fragment}'"
);
}
fn assert_allowed(addr: &str, policy: Policy) {
let ip: IpAddr = addr.parse().unwrap();
assert!(blocked_reason(ip, policy).is_none(), "{addr} unexpectedly blocked");
}
#[test]
fn blocks_loopback_v4_strict() {
assert_blocked("127.0.0.1", Policy::strict(), "loopback");
assert_blocked("127.255.255.254", Policy::strict(), "loopback");
}
#[test]
fn allows_loopback_v4_in_test_policy() {
assert_allowed("127.0.0.1", Policy::for_test());
}
#[test]
fn blocks_link_local_aws_metadata() {
assert_blocked("169.254.169.254", Policy::strict(), "link-local");
}
#[test]
fn blocks_rfc1918_ranges() {
assert_blocked("10.0.0.1", Policy::strict(), "RFC1918");
assert_blocked("172.16.0.1", Policy::strict(), "RFC1918");
assert_blocked("172.31.255.255", Policy::strict(), "RFC1918");
assert_blocked("192.168.0.1", Policy::strict(), "RFC1918");
}
#[test]
fn blocks_cgnat() {
assert_blocked("100.64.0.1", Policy::strict(), "CGNAT");
assert_blocked("100.127.255.255", Policy::strict(), "CGNAT");
}
#[test]
fn allows_ranges_outside_cgnat() {
assert_allowed("100.63.255.255", Policy::strict());
assert_allowed("100.128.0.1", Policy::strict());
}
#[test]
fn blocks_benchmark_range() {
assert_blocked("198.18.0.1", Policy::strict(), "benchmark");
assert_blocked("198.19.255.255", Policy::strict(), "benchmark");
}
#[test]
fn allows_public_v4() {
assert_allowed("8.8.8.8", Policy::strict());
assert_allowed("1.1.1.1", Policy::strict());
assert_allowed("142.250.190.78", Policy::strict()); }
#[test]
fn blocks_loopback_v6_strict() {
assert_blocked("::1", Policy::strict(), "loopback");
}
#[test]
fn blocks_link_local_v6() {
assert_blocked("fe80::1", Policy::strict(), "link-local");
assert_blocked("febf::1", Policy::strict(), "link-local");
}
#[test]
fn blocks_ula() {
assert_blocked("fc00::1", Policy::strict(), "unique-local");
assert_blocked("fd12:3456::1", Policy::strict(), "unique-local");
}
#[test]
fn blocks_ipv4_mapped_private() {
assert_blocked("::ffff:10.0.0.1", Policy::strict(), "RFC1918");
assert_blocked("::ffff:127.0.0.1", Policy::strict(), "loopback");
}
#[test]
fn allows_public_v6() {
assert_allowed("2606:4700:4700::1111", Policy::strict()); assert_allowed("2001:4860:4860::8888", Policy::strict()); }
#[tokio::test]
async fn validate_rejects_non_http_scheme() {
let err = validate_target_url("file:///etc/passwd", Policy::strict()).await.unwrap_err();
assert!(matches!(err, SsrfError::DisallowedScheme(s) if s == "file"));
}
#[tokio::test]
async fn validate_rejects_garbage_url() {
let err = validate_target_url("not a url", Policy::strict()).await.unwrap_err();
assert!(matches!(err, SsrfError::InvalidUrl(_)));
}
#[tokio::test]
async fn validate_rejects_literal_loopback() {
let err = validate_target_url("http://127.0.0.1/", Policy::strict()).await.unwrap_err();
assert!(matches!(err, SsrfError::BlockedAddress { .. }));
}
#[tokio::test]
async fn validate_rejects_literal_metadata_ip() {
let err = validate_target_url("http://169.254.169.254/latest/meta-data/", Policy::strict())
.await
.unwrap_err();
match err {
SsrfError::BlockedAddress { reason, .. } => assert!(reason.contains("link-local")),
other => panic!("expected BlockedAddress, got {other:?}"),
}
}
#[tokio::test]
async fn validate_rejects_literal_rfc1918() {
let err = validate_target_url("http://10.0.0.1/", Policy::strict()).await.unwrap_err();
assert!(matches!(err, SsrfError::BlockedAddress { .. }));
}
#[tokio::test]
async fn validate_allows_loopback_in_test_policy() {
validate_target_url("http://127.0.0.1:8080/", Policy::for_test()).await.unwrap();
}
}