use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use thiserror::Error;
use url::Url;
use crate::metrics::SecurityMetrics;
pub const ENV_SSRF_ALLOWLIST: &str = "SSRF_ALLOWLIST";
pub const ENV_SSRF_DENYLIST: &str = "SSRF_DENYLIST";
pub const ENV_SSRF_ALLOW_PRIVATE: &str = "SSRF_ALLOW_PRIVATE";
pub const ENV_SSRF_ALLOW_LOOPBACK: &str = "SSRF_ALLOW_LOOPBACK";
pub const ENV_SSRF_ALLOW_LINK_LOCAL: &str = "SSRF_ALLOW_LINK_LOCAL";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IpClass {
Public,
Private,
Loopback,
LinkLocal,
Multicast,
Reserved,
}
#[derive(Debug, Error)]
pub enum SsrfError {
#[error("URL has no host component: {0}")]
MissingHost(String),
#[error("DNS resolution failed for host '{host}': {source}")]
DnsFailure {
host: String,
#[source]
source: std::io::Error,
},
#[error("DNS resolution returned no addresses for host '{host}'")]
NoAddresses { host: String },
#[error("host '{host}' (resolved to {ip}) is denylisted")]
Denylisted { host: String, ip: IpAddr },
#[error("host '{host}' (resolved to {ip}) blocked: {class:?}")]
BlockedClass {
host: String,
ip: IpAddr,
class: IpClass,
},
}
#[derive(Debug, Clone)]
pub struct SsrfPolicy {
allow_private: bool,
allow_loopback: bool,
allow_link_local: bool,
allowlist: Vec<String>,
denylist: Vec<String>,
metrics: Option<SecurityMetrics>,
}
impl SsrfPolicy {
pub fn new() -> Self {
Self {
allow_private: false,
allow_loopback: false,
allow_link_local: false,
allowlist: Vec::new(),
denylist: Vec::new(),
metrics: None,
}
}
pub fn from_env() -> Self {
Self {
allow_private: env_bool(ENV_SSRF_ALLOW_PRIVATE),
allow_loopback: env_bool(ENV_SSRF_ALLOW_LOOPBACK),
allow_link_local: env_bool(ENV_SSRF_ALLOW_LINK_LOCAL),
allowlist: env_csv(ENV_SSRF_ALLOWLIST),
denylist: env_csv(ENV_SSRF_DENYLIST),
metrics: None,
}
}
pub fn with_metrics(mut self, metrics: SecurityMetrics) -> Self {
self.metrics = Some(metrics);
self
}
pub fn with_allowlist(mut self, hosts: Vec<String>) -> Self {
self.allowlist = hosts;
self
}
pub fn with_denylist(mut self, hosts: Vec<String>) -> Self {
self.denylist = hosts;
self
}
pub fn with_allow_private(mut self, allow: bool) -> Self {
self.allow_private = allow;
self
}
pub fn with_allow_loopback(mut self, allow: bool) -> Self {
self.allow_loopback = allow;
self
}
pub fn with_allow_link_local(mut self, allow: bool) -> Self {
self.allow_link_local = allow;
self
}
pub fn classify(ip: IpAddr) -> IpClass {
match ip {
IpAddr::V4(v4) => classify_v4(v4),
IpAddr::V6(v6) => classify_v6(v6),
}
}
pub async fn resolve_and_check(&self, url: &Url) -> Result<IpAddr, SsrfError> {
let host = url
.host_str()
.ok_or_else(|| SsrfError::MissingHost(url.to_string()))?;
let host_lower = host.to_ascii_lowercase();
let port = url.port_or_known_default().unwrap_or(80);
let lookup_target = format!("{host}:{port}");
let mut addrs = tokio::net::lookup_host(&lookup_target)
.await
.map_err(|e| SsrfError::DnsFailure {
host: host.to_string(),
source: e,
})?;
let sock_addr = addrs.next().ok_or_else(|| SsrfError::NoAddresses {
host: host.to_string(),
})?;
let ip = sock_addr.ip();
if list_contains_host(&self.denylist, &host_lower) {
self.record_block(IpClass::Reserved);
return Err(SsrfError::Denylisted {
host: host.to_string(),
ip,
});
}
if list_contains_host(&self.allowlist, &host_lower) {
return Ok(ip);
}
let class = Self::classify(ip);
let permitted = match class {
IpClass::Public => true,
IpClass::Private => self.allow_private,
IpClass::Loopback => self.allow_loopback,
IpClass::LinkLocal => self.allow_link_local,
IpClass::Multicast | IpClass::Reserved => false,
};
if !permitted {
self.record_block(class);
return Err(SsrfError::BlockedClass {
host: host.to_string(),
ip,
class,
});
}
Ok(ip)
}
fn record_block(&self, class: IpClass) {
if let Some(m) = &self.metrics {
m.record_ssrf_block(class);
}
}
}
impl Default for SsrfPolicy {
fn default() -> Self {
Self::new()
}
}
fn classify_v4(v4: Ipv4Addr) -> IpClass {
let o = v4.octets();
if o == [169, 254, 169, 254] {
return IpClass::Reserved;
}
if v4.is_unspecified() || v4.is_broadcast() || v4.is_documentation() {
return IpClass::Reserved;
}
if v4.is_loopback() {
return IpClass::Loopback;
}
if v4.is_link_local() {
return IpClass::LinkLocal;
}
if v4.is_multicast() {
return IpClass::Multicast;
}
if v4.is_private() {
return IpClass::Private;
}
match o[0] {
0 => return IpClass::Reserved,
100 if (o[1] & 0xC0) == 0x40 => return IpClass::Reserved, 192 if o[1] == 0 && o[2] == 0 => return IpClass::Reserved,
192 if o[1] == 88 && o[2] == 99 => return IpClass::Reserved,
198 if o[1] == 18 || o[1] == 19 => return IpClass::Reserved,
240..=255 => return IpClass::Reserved,
_ => {}
}
IpClass::Public
}
fn classify_v6(v6: Ipv6Addr) -> IpClass {
let segs = v6.segments();
if segs == [0xfd00, 0x0ec2, 0, 0, 0, 0, 0, 0x0254] {
return IpClass::Reserved;
}
if v6.is_unspecified() {
return IpClass::Reserved;
}
if v6.is_loopback() {
return IpClass::Loopback;
}
if v6.is_multicast() {
return IpClass::Multicast;
}
if let Some(v4) = v6.to_ipv4_mapped() {
return classify_v4(v4);
}
let first = segs[0];
if (first & 0xFFC0) == 0xFE80 {
return IpClass::LinkLocal;
}
if (first & 0xFE00) == 0xFC00 {
return IpClass::Private;
}
if (first & 0xFFC0) == 0xFEC0 {
return IpClass::Private;
}
if first == 0x0100 && segs[1] == 0 && segs[2] == 0 && segs[3] == 0 {
return IpClass::Reserved;
}
if first == 0x2001 && segs[1] == 0x0db8 {
return IpClass::Reserved;
}
IpClass::Public
}
fn list_contains_host(list: &[String], host_lower: &str) -> bool {
list.iter().any(|entry| {
let e = entry.trim().to_ascii_lowercase();
let e_host = e.split(':').next().unwrap_or(&e);
!e_host.is_empty() && e_host == host_lower
})
}
fn env_bool(key: &str) -> bool {
std::env::var(key)
.ok()
.map(|v| {
let v = v.trim().to_ascii_lowercase();
matches!(v.as_str(), "1" | "true" | "yes" | "on")
})
.unwrap_or(false)
}
fn env_csv(key: &str) -> Vec<String> {
std::env::var(key)
.ok()
.map(|raw| {
raw.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
})
.unwrap_or_default()
}
pub fn is_safe_url(url: &str) -> Result<(), SsrfError> {
let parsed = Url::parse(url).map_err(|_| SsrfError::MissingHost(url.to_string()))?;
let host = parsed
.host()
.ok_or_else(|| SsrfError::MissingHost(url.to_string()))?;
match host {
url::Host::Ipv4(v4) => check_ip_safe(&v4.to_string(), IpAddr::V4(v4)),
url::Host::Ipv6(v6) => check_ip_safe(&v6.to_string(), IpAddr::V6(v6)),
url::Host::Domain(d) => {
if is_known_metadata_hostname(d) {
return Err(SsrfError::BlockedClass {
host: d.to_string(),
ip: IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)),
class: IpClass::Reserved,
});
}
Ok(())
}
}
}
pub async fn resolve_and_check(host: &str) -> Result<IpAddr, SsrfError> {
if is_known_metadata_hostname(host) {
return Err(SsrfError::BlockedClass {
host: host.to_string(),
ip: IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)),
class: IpClass::Reserved,
});
}
let lookup_target = if host.contains(':') {
host.to_string()
} else {
format!("{host}:80")
};
let addrs = tokio::net::lookup_host(&lookup_target)
.await
.map_err(|e| SsrfError::DnsFailure {
host: host.to_string(),
source: e,
})?;
let mut first: Option<IpAddr> = None;
for sock in addrs {
let ip = sock.ip();
check_ip_safe(host, ip)?;
if first.is_none() {
first = Some(ip);
}
}
first.ok_or_else(|| SsrfError::NoAddresses {
host: host.to_string(),
})
}
fn check_ip_safe(host: &str, ip: IpAddr) -> Result<(), SsrfError> {
let class = SsrfPolicy::classify(ip);
match class {
IpClass::Public => Ok(()),
IpClass::Private
| IpClass::Loopback
| IpClass::LinkLocal
| IpClass::Multicast
| IpClass::Reserved => Err(SsrfError::BlockedClass {
host: host.to_string(),
ip,
class,
}),
}
}
fn is_known_metadata_hostname(host: &str) -> bool {
let host_only = host.split(':').next().unwrap_or(host);
let lc = host_only.to_ascii_lowercase();
matches!(
lc.as_str(),
"metadata.google.internal" | "metadata" | "metadata.goog"
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn classify_rfc1918_private() {
assert_eq!(
SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))),
IpClass::Private
);
assert_eq!(
SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))),
IpClass::Private
);
assert_eq!(
SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))),
IpClass::Private
);
}
#[test]
fn classify_loopback() {
assert_eq!(
SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
IpClass::Loopback
);
assert_eq!(
SsrfPolicy::classify(IpAddr::V6(Ipv6Addr::LOCALHOST)),
IpClass::Loopback
);
}
#[test]
fn classify_public() {
assert_eq!(
SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))),
IpClass::Public
);
assert_eq!(
SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))),
IpClass::Public
);
}
#[test]
fn classify_cloud_metadata() {
assert_eq!(
SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))),
IpClass::Reserved
);
}
#[test]
fn classify_ipv6_link_local() {
assert_eq!(
SsrfPolicy::classify(IpAddr::V6("fe80::1".parse().unwrap())),
IpClass::LinkLocal
);
}
#[test]
fn classify_ipv6_ula() {
assert_eq!(
SsrfPolicy::classify(IpAddr::V6("fc00::1".parse().unwrap())),
IpClass::Private
);
assert_eq!(
SsrfPolicy::classify(IpAddr::V6("fd12:3456:789a::1".parse().unwrap())),
IpClass::Private
);
}
#[test]
fn classify_ipv6_public() {
assert_eq!(
SsrfPolicy::classify(IpAddr::V6("2001:4860:4860::8888".parse().unwrap())),
IpClass::Public
);
}
#[test]
fn default_policy_blocks_private() {
let p = SsrfPolicy::new();
assert!(!p.allow_private);
assert!(!p.allow_loopback);
assert!(!p.allow_link_local);
}
fn assert_blocked(url: &str, want_class: IpClass) {
match is_safe_url(url) {
Err(SsrfError::BlockedClass { class, .. }) => assert_eq!(
class, want_class,
"url {url} blocked with {class:?}, wanted {want_class:?}"
),
Err(other) => panic!("url {url} rejected with unexpected error: {other}"),
Ok(()) => panic!("url {url} accepted but expected block for {want_class:?}"),
}
}
#[test]
fn blocks_rfc1918_addresses() {
let cases = [
"http://10.0.0.1/",
"http://10.255.255.255/",
"http://172.16.0.1/",
"http://172.31.255.255/",
"http://192.168.0.1/",
"http://192.168.255.255/",
"http://[fc00::1]/",
"http://[fd00::1]/",
];
for url in cases {
assert_blocked(url, IpClass::Private);
}
}
#[test]
fn blocks_loopback() {
assert_blocked("http://127.0.0.1/", IpClass::Loopback);
assert_blocked("http://127.255.255.254/", IpClass::Loopback);
assert_blocked("http://[::1]/", IpClass::Loopback);
}
#[test]
fn blocks_link_local() {
assert_blocked("http://169.254.1.1/", IpClass::LinkLocal);
assert_blocked("http://169.254.254.254/", IpClass::LinkLocal);
assert_blocked("http://[fe80::1]/", IpClass::LinkLocal);
}
#[test]
fn blocks_aws_metadata_ip() {
assert_blocked("http://169.254.169.254/latest/meta-data/", IpClass::Reserved);
assert_blocked("http://[fd00:ec2::254]/latest/meta-data/", IpClass::Reserved);
}
#[tokio::test]
async fn blocks_aws_metadata_hostname() {
assert_blocked(
"http://metadata.google.internal/computeMetadata/v1/",
IpClass::Reserved,
);
match resolve_and_check("metadata.google.internal").await {
Err(SsrfError::BlockedClass { class, .. }) => assert_eq!(class, IpClass::Reserved),
other => panic!("expected BlockedClass for metadata.google.internal, got {other:?}"),
}
match resolve_and_check("metadata").await {
Err(SsrfError::BlockedClass { class, .. }) => assert_eq!(class, IpClass::Reserved),
other => panic!("expected BlockedClass for bare 'metadata', got {other:?}"),
}
}
#[test]
fn allows_public_ipv4() {
assert!(is_safe_url("https://8.8.8.8/").is_ok());
assert!(is_safe_url("https://1.1.1.1/").is_ok());
assert!(is_safe_url("https://93.184.216.34/").is_ok());
}
#[test]
fn allows_public_ipv6() {
assert!(is_safe_url("https://[2001:4860:4860::8888]/").is_ok());
assert!(is_safe_url("https://[2606:4700:4700::1111]/").is_ok());
}
#[test]
fn rejects_malformed_url() {
match is_safe_url("not a url") {
Err(SsrfError::MissingHost(_)) => {}
other => panic!("expected MissingHost for malformed url, got {other:?}"),
}
match is_safe_url("") {
Err(SsrfError::MissingHost(_)) => {}
other => panic!("expected MissingHost for empty url, got {other:?}"),
}
}
#[test]
fn rejects_http_without_host() {
match is_safe_url("file:///etc/passwd") {
Err(SsrfError::MissingHost(_)) => {}
other => panic!("expected MissingHost for file URL, got {other:?}"),
}
}
#[tokio::test]
async fn dns_failure_blocks_request() {
let result = resolve_and_check("this-host-does-not-exist.invalid").await;
match result {
Err(SsrfError::DnsFailure { host, .. }) => {
assert_eq!(host, "this-host-does-not-exist.invalid");
}
Err(SsrfError::NoAddresses { host, .. }) => {
assert_eq!(host, "this-host-does-not-exist.invalid");
}
Err(other) => panic!(
"expected DnsFailure or NoAddresses for unresolvable host, got {other:?}"
),
Ok(ip) => panic!(
"expected DNS failure for unresolvable host, got Ok({ip})"
),
}
}
#[tokio::test]
async fn policy_dns_failure_blocks_request() {
let policy = SsrfPolicy::new();
let url = Url::parse("https://this-host-does-not-exist.invalid/resource")
.expect("valid URL");
let result = policy.resolve_and_check(&url).await;
match result {
Err(SsrfError::DnsFailure { host, .. }) => {
assert!(host.contains("this-host-does-not-exist.invalid"));
}
Err(SsrfError::NoAddresses { host, .. }) => {
assert!(host.contains("this-host-does-not-exist.invalid"));
}
Err(other) => panic!(
"expected DnsFailure/NoAddresses through policy, got {other:?}"
),
Ok(ip) => panic!(
"expected DNS failure through policy, got Ok({ip})"
),
}
}
}