1use crate::error::Result;
7use nono::net_filter::{FilterResult, HostFilter};
8use std::net::{IpAddr, SocketAddr};
9use tracing::debug;
10
11pub struct CheckResult {
17 pub result: FilterResult,
19 pub resolved_addrs: Vec<SocketAddr>,
21}
22
23#[derive(Debug, Clone)]
25pub struct ProxyFilter {
26 inner: HostFilter,
27}
28
29impl ProxyFilter {
30 #[must_use]
32 pub fn new(allowed_hosts: &[String]) -> Self {
33 Self {
34 inner: HostFilter::new(allowed_hosts),
35 }
36 }
37
38 #[must_use]
40 pub fn allow_all() -> Self {
41 Self {
42 inner: HostFilter::allow_all(),
43 }
44 }
45
46 pub async fn check_host(&self, host: &str, port: u16) -> Result<CheckResult> {
57 let addr_str = format!("{}:{}", host, port);
59 let resolved: Vec<SocketAddr> = match tokio::net::lookup_host(&addr_str).await {
60 Ok(addrs) => addrs.collect(),
61 Err(e) => {
62 debug!("DNS resolution failed for {}: {}", host, e);
63 Vec::new()
66 }
67 };
68
69 let resolved_ips: Vec<IpAddr> = resolved.iter().map(|a| a.ip()).collect();
70 let result = self.inner.check_host(host, &resolved_ips);
71
72 let addrs = if result.is_allowed() {
74 resolved
75 } else {
76 Vec::new()
77 };
78
79 Ok(CheckResult {
80 result,
81 resolved_addrs: addrs,
82 })
83 }
84
85 #[must_use]
87 pub fn check_host_with_ips(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
88 self.inner.check_host(host, resolved_ips)
89 }
90
91 #[must_use]
93 pub fn allowed_count(&self) -> usize {
94 self.inner.allowed_count()
95 }
96}
97
98#[cfg(test)]
99#[allow(clippy::unwrap_used)]
100mod tests {
101 use super::*;
102 use std::net::Ipv4Addr;
103
104 #[test]
105 fn test_proxy_filter_delegates_to_host_filter() {
106 let filter = ProxyFilter::new(&["api.openai.com".to_string()]);
107 let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
108
109 let result = filter.check_host_with_ips("api.openai.com", &public_ip);
110 assert!(result.is_allowed());
111
112 let result = filter.check_host_with_ips("evil.com", &public_ip);
113 assert!(!result.is_allowed());
114 }
115
116 #[test]
117 fn test_proxy_filter_allow_all() {
118 let filter = ProxyFilter::allow_all();
119 let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
120 let result = filter.check_host_with_ips("anything.com", &public_ip);
121 assert!(result.is_allowed());
122 }
123
124 #[test]
125 fn test_proxy_filter_denies_private_ips() {
126 let filter = ProxyFilter::allow_all();
127 let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
128 let result = filter.check_host_with_ips("corp.internal", &private_ip);
129 assert!(!result.is_allowed());
130 }
131}