wafrift-proxy 0.3.1

HTTP forward proxy with automatic WAF evasion and optional TLS interception support.
Documentation
//! Upstream destination policy: literal-IP bogons and DNS SSRF-style checks.

use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

/// Policy for CONNECT and cleartext forward destinations.
#[derive(Debug, Clone, Default)]
pub struct UpstreamPolicy {
    /// Allow RFC1918 / loopback / link-local targets (literal or DNS).
    pub allow_private_upstream: bool,
    /// Skip all destination checks (lab only).
    pub insecure_open_upstream: bool,
}

/// Re-export the workspace-canonical bogon classifier.
pub use wafrift_types::ip_addr_is_bogon;

/// True when this IP should never be the target of a proxy-initiated
/// outbound connection.
///
/// Extends [`ip_addr_is_bogon`] with IPv4 multicast (`224.0.0.0/4`).
/// The bogon crate intentionally leaves IPv4 multicast allowed because
/// scanner workloads legitimately probe multicast addresses; the proxy
/// forward/CONNECT path has no such use case and must refuse them to
/// prevent SSRF via multicast-capable LAN services.
#[must_use]
pub fn proxy_ip_is_forbidden(ip: IpAddr) -> bool {
    if ip_addr_is_bogon(ip) {
        return true;
    }
    // IPv4 multicast: 224.0.0.0/4 (first octet 224–239).
    if let IpAddr::V4(v4) = ip
        && v4.is_multicast()
    {
        return true;
    }
    false
}

/// Block forwarding when the URL host is a literal forbidden IP.
#[must_use]
pub fn upstream_literal_ip_forbidden(url: &str) -> bool {
    let Ok(u) = reqwest::Url::parse(url) else {
        return false;
    };
    let Some(host) = u.host_str() else {
        return false;
    };
    let Ok(ip) = host.parse::<IpAddr>() else {
        return false;
    };
    proxy_ip_is_forbidden(ip)
}

async fn resolve_host_all_public(host: &str, port: u16) -> Result<(), String> {
    let mut any = false;
    let sa_iter = tokio::net::lookup_host((host, port))
        .await
        .map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
    for sa in sa_iter {
        any = true;
        if proxy_ip_is_forbidden(sa.ip()) {
            return Err(format!(
                "refusing upstream: DNS for {host} includes non-public address {}",
                sa.ip()
            ));
        }
    }
    if !any {
        return Err(format!("refusing upstream: no addresses for {host}"));
    }
    Ok(())
}

/// Validate `https?://…` (or absolute URL) before forwarding.
pub async fn assert_forward_url_allowed(url: &str, policy: &UpstreamPolicy) -> Result<(), String> {
    if policy.insecure_open_upstream {
        return Ok(());
    }
    if policy.allow_private_upstream {
        return Ok(());
    }
    if upstream_literal_ip_forbidden(url) {
        return Err(format!(
            "upstream URL uses a disallowed literal IP (private / loopback / link-local / RFC1918): {url}. \
             If you're intentionally targeting localhost or RFC1918 lab infrastructure, \
             restart wafrift-proxy with `--allow-private-upstream`."
        ));
    }
    let u = reqwest::Url::parse(url).map_err(|e| format!("invalid URL: {e}"))?;
    let Some(host) = u.host_str() else {
        return Err("upstream URL has no host".to_string());
    };
    if host.parse::<IpAddr>().is_ok() {
        return Ok(());
    }
    let port = u.port_or_known_default().unwrap_or(80);
    resolve_host_all_public(host, port).await?;
    Ok(())
}

/// Resolve a forward URL to validated public socket addresses.
///
/// Callers (especially the stealth TLS path) must use these pinned
/// addresses instead of re-resolving DNS after an intercept wait, which
/// would reopen a DNS-rebinding TOCTOU window.
pub async fn resolve_forward_url_pinned(
    url: &str,
    policy: &UpstreamPolicy,
) -> Result<Vec<SocketAddr>, String> {
    if policy.insecure_open_upstream || policy.allow_private_upstream {
        let u = reqwest::Url::parse(url).map_err(|e| format!("invalid URL: {e}"))?;
        let host = u
            .host_str()
            .ok_or_else(|| "upstream URL has no host".to_string())?;
        let port = u.port_or_known_default().unwrap_or(80);
        if let Ok(ip) = host.parse::<IpAddr>() {
            return Ok(vec![SocketAddr::new(ip, port)]);
        }
        let lookups = tokio::net::lookup_host((host, port))
            .await
            .map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
        let v: Vec<SocketAddr> = lookups.collect();
        if v.is_empty() {
            return Err(format!("refusing upstream: no addresses for {host}"));
        }
        return Ok(v);
    }
    if upstream_literal_ip_forbidden(url) {
        return Err(format!(
            "upstream URL uses a disallowed literal IP (private / loopback / link-local / RFC1918): {url}"
        ));
    }
    let u = reqwest::Url::parse(url).map_err(|e| format!("invalid URL: {e}"))?;
    let host = u
        .host_str()
        .ok_or_else(|| "upstream URL has no host".to_string())?;
    if let Ok(ip) = host.parse::<IpAddr>() {
        return Ok(vec![SocketAddr::new(
            ip,
            u.port_or_known_default().unwrap_or(80),
        )]);
    }
    let port = u.port_or_known_default().unwrap_or(80);
    let mut filtered = Vec::new();
    let lookups = tokio::net::lookup_host((host, port))
        .await
        .map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
    for sa in lookups {
        if proxy_ip_is_forbidden(sa.ip()) {
            return Err(format!(
                "refusing upstream: DNS for {host} includes non-public address {}",
                sa.ip()
            ));
        }
        filtered.push(sa);
    }
    if filtered.is_empty() {
        return Err(format!("refusing upstream: no addresses for {host}"));
    }
    Ok(filtered)
}

