use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use tracing::warn;
use url::Url;
use crate::error::NabError;
pub const DEFAULT_MAX_REDIRECTS: u32 = 5;
pub const DEFAULT_MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
pub const ALLOW_PRIVATE_ENV: &str = "NAB_SSRF_ALLOW_PRIVATE";
pub const ALLOWLIST_ENV: &str = "NAB_SSRF_ALLOWLIST";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpCidr {
V4 {
network: u32,
prefix_len: u8,
},
V6 {
network: u128,
prefix_len: u8,
},
}
impl IpCidr {
pub fn parse(spec: &str) -> Result<Self, String> {
let spec = spec.trim();
let (addr_part, prefix_part) = match spec.split_once('/') {
Some((a, p)) => (a, Some(p)),
None => (spec, None),
};
match addr_part.parse::<IpAddr>() {
Ok(IpAddr::V4(v4)) => {
let prefix_len = Self::parse_prefix(prefix_part, 32)?;
let bits = u32::from(v4);
Ok(Self::V4 {
network: mask_u32(bits, prefix_len),
prefix_len,
})
}
Ok(IpAddr::V6(v6)) => {
let prefix_len = Self::parse_prefix(prefix_part, 128)?;
let bits = u128::from(v6);
Ok(Self::V6 {
network: mask_u128(bits, prefix_len),
prefix_len,
})
}
Err(e) => Err(format!("invalid IP address '{addr_part}': {e}")),
}
}
fn parse_prefix(prefix_part: Option<&str>, max: u8) -> Result<u8, String> {
match prefix_part {
None => Ok(max),
Some(p) => {
let value: u8 = p
.trim()
.parse()
.map_err(|_| format!("invalid prefix length '{p}'"))?;
if value > max {
return Err(format!("prefix /{value} exceeds maximum /{max}"));
}
Ok(value)
}
}
}
#[must_use]
pub fn contains(&self, ip: IpAddr) -> bool {
match (self, ip) {
(
Self::V4 {
network,
prefix_len,
},
IpAddr::V4(v4),
) => mask_u32(u32::from(v4), *prefix_len) == *network,
(
Self::V6 {
network,
prefix_len,
},
IpAddr::V6(v6),
) => mask_u128(u128::from(v6), *prefix_len) == *network,
_ => false,
}
}
}
fn mask_u32(bits: u32, prefix_len: u8) -> u32 {
if prefix_len == 0 {
0
} else if prefix_len >= 32 {
bits
} else {
bits & (u32::MAX << (32 - prefix_len))
}
}
fn mask_u128(bits: u128, prefix_len: u8) -> u128 {
if prefix_len == 0 {
0
} else if prefix_len >= 128 {
bits
} else {
bits & (u128::MAX << (128 - prefix_len))
}
}
#[derive(Debug, Clone, Default)]
pub struct SsrfPolicy {
allow_private: bool,
allowlist: Vec<IpCidr>,
}
impl SsrfPolicy {
#[must_use]
pub fn deny_all() -> Self {
Self {
allow_private: false,
allowlist: Vec::new(),
}
}
#[must_use]
pub fn from_env() -> Self {
let allow_private = std::env::var(ALLOW_PRIVATE_ENV)
.ok()
.as_deref()
.map(str::trim)
.is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true"));
let allowlist = std::env::var(ALLOWLIST_ENV)
.ok()
.map(|raw| parse_allowlist(&raw))
.unwrap_or_default();
Self {
allow_private,
allowlist,
}
}
#[must_use]
pub fn with_allow_private(mut self, allow: bool) -> Self {
if allow {
self.allow_private = true;
}
self
}
#[must_use]
pub fn with_allowlist_entries<I, S>(mut self, entries: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for entry in entries {
let entry = entry.as_ref();
if entry.trim().is_empty() {
continue;
}
match IpCidr::parse(entry) {
Ok(cidr) => self.allowlist.push(cidr),
Err(e) => warn!("SSRF: ignoring malformed allowlist entry '{entry}': {e}"),
}
}
self
}
#[must_use]
pub fn is_relaxed(&self) -> bool {
self.allow_private || !self.allowlist.is_empty()
}
fn permits_relaxable(&self, ip: IpAddr) -> bool {
if !is_relaxable(ip) {
return false;
}
if self.allow_private {
return true;
}
self.allowlist.iter().any(|cidr| cidr.contains(ip))
}
}
fn parse_allowlist(raw: &str) -> Vec<IpCidr> {
raw.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.filter_map(|spec| match IpCidr::parse(spec) {
Ok(cidr) => Some(cidr),
Err(e) => {
warn!("SSRF: ignoring malformed allowlist entry '{spec}': {e}");
None
}
})
.collect()
}
fn is_relaxable(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => v4.is_private() || is_ipv4_cgn(v4),
IpAddr::V6(v6) => (v6.segments()[0] & 0xfe00) == 0xfc00,
}
}
#[must_use]
pub fn is_denied_ipv4(ip: Ipv4Addr) -> bool {
is_denied_ipv4_with_policy(ip, &SsrfPolicy::deny_all())
}
#[must_use]
pub fn is_denied_ipv4_with_policy(ip: Ipv4Addr, policy: &SsrfPolicy) -> bool {
let always_denied = ip.is_loopback()
|| ip.is_link_local()
|| ip.is_broadcast()
|| ip.is_unspecified()
|| ip.is_multicast()
|| is_ipv4_documentation(ip)
|| is_ipv4_benchmarking(ip)
|| is_ipv4_protocol_assignments(ip)
|| is_ipv4_6to4_relay(ip)
|| is_ipv4_reserved(ip);
if always_denied {
return true;
}
if ip.is_private() || is_ipv4_cgn(ip) {
return !policy.permits_relaxable(IpAddr::V4(ip));
}
false
}
fn is_ipv4_documentation(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
matches!(
(octets[0], octets[1], octets[2]),
(192, 0, 2) | (198, 51, 100) | (203, 0, 113)
)
}
fn is_ipv4_benchmarking(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 198 && (octets[1] == 18 || octets[1] == 19)
}
fn is_ipv4_cgn(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 100 && (octets[1] & 0xC0) == 64
}
fn is_ipv4_protocol_assignments(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 192 && octets[1] == 0 && octets[2] == 0
}
fn is_ipv4_6to4_relay(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 192 && octets[1] == 88 && octets[2] == 99
}
fn is_ipv4_reserved(ip: Ipv4Addr) -> bool {
ip.octets()[0] >= 240
}
#[must_use]
pub fn is_denied_ipv6(ip: Ipv6Addr) -> bool {
is_denied_ipv6_with_policy(ip, &SsrfPolicy::deny_all())
}
#[must_use]
pub fn is_denied_ipv6_with_policy(ip: Ipv6Addr, policy: &SsrfPolicy) -> bool {
if ip.is_loopback() || ip.is_unspecified() || ip.is_multicast() {
return true;
}
if let Some(ipv4) = extract_mapped_ipv4(&ip) {
return is_denied_ipv4_with_policy(ipv4, policy);
}
let segments = ip.segments();
if segments[0] & 0xffc0 == 0xfe80 {
return true;
}
if segments[0] & 0xffc0 == 0xfec0 {
return true;
}
if segments[0] & 0xfe00 == 0xfc00 {
return !policy.permits_relaxable(IpAddr::V6(ip));
}
if segments[0] == 0x2001 && segments[1] == 0x0db8 {
return true;
}
if segments[0] == 0x0100 && segments[1..4] == [0, 0, 0] {
return true;
}
if segments[0] == 0x2001 && segments[1] == 0x0000 {
return true;
}
if segments[0] == 0x2001 && (segments[1] & 0xfff0) == 0x0020 {
return true;
}
if segments[0] == 0x2002 {
return true;
}
if segments[0] == 0x0064 && segments[1] == 0xff9b && segments[2..6] == [0, 0, 0, 0] {
let embedded = Ipv4Addr::new(
(segments[6] >> 8) as u8,
(segments[6] & 0xff) as u8,
(segments[7] >> 8) as u8,
(segments[7] & 0xff) as u8,
);
return is_denied_ipv4_with_policy(embedded, policy);
}
if segments[0] == 0x0064 && segments[1] == 0xff9b && segments[2] == 0x0001 {
return true;
}
false
}
pub fn extract_mapped_ipv4(ip: &Ipv6Addr) -> Option<Ipv4Addr> {
let segments = ip.segments();
if segments[0..5] == [0, 0, 0, 0, 0] && segments[5] == 0xffff {
let high = segments[6];
let low = segments[7];
return Some(Ipv4Addr::new(
(high >> 8) as u8,
(high & 0xff) as u8,
(low >> 8) as u8,
(low & 0xff) as u8,
));
}
if segments[0..6] == [0, 0, 0, 0, 0, 0] && (segments[6] != 0 || segments[7] > 1) {
let high = segments[6];
let low = segments[7];
return Some(Ipv4Addr::new(
(high >> 8) as u8,
(high & 0xff) as u8,
(low >> 8) as u8,
(low & 0xff) as u8,
));
}
None
}
pub fn validate_ip(ip: IpAddr) -> Result<(), NabError> {
validate_ip_with_policy(ip, &SsrfPolicy::deny_all())
}
pub fn validate_ip_with_policy(ip: IpAddr, policy: &SsrfPolicy) -> Result<(), NabError> {
let denied = match ip {
IpAddr::V4(v4) => is_denied_ipv4_with_policy(v4, policy),
IpAddr::V6(v6) => is_denied_ipv6_with_policy(v6, policy),
};
if denied {
return Err(NabError::SsrfBlocked(format!(
"IP address {ip} is in a denied range"
)));
}
if policy.is_relaxed() && is_relaxable(ip) {
warn!(
allowed_ip = %ip,
"SSRF: allowing private/internal address via NAB_SSRF opt-out (loopback/metadata stay blocked)"
);
}
Ok(())
}
pub fn resolve_and_validate(host: &str, port: u16) -> Result<SocketAddr, NabError> {
resolve_and_validate_with_policy(host, port, &SsrfPolicy::deny_all())
}
pub fn resolve_and_validate_with_policy(
host: &str,
port: u16,
policy: &SsrfPolicy,
) -> Result<SocketAddr, NabError> {
let addr_str = format!("{host}:{port}");
let addrs: Vec<SocketAddr> = addr_str
.to_socket_addrs()
.map_err(|e| NabError::SsrfBlocked(format!("DNS resolution failed for {host}: {e}")))?
.collect();
if addrs.is_empty() {
return Err(NabError::SsrfBlocked(format!(
"DNS resolution returned no addresses for {host}"
)));
}
for addr in &addrs {
match validate_ip_with_policy(addr.ip(), policy) {
Ok(()) => return Ok(*addr),
Err(e) => {
warn!("SSRF: skipping {addr} for {host}: {e}");
}
}
}
Err(NabError::SsrfBlocked(format!(
"all resolved addresses for {host} are in denied ranges: {addrs:?}"
)))
}
pub fn validate_url(url: &Url) -> Result<SocketAddr, NabError> {
validate_url_with_policy(url, &SsrfPolicy::deny_all())
}
pub fn validate_url_with_policy(url: &Url, policy: &SsrfPolicy) -> Result<SocketAddr, NabError> {
let host = url
.host_str()
.ok_or_else(|| NabError::InvalidUrl(format!("URL has no host: {url}")))?;
let port = url.port_or_known_default().unwrap_or(443);
if let Ok(ip) = host.parse::<IpAddr>() {
validate_ip_with_policy(ip, policy)?;
return Ok(SocketAddr::new(ip, port));
}
let stripped = host.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = stripped.parse::<IpAddr>() {
validate_ip_with_policy(ip, policy)?;
return Ok(SocketAddr::new(ip, port));
}
resolve_and_validate_with_policy(host, port, policy)
}
pub fn validate_redirect_target(url: &Url) -> Result<(), NabError> {
validate_redirect_target_with_policy(url, &SsrfPolicy::deny_all())
}
pub fn validate_redirect_target_with_policy(
url: &Url,
policy: &SsrfPolicy,
) -> Result<(), NabError> {
match url.scheme() {
"http" | "https" => {}
scheme => {
return Err(NabError::SsrfBlocked(format!(
"disallowed redirect scheme '{scheme}'"
)));
}
}
validate_url_with_policy(url, policy).map(|_| ())
}