Skip to main content

agent_fetch/
dns.rs

1use std::net::{IpAddr, SocketAddr};
2
3use hickory_resolver::TokioResolver;
4
5use crate::error::FetchError;
6use crate::ip_check::is_private_ip;
7
8/// DNS resolver that validates all resolved IPs against SSRF rules.
9pub struct SafeDnsResolver {
10    resolver: TokioResolver,
11    deny_private_ips: bool,
12}
13
14impl SafeDnsResolver {
15    pub fn new(deny_private_ips: bool) -> Self {
16        let resolver = TokioResolver::builder_tokio()
17            .expect("failed to read system DNS config")
18            .build();
19
20        Self {
21            resolver,
22            deny_private_ips,
23        }
24    }
25
26    /// Resolve a hostname and validate all returned IPs.
27    /// Returns the set of validated socket addresses.
28    pub async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>, FetchError> {
29        if let Ok(ip) = host.parse::<IpAddr>() {
30            if self.deny_private_ips && is_private_ip(ip) {
31                return Err(FetchError::PrivateIpBlocked {
32                    host: host.to_string(),
33                    resolved_ip: ip,
34                });
35            }
36            return Ok(vec![SocketAddr::new(ip, port)]);
37        }
38
39        let response =
40            self.resolver
41                .lookup_ip(host)
42                .await
43                .map_err(|e: hickory_resolver::ResolveError| {
44                    FetchError::DnsResolutionFailed(e.to_string())
45                })?;
46
47        let ips: Vec<IpAddr> = response.iter().collect();
48
49        if ips.is_empty() {
50            return Err(FetchError::DnsResolutionFailed(format!(
51                "no addresses found for {host}"
52            )));
53        }
54
55        if self.deny_private_ips {
56            for &ip in &ips {
57                if is_private_ip(ip) {
58                    return Err(FetchError::PrivateIpBlocked {
59                        host: host.to_string(),
60                        resolved_ip: ip,
61                    });
62                }
63            }
64        }
65
66        Ok(ips
67            .into_iter()
68            .map(|ip| SocketAddr::new(ip, port))
69            .collect())
70    }
71}