use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
pub struct UpstreamPolicy {
pub allow_private_upstream: bool,
pub insecure_open_upstream: bool,
}
pub use wafrift_types::ip_addr_is_bogon;
#[must_use]
pub fn proxy_ip_is_forbidden(ip: IpAddr) -> bool {
if ip_addr_is_bogon(ip) {
return true;
}
if let IpAddr::V4(v4) = ip
&& v4.is_multicast()
{
return true;
}
false
}
#[must_use]
pub fn upstream_literal_ip_forbidden(url: &str) -> bool {
let Ok(u) = reqwest::Url::parse(url) else {
return false;
};
let Some(host) = u.host_str() else {
return false;
};
let Ok(ip) = host.parse::<IpAddr>() else {
return false;
};
proxy_ip_is_forbidden(ip)
}
async fn resolve_host_all_public(host: &str, port: u16) -> Result<(), String> {
let mut any = false;
let sa_iter = tokio::net::lookup_host((host, port))
.await
.map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
for sa in sa_iter {
any = true;
if proxy_ip_is_forbidden(sa.ip()) {
return Err(format!(
"refusing upstream: DNS for {host} includes non-public address {}",
sa.ip()
));
}
}
if !any {
return Err(format!("refusing upstream: no addresses for {host}"));
}
Ok(())
}
pub async fn assert_forward_url_allowed(url: &str, policy: &UpstreamPolicy) -> Result<(), String> {
if policy.insecure_open_upstream {
return Ok(());
}
if policy.allow_private_upstream {
return Ok(());
}
if upstream_literal_ip_forbidden(url) {
return Err(format!(
"upstream URL uses a disallowed literal IP (private / loopback / link-local / RFC1918): {url}. \
If you're intentionally targeting localhost or RFC1918 lab infrastructure, \
restart wafrift-proxy with `--allow-private-upstream`."
));
}
let u = reqwest::Url::parse(url).map_err(|e| format!("invalid URL: {e}"))?;
let Some(host) = u.host_str() else {
return Err("upstream URL has no host".to_string());
};
if host.parse::<IpAddr>().is_ok() {
return Ok(());
}
let port = u.port_or_known_default().unwrap_or(80);
resolve_host_all_public(host, port).await?;
Ok(())
}
pub async fn resolve_forward_url_pinned(
url: &str,
policy: &UpstreamPolicy,
) -> Result<Vec<SocketAddr>, String> {
if policy.insecure_open_upstream || policy.allow_private_upstream {
let u = reqwest::Url::parse(url).map_err(|e| format!("invalid URL: {e}"))?;
let host = u
.host_str()
.ok_or_else(|| "upstream URL has no host".to_string())?;
let port = u.port_or_known_default().unwrap_or(80);
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(vec![SocketAddr::new(ip, port)]);
}
let lookups = tokio::net::lookup_host((host, port))
.await
.map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
let v: Vec<SocketAddr> = lookups.collect();
if v.is_empty() {
return Err(format!("refusing upstream: no addresses for {host}"));
}
return Ok(v);
}
if upstream_literal_ip_forbidden(url) {
return Err(format!(
"upstream URL uses a disallowed literal IP (private / loopback / link-local / RFC1918): {url}"
));
}
let u = reqwest::Url::parse(url).map_err(|e| format!("invalid URL: {e}"))?;
let host = u
.host_str()
.ok_or_else(|| "upstream URL has no host".to_string())?;
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(vec![SocketAddr::new(
ip,
u.port_or_known_default().unwrap_or(80),
)]);
}
let port = u.port_or_known_default().unwrap_or(80);
let mut filtered = Vec::new();
let lookups = tokio::net::lookup_host((host, port))
.await
.map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
for sa in lookups {
if proxy_ip_is_forbidden(sa.ip()) {
return Err(format!(
"refusing upstream: DNS for {host} includes non-public address {}",
sa.ip()
));
}
filtered.push(sa);
}
if filtered.is_empty() {
return Err(format!("refusing upstream: no addresses for {host}"));
}
Ok(filtered)
}
pub async fn assert_connect_target_allowed(
addr: &str,
policy: &UpstreamPolicy,
) -> Result<(), String> {
let _ = resolve_connect_target_allowed(addr, policy).await?;
Ok(())
}
pub async fn resolve_connect_target_allowed(
addr: &str,
policy: &UpstreamPolicy,
) -> Result<Vec<SocketAddr>, String> {
let authority = addr
.parse::<hyper::http::uri::Authority>()
.map_err(|_| format!("invalid CONNECT authority: {addr}"))?;
let host = authority.host();
let port = authority.port_u16().unwrap_or(443);
if policy.insecure_open_upstream || policy.allow_private_upstream {
let lookups = tokio::net::lookup_host((host, port))
.await
.map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
let v: Vec<SocketAddr> = lookups.collect();
if v.is_empty() {
return Err(format!("no addresses for {host}"));
}
return Ok(v);
}
if let Ok(ip) = host.parse::<IpAddr>() {
if proxy_ip_is_forbidden(ip) {
return Err(format!(
"refusing CONNECT to non-public literal IP {ip}. \
If you're targeting a localhost or RFC1918 lab service, \
restart wafrift-proxy with `--allow-private-upstream`."
));
}
return Ok(vec![SocketAddr::new(ip, port)]);
}
let mut filtered = Vec::new();
let lookups = tokio::net::lookup_host((host, port))
.await
.map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
for sa in lookups {
if proxy_ip_is_forbidden(sa.ip()) {
return Err(format!(
"refusing upstream: DNS for {host} includes non-public address {}",
sa.ip()
));
}
filtered.push(sa);
}
if filtered.is_empty() {
return Err(format!("refusing upstream: no addresses for {host}"));
}
Ok(filtered)
}
pub struct BogonFilteringResolver {
pub policy: Arc<UpstreamPolicy>,
}
impl reqwest::dns::Resolve for BogonFilteringResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let policy = self.policy.clone();
let host = name.as_str().to_string();
Box::pin(async move {
let lookups = tokio::net::lookup_host((host.as_str(), 0))
.await
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
let allow_private = policy.allow_private_upstream || policy.insecure_open_upstream;
let filtered: Vec<SocketAddr> = lookups
.into_iter()
.filter(|sa| allow_private || !proxy_ip_is_forbidden(sa.ip()))
.collect();
if filtered.is_empty() {
return Err(Box::<dyn std::error::Error + Send + Sync>::from(format!(
"DNS rebinding refused: every address for {host} is in the bogon set"
)));
}
let iter: reqwest::dns::Addrs = Box::new(filtered.into_iter());
Ok(iter)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bogon_v4_loopback() {
assert!(ip_addr_is_bogon("127.0.0.1".parse().unwrap()));
}
#[test]
fn public_v4_ok() {
assert!(!ip_addr_is_bogon("8.8.8.8".parse().unwrap()));
}
#[test]
fn proxy_forbidden_blocks_multicast_224() {
for a in [224u8, 225, 239] {
let ip: IpAddr = format!("{a}.0.0.1").parse().unwrap();
assert!(
proxy_ip_is_forbidden(ip),
"{ip} in 224–239 multicast must be forbidden by proxy policy"
);
}
}
#[test]
fn proxy_forbidden_passes_public_not_multicast() {
for addr in ["8.8.8.8", "1.1.1.1", "2001:4860:4860::8888"] {
let ip: IpAddr = addr.parse().unwrap();
assert!(
!proxy_ip_is_forbidden(ip),
"{ip} is public and must not be blocked by proxy policy"
);
}
}
#[test]
fn proxy_forbidden_inherits_all_bogon_ranges() {
for addr in [
"127.0.0.1",
"169.254.169.254",
"10.0.0.1",
"192.168.1.1",
"::1",
] {
let ip: IpAddr = addr.parse().unwrap();
assert!(
proxy_ip_is_forbidden(ip),
"{ip} must be blocked by proxy policy (inherited from bogon)"
);
}
}
#[test]
fn ipv4_mapped_v6_loopback_is_bogon() {
assert!(ip_addr_is_bogon("::ffff:127.0.0.1".parse().unwrap()));
}
#[test]
fn ipv4_mapped_v6_imds_is_bogon() {
assert!(ip_addr_is_bogon("::ffff:169.254.169.254".parse().unwrap()));
}
#[test]
fn ipv4_mapped_v6_rfc1918_is_bogon() {
assert!(ip_addr_is_bogon("::ffff:10.0.0.1".parse().unwrap()));
assert!(ip_addr_is_bogon("::ffff:192.168.1.1".parse().unwrap()));
assert!(ip_addr_is_bogon("::ffff:172.16.0.1".parse().unwrap()));
}
#[test]
fn ipv4_mapped_v6_public_ok() {
assert!(!ip_addr_is_bogon("::ffff:8.8.8.8".parse().unwrap()));
}
#[test]
fn rfc3849_documentation_v6_is_bogon() {
assert!(ip_addr_is_bogon("2001:db8::1".parse().unwrap()));
assert!(ip_addr_is_bogon("2001:db8:cafe::1".parse().unwrap()));
}
#[test]
fn six_to_four_with_private_v4_is_bogon() {
assert!(ip_addr_is_bogon("2002:7f00:1::".parse().unwrap()));
assert!(ip_addr_is_bogon("2002:c0a8:101::".parse().unwrap()));
assert!(ip_addr_is_bogon("2002:a9fe:a9fe::".parse().unwrap()));
}
#[test]
fn six_to_four_with_public_v4_ok() {
assert!(!ip_addr_is_bogon("2002:808:808::".parse().unwrap()));
}
#[test]
fn public_v6_google_dns_ok() {
assert!(!ip_addr_is_bogon("2001:4860:4860::8888".parse().unwrap()));
}
}