/// Validate `CONNECT` authority `host:port` before tunnel/MITM.
pub async fn assert_connect_target_allowed(
    addr: &str,
    policy: &UpstreamPolicy,
) -> Result<(), String> {
    let _ = resolve_connect_target_allowed(addr, policy).await?;
    Ok(())
}

/// Validate `CONNECT` authority `host:port` AND return the resolved
/// public socket addresses. Callers should pass these straight to
/// `TcpStream::connect_to_addr` instead of reusing `host:port` so a
/// DNS rebinding flip between the validation and the connect cannot
/// land — pre-fix `tunnel(addr: String)` re-resolved DNS, opening a
/// TOCTOU window the audit caught as CRITICAL.
pub async fn resolve_connect_target_allowed(
    addr: &str,
    policy: &UpstreamPolicy,
) -> Result<Vec<SocketAddr>, String> {
    let authority = addr
        .parse::<hyper::http::uri::Authority>()
        .map_err(|_| format!("invalid CONNECT authority: {addr}"))?;
    let host = authority.host();
    let port = authority.port_u16().unwrap_or(443);

    if policy.insecure_open_upstream || policy.allow_private_upstream {
        // Permissive mode: still resolve so the caller has addresses
        // to connect to without doing its own lookup, but skip bogon
        // filtering.
        let lookups = tokio::net::lookup_host((host, port))
            .await
            .map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
        let v: Vec<SocketAddr> = lookups.collect();
        if v.is_empty() {
            return Err(format!("no addresses for {host}"));
        }
        return Ok(v);
    }

    if let Ok(ip) = host.parse::<IpAddr>() {
        if proxy_ip_is_forbidden(ip) {
            return Err(format!(
                "refusing CONNECT to non-public literal IP {ip}. \
                 If you're targeting a localhost or RFC1918 lab service, \
                 restart wafrift-proxy with `--allow-private-upstream`."
            ));
        }
        return Ok(vec![SocketAddr::new(ip, port)]);
    }

    let mut filtered = Vec::new();
    let lookups = tokio::net::lookup_host((host, port))
        .await
        .map_err(|e| format!("DNS resolution failed for {host}: {e}"))?;
    for sa in lookups {
        if proxy_ip_is_forbidden(sa.ip()) {
            return Err(format!(
                "refusing upstream: DNS for {host} includes non-public address {}",
                sa.ip()
            ));
        }
        filtered.push(sa);
    }
    if filtered.is_empty() {
        return Err(format!("refusing upstream: no addresses for {host}"));
    }
    Ok(filtered)
}

/// `reqwest::dns::Resolve` impl that wraps the system resolver and
/// drops any address that fails `ip_addr_is_bogon`. This closes the
/// DNS-rebinding TOCTOU between `assert_forward_url_allowed` (first
/// lookup) and reqwest's connection-time lookup (second lookup): both
/// now go through the same bogon filter, so a hostname that resolves
/// to a public IP at policy-check time can't suddenly resolve to
/// 169.254.169.254 / 127.0.0.1 / RFC1918 at fetch time.
///
/// The wrapper is permissive when `allow_private_upstream` is set —
/// caller flips that switch when targeting localhost on purpose
/// (e.g. lab tests).
pub struct BogonFilteringResolver {
    pub policy: Arc<UpstreamPolicy>,
}

