1use crate::error::Result;
8use nono::net_filter::{FilterResult, HostFilter};
9use std::net::{IpAddr, SocketAddr};
10use tracing::debug;
11
12pub struct CheckResult {
18 pub result: FilterResult,
20 pub resolved_addrs: Vec<SocketAddr>,
22}
23
24#[derive(Debug, Clone)]
26pub struct ProxyFilter {
27 inner: HostFilter,
28}
29
30impl ProxyFilter {
31 #[must_use]
33 pub fn new(allowed_hosts: &[String]) -> Self {
34 Self {
35 inner: HostFilter::new(allowed_hosts),
36 }
37 }
38
39 #[must_use]
41 pub fn new_strict(allowed_hosts: &[String]) -> Self {
42 Self {
43 inner: HostFilter::new_strict(allowed_hosts),
44 }
45 }
46
47 #[must_use]
49 pub fn allow_all() -> Self {
50 Self {
51 inner: HostFilter::allow_all(),
52 }
53 }
54
55 pub async fn check_host(&self, host: &str, port: u16) -> Result<CheckResult> {
66 let addr_str = format!("{}:{}", host, port);
68 let resolved: Vec<SocketAddr> = match tokio::net::lookup_host(&addr_str).await {
69 Ok(addrs) => addrs.collect(),
70 Err(e) => {
71 debug!("DNS resolution failed for {}: {}", host, e);
72 Vec::new()
75 }
76 };
77
78 let resolved_ips: Vec<IpAddr> = resolved.iter().map(|a| a.ip()).collect();
79 let result = self.inner.check_host(host, &resolved_ips);
80
81 let addrs = if result.is_allowed() {
83 resolved
84 } else {
85 Vec::new()
86 };
87
88 Ok(CheckResult {
89 result,
90 resolved_addrs: addrs,
91 })
92 }
93
94 #[must_use]
96 pub fn check_host_with_ips(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
97 self.inner.check_host(host, resolved_ips)
98 }
99
100 #[must_use]
102 pub fn allowed_count(&self) -> usize {
103 self.inner.allowed_count()
104 }
105}
106
107#[cfg(test)]
108#[allow(clippy::unwrap_used)]
109mod tests {
110 use super::*;
111 use std::net::Ipv4Addr;
112
113 #[test]
114 fn test_proxy_filter_delegates_to_host_filter() {
115 let filter = ProxyFilter::new(&["api.openai.com".to_string()]);
116 let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
117
118 let result = filter.check_host_with_ips("api.openai.com", &public_ip);
119 assert!(result.is_allowed());
120
121 let result = filter.check_host_with_ips("evil.com", &public_ip);
122 assert!(!result.is_allowed());
123 }
124
125 #[test]
126 fn test_proxy_filter_allow_all() {
127 let filter = ProxyFilter::allow_all();
128 let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
129 let result = filter.check_host_with_ips("anything.com", &public_ip);
130 assert!(result.is_allowed());
131 }
132
133 #[test]
134 fn test_proxy_filter_allows_private_networks() {
135 let filter = ProxyFilter::allow_all();
136 let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
137 let result = filter.check_host_with_ips("corp.internal", &private_ip);
138 assert!(result.is_allowed());
139 }
140
141 #[test]
142 fn test_proxy_filter_denies_link_local() {
143 let filter = ProxyFilter::allow_all();
144 let link_local = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))];
145 let result = filter.check_host_with_ips("evil.com", &link_local);
146 assert!(!result.is_allowed());
147 }
148}