Skip to main content

nono_proxy/
filter.rs

1//! Async host filtering wrapping the library's [`HostFilter`](nono::HostFilter).
2//!
3//! Performs DNS resolution via `tokio::net::lookup_host()`, checks resolved
4//! IPs against the link-local range (cloud metadata SSRF protection), and
5//! validates the hostname against the cloud metadata deny list and allowlist.
6
7use crate::error::Result;
8use nono::net_filter::{FilterResult, HostFilter};
9use std::net::{IpAddr, SocketAddr};
10use tracing::debug;
11
12/// Result of a filter check including resolved socket addresses.
13///
14/// When the filter allows a host, `resolved_addrs` contains the DNS-resolved
15/// addresses. Callers MUST connect to these addresses (not re-resolve the
16/// hostname) to prevent DNS rebinding TOCTOU attacks.
17pub struct CheckResult {
18    /// The filter decision
19    pub result: FilterResult,
20    /// DNS-resolved addresses (empty if denied or DNS failed)
21    pub resolved_addrs: Vec<SocketAddr>,
22}
23
24/// Async wrapper around `HostFilter` that performs DNS resolution.
25#[derive(Debug, Clone)]
26pub struct ProxyFilter {
27    inner: HostFilter,
28}
29
30impl ProxyFilter {
31    /// Create a new proxy filter with the given allowed hosts.
32    #[must_use]
33    pub fn new(allowed_hosts: &[String]) -> Self {
34        Self {
35            inner: HostFilter::new(allowed_hosts),
36        }
37    }
38
39    /// Create a strict proxy filter: an empty allowlist denies every host.
40    #[must_use]
41    pub fn new_strict(allowed_hosts: &[String]) -> Self {
42        Self {
43            inner: HostFilter::new_strict(allowed_hosts),
44        }
45    }
46
47    /// Create a filter that allows all hosts (except cloud metadata).
48    #[must_use]
49    pub fn allow_all() -> Self {
50        Self {
51            inner: HostFilter::allow_all(),
52        }
53    }
54
55    /// Check a host against the filter with async DNS resolution.
56    ///
57    /// Resolves the hostname to IP addresses, then checks all resolved IPs
58    /// against the link-local deny range (cloud metadata SSRF protection).
59    /// If any resolved IP is link-local, the request is blocked.
60    ///
61    /// On success, returns both the filter result and the resolved socket
62    /// addresses. Callers MUST use `resolved_addrs` to connect to the upstream
63    /// instead of re-resolving the hostname, eliminating the DNS rebinding
64    /// TOCTOU window.
65    pub async fn check_host(&self, host: &str, port: u16) -> Result<CheckResult> {
66        // Resolve DNS
67        let addr_str = format!("{}:{}", host, port);
68        let resolved: Vec<SocketAddr> = match tokio::net::lookup_host(&addr_str).await {
69            Ok(addrs) => addrs.collect(),
70            Err(e) => {
71                debug!("DNS resolution failed for {}: {}", host, e);
72                // If DNS fails, we still check the hostname against deny list
73                // (cloud metadata hostnames don't need DNS resolution to be blocked)
74                Vec::new()
75            }
76        };
77
78        let resolved_ips: Vec<IpAddr> = resolved.iter().map(|a| a.ip()).collect();
79        let result = self.inner.check_host(host, &resolved_ips);
80
81        // Only return resolved addrs on allow to prevent misuse
82        let addrs = if result.is_allowed() {
83            resolved
84        } else {
85            Vec::new()
86        };
87
88        Ok(CheckResult {
89            result,
90            resolved_addrs: addrs,
91        })
92    }
93
94    /// Check a host with pre-resolved IPs (no DNS lookup).
95    #[must_use]
96    pub fn check_host_with_ips(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
97        self.inner.check_host(host, resolved_ips)
98    }
99
100    /// Number of allowed hosts configured.
101    #[must_use]
102    pub fn allowed_count(&self) -> usize {
103        self.inner.allowed_count()
104    }
105}
106
107#[cfg(test)]
108#[allow(clippy::unwrap_used)]
109mod tests {
110    use super::*;
111    use std::net::Ipv4Addr;
112
113    #[test]
114    fn test_proxy_filter_delegates_to_host_filter() {
115        let filter = ProxyFilter::new(&["api.openai.com".to_string()]);
116        let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
117
118        let result = filter.check_host_with_ips("api.openai.com", &public_ip);
119        assert!(result.is_allowed());
120
121        let result = filter.check_host_with_ips("evil.com", &public_ip);
122        assert!(!result.is_allowed());
123    }
124
125    #[test]
126    fn test_proxy_filter_allow_all() {
127        let filter = ProxyFilter::allow_all();
128        let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
129        let result = filter.check_host_with_ips("anything.com", &public_ip);
130        assert!(result.is_allowed());
131    }
132
133    #[test]
134    fn test_proxy_filter_allows_private_networks() {
135        let filter = ProxyFilter::allow_all();
136        let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
137        let result = filter.check_host_with_ips("corp.internal", &private_ip);
138        assert!(result.is_allowed());
139    }
140
141    #[test]
142    fn test_proxy_filter_denies_link_local() {
143        let filter = ProxyFilter::allow_all();
144        let link_local = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))];
145        let result = filter.check_host_with_ips("evil.com", &link_local);
146        assert!(!result.is_allowed());
147    }
148}