use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::time::Duration;
use thiserror::Error;
use tracing::{debug, warn};
use url::Url;
#[derive(Debug, Clone, Error)]
pub enum SsrfError {
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("URL scheme not allowed: {0} (only https is permitted)")]
InvalidScheme(String),
#[error("IP address blocked: {0} ({1})")]
BlockedIpAddress(IpAddr, String),
#[error("Failed to resolve hostname: {0}")]
ResolutionFailed(String),
#[error("Multiple IP addresses resolved for hostname (potential DNS rebinding): {0}")]
MultipleIpAddresses(String),
#[error("Response size limit exceeded: {0} bytes (max: {1} bytes)")]
ResponseSizeLimitExceeded(usize, usize),
#[error("Request timeout after {0:?}")]
Timeout(Duration),
#[error("Access to cloud metadata endpoint blocked: {0}")]
CloudMetadataBlocked(IpAddr),
#[error("Rate limit exceeded for URL: {0}")]
RateLimitExceeded(String),
}
#[derive(Debug, Clone)]
pub struct SsrfPolicy {
pub allow_private_networks: bool,
pub allow_localhost: bool,
pub allow_link_local: bool,
pub allow_cloud_metadata: bool,
pub max_response_size: usize,
pub request_timeout: Duration,
pub require_https: bool,
pub allow_redirects: bool,
pub max_redirects: u32,
pub ip_allowlist: Option<Vec<IpAddr>>,
pub ip_denylist: Vec<IpAddr>,
pub hostname_allowlist: Option<Vec<String>>,
}
impl Default for SsrfPolicy {
fn default() -> Self {
Self {
allow_private_networks: false,
allow_localhost: false,
allow_link_local: false,
allow_cloud_metadata: false,
max_response_size: 5 * 1024, request_timeout: Duration::from_secs(5),
require_https: true,
allow_redirects: false, max_redirects: 0,
ip_allowlist: None,
ip_denylist: vec![
IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)),
IpAddr::V4(Ipv4Addr::LOCALHOST),
IpAddr::V6(Ipv6Addr::LOCALHOST),
],
hostname_allowlist: None,
}
}
}
impl SsrfPolicy {
pub fn builder() -> SsrfPolicyBuilder {
SsrfPolicyBuilder::default()
}
#[cfg(test)]
pub fn permissive() -> Self {
Self {
allow_private_networks: true,
allow_localhost: true,
allow_link_local: true,
allow_cloud_metadata: false, max_response_size: 1024 * 1024, request_timeout: Duration::from_secs(30),
require_https: false,
allow_redirects: true,
max_redirects: 5,
ip_allowlist: None,
ip_denylist: vec![],
hostname_allowlist: None,
}
}
}
#[derive(Debug, Default)]
pub struct SsrfPolicyBuilder {
allow_private_networks: Option<bool>,
allow_localhost: Option<bool>,
allow_link_local: Option<bool>,
allow_cloud_metadata: Option<bool>,
max_response_size: Option<usize>,
request_timeout: Option<Duration>,
require_https: Option<bool>,
allow_redirects: Option<bool>,
max_redirects: Option<u32>,
ip_allowlist: Option<Option<Vec<IpAddr>>>,
ip_denylist: Option<Vec<IpAddr>>,
hostname_allowlist: Option<Option<Vec<String>>>,
}
impl SsrfPolicyBuilder {
pub fn allow_private_networks(mut self, allow: bool) -> Self {
self.allow_private_networks = Some(allow);
self
}
pub fn allow_localhost(mut self, allow: bool) -> Self {
self.allow_localhost = Some(allow);
self
}
pub fn allow_link_local(mut self, allow: bool) -> Self {
self.allow_link_local = Some(allow);
self
}
pub fn allow_cloud_metadata(mut self, allow: bool) -> Self {
self.allow_cloud_metadata = Some(allow);
self
}
pub fn max_response_size(mut self, size: usize) -> Self {
self.max_response_size = Some(size);
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = Some(timeout);
self
}
pub fn require_https(mut self, require: bool) -> Self {
self.require_https = Some(require);
self
}
pub fn allow_redirects(mut self, allow: bool) -> Self {
self.allow_redirects = Some(allow);
self
}
pub fn max_redirects(mut self, max: u32) -> Self {
self.max_redirects = Some(max);
self
}
pub fn ip_allowlist(mut self, ips: Vec<IpAddr>) -> Self {
self.ip_allowlist = Some(Some(ips));
self
}
pub fn ip_denylist(mut self, ips: Vec<IpAddr>) -> Self {
self.ip_denylist = Some(ips);
self
}
pub fn hostname_allowlist(mut self, hostnames: Vec<String>) -> Self {
self.hostname_allowlist = Some(Some(hostnames));
self
}
pub fn build(self) -> SsrfPolicy {
let default = SsrfPolicy::default();
SsrfPolicy {
allow_private_networks: self
.allow_private_networks
.unwrap_or(default.allow_private_networks),
allow_localhost: self.allow_localhost.unwrap_or(default.allow_localhost),
allow_link_local: self.allow_link_local.unwrap_or(default.allow_link_local),
allow_cloud_metadata: self
.allow_cloud_metadata
.unwrap_or(default.allow_cloud_metadata),
max_response_size: self.max_response_size.unwrap_or(default.max_response_size),
request_timeout: self.request_timeout.unwrap_or(default.request_timeout),
require_https: self.require_https.unwrap_or(default.require_https),
allow_redirects: self.allow_redirects.unwrap_or(default.allow_redirects),
max_redirects: self.max_redirects.unwrap_or(default.max_redirects),
ip_allowlist: self.ip_allowlist.unwrap_or(default.ip_allowlist),
ip_denylist: self.ip_denylist.unwrap_or(default.ip_denylist),
hostname_allowlist: self
.hostname_allowlist
.unwrap_or(default.hostname_allowlist),
}
}
}
#[derive(Debug, Clone)]
pub struct SsrfValidator {
policy: SsrfPolicy,
}
impl Default for SsrfValidator {
fn default() -> Self {
Self::new(SsrfPolicy::default())
}
}
impl SsrfValidator {
pub fn new(policy: SsrfPolicy) -> Self {
Self { policy }
}
pub fn validate_url(&self, url_str: &str) -> Result<(), SsrfError> {
let url = Url::parse(url_str)
.map_err(|e| SsrfError::InvalidUrl(format!("Failed to parse URL: {}", e)))?;
if self.policy.require_https && url.scheme() != "https" {
return Err(SsrfError::InvalidScheme(url.scheme().to_string()));
}
if let Some(ref allowlist) = self.policy.hostname_allowlist
&& let Some(host) = url.host_str()
&& !allowlist.iter().any(|allowed| host == allowed)
{
debug!("Hostname not in allowlist: {}", host);
return Err(SsrfError::InvalidUrl(format!(
"Hostname not in allowlist: {}",
host
)));
}
if let Some(host) = url.host_str() {
self.validate_hostname(host)?;
} else {
return Err(SsrfError::InvalidUrl("URL has no host".to_string()));
}
Ok(())
}
fn validate_hostname(&self, hostname: &str) -> Result<(), SsrfError> {
let addr_str = format!("{}:443", hostname); let addrs: Vec<_> = addr_str
.to_socket_addrs()
.map_err(|e| SsrfError::ResolutionFailed(format!("{}: {}", hostname, e)))?
.collect();
if addrs.is_empty() {
return Err(SsrfError::ResolutionFailed(format!(
"No IP addresses resolved for: {}",
hostname
)));
}
if addrs.len() > 1 {
warn!(
"Multiple IP addresses resolved for hostname (potential DNS rebinding): {} -> {:?}",
hostname, addrs
);
}
for socket_addr in addrs {
let ip = socket_addr.ip();
self.validate_ip_address(&ip)?;
}
Ok(())
}
pub fn validate_ip_address(&self, ip: &IpAddr) -> Result<(), SsrfError> {
if let Some(ref allowlist) = self.policy.ip_allowlist {
if !allowlist.contains(ip) {
debug!("IP not in allowlist: {}", ip);
return Err(SsrfError::BlockedIpAddress(
*ip,
"IP not in allowlist".to_string(),
));
}
return Ok(());
}
if !self.policy.allow_cloud_metadata
&& let IpAddr::V4(ipv4) = ip
&& *ipv4 == Ipv4Addr::new(169, 254, 169, 254)
{
warn!("Cloud metadata endpoint access attempt: {}", ip);
return Err(SsrfError::CloudMetadataBlocked(*ip));
}
if self.policy.ip_denylist.contains(ip) {
warn!("IP in denylist: {}", ip);
return Err(SsrfError::BlockedIpAddress(
*ip,
"IP in denylist".to_string(),
));
}
match ip {
IpAddr::V4(ipv4) => self.validate_ipv4(ipv4)?,
IpAddr::V6(ipv6) => self.validate_ipv6(ipv6)?,
}
Ok(())
}
fn validate_ipv4(&self, ip: &Ipv4Addr) -> Result<(), SsrfError> {
if !self.policy.allow_private_networks && ip.is_private() {
debug!("Private network access blocked: {}", ip);
return Err(SsrfError::BlockedIpAddress(
IpAddr::V4(*ip),
"Private network (RFC 1918)".to_string(),
));
}
if !self.policy.allow_localhost && ip.is_loopback() {
debug!("Localhost access blocked: {}", ip);
return Err(SsrfError::BlockedIpAddress(
IpAddr::V4(*ip),
"Localhost".to_string(),
));
}
if !self.policy.allow_link_local && ip.is_link_local() {
debug!("Link-local access blocked: {}", ip);
return Err(SsrfError::BlockedIpAddress(
IpAddr::V4(*ip),
"Link-local".to_string(),
));
}
if ip.is_unspecified() {
return Err(SsrfError::BlockedIpAddress(
IpAddr::V4(*ip),
"Unspecified address (0.0.0.0)".to_string(),
));
}
if ip.is_broadcast() {
return Err(SsrfError::BlockedIpAddress(
IpAddr::V4(*ip),
"Broadcast address".to_string(),
));
}
if ip.is_documentation() {
return Err(SsrfError::BlockedIpAddress(
IpAddr::V4(*ip),
"Documentation address range".to_string(),
));
}
Ok(())
}
fn validate_ipv6(&self, ip: &Ipv6Addr) -> Result<(), SsrfError> {
if !self.policy.allow_localhost && ip.is_loopback() {
debug!("Localhost access blocked: {}", ip);
return Err(SsrfError::BlockedIpAddress(
IpAddr::V6(*ip),
"Localhost (::1)".to_string(),
));
}
if ip.is_unspecified() {
return Err(SsrfError::BlockedIpAddress(
IpAddr::V6(*ip),
"Unspecified address (::)".to_string(),
));
}
if !self.policy.allow_private_networks {
if ip.segments()[0] & 0xfe00 == 0xfc00 {
debug!("Private network access blocked: {}", ip);
return Err(SsrfError::BlockedIpAddress(
IpAddr::V6(*ip),
"Unique local address (fc00::/7)".to_string(),
));
}
}
if !self.policy.allow_link_local && (ip.segments()[0] & 0xffc0 == 0xfe80) {
debug!("Link-local access blocked: {}", ip);
return Err(SsrfError::BlockedIpAddress(
IpAddr::V6(*ip),
"Link-local (fe80::/10)".to_string(),
));
}
Ok(())
}
pub fn policy(&self) -> &SsrfPolicy {
&self.policy
}
pub fn create_pinned_client(
&self,
url_str: &str,
) -> Result<(reqwest::Client, String), SsrfError> {
let url = Url::parse(url_str)
.map_err(|e| SsrfError::InvalidUrl(format!("Failed to parse URL: {}", e)))?;
let hostname = url
.host_str()
.ok_or_else(|| SsrfError::InvalidUrl("URL has no host".to_string()))?;
let port = url
.port()
.unwrap_or_else(|| if url.scheme() == "https" { 443 } else { 80 });
let addr_str = format!("{}:{}", hostname, port);
let addrs: Vec<SocketAddr> = addr_str
.to_socket_addrs()
.map_err(|e| SsrfError::ResolutionFailed(format!("{}: {}", hostname, e)))?
.collect();
if addrs.is_empty() {
return Err(SsrfError::ResolutionFailed(format!(
"No IP addresses resolved for: {}",
hostname
)));
}
for socket_addr in &addrs {
let ip = socket_addr.ip();
self.validate_ip_address(&ip)?;
}
let mut client_builder = reqwest::Client::builder().timeout(self.policy.request_timeout);
if let Some(first_addr) = addrs.first() {
debug!(
hostname = hostname,
resolved_ip = %first_addr.ip(),
"Pinning DNS resolution to validated IP"
);
client_builder = client_builder.resolve(hostname, *first_addr);
}
if !self.policy.allow_redirects {
client_builder = client_builder.redirect(reqwest::redirect::Policy::none());
} else {
client_builder = client_builder.redirect(reqwest::redirect::Policy::limited(
self.policy.max_redirects as usize,
));
}
let client = client_builder
.build()
.map_err(|e| SsrfError::InvalidUrl(format!("Failed to create HTTP client: {}", e)))?;
Ok((client, url_str.to_string()))
}
pub async fn fetch(&self, url: &str) -> Result<Vec<u8>, SsrfError> {
self.validate_url(url)?;
let (client, final_url) = self.create_pinned_client(url)?;
let response = client.get(&final_url).send().await.map_err(|e| {
if e.is_timeout() {
SsrfError::Timeout(self.policy.request_timeout)
} else {
SsrfError::InvalidUrl(format!("HTTP request failed: {}", e))
}
})?;
let content_length = response.content_length().unwrap_or(0) as usize;
if content_length > self.policy.max_response_size {
return Err(SsrfError::ResponseSizeLimitExceeded(
content_length,
self.policy.max_response_size,
));
}
let bytes = response
.bytes()
.await
.map_err(|e| SsrfError::InvalidUrl(format!("Failed to read response: {}", e)))?;
if bytes.len() > self.policy.max_response_size {
return Err(SsrfError::ResponseSizeLimitExceeded(
bytes.len(),
self.policy.max_response_size,
));
}
Ok(bytes.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy_blocks_private_networks() {
let validator = SsrfValidator::default();
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
.is_err()
);
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)))
.is_err()
);
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))
.is_err()
);
}
#[test]
fn test_default_policy_blocks_localhost() {
let validator = SsrfValidator::default();
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
.is_err()
);
assert!(
validator
.validate_ip_address(&IpAddr::V6(Ipv6Addr::LOCALHOST))
.is_err()
);
}
#[test]
fn test_default_policy_blocks_cloud_metadata() {
let validator = SsrfValidator::default();
assert!(matches!(
validator.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))),
Err(SsrfError::CloudMetadataBlocked(_))
));
}
#[test]
fn test_default_policy_allows_public_ip() {
let validator = SsrfValidator::default();
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))
.is_ok()
);
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)))
.is_ok()
);
}
#[test]
fn test_url_validation_requires_https() {
let validator = SsrfValidator::default();
assert!(matches!(
validator.validate_url("http://example.com"),
Err(SsrfError::InvalidScheme(_))
));
}
#[test]
fn test_custom_policy_builder() {
let policy = SsrfPolicy::builder()
.allow_private_networks(true)
.allow_localhost(false)
.max_response_size(10 * 1024)
.build();
let validator = SsrfValidator::new(policy);
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))
.is_ok()
);
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::LOCALHOST))
.is_err()
);
}
#[test]
fn test_ip_allowlist() {
let policy = SsrfPolicy::builder()
.ip_allowlist(vec![IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100))])
.build();
let validator = SsrfValidator::new(policy);
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)))
.is_ok()
);
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101)))
.is_err()
);
}
#[test]
fn test_ipv6_unique_local_blocked() {
let validator = SsrfValidator::default();
let ipv6 = Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1);
assert!(validator.validate_ip_address(&IpAddr::V6(ipv6)).is_err());
}
#[test]
fn test_link_local_blocked() {
let validator = SsrfValidator::default();
assert!(
validator
.validate_ip_address(&IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1)))
.is_err()
);
let ipv6 = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1);
assert!(validator.validate_ip_address(&IpAddr::V6(ipv6)).is_err());
}
}