1use crate::error::Result;
8use nono::net_filter::{FilterResult, HostFilter};
9use std::net::{IpAddr, SocketAddr};
10use tracing::debug;
11
12pub struct CheckResult {
18 pub result: FilterResult,
20 pub resolved_addrs: Vec<SocketAddr>,
22}
23
24#[derive(Debug, Clone)]
26pub struct ProxyFilter {
27 inner: HostFilter,
28}
29
30impl ProxyFilter {
31 #[must_use]
33 pub fn new(allowed_hosts: &[String]) -> Self {
34 Self {
35 inner: HostFilter::new(allowed_hosts),
36 }
37 }
38
39 #[must_use]
41 pub fn allow_all() -> Self {
42 Self {
43 inner: HostFilter::allow_all(),
44 }
45 }
46
47 pub async fn check_host(&self, host: &str, port: u16) -> Result<CheckResult> {
58 let addr_str = format!("{}:{}", host, port);
60 let resolved: Vec<SocketAddr> = match tokio::net::lookup_host(&addr_str).await {
61 Ok(addrs) => addrs.collect(),
62 Err(e) => {
63 debug!("DNS resolution failed for {}: {}", host, e);
64 Vec::new()
67 }
68 };
69
70 let resolved_ips: Vec<IpAddr> = resolved.iter().map(|a| a.ip()).collect();
71 let result = self.inner.check_host(host, &resolved_ips);
72
73 let addrs = if result.is_allowed() {
75 resolved
76 } else {
77 Vec::new()
78 };
79
80 Ok(CheckResult {
81 result,
82 resolved_addrs: addrs,
83 })
84 }
85
86 #[must_use]
88 pub fn check_host_with_ips(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
89 self.inner.check_host(host, resolved_ips)
90 }
91
92 #[must_use]
94 pub fn allowed_count(&self) -> usize {
95 self.inner.allowed_count()
96 }
97}
98
99#[cfg(test)]
100#[allow(clippy::unwrap_used)]
101mod tests {
102 use super::*;
103 use std::net::Ipv4Addr;
104
105 #[test]
106 fn test_proxy_filter_delegates_to_host_filter() {
107 let filter = ProxyFilter::new(&["api.openai.com".to_string()]);
108 let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
109
110 let result = filter.check_host_with_ips("api.openai.com", &public_ip);
111 assert!(result.is_allowed());
112
113 let result = filter.check_host_with_ips("evil.com", &public_ip);
114 assert!(!result.is_allowed());
115 }
116
117 #[test]
118 fn test_proxy_filter_allow_all() {
119 let filter = ProxyFilter::allow_all();
120 let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
121 let result = filter.check_host_with_ips("anything.com", &public_ip);
122 assert!(result.is_allowed());
123 }
124
125 #[test]
126 fn test_proxy_filter_allows_private_networks() {
127 let filter = ProxyFilter::allow_all();
128 let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
129 let result = filter.check_host_with_ips("corp.internal", &private_ip);
130 assert!(result.is_allowed());
131 }
132
133 #[test]
134 fn test_proxy_filter_denies_link_local() {
135 let filter = ProxyFilter::allow_all();
136 let link_local = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))];
137 let result = filter.check_host_with_ips("evil.com", &link_local);
138 assert!(!result.is_allowed());
139 }
140}