Skip to main content

nono_proxy/
filter.rs

1//! Async host filtering wrapping the library's [`HostFilter`](nono::HostFilter).
2//!
3//! Performs DNS resolution via `tokio::net::lookup_host()`, checks resolved
4//! IPs against the link-local range (cloud metadata SSRF protection), and
5//! validates the hostname against the cloud metadata deny list and allowlist.
6
7use crate::error::Result;
8use nono::net_filter::{FilterResult, HostFilter};
9use std::net::{IpAddr, SocketAddr};
10use tracing::debug;
11
12/// Result of a filter check including resolved socket addresses.
13///
14/// When the filter allows a host, `resolved_addrs` contains the DNS-resolved
15/// addresses. Callers MUST connect to these addresses (not re-resolve the
16/// hostname) to prevent DNS rebinding TOCTOU attacks.
17pub struct CheckResult {
18    /// The filter decision
19    pub result: FilterResult,
20    /// DNS-resolved addresses (empty if denied or DNS failed)
21    pub resolved_addrs: Vec<SocketAddr>,
22}
23
24/// Async wrapper around `HostFilter` that performs DNS resolution.
25#[derive(Debug, Clone)]
26pub struct ProxyFilter {
27    inner: HostFilter,
28}
29
30impl ProxyFilter {
31    /// Create a new proxy filter with the given allowed hosts.
32    #[must_use]
33    pub fn new(allowed_hosts: &[String]) -> Self {
34        Self {
35            inner: HostFilter::new(allowed_hosts),
36        }
37    }
38
39    /// Create a filter that allows all hosts (except cloud metadata).
40    #[must_use]
41    pub fn allow_all() -> Self {
42        Self {
43            inner: HostFilter::allow_all(),
44        }
45    }
46
47    /// Check a host against the filter with async DNS resolution.
48    ///
49    /// Resolves the hostname to IP addresses, then checks all resolved IPs
50    /// against the link-local deny range (cloud metadata SSRF protection).
51    /// If any resolved IP is link-local, the request is blocked.
52    ///
53    /// On success, returns both the filter result and the resolved socket
54    /// addresses. Callers MUST use `resolved_addrs` to connect to the upstream
55    /// instead of re-resolving the hostname, eliminating the DNS rebinding
56    /// TOCTOU window.
57    pub async fn check_host(&self, host: &str, port: u16) -> Result<CheckResult> {
58        // Resolve DNS
59        let addr_str = format!("{}:{}", host, port);
60        let resolved: Vec<SocketAddr> = match tokio::net::lookup_host(&addr_str).await {
61            Ok(addrs) => addrs.collect(),
62            Err(e) => {
63                debug!("DNS resolution failed for {}: {}", host, e);
64                // If DNS fails, we still check the hostname against deny list
65                // (cloud metadata hostnames don't need DNS resolution to be blocked)
66                Vec::new()
67            }
68        };
69
70        let resolved_ips: Vec<IpAddr> = resolved.iter().map(|a| a.ip()).collect();
71        let result = self.inner.check_host(host, &resolved_ips);
72
73        // Only return resolved addrs on allow to prevent misuse
74        let addrs = if result.is_allowed() {
75            resolved
76        } else {
77            Vec::new()
78        };
79
80        Ok(CheckResult {
81            result,
82            resolved_addrs: addrs,
83        })
84    }
85
86    /// Check a host with pre-resolved IPs (no DNS lookup).
87    #[must_use]
88    pub fn check_host_with_ips(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
89        self.inner.check_host(host, resolved_ips)
90    }
91
92    /// Number of allowed hosts configured.
93    #[must_use]
94    pub fn allowed_count(&self) -> usize {
95        self.inner.allowed_count()
96    }
97}
98
99#[cfg(test)]
100#[allow(clippy::unwrap_used)]
101mod tests {
102    use super::*;
103    use std::net::Ipv4Addr;
104
105    #[test]
106    fn test_proxy_filter_delegates_to_host_filter() {
107        let filter = ProxyFilter::new(&["api.openai.com".to_string()]);
108        let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
109
110        let result = filter.check_host_with_ips("api.openai.com", &public_ip);
111        assert!(result.is_allowed());
112
113        let result = filter.check_host_with_ips("evil.com", &public_ip);
114        assert!(!result.is_allowed());
115    }
116
117    #[test]
118    fn test_proxy_filter_allow_all() {
119        let filter = ProxyFilter::allow_all();
120        let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
121        let result = filter.check_host_with_ips("anything.com", &public_ip);
122        assert!(result.is_allowed());
123    }
124
125    #[test]
126    fn test_proxy_filter_allows_private_networks() {
127        let filter = ProxyFilter::allow_all();
128        let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
129        let result = filter.check_host_with_ips("corp.internal", &private_ip);
130        assert!(result.is_allowed());
131    }
132
133    #[test]
134    fn test_proxy_filter_denies_link_local() {
135        let filter = ProxyFilter::allow_all();
136        let link_local = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))];
137        let result = filter.check_host_with_ips("evil.com", &link_local);
138        assert!(!result.is_allowed());
139    }
140}