use crate::cidr::CIDRBlock;
use crate::error::AntiSSRFError;
use crate::ip_address_ranges;
use std::net::IpAddr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PolicyConfigOptions {
None,
InternalOnly,
ExternalOnlyV1,
ExternalOnlyLatest,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AntiSSRFPolicy {
allowed_addresses: Vec<CIDRBlock>,
denied_addresses: Vec<CIDRBlock>,
deny_all_unspecified_ips: bool,
required_headers: Vec<String>,
denied_headers: Vec<String>,
add_xff_header: bool,
allow_plaintext_http: bool,
locked: bool,
}
impl AntiSSRFPolicy {
pub fn new(config: PolicyConfigOptions) -> Self {
let mut policy = Self {
allowed_addresses: Vec::new(),
denied_addresses: Vec::new(),
deny_all_unspecified_ips: false,
required_headers: Vec::new(),
denied_headers: Vec::new(),
add_xff_header: false,
allow_plaintext_http: false,
locked: false,
};
match config {
PolicyConfigOptions::None => {}
PolicyConfigOptions::InternalOnly => {
policy.deny_all_unspecified_ips = true;
}
PolicyConfigOptions::ExternalOnlyV1 | PolicyConfigOptions::ExternalOnlyLatest => {
policy.add_denied_addresses_from_slice(ip_address_ranges::RECOMMENDEDV1);
policy.add_xff_header = true;
}
}
policy
}
pub fn allowed_addresses(&self) -> &[CIDRBlock] {
&self.allowed_addresses
}
pub fn denied_addresses(&self) -> &[CIDRBlock] {
&self.denied_addresses
}
pub fn deny_all_unspecified_ips(&self) -> bool {
self.deny_all_unspecified_ips
}
pub fn required_headers(&self) -> &[String] {
&self.required_headers
}
pub fn denied_headers(&self) -> &[String] {
&self.denied_headers
}
pub fn add_xff_header(&self) -> bool {
self.add_xff_header
}
pub fn allow_plaintext_http(&self) -> bool {
self.allow_plaintext_http
}
pub fn is_locked(&self) -> bool {
self.locked
}
fn lock(&mut self) {
self.locked = true;
}
fn assert_not_locked(&self) -> Result<(), AntiSSRFError> {
if self.locked {
Err(AntiSSRFError::PolicyLocked)
} else {
Ok(())
}
}
pub fn add_allowed_addresses(&mut self, addresses: &[&str]) -> Result<(), AntiSSRFError> {
self.assert_not_locked()?;
for addr in addresses {
let normalized = Self::normalize_address(addr);
let block = CIDRBlock::parse(&normalized)?;
self.allowed_addresses.push(block);
}
Ok(())
}
pub fn add_denied_addresses(&mut self, addresses: &[&str]) -> Result<(), AntiSSRFError> {
self.assert_not_locked()?;
if self.deny_all_unspecified_ips {
return Err(AntiSSRFError::ConflictingConfiguration);
}
for addr in addresses {
let normalized = Self::normalize_address(addr);
let block = CIDRBlock::parse(&normalized)?;
self.denied_addresses.push(block);
}
Ok(())
}
pub fn is_network_connection_allowed(
&mut self,
ipaddresses: &[&str],
) -> Result<bool, AntiSSRFError> {
self.lock();
for ip_str in ipaddresses {
let ip: IpAddr = ip_str
.parse()
.map_err(|_| AntiSSRFError::InvalidIP(ip_str.to_string()))?;
if self.allowed_addresses.iter().any(|a| a.contains(ip)) {
continue;
}
if self.deny_all_unspecified_ips || self.denied_addresses.iter().any(|d| d.contains(ip)) {
return Ok(false);
}
}
Ok(true)
}
pub fn set_deny_all_unspecified_ips(&mut self, value: bool) -> Result<(), AntiSSRFError> {
self.assert_not_locked()?;
self.deny_all_unspecified_ips = value;
Ok(())
}
pub fn add_required_headers(&mut self, headers: &[&str]) -> Result<(), AntiSSRFError> {
self.assert_not_locked()?;
for header in headers {
let h = header.trim().to_ascii_lowercase();
if h.is_empty() {
return Err(AntiSSRFError::InvalidHeader);
}
self.required_headers.push(h);
}
Ok(())
}
pub fn add_denied_headers(&mut self, headers: &[&str]) -> Result<(), AntiSSRFError> {
self.assert_not_locked()?;
for header in headers {
let h = header.trim().to_ascii_lowercase();
if h.is_empty() {
return Err(AntiSSRFError::InvalidHeader);
}
self.denied_headers.push(h);
}
Ok(())
}
pub fn set_add_xff_header(&mut self, value: bool) -> Result<(), AntiSSRFError> {
self.assert_not_locked()?;
self.add_xff_header = value;
Ok(())
}
pub fn set_allow_plaintext_http(&mut self, value: bool) -> Result<(), AntiSSRFError> {
self.assert_not_locked()?;
self.allow_plaintext_http = value;
Ok(())
}
pub fn validate_request(
&mut self,
protocol: &str,
headers: &mut Vec<(String, String)>,
) -> Result<bool, AntiSSRFError> {
self.lock();
if protocol != "https:" {
if protocol == "http:" && self.allow_plaintext_http {
} else {
return Err(AntiSSRFError::SchemeDisallowed);
}
}
if protocol != "http:" && protocol != "https:" {
return Err(AntiSSRFError::SchemeDisallowed);
}
for denied in &self.denied_headers {
if headers.iter().any(|(k, _)| k.eq_ignore_ascii_case(denied)) {
return Err(AntiSSRFError::HeaderDenied);
}
}
for required in &self.required_headers {
if !headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case(required))
{
return Err(AntiSSRFError::HeaderRequired);
}
}
if self.add_xff_header
&& !headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("x-forwarded-for"))
{
headers.push(("X-Forwarded-For".to_string(), "true".to_string()));
}
Ok(true)
}
pub fn needs_xff_header(&self, headers: &[(String, String)]) -> bool {
self.add_xff_header
&& !headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("x-forwarded-for"))
}
fn add_denied_addresses_from_slice(&mut self, addresses: &[&str]) {
for addr in addresses {
let normalized = Self::normalize_address(addr);
if let Ok(block) = CIDRBlock::parse(&normalized) {
self.denied_addresses.push(block);
}
}
}
fn normalize_address(addr: &str) -> String {
let trimmed = addr.trim();
if !trimmed.contains('/') {
if trimmed.contains(':') {
format!("{}/128", trimmed)
} else {
format!("{}/32", trimmed)
}
} else {
trimmed.to_string()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_none_has_empty_lists() {
let p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
assert!(!p.deny_all_unspecified_ips);
assert!(!p.add_xff_header);
assert!(!p.allow_plaintext_http);
}
#[test]
fn new_internal_only_sets_deny_all() {
let p = AntiSSRFPolicy::new(PolicyConfigOptions::InternalOnly);
assert!(p.deny_all_unspecified_ips);
}
#[test]
fn new_external_v1_populates_denylist() {
let p = AntiSSRFPolicy::new(PolicyConfigOptions::ExternalOnlyV1);
assert!(!p.denied_addresses.is_empty());
assert!(p.add_xff_header);
}
#[test]
fn new_external_latest_populates_denylist() {
let p = AntiSSRFPolicy::new(PolicyConfigOptions::ExternalOnlyLatest);
assert!(!p.denied_addresses.is_empty());
assert!(p.add_xff_header);
}
#[test]
fn add_allowed_addresses_works() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::InternalOnly);
p.add_allowed_addresses(&["10.0.0.0/8"]).unwrap();
assert_eq!(p.allowed_addresses.len(), 1);
assert!(p.is_network_connection_allowed(&["10.0.0.0"]).unwrap());
assert!(p.is_network_connection_allowed(&["10.0.0.1"]).unwrap());
assert!(p.is_network_connection_allowed(&["10.1.2.3"]).unwrap());
assert!(p.is_network_connection_allowed(&["10.255.255.255"]).unwrap());
assert!(!p.is_network_connection_allowed(&["11.0.0.0"]).unwrap());
}
#[test]
fn add_allowed_addresses_normalizes_single_ip() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::InternalOnly);
p.add_allowed_addresses(&["10.0.0.1"]).unwrap();
assert_eq!(p.allowed_addresses.len(), 1);
assert!(p.is_network_connection_allowed(&["10.0.0.1"]).unwrap());
assert!(!p.is_network_connection_allowed(&["10.0.0.2"]).unwrap());
}
#[test]
fn add_allowed_addresses_rejects_invalid() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
let result = p.add_allowed_addresses(&["not-an-ip"]);
assert!(matches!(result, Err(AntiSSRFError::InvalidCIDR(_))));
}
#[test]
fn add_denied_addresses_works() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_denied_addresses(&["169.254.169.254/32"]).unwrap();
assert!(p.is_network_connection_allowed(&["169.254.169.253"]).unwrap());
assert!(!p.is_network_connection_allowed(&["169.254.169.254"]).unwrap());
assert_eq!(p.denied_addresses.len(), 1);
}
#[test]
fn add_denied_addresses_fails_when_deny_all_set() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::InternalOnly);
let result = p.add_denied_addresses(&["10.0.0.0/8"]);
assert!(matches!(result, Err(AntiSSRFError::ConflictingConfiguration)));
}
#[test]
fn is_network_connection_allowed_basic() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
assert!(p.is_network_connection_allowed(&["8.8.8.8"]).unwrap());
}
#[test]
fn is_network_connection_allowed_blocks_denylist() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_denied_addresses(&["169.254.169.254/32"]).unwrap();
assert!(
!p.is_network_connection_allowed(&["169.254.169.254"])
.unwrap()
);
}
#[test]
fn is_network_connection_allowed_allows_allowlist() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_denied_addresses(&["10.0.0.0/8"]).unwrap();
p.add_allowed_addresses(&["10.0.0.1/32"]).unwrap();
assert!(p.is_network_connection_allowed(&["10.0.0.1"]).unwrap());
assert!(!p.is_network_connection_allowed(&["10.0.0.2"]).unwrap());
}
#[test]
fn is_network_connection_allowed_deny_all_unspecified() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::InternalOnly);
assert!(!p.is_network_connection_allowed(&["8.8.8.8"]).unwrap());
}
#[test]
fn is_network_connection_allowed_deny_all_with_allowlist() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::InternalOnly);
p.add_allowed_addresses(&["8.8.8.8/32"]).unwrap();
assert!(p.is_network_connection_allowed(&["8.8.8.8"]).unwrap());
assert!(!p.is_network_connection_allowed(&["8.8.8.9"]).unwrap());
}
#[test]
fn is_network_connection_allowed_mixed_allowlist_and_denylist() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_denied_addresses(&["10.0.0.0/8"]).unwrap();
p.add_allowed_addresses(&["127.0.0.1/32"]).unwrap();
assert!(
!p.is_network_connection_allowed(&["127.0.0.1", "10.0.0.1"])
.unwrap()
);
}
#[test]
fn is_network_connection_allowed_invalid_ip() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
let result = p.is_network_connection_allowed(&["not-an-ip"]);
assert!(matches!(result, Err(AntiSSRFError::InvalidIP(_))));
}
#[test]
fn is_network_connection_allowed_locks_policy() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.is_network_connection_allowed(&["8.8.8.8"]).unwrap();
assert!(matches!(p.add_allowed_addresses(&["10.0.0.0/8"]), Err(AntiSSRFError::PolicyLocked)));
}
#[test]
fn ipv6_single_ip_normalized() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_allowed_addresses(&["::1"]).unwrap();
assert_eq!(p.allowed_addresses.len(), 1);
assert_eq!(p.allowed_addresses[0].to_string(), "::1/128");
}
#[test]
fn external_v1_blocks_imds() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::ExternalOnlyV1);
assert!(
!p.is_network_connection_allowed(&["169.254.169.254"])
.unwrap()
);
assert!(
!p.is_network_connection_allowed(&["168.63.129.16"])
.unwrap()
);
}
#[test]
fn external_latest_blocks_imds_and_wireserver() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::ExternalOnlyLatest);
assert!(
!p.is_network_connection_allowed(&["169.254.169.254"])
.unwrap()
);
assert!(
!p.is_network_connection_allowed(&["168.63.129.16"])
.unwrap()
);
}
#[test]
fn external_v1_allows_external_ips() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::ExternalOnlyV1);
assert!(
p.is_network_connection_allowed(&["8.8.8.8"])
.unwrap()
);
assert!(
p.is_network_connection_allowed(&["1.1.1.1"])
.unwrap()
);
}
#[test]
fn external_latest_allows_external_ips() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::ExternalOnlyLatest);
assert!(
p.is_network_connection_allowed(&["8.8.8.8"])
.unwrap()
);
assert!(
p.is_network_connection_allowed(&["1.1.1.1"])
.unwrap()
);
}
#[test]
fn none_with_deny_all_blocks_unless_allowlisted() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_deny_all_unspecified_ips(true).unwrap();
p.add_allowed_addresses(&["8.8.8.8/32"]).unwrap();
assert!(
p.is_network_connection_allowed(&["8.8.8.8"])
.unwrap()
);
assert!(
!p.is_network_connection_allowed(&["1.1.1.1"])
.unwrap()
);
}
#[test]
fn add_required_headers_works() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_required_headers(&["Authorization", "X-Custom"])
.unwrap();
assert_eq!(p.required_headers, vec!["authorization", "x-custom"]);
}
#[test]
fn add_required_headers_rejects_empty() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
let result = p.add_required_headers(&[""]);
assert!(matches!(result, Err(AntiSSRFError::InvalidHeader)));
}
#[test]
fn add_denied_headers_works() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_denied_headers(&["X-Secret"]).unwrap();
assert_eq!(p.denied_headers, vec!["x-secret"]);
}
#[test]
fn validate_request_https_always_allowed() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
assert!(p.validate_request("https:", &mut vec![]).unwrap());
}
#[test]
fn validate_request_http_denied_by_default() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
let result = p.validate_request("http:", &mut vec![]);
assert!(matches!(result, Err(AntiSSRFError::SchemeDisallowed)));
}
#[test]
fn validate_request_http_allowed_when_configured() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_allow_plaintext_http(true).unwrap();
assert!(p.validate_request("http:", &mut vec![]).unwrap());
}
#[test]
fn validate_request_unknown_protocol_denied() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
let result = p.validate_request("ftp:", &mut vec![]);
assert!(matches!(result, Err(AntiSSRFError::SchemeDisallowed)));
}
#[test]
fn validate_request_denied_header_found() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_denied_headers(&["x-secret"]).unwrap();
let mut headers = vec![("X-Secret".to_string(), "value".to_string())];
let result = p.validate_request("https:", &mut headers);
assert!(matches!(result, Err(AntiSSRFError::HeaderDenied)));
}
#[test]
fn validate_request_required_header_missing() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_required_headers(&["authorization"]).unwrap();
let result = p.validate_request("https:", &mut vec![]);
assert!(matches!(result, Err(AntiSSRFError::HeaderRequired)));
}
#[test]
fn validate_request_required_header_present() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_required_headers(&["authorization"]).unwrap();
let mut headers = vec![("Authorization".to_string(), "Bearer token".to_string())];
assert!(p.validate_request("https:", &mut headers).unwrap());
}
#[test]
fn validate_request_locks_policy() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.validate_request("https:", &mut vec![]).unwrap();
assert!(matches!(
p.add_required_headers(&["X-Test"]),
Err(AntiSSRFError::PolicyLocked)
));
}
#[test]
fn is_locked_returns_true_after_validate_request() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
assert!(!p.is_locked());
p.validate_request("https:", &mut vec![]).unwrap();
assert!(p.is_locked());
}
#[test]
fn set_deny_all_unspecified_ips_works() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_deny_all_unspecified_ips(true).unwrap();
assert!(p.deny_all_unspecified_ips());
}
#[test]
fn set_add_xff_header_works() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_add_xff_header(true).unwrap();
assert!(p.add_xff_header());
}
#[test]
fn set_allow_plaintext_http_works() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_allow_plaintext_http(true).unwrap();
assert!(p.allow_plaintext_http());
}
#[test]
fn set_deny_all_unspecified_ips_locked() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::InternalOnly);
p.is_network_connection_allowed(&["8.8.8.8"]).unwrap();
assert!(matches!(
p.set_deny_all_unspecified_ips(false),
Err(AntiSSRFError::PolicyLocked)
));
}
#[test]
fn needs_xff_header_when_enabled_and_missing() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_add_xff_header(true).unwrap();
assert!(p.needs_xff_header(&[]));
}
#[test]
fn needs_xff_header_when_already_present() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_add_xff_header(true).unwrap();
let headers = vec![("X-Forwarded-For".to_string(), "1.2.3.4".to_string())];
assert!(!p.needs_xff_header(&headers));
}
#[test]
fn needs_xff_header_when_disabled() {
let p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
assert!(!p.needs_xff_header(&[]));
}
#[test]
fn needs_xff_header_case_insensitive() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_add_xff_header(true).unwrap();
let headers: Vec<(String, String)> = vec![("x-forwarded-for".to_string(), "1.2.3.4".to_string())];
assert!(!p.needs_xff_header(&headers));
}
#[test]
fn allowed_addresses_getter() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_allowed_addresses(&["10.0.0.0/8", "192.168.1.1"]).unwrap();
let addrs = p.allowed_addresses();
assert_eq!(addrs.len(), 2);
assert_eq!(addrs[0].to_string(), "10.0.0.0/8");
assert_eq!(addrs[1].to_string(), "192.168.1.1/32");
}
#[test]
fn denied_addresses_getter() {
let p = AntiSSRFPolicy::new(PolicyConfigOptions::ExternalOnlyV1);
let addrs = p.denied_addresses();
assert!(!addrs.is_empty());
assert!(addrs.iter().any(|a| a.to_string() == "169.254.0.0/16"));
}
#[test]
fn required_headers_getter() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_required_headers(&["Authorization", "X-Custom"]).unwrap();
let headers = p.required_headers();
assert_eq!(headers, vec!["authorization", "x-custom"]);
}
#[test]
fn denied_headers_getter() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.add_denied_headers(&["X-Secret"]).unwrap();
let headers = p.denied_headers();
assert_eq!(headers, vec!["x-secret"]);
}
#[test]
fn validate_request_injects_xff_when_missing() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_add_xff_header(true).unwrap();
let mut headers = vec![];
let result = p.validate_request("https:", &mut headers).unwrap();
assert!(result);
assert!(headers.iter().any(|(k, _)| k.eq_ignore_ascii_case("x-forwarded-for")));
assert_eq!(headers.len(), 1);
}
#[test]
fn validate_request_does_not_inject_xff_when_present() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_add_xff_header(true).unwrap();
let mut headers = vec![("X-Forwarded-For".to_string(), "1.2.3.4".to_string())];
let result = p.validate_request("https:", &mut headers).unwrap();
assert!(result);
assert_eq!(headers.len(), 1);
}
#[test]
fn validate_request_does_not_inject_xff_when_disabled() {
let mut p = AntiSSRFPolicy::new(PolicyConfigOptions::None);
p.set_add_xff_header(false).unwrap();
let mut headers = vec![];
let result = p.validate_request("https:", &mut headers).unwrap();
assert!(result);
assert!(headers.is_empty());
}
}