use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use tracing::warn;
use url::Url;
use crate::error::NabError;
pub const DEFAULT_MAX_REDIRECTS: u32 = 5;
pub const DEFAULT_MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
pub fn is_denied_ipv4(ip: Ipv4Addr) -> bool {
ip.is_loopback()
|| ip.is_private()
|| ip.is_link_local()
|| ip.is_broadcast()
|| ip.is_unspecified()
|| ip.is_multicast()
|| is_ipv4_documentation(ip)
|| is_ipv4_benchmarking(ip)
|| is_ipv4_cgn(ip)
|| is_ipv4_protocol_assignments(ip)
|| is_ipv4_6to4_relay(ip)
|| is_ipv4_reserved(ip)
}
fn is_ipv4_documentation(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
matches!(
(octets[0], octets[1], octets[2]),
(192, 0, 2) | (198, 51, 100) | (203, 0, 113)
)
}
fn is_ipv4_benchmarking(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 198 && (octets[1] == 18 || octets[1] == 19)
}
fn is_ipv4_cgn(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 100 && (octets[1] & 0xC0) == 64
}
fn is_ipv4_protocol_assignments(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 192 && octets[1] == 0 && octets[2] == 0
}
fn is_ipv4_6to4_relay(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
octets[0] == 192 && octets[1] == 88 && octets[2] == 99
}
fn is_ipv4_reserved(ip: Ipv4Addr) -> bool {
ip.octets()[0] >= 240
}
pub fn is_denied_ipv6(ip: Ipv6Addr) -> bool {
if ip.is_loopback() || ip.is_unspecified() || ip.is_multicast() {
return true;
}
if let Some(ipv4) = extract_mapped_ipv4(&ip) {
return is_denied_ipv4(ipv4);
}
let segments = ip.segments();
if segments[0] & 0xffc0 == 0xfe80 {
return true;
}
if segments[0] & 0xffc0 == 0xfec0 {
return true;
}
if segments[0] & 0xfe00 == 0xfc00 {
return true;
}
if segments[0] == 0x2001 && segments[1] == 0x0db8 {
return true;
}
if segments[0] == 0x0100 && segments[1..4] == [0, 0, 0] {
return true;
}
if segments[0] == 0x2001 && segments[1] == 0x0000 {
return true;
}
if segments[0] == 0x2001 && (segments[1] & 0xfff0) == 0x0020 {
return true;
}
if segments[0] == 0x2002 {
return true;
}
if segments[0] == 0x0064 && segments[1] == 0xff9b && segments[2..6] == [0, 0, 0, 0] {
let embedded = Ipv4Addr::new(
(segments[6] >> 8) as u8,
(segments[6] & 0xff) as u8,
(segments[7] >> 8) as u8,
(segments[7] & 0xff) as u8,
);
return is_denied_ipv4(embedded);
}
if segments[0] == 0x0064 && segments[1] == 0xff9b && segments[2] == 0x0001 {
return true;
}
false
}
pub fn extract_mapped_ipv4(ip: &Ipv6Addr) -> Option<Ipv4Addr> {
let segments = ip.segments();
if segments[0..5] == [0, 0, 0, 0, 0] && segments[5] == 0xffff {
let high = segments[6];
let low = segments[7];
return Some(Ipv4Addr::new(
(high >> 8) as u8,
(high & 0xff) as u8,
(low >> 8) as u8,
(low & 0xff) as u8,
));
}
if segments[0..6] == [0, 0, 0, 0, 0, 0] && (segments[6] != 0 || segments[7] > 1) {
let high = segments[6];
let low = segments[7];
return Some(Ipv4Addr::new(
(high >> 8) as u8,
(high & 0xff) as u8,
(low >> 8) as u8,
(low & 0xff) as u8,
));
}
None
}
pub fn validate_ip(ip: IpAddr) -> Result<(), NabError> {
match ip {
IpAddr::V4(v4) => {
if is_denied_ipv4(v4) {
return Err(NabError::SsrfBlocked(format!(
"IPv4 address {v4} is in a denied range"
)));
}
}
IpAddr::V6(v6) => {
if is_denied_ipv6(v6) {
return Err(NabError::SsrfBlocked(format!(
"IPv6 address {v6} is in a denied range"
)));
}
}
}
Ok(())
}
pub fn resolve_and_validate(host: &str, port: u16) -> Result<SocketAddr, NabError> {
let addr_str = format!("{host}:{port}");
let addrs: Vec<SocketAddr> = addr_str
.to_socket_addrs()
.map_err(|e| NabError::SsrfBlocked(format!("DNS resolution failed for {host}: {e}")))?
.collect();
if addrs.is_empty() {
return Err(NabError::SsrfBlocked(format!(
"DNS resolution returned no addresses for {host}"
)));
}
for addr in &addrs {
match validate_ip(addr.ip()) {
Ok(()) => return Ok(*addr),
Err(e) => {
warn!("SSRF: skipping {addr} for {host}: {e}");
}
}
}
Err(NabError::SsrfBlocked(format!(
"all resolved addresses for {host} are in denied ranges: {addrs:?}"
)))
}
pub fn validate_url(url: &Url) -> Result<SocketAddr, NabError> {
let host = url
.host_str()
.ok_or_else(|| NabError::InvalidUrl(format!("URL has no host: {url}")))?;
let port = url.port_or_known_default().unwrap_or(443);
if let Ok(ip) = host.parse::<IpAddr>() {
validate_ip(ip)?;
return Ok(SocketAddr::new(ip, port));
}
let stripped = host.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = stripped.parse::<IpAddr>() {
validate_ip(ip)?;
return Ok(SocketAddr::new(ip, port));
}
resolve_and_validate(host, port)
}
pub fn validate_redirect_target(url: &Url) -> Result<(), NabError> {
match url.scheme() {
"http" | "https" => {}
scheme => {
return Err(NabError::SsrfBlocked(format!(
"disallowed redirect scheme '{scheme}'"
)));
}
}
validate_url(url).map(|_| ())
}