Skip to main content

nono_proxy/
filter.rs

1//! Async host filtering wrapping the library's [`HostFilter`](nono::HostFilter).
2//!
3//! Performs DNS resolution via `tokio::net::lookup_host()` and checks
4//! resolved IPs against deny CIDRs before checking the hostname allowlist.
5
6use crate::error::Result;
7use nono::net_filter::{FilterResult, HostFilter};
8use std::net::{IpAddr, SocketAddr};
9use tracing::debug;
10
11/// Result of a filter check including resolved socket addresses.
12///
13/// When the filter allows a host, `resolved_addrs` contains the DNS-resolved
14/// addresses. Callers MUST connect to these addresses (not re-resolve the
15/// hostname) to prevent DNS rebinding TOCTOU attacks.
16pub struct CheckResult {
17    /// The filter decision
18    pub result: FilterResult,
19    /// DNS-resolved addresses (empty if denied or DNS failed)
20    pub resolved_addrs: Vec<SocketAddr>,
21}
22
23/// Async wrapper around `HostFilter` that performs DNS resolution.
24#[derive(Debug, Clone)]
25pub struct ProxyFilter {
26    inner: HostFilter,
27}
28
29impl ProxyFilter {
30    /// Create a new proxy filter with the given allowed hosts.
31    #[must_use]
32    pub fn new(allowed_hosts: &[String]) -> Self {
33        Self {
34            inner: HostFilter::new(allowed_hosts),
35        }
36    }
37
38    /// Create a filter that allows all hosts (except deny list).
39    #[must_use]
40    pub fn allow_all() -> Self {
41        Self {
42            inner: HostFilter::allow_all(),
43        }
44    }
45
46    /// Check a host against the filter with async DNS resolution.
47    ///
48    /// Resolves the hostname to IP addresses, then checks all resolved IPs
49    /// against the deny CIDR list. If any IP is denied, the request is blocked
50    /// (DNS rebinding protection).
51    ///
52    /// On success, returns both the filter result and the resolved socket
53    /// addresses. Callers MUST use `resolved_addrs` to connect to the upstream
54    /// instead of re-resolving the hostname, eliminating the DNS rebinding
55    /// TOCTOU window.
56    pub async fn check_host(&self, host: &str, port: u16) -> Result<CheckResult> {
57        // Resolve DNS
58        let addr_str = format!("{}:{}", host, port);
59        let resolved: Vec<SocketAddr> = match tokio::net::lookup_host(&addr_str).await {
60            Ok(addrs) => addrs.collect(),
61            Err(e) => {
62                debug!("DNS resolution failed for {}: {}", host, e);
63                // If DNS fails, we still check the hostname against deny list
64                // (cloud metadata hostnames don't need DNS resolution to be blocked)
65                Vec::new()
66            }
67        };
68
69        let resolved_ips: Vec<IpAddr> = resolved.iter().map(|a| a.ip()).collect();
70        let result = self.inner.check_host(host, &resolved_ips);
71
72        // Only return resolved addrs on allow to prevent misuse
73        let addrs = if result.is_allowed() {
74            resolved
75        } else {
76            Vec::new()
77        };
78
79        Ok(CheckResult {
80            result,
81            resolved_addrs: addrs,
82        })
83    }
84
85    /// Check a host with pre-resolved IPs (no DNS lookup).
86    #[must_use]
87    pub fn check_host_with_ips(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
88        self.inner.check_host(host, resolved_ips)
89    }
90
91    /// Number of allowed hosts configured.
92    #[must_use]
93    pub fn allowed_count(&self) -> usize {
94        self.inner.allowed_count()
95    }
96}
97
98#[cfg(test)]
99#[allow(clippy::unwrap_used)]
100mod tests {
101    use super::*;
102    use std::net::Ipv4Addr;
103
104    #[test]
105    fn test_proxy_filter_delegates_to_host_filter() {
106        let filter = ProxyFilter::new(&["api.openai.com".to_string()]);
107        let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
108
109        let result = filter.check_host_with_ips("api.openai.com", &public_ip);
110        assert!(result.is_allowed());
111
112        let result = filter.check_host_with_ips("evil.com", &public_ip);
113        assert!(!result.is_allowed());
114    }
115
116    #[test]
117    fn test_proxy_filter_allow_all() {
118        let filter = ProxyFilter::allow_all();
119        let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
120        let result = filter.check_host_with_ips("anything.com", &public_ip);
121        assert!(result.is_allowed());
122    }
123
124    #[test]
125    fn test_proxy_filter_denies_private_ips() {
126        let filter = ProxyFilter::allow_all();
127        let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
128        let result = filter.check_host_with_ips("corp.internal", &private_ip);
129        assert!(!result.is_allowed());
130    }
131}