#[cfg(feature = "client")]
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use crate::error::AcdpError;
#[cfg(feature = "client")]
use std::sync::Arc;
pub use crate::limits::{MAX_CONTEXT_BYTES, MAX_METADATA_BYTES, MAX_REDIRECTS};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SsrfReason {
NonHttps,
IpLiteral,
InvalidUrl,
Loopback,
Private,
Imds,
MulticastOrReserved,
CrossAuthority,
}
impl SsrfReason {
pub fn as_str(&self) -> &'static str {
match self {
SsrfReason::NonHttps => "non_https",
SsrfReason::IpLiteral => "ip_literal",
SsrfReason::InvalidUrl => "invalid_url",
SsrfReason::Loopback => "loopback",
SsrfReason::Private => "private",
SsrfReason::Imds => "imds",
SsrfReason::MulticastOrReserved => "multicast_or_reserved",
SsrfReason::CrossAuthority => "cross_authority",
}
}
}
impl std::fmt::Display for SsrfReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone)]
pub struct SsrfRejection {
pub reason: SsrfReason,
pub detail: String,
}
impl std::fmt::Display for SsrfRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} [{}]", self.detail, self.reason)
}
}
impl From<SsrfRejection> for AcdpError {
fn from(r: SsrfRejection) -> Self {
AcdpError::SchemaViolation(r.detail)
}
}
#[derive(Debug, Clone)]
pub struct SsrfPolicy {
pub reject_ip_literals: bool,
pub allow_http: bool,
pub allow_loopback_resolved: bool,
}
impl Default for SsrfPolicy {
fn default() -> Self {
Self {
reject_ip_literals: true,
allow_http: false,
allow_loopback_resolved: false,
}
}
}
impl SsrfPolicy {
#[doc(hidden)]
pub fn allow_test_loopback() -> Self {
Self {
allow_loopback_resolved: true,
..Self::default()
}
}
}
impl SsrfPolicy {
pub fn check_url(&self, url: &str) -> Result<(), AcdpError> {
self.classify_url(url).map_err(AcdpError::from)
}
pub fn classify_url(&self, url: &str) -> Result<(), SsrfRejection> {
let parsed = url::Url::parse(url).map_err(|e| SsrfRejection {
reason: SsrfReason::InvalidUrl,
detail: format!("invalid URL: {e}"),
})?;
if !self.allow_http && parsed.scheme() != "https" {
return Err(SsrfRejection {
reason: SsrfReason::NonHttps,
detail: format!(
"SSRF policy: scheme '{}' not permitted; only https",
parsed.scheme()
),
});
}
let host = parsed.host().ok_or_else(|| SsrfRejection {
reason: SsrfReason::InvalidUrl,
detail: format!("URL has no host: {url}"),
})?;
match host {
url::Host::Ipv4(v4) => {
if self.reject_ip_literals {
return Err(SsrfRejection {
reason: SsrfReason::IpLiteral,
detail: format!(
"SSRF policy: IPv4 literal '{v4}' not permitted; use a hostname"
),
});
}
self.classify_ip(IpAddr::V4(v4))?;
}
url::Host::Ipv6(v6) => {
if self.reject_ip_literals {
return Err(SsrfRejection {
reason: SsrfReason::IpLiteral,
detail: format!(
"SSRF policy: IPv6 literal '{v6}' not permitted; use a hostname"
),
});
}
self.classify_ip(IpAddr::V6(v6))?;
}
url::Host::Domain(name) => {
if name.is_empty() || name.len() > 253 {
return Err(SsrfRejection {
reason: SsrfReason::InvalidUrl,
detail: format!("SSRF policy: invalid hostname length: {name}"),
});
}
}
}
Ok(())
}
pub fn check_resolved_ip(&self, ip: IpAddr) -> Result<(), AcdpError> {
self.check_ip(ip)
}
pub fn check_ip(&self, ip: IpAddr) -> Result<(), AcdpError> {
self.classify_ip(ip).map_err(AcdpError::from)
}
pub fn classify_ip(&self, ip: IpAddr) -> Result<(), SsrfRejection> {
let reason = match ip {
IpAddr::V4(v4) => {
if self.allow_loopback_resolved && v4.is_loopback() {
None
} else {
classify_unsafe_v4(v4)
}
}
IpAddr::V6(v6) => {
if self.allow_loopback_resolved && v6.is_loopback() {
None
} else {
classify_unsafe_v6(v6)
}
}
};
match reason {
Some(reason) => Err(SsrfRejection {
reason,
detail: format!("SSRF policy: IP address '{ip}' is in a forbidden range"),
}),
None => Ok(()),
}
}
#[cfg(feature = "client")]
pub async fn pin_resolved_ip(&self, host: &str, port: u16) -> Result<SocketAddr, AcdpError> {
let target = format!("{host}:{port}");
let candidates: Vec<SocketAddr> = tokio::net::lookup_host(&target)
.await
.map_err(|e| AcdpError::Http(format!("DNS lookup for '{host}' failed: {e}")))?
.collect();
if candidates.is_empty() {
return Err(AcdpError::Http(format!(
"DNS lookup for '{host}' returned no addresses"
)));
}
reject_if_any_forbidden(self, host, &candidates)?;
let pinned = candidates
.iter()
.find(|a| a.is_ipv4())
.or_else(|| candidates.first())
.copied()
.expect("candidates is non-empty");
Ok(pinned)
}
pub fn check_redirect_authority(
&self,
original_url: &url::Url,
redirect_url: &str,
) -> Result<(), AcdpError> {
self.classify_redirect_authority(original_url, redirect_url)
.map_err(AcdpError::from)
}
pub fn classify_redirect_authority(
&self,
original_url: &url::Url,
redirect_url: &str,
) -> Result<(), SsrfRejection> {
let redirect = url::Url::parse(redirect_url).map_err(|e| SsrfRejection {
reason: SsrfReason::InvalidUrl,
detail: format!("invalid redirect URL: {e}"),
})?;
if !same_fetch_authority(original_url, &redirect) {
return Err(SsrfRejection {
reason: SsrfReason::CrossAuthority,
detail: format!(
"SSRF policy: cross-authority redirect rejected: {original_url} → {redirect}"
),
});
}
Ok(())
}
pub fn classify_redirect(&self, from_url: &str, to_url: &str) -> Result<(), SsrfRejection> {
let original = url::Url::parse(from_url).map_err(|e| SsrfRejection {
reason: SsrfReason::InvalidUrl,
detail: format!("invalid origin URL: {e}"),
})?;
self.classify_redirect_authority(&original, to_url)
}
}
pub(crate) fn same_fetch_authority(a: &url::Url, b: &url::Url) -> bool {
a.scheme() == b.scheme()
&& a.host_str() == b.host_str()
&& a.port_or_known_default() == b.port_or_known_default()
}
#[cfg(test)]
fn check_safe_ip(ip: IpAddr) -> Result<(), AcdpError> {
let bad = match ip {
IpAddr::V4(v4) => classify_unsafe_v4(v4).is_some(),
IpAddr::V6(v6) => classify_unsafe_v6(v6).is_some(),
};
if bad {
return Err(AcdpError::SchemaViolation(format!(
"SSRF policy: IP address '{ip}' is in a forbidden range"
)));
}
Ok(())
}
#[cfg(feature = "client")]
fn reject_if_any_forbidden(
policy: &SsrfPolicy,
host: &str,
candidates: &[SocketAddr],
) -> Result<(), AcdpError> {
for addr in candidates {
if let Err(e) = policy.check_ip(addr.ip()) {
return Err(AcdpError::SchemaViolation(format!(
"SSRF policy: DNS answer for '{host}' contains a forbidden address \
({} is disallowed); rejecting the entire resolution. {e}",
addr.ip()
)));
}
}
Ok(())
}
#[cfg(feature = "client")]
pub(crate) struct SafeDnsResolver {
policy: SsrfPolicy,
}
#[cfg(feature = "client")]
impl SafeDnsResolver {
pub(crate) fn arc(policy: SsrfPolicy) -> Arc<Self> {
Arc::new(Self { policy })
}
}
#[cfg(feature = "client")]
pub fn safe_client(
policy: &SsrfPolicy,
timeout: std::time::Duration,
) -> Result<reqwest::Client, AcdpError> {
reqwest::Client::builder()
.use_rustls_tls()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(timeout)
.redirect(reqwest::redirect::Policy::none())
.pool_max_idle_per_host(0)
.dns_resolver(SafeDnsResolver::arc(policy.clone()))
.build()
.map_err(|e| AcdpError::Http(e.to_string()))
}
#[cfg(feature = "client")]
impl reqwest::dns::Resolve for SafeDnsResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let policy = self.policy.clone();
let host = name.as_str().to_string();
Box::pin(async move {
let target = format!("{host}:0");
let candidates: Vec<SocketAddr> = tokio::net::lookup_host(&target)
.await
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?
.collect();
if candidates.is_empty() {
let msg: String = format!("DNS lookup for '{host}' returned no addresses");
return Err(msg.into());
}
if let Err(e) = reject_if_any_forbidden(&policy, &host, &candidates) {
let msg: String = e.to_string();
return Err(msg.into());
}
let addrs: reqwest::dns::Addrs = Box::new(candidates.into_iter());
Ok(addrs)
})
}
}
fn classify_unsafe_v4(ip: Ipv4Addr) -> Option<SsrfReason> {
let o = ip.octets();
if o[0] == 0 {
Some(SsrfReason::MulticastOrReserved)
} else if o[0] == 10 {
Some(SsrfReason::Private)
} else if o[0] == 100 && (o[1] & 0xc0) == 64 {
Some(SsrfReason::Private)
} else if o[0] == 127 {
Some(SsrfReason::Loopback)
} else if o[0] == 169 && o[1] == 254 {
Some(SsrfReason::Imds)
} else if o[0] == 172 && (o[1] & 0xf0) == 16 {
Some(SsrfReason::Private)
} else if o[0] == 192 && o[1] == 0 && o[2] == 0 {
Some(SsrfReason::MulticastOrReserved)
} else if o[0] == 192 && o[1] == 168 {
Some(SsrfReason::Private)
} else if o[0] == 198 && (o[1] == 18 || o[1] == 19) {
Some(SsrfReason::MulticastOrReserved)
} else if o[0] >= 224 && o[0] <= 239 {
Some(SsrfReason::MulticastOrReserved)
} else if o[0] >= 240 {
Some(SsrfReason::MulticastOrReserved)
} else {
None
}
}
fn classify_unsafe_v6(ip: Ipv6Addr) -> Option<SsrfReason> {
if ip.is_loopback() {
return Some(SsrfReason::Loopback);
}
if ip.is_unspecified() || ip.is_multicast() {
return Some(SsrfReason::MulticastOrReserved);
}
let segments = ip.segments();
if segments[0..5] == [0, 0, 0, 0, 0] && (segments[5] == 0 || segments[5] == 0xffff) {
let v4 = Ipv4Addr::new(
(segments[6] >> 8) as u8,
(segments[6] & 0xff) as u8,
(segments[7] >> 8) as u8,
(segments[7] & 0xff) as u8,
);
if !v4.is_unspecified() {
return classify_unsafe_v4(v4);
}
}
if segments[0] == 0x0064 && segments[1] == 0xff9b {
return Some(SsrfReason::Imds);
}
if (segments[0] & 0xfe00) == 0xfc00 {
return Some(SsrfReason::Private);
}
if (segments[0] & 0xffc0) == 0xfe80 {
return Some(SsrfReason::Imds);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "client")]
#[tokio::test]
async fn safe_client_default_refuses_loopback() {
let client =
safe_client(&SsrfPolicy::default(), std::time::Duration::from_secs(2)).unwrap();
let result = client.get("http://127.0.0.1:9/").send().await;
assert!(
result.is_err(),
"default policy must refuse a loopback target"
);
}
#[cfg(feature = "client")]
#[test]
fn safe_client_builds_with_loopback_policy() {
assert!(safe_client(
&SsrfPolicy::allow_test_loopback(),
std::time::Duration::from_secs(2)
)
.is_ok());
}
#[test]
fn https_only_by_default() {
let p = SsrfPolicy::default();
assert!(p.check_url("https://registry.example.com").is_ok());
assert!(p.check_url("http://registry.example.com").is_err());
assert!(p.check_url("file:///etc/passwd").is_err());
}
#[test]
fn rejects_ip_literals_by_default() {
let p = SsrfPolicy::default();
assert!(p.check_url("https://192.168.1.1").is_err());
assert!(p.check_url("https://[::1]").is_err());
}
#[test]
fn private_v4_ranges_rejected() {
assert!(check_safe_ip("10.0.0.1".parse().unwrap()).is_err());
assert!(check_safe_ip("172.16.5.5".parse().unwrap()).is_err());
assert!(check_safe_ip("192.168.1.1".parse().unwrap()).is_err());
assert!(check_safe_ip("127.0.0.1".parse().unwrap()).is_err());
assert!(check_safe_ip("169.254.169.254".parse().unwrap()).is_err());
assert!(check_safe_ip("239.0.0.1".parse().unwrap()).is_err());
assert!(check_safe_ip("8.8.8.8".parse().unwrap()).is_ok());
assert!(check_safe_ip("203.0.113.1".parse().unwrap()).is_ok());
}
#[test]
fn unsafe_v6_rejected() {
assert!(check_safe_ip("::1".parse().unwrap()).is_err());
assert!(check_safe_ip("fc00::1".parse().unwrap()).is_err());
assert!(check_safe_ip("fe80::1".parse().unwrap()).is_err());
assert!(check_safe_ip("::ffff:10.0.0.1".parse().unwrap()).is_err());
assert!(check_safe_ip("::127.0.0.1".parse().unwrap()).is_err());
assert!(check_safe_ip("::7f00:1".parse().unwrap()).is_err());
assert!(check_safe_ip("::169.254.169.254".parse().unwrap()).is_err());
assert!(check_safe_ip("64:ff9b::a9fe:a9fe".parse().unwrap()).is_err());
assert!(check_safe_ip("64:ff9b::169.254.169.254".parse().unwrap()).is_err());
assert!(check_safe_ip("2001:db8::1".parse().unwrap()).is_ok());
assert!(check_safe_ip("::93.184.216.34".parse().unwrap()).is_ok());
}
#[test]
fn cross_authority_redirect_rejected() {
let p = SsrfPolicy::default();
let orig = url::Url::parse("https://registry.example.com/a").unwrap();
let err = p
.check_redirect_authority(&orig, "https://attacker.com/x")
.unwrap_err();
assert!(matches!(err, AcdpError::SchemaViolation(_)));
p.check_redirect_authority(&orig, "https://registry.example.com/y")
.unwrap();
}
fn u(s: &str) -> url::Url {
url::Url::parse(s).unwrap()
}
#[test]
fn same_host_same_implicit_port_allowed() {
assert!(same_fetch_authority(
&u("https://a.example/x"),
&u("https://a.example/y")
));
}
#[test]
fn same_host_explicit_443_same_as_implicit_allowed() {
assert!(same_fetch_authority(
&u("https://a.example/x"),
&u("https://a.example:443/y")
));
}
#[test]
fn same_host_different_port_rejected() {
assert!(!same_fetch_authority(
&u("https://a.example/x"),
&u("https://a.example:8443/y")
));
}
#[test]
fn https_to_http_same_host_rejected() {
assert!(!same_fetch_authority(
&u("https://a.example/x"),
&u("http://a.example/y")
));
}
#[test]
fn different_host_rejected() {
assert!(!same_fetch_authority(
&u("https://a.example/x"),
&u("https://b.example/y")
));
}
#[test]
fn check_redirect_authority_rejects_port_change() {
let p = SsrfPolicy::default();
let orig = u("https://registry.example.com/a");
let err = p
.check_redirect_authority(&orig, "https://registry.example.com:8443/b")
.unwrap_err();
assert!(matches!(err, AcdpError::SchemaViolation(_)));
}
#[cfg(feature = "client")]
fn sock(s: &str) -> SocketAddr {
s.parse().unwrap()
}
#[cfg(feature = "client")]
#[test]
fn mixed_public_private_dns_rejected_entirely() {
let p = SsrfPolicy::default();
let candidates = [sock("203.0.113.10:443"), sock("10.0.0.1:443")];
assert!(reject_if_any_forbidden(&p, "evil.example", &candidates).is_err());
}
#[cfg(feature = "client")]
#[test]
fn mixed_public_loopback_rejected() {
let p = SsrfPolicy::default();
let candidates = [sock("198.51.100.1:443"), sock("127.0.0.1:443")];
assert!(reject_if_any_forbidden(&p, "evil.example", &candidates).is_err());
}
#[cfg(feature = "client")]
#[test]
fn mixed_public_imds_rejected() {
let p = SsrfPolicy::default();
let candidates = [sock("198.51.100.1:443"), sock("169.254.169.254:443")];
assert!(reject_if_any_forbidden(&p, "evil.example", &candidates).is_err());
}
#[cfg(feature = "client")]
#[test]
fn single_public_ip_allowed() {
let p = SsrfPolicy::default();
let candidates = [sock("203.0.113.10:443")];
assert!(reject_if_any_forbidden(&p, "ok.example", &candidates).is_ok());
}
#[cfg(feature = "client")]
#[test]
fn all_public_ips_allowed() {
let p = SsrfPolicy::default();
let candidates = [sock("203.0.113.10:443"), sock("198.51.100.1:443")];
assert!(reject_if_any_forbidden(&p, "ok.example", &candidates).is_ok());
}
#[test]
fn allow_http_can_be_opted_into() {
let p = SsrfPolicy {
allow_http: true,
..SsrfPolicy::default()
};
assert!(p.check_url("http://registry.example.com").is_ok());
}
fn reason_for_ip(s: &str) -> SsrfReason {
SsrfPolicy::default()
.classify_ip(s.parse().unwrap())
.unwrap_err()
.reason
}
#[test]
fn classify_ip_maps_stable_reasons() {
assert_eq!(reason_for_ip("127.0.0.1"), SsrfReason::Loopback);
assert_eq!(reason_for_ip("10.0.0.1"), SsrfReason::Private);
assert_eq!(reason_for_ip("172.16.5.5"), SsrfReason::Private);
assert_eq!(reason_for_ip("192.168.1.1"), SsrfReason::Private);
assert_eq!(reason_for_ip("100.64.0.1"), SsrfReason::Private);
assert_eq!(reason_for_ip("169.254.169.254"), SsrfReason::Imds);
assert_eq!(reason_for_ip("239.0.0.1"), SsrfReason::MulticastOrReserved);
assert_eq!(reason_for_ip("0.0.0.1"), SsrfReason::MulticastOrReserved);
assert_eq!(reason_for_ip("240.0.0.1"), SsrfReason::MulticastOrReserved);
assert_eq!(reason_for_ip("::1"), SsrfReason::Loopback);
assert_eq!(reason_for_ip("fc00::1"), SsrfReason::Private);
assert_eq!(reason_for_ip("fe80::1"), SsrfReason::Imds);
assert_eq!(reason_for_ip("64:ff9b::a9fe:a9fe"), SsrfReason::Imds);
assert_eq!(reason_for_ip("::ffff:10.0.0.1"), SsrfReason::Private);
assert!(SsrfPolicy::default()
.classify_ip("8.8.8.8".parse().unwrap())
.is_ok());
assert!(SsrfPolicy::default()
.classify_ip("2001:db8::1".parse().unwrap())
.is_ok());
}
#[test]
fn classify_reason_as_str_is_stable() {
assert_eq!(SsrfReason::NonHttps.as_str(), "non_https");
assert_eq!(SsrfReason::IpLiteral.as_str(), "ip_literal");
assert_eq!(SsrfReason::InvalidUrl.as_str(), "invalid_url");
assert_eq!(SsrfReason::Loopback.as_str(), "loopback");
assert_eq!(SsrfReason::Private.as_str(), "private");
assert_eq!(SsrfReason::Imds.as_str(), "imds");
assert_eq!(
SsrfReason::MulticastOrReserved.as_str(),
"multicast_or_reserved"
);
assert_eq!(SsrfReason::CrossAuthority.as_str(), "cross_authority");
}
#[test]
fn classify_url_maps_stable_reasons() {
let p = SsrfPolicy::default();
assert_eq!(
p.classify_url("http://registry.example.com")
.unwrap_err()
.reason,
SsrfReason::NonHttps
);
assert_eq!(
p.classify_url("https://192.168.1.1").unwrap_err().reason,
SsrfReason::IpLiteral
);
assert_eq!(
p.classify_url("https://[::1]").unwrap_err().reason,
SsrfReason::IpLiteral
);
assert_eq!(
p.classify_url("not a url").unwrap_err().reason,
SsrfReason::InvalidUrl
);
assert!(p.classify_url("https://registry.example.com").is_ok());
}
#[test]
fn classify_redirect_reasons_and_port_parity() {
let p = SsrfPolicy::default();
assert_eq!(
p.classify_redirect("https://a.example/x", "https://b.example/y")
.unwrap_err()
.reason,
SsrfReason::CrossAuthority
);
assert_eq!(
p.classify_redirect("https://a.example/x", "https://a.example:8443/y")
.unwrap_err()
.reason,
SsrfReason::CrossAuthority
);
assert_eq!(
p.classify_redirect("https://a.example/x", "http://a.example/y")
.unwrap_err()
.reason,
SsrfReason::CrossAuthority
);
assert!(p
.classify_redirect("https://a.example/x", "https://a.example:443/y")
.is_ok());
assert!(p
.classify_redirect("https://a.example/x", "https://a.example/y")
.is_ok());
assert_eq!(
p.classify_redirect("::not-a-url", "https://a.example/y")
.unwrap_err()
.reason,
SsrfReason::InvalidUrl
);
}
#[test]
fn check_wrappers_preserve_schema_violation() {
let p = SsrfPolicy::default();
let err = p.check_url("http://registry.example.com").unwrap_err();
assert!(matches!(err, AcdpError::SchemaViolation(_)));
let err = p.check_ip("10.0.0.1".parse().unwrap()).unwrap_err();
assert!(matches!(err, AcdpError::SchemaViolation(_)));
}
#[cfg(feature = "client")]
#[tokio::test]
async fn pin_resolved_ip_rejects_loopback_hostname() {
let p = SsrfPolicy::default();
let err = p.pin_resolved_ip("localhost", 443).await.unwrap_err();
assert!(matches!(err, AcdpError::SchemaViolation(_)));
}
}