use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use url::Url;
pub type OutboundUrlValidator = Arc<dyn Fn(&Url) -> Result<(), String> + Send + Sync>;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum SsrfDecision {
Allow { resolved_ip: IpAddr },
Block(SsrfBlockReason),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SsrfBlockReason {
PrivateIp,
InsecureScheme,
ValidatorDenied,
InvalidUrl,
DnsResolutionFailed,
}
pub fn is_blocked_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => is_blocked_ipv4(v4),
IpAddr::V6(v6) => is_blocked_ipv6(v6),
}
}
fn is_blocked_ipv4(v4: Ipv4Addr) -> bool {
let octets = v4.octets();
if v4.is_unspecified() {
return true;
}
if v4.is_broadcast() {
return true;
}
if v4.is_loopback() {
return true;
}
if v4.is_link_local() {
return true;
}
if v4.is_private() {
return true;
}
if v4.is_multicast() {
return true;
}
if octets[0] == 100 && (octets[1] & 0xc0) == 0x40 {
return true;
}
false
}
fn is_blocked_ipv6(v6: Ipv6Addr) -> bool {
if v6.is_unspecified() {
return true;
}
if v6.is_loopback() {
return true;
}
if v6.is_multicast() {
return true;
}
let segments = v6.segments();
if (segments[0] & 0xffc0) == 0xfe80 {
return true;
}
if (segments[0] & 0xfe00) == 0xfc00 {
return true;
}
if let Some(v4) = v6.to_ipv4_mapped() {
return is_blocked_ipv4(v4);
}
false
}
pub fn validate_scheme(url: &Url, allow_insecure: bool) -> Result<(), SsrfBlockReason> {
match url.scheme() {
"https" => Ok(()),
"http" if allow_insecure => Ok(()),
_ => Err(SsrfBlockReason::InsecureScheme),
}
}
pub fn decide(
url: &Url,
resolved_ips: &[IpAddr],
allow_insecure: bool,
outbound_validator: Option<&OutboundUrlValidator>,
) -> SsrfDecision {
if url.host_str().is_none() {
return SsrfDecision::Block(SsrfBlockReason::InvalidUrl);
}
if let Err(reason) = validate_scheme(url, allow_insecure) {
return SsrfDecision::Block(reason);
}
if let Some(validator) = outbound_validator {
if validator(url).is_err() {
return SsrfDecision::Block(SsrfBlockReason::ValidatorDenied);
}
}
for ip in resolved_ips {
if allow_insecure {
return SsrfDecision::Allow { resolved_ip: *ip };
}
if !is_blocked_ip(*ip) {
return SsrfDecision::Allow { resolved_ip: *ip };
}
}
if resolved_ips.is_empty() {
SsrfDecision::Block(SsrfBlockReason::DnsResolutionFailed)
} else {
SsrfDecision::Block(SsrfBlockReason::PrivateIp)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ipv4(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
#[test]
fn blocks_ipv4_loopback_and_localhost() {
assert!(is_blocked_ip(ipv4(127, 0, 0, 1)));
assert!(is_blocked_ip(ipv4(127, 255, 255, 254)));
}
#[test]
fn blocks_ipv4_rfc1918_private() {
assert!(is_blocked_ip(ipv4(10, 0, 0, 5)));
assert!(is_blocked_ip(ipv4(10, 255, 255, 255)));
assert!(is_blocked_ip(ipv4(172, 16, 0, 1)));
assert!(is_blocked_ip(ipv4(172, 31, 255, 255)));
assert!(is_blocked_ip(ipv4(192, 168, 1, 1)));
}
#[test]
fn blocks_ipv4_link_local_and_metadata() {
assert!(is_blocked_ip(ipv4(169, 254, 0, 1)));
assert!(is_blocked_ip(ipv4(169, 254, 169, 254)));
}
#[test]
fn blocks_ipv4_unspecified_broadcast_multicast() {
assert!(is_blocked_ip(ipv4(0, 0, 0, 0)));
assert!(is_blocked_ip(ipv4(255, 255, 255, 255)));
assert!(is_blocked_ip(ipv4(224, 0, 0, 1))); assert!(is_blocked_ip(ipv4(239, 255, 255, 255)));
}
#[test]
fn blocks_ipv4_carrier_grade_nat() {
assert!(is_blocked_ip(ipv4(100, 64, 0, 1)));
assert!(is_blocked_ip(ipv4(100, 127, 255, 255)));
assert!(!is_blocked_ip(ipv4(100, 63, 0, 0)));
assert!(!is_blocked_ip(ipv4(100, 128, 0, 0)));
}
#[test]
fn allows_ipv4_public() {
assert!(!is_blocked_ip(ipv4(8, 8, 8, 8)));
assert!(!is_blocked_ip(ipv4(1, 1, 1, 1)));
assert!(!is_blocked_ip(ipv4(203, 0, 113, 10))); }
#[test]
fn blocks_ipv6_loopback_and_link_local() {
assert!(is_blocked_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)));
assert!(is_blocked_ip(IpAddr::V6("fe80::1234".parse().unwrap())));
}
#[test]
fn blocks_ipv6_unique_local_and_multicast() {
assert!(is_blocked_ip(IpAddr::V6("fc00::1".parse().unwrap())));
assert!(is_blocked_ip(IpAddr::V6("fd00:ec2::254".parse().unwrap())));
assert!(is_blocked_ip(IpAddr::V6("ff02::1".parse().unwrap())));
}
#[test]
fn blocks_ipv6_mapped_private() {
assert!(is_blocked_ip(IpAddr::V6(
"::ffff:7f00:0001".parse().unwrap()
)));
assert!(is_blocked_ip(IpAddr::V6(
"::ffff:10.0.0.5".parse().unwrap()
)));
}
#[test]
fn allows_ipv6_public() {
assert!(!is_blocked_ip(IpAddr::V6("2001:db8::1".parse().unwrap()))); assert!(!is_blocked_ip(IpAddr::V6(
"2606:4700:4700::1111".parse().unwrap()
)));
}
#[test]
fn scheme_https_always_allowed() {
let url = Url::parse("https://webhook.example.com/deliver").unwrap();
assert!(validate_scheme(&url, false).is_ok());
assert!(validate_scheme(&url, true).is_ok());
}
#[test]
fn scheme_http_rejected_in_production() {
let url = Url::parse("http://webhook.example.com/deliver").unwrap();
assert!(matches!(
validate_scheme(&url, false),
Err(SsrfBlockReason::InsecureScheme)
));
}
#[test]
fn scheme_http_allowed_in_dev_mode() {
let url = Url::parse("http://localhost:8080/").unwrap();
assert!(validate_scheme(&url, true).is_ok());
}
#[test]
fn scheme_other_always_rejected() {
for s in ["file:///etc/passwd", "gopher://x", "ftp://x", "ldap://x"] {
let url = Url::parse(s).unwrap();
assert!(matches!(
validate_scheme(&url, true),
Err(SsrfBlockReason::InsecureScheme)
));
}
}
#[test]
fn decide_allows_public_https() {
let url = Url::parse("https://webhook.example.com/deliver").unwrap();
let ips = vec![ipv4(203, 0, 113, 10)];
assert!(matches!(
decide(&url, &ips, false, None),
SsrfDecision::Allow { .. }
));
}
#[test]
fn decide_blocks_private_ip() {
let url = Url::parse("https://metadata.attacker.example/").unwrap();
let ips = vec![ipv4(169, 254, 169, 254)];
assert!(matches!(
decide(&url, &ips, false, None),
SsrfDecision::Block(SsrfBlockReason::PrivateIp)
));
}
#[test]
fn decide_prefers_first_public_ip_when_mixed() {
let url = Url::parse("https://webhook.example.com/").unwrap();
let ips = vec![ipv4(10, 0, 0, 5), ipv4(203, 0, 113, 10)];
match decide(&url, &ips, false, None) {
SsrfDecision::Allow { resolved_ip } => {
assert_eq!(resolved_ip, ipv4(203, 0, 113, 10));
}
other => panic!("expected Allow on public IP, got {other:?}"),
}
}
#[test]
fn decide_allows_private_in_insecure_mode() {
let url = Url::parse("http://localhost:3000/").unwrap();
let ips = vec![ipv4(127, 0, 0, 1)];
assert!(matches!(
decide(&url, &ips, true, None),
SsrfDecision::Allow { .. }
));
}
#[test]
fn decide_empty_dns_is_blocked() {
let url = Url::parse("https://nothing.example/").unwrap();
assert!(matches!(
decide(&url, &[], false, None),
SsrfDecision::Block(SsrfBlockReason::DnsResolutionFailed)
));
}
#[test]
fn decide_respects_outbound_validator() {
let url = Url::parse("https://evil.attacker.com/").unwrap();
let validator: OutboundUrlValidator = Arc::new(|u: &Url| {
if u.host_str() == Some("webhook.example.com") {
Ok(())
} else {
Err(format!(
"host {} not in allowlist",
u.host_str().unwrap_or("")
))
}
});
let ips = vec![ipv4(203, 0, 113, 10)];
assert!(matches!(
decide(&url, &ips, false, Some(&validator)),
SsrfDecision::Block(SsrfBlockReason::ValidatorDenied)
));
let ok_url = Url::parse("https://webhook.example.com/").unwrap();
assert!(matches!(
decide(&ok_url, &ips, false, Some(&validator)),
SsrfDecision::Allow { .. }
));
}
#[test]
fn decide_rejects_non_https_in_production() {
let url = Url::parse("http://webhook.example.com/").unwrap();
let ips = vec![ipv4(203, 0, 113, 10)];
assert!(matches!(
decide(&url, &ips, false, None),
SsrfDecision::Block(SsrfBlockReason::InsecureScheme)
));
}
#[test]
fn decide_rejects_malformed_url() {
let url = Url::parse("data:text/plain,hello").unwrap();
let ips = vec![ipv4(203, 0, 113, 10)];
assert!(matches!(
decide(&url, &ips, false, None),
SsrfDecision::Block(SsrfBlockReason::InvalidUrl)
));
}
}