impl reqwest::dns::Resolve for BogonFilteringResolver {
    fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
        let policy = self.policy.clone();
        let host = name.as_str().to_string();
        Box::pin(async move {
            let lookups = tokio::net::lookup_host((host.as_str(), 0))
                .await
                .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
            let allow_private = policy.allow_private_upstream || policy.insecure_open_upstream;
            let filtered: Vec<SocketAddr> = lookups
                .into_iter()
                .filter(|sa| allow_private || !proxy_ip_is_forbidden(sa.ip()))
                .collect();
            if filtered.is_empty() {
                return Err(Box::<dyn std::error::Error + Send + Sync>::from(format!(
                    "DNS rebinding refused: every address for {host} is in the bogon set"
                )));
            }
            let iter: reqwest::dns::Addrs = Box::new(filtered.into_iter());
            Ok(iter)
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn bogon_v4_loopback() {
        assert!(ip_addr_is_bogon("127.0.0.1".parse().unwrap()));
    }

    #[test]
    fn public_v4_ok() {
        assert!(!ip_addr_is_bogon("8.8.8.8".parse().unwrap()));
    }

    // ── proxy_ip_is_forbidden: extends bogon with IPv4 multicast ────────────

    #[test]
    fn proxy_forbidden_blocks_multicast_224() {
        // 224.0.0.0/4 is IPv4 multicast. ip_addr_is_bogon allows it
        // (scanner workloads need it); the proxy layer adds the check.
        for a in [224u8, 225, 239] {
            let ip: IpAddr = format!("{a}.0.0.1").parse().unwrap();
            assert!(
                proxy_ip_is_forbidden(ip),
                "{ip} in 224–239 multicast must be forbidden by proxy policy"
            );
        }
    }

    #[test]
    fn proxy_forbidden_passes_public_not_multicast() {
        for addr in ["8.8.8.8", "1.1.1.1", "2001:4860:4860::8888"] {
            let ip: IpAddr = addr.parse().unwrap();
            assert!(
                !proxy_ip_is_forbidden(ip),
                "{ip} is public and must not be blocked by proxy policy"
            );
        }
    }

    #[test]
    fn proxy_forbidden_inherits_all_bogon_ranges() {
        // Spot-check that proxy_ip_is_forbidden is at least as strict as
        // ip_addr_is_bogon for the ranges that matter most to the proxy.
        for addr in [
            "127.0.0.1",
            "169.254.169.254",
            "10.0.0.1",
            "192.168.1.1",
            "::1",
        ] {
            let ip: IpAddr = addr.parse().unwrap();
            assert!(
                proxy_ip_is_forbidden(ip),
                "{ip} must be blocked by proxy policy (inherited from bogon)"
            );
        }
    }

    #[test]
    fn ipv4_mapped_v6_loopback_is_bogon() {
        // ::ffff:127.0.0.1 — without the IPv4-mapped re-check, this
        // sneaks past v.is_loopback() (which only catches ::1).
        assert!(ip_addr_is_bogon("::ffff:127.0.0.1".parse().unwrap()));
    }

    #[test]
    fn ipv4_mapped_v6_imds_is_bogon() {
        // The exact bypass that would have leaked AWS IMDS via SSRF.
        assert!(ip_addr_is_bogon("::ffff:169.254.169.254".parse().unwrap()));
    }

    #[test]
    fn ipv4_mapped_v6_rfc1918_is_bogon() {
        assert!(ip_addr_is_bogon("::ffff:10.0.0.1".parse().unwrap()));
        assert!(ip_addr_is_bogon("::ffff:192.168.1.1".parse().unwrap()));
        assert!(ip_addr_is_bogon("::ffff:172.16.0.1".parse().unwrap()));
    }

    #[test]
    fn ipv4_mapped_v6_public_ok() {
        // Sanity — mapped form of a public address must NOT be flagged.
        assert!(!ip_addr_is_bogon("::ffff:8.8.8.8".parse().unwrap()));
    }

    #[test]
    fn rfc3849_documentation_v6_is_bogon() {
        // 2001:db8::/32 is the IPv6 documentation prefix. Real upstream
        // services should never live there; if a target's DNS returned
        // it, that's almost certainly a misconfiguration we want to refuse.
        assert!(ip_addr_is_bogon("2001:db8::1".parse().unwrap()));
        assert!(ip_addr_is_bogon("2001:db8:cafe::1".parse().unwrap()));
    }

    #[test]
    fn six_to_four_with_private_v4_is_bogon() {
        // 6to4 (RFC 3056) embeds an IPv4 in 2002:WWXX:YYZZ::/48.
        // 2002:7f00:0001:: -> 127.0.0.1 over 6to4.
        assert!(ip_addr_is_bogon("2002:7f00:1::".parse().unwrap()));
        // 2002:c0a8:0101:: -> 192.168.1.1 over 6to4.
        assert!(ip_addr_is_bogon("2002:c0a8:101::".parse().unwrap()));
        // 2002:a9fe:a9fe:: -> 169.254.169.254 over 6to4 (AWS IMDS).
        assert!(ip_addr_is_bogon("2002:a9fe:a9fe::".parse().unwrap()));
    }

    #[test]
    fn six_to_four_with_public_v4_ok() {
        // 2002:0808:0808:: -> 8.8.8.8 over 6to4. Not a bogon.
        assert!(!ip_addr_is_bogon("2002:808:808::".parse().unwrap()));
    }

    #[test]
    fn public_v6_google_dns_ok() {
        assert!(!ip_addr_is_bogon("2001:4860:4860::8888".parse().unwrap()));
    }
}