use crate::error::Result;
use nono::net_filter::{FilterResult, HostFilter};
use std::net::{IpAddr, SocketAddr};
use tracing::debug;
pub struct CheckResult {
pub result: FilterResult,
pub resolved_addrs: Vec<SocketAddr>,
}
#[derive(Debug, Clone)]
pub struct ProxyFilter {
inner: HostFilter,
}
impl ProxyFilter {
#[must_use]
pub fn new(allowed_hosts: &[String]) -> Self {
Self {
inner: HostFilter::new(allowed_hosts),
}
}
#[must_use]
pub fn allow_all() -> Self {
Self {
inner: HostFilter::allow_all(),
}
}
pub async fn check_host(&self, host: &str, port: u16) -> Result<CheckResult> {
let addr_str = format!("{}:{}", host, port);
let resolved: Vec<SocketAddr> = match tokio::net::lookup_host(&addr_str).await {
Ok(addrs) => addrs.collect(),
Err(e) => {
debug!("DNS resolution failed for {}: {}", host, e);
Vec::new()
}
};
let resolved_ips: Vec<IpAddr> = resolved.iter().map(|a| a.ip()).collect();
let result = self.inner.check_host(host, &resolved_ips);
let addrs = if result.is_allowed() {
resolved
} else {
Vec::new()
};
Ok(CheckResult {
result,
resolved_addrs: addrs,
})
}
#[must_use]
pub fn check_host_with_ips(&self, host: &str, resolved_ips: &[IpAddr]) -> FilterResult {
self.inner.check_host(host, resolved_ips)
}
#[must_use]
pub fn allowed_count(&self) -> usize {
self.inner.allowed_count()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn test_proxy_filter_delegates_to_host_filter() {
let filter = ProxyFilter::new(&["api.openai.com".to_string()]);
let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
let result = filter.check_host_with_ips("api.openai.com", &public_ip);
assert!(result.is_allowed());
let result = filter.check_host_with_ips("evil.com", &public_ip);
assert!(!result.is_allowed());
}
#[test]
fn test_proxy_filter_allow_all() {
let filter = ProxyFilter::allow_all();
let public_ip = vec![IpAddr::V4(Ipv4Addr::new(104, 18, 7, 96))];
let result = filter.check_host_with_ips("anything.com", &public_ip);
assert!(result.is_allowed());
}
#[test]
fn test_proxy_filter_allows_private_networks() {
let filter = ProxyFilter::allow_all();
let private_ip = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
let result = filter.check_host_with_ips("corp.internal", &private_ip);
assert!(result.is_allowed());
}
#[test]
fn test_proxy_filter_denies_link_local() {
let filter = ProxyFilter::allow_all();
let link_local = vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))];
let result = filter.check_host_with_ips("evil.com", &link_local);
assert!(!result.is_allowed());
}
}