1use std::net::{IpAddr, SocketAddr};
2
3use hickory_resolver::TokioResolver;
4
5use crate::error::FetchError;
6use crate::ip_check::is_private_ip;
7
8pub struct SafeDnsResolver {
10 resolver: TokioResolver,
11 deny_private_ips: bool,
12}
13
14impl SafeDnsResolver {
15 pub fn new(deny_private_ips: bool) -> Self {
16 let resolver = TokioResolver::builder_tokio()
17 .expect("failed to read system DNS config")
18 .build();
19
20 Self {
21 resolver,
22 deny_private_ips,
23 }
24 }
25
26 pub async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>, FetchError> {
29 if let Ok(ip) = host.parse::<IpAddr>() {
30 if self.deny_private_ips && is_private_ip(ip) {
31 return Err(FetchError::PrivateIpBlocked {
32 host: host.to_string(),
33 resolved_ip: ip,
34 });
35 }
36 return Ok(vec![SocketAddr::new(ip, port)]);
37 }
38
39 let response =
40 self.resolver
41 .lookup_ip(host)
42 .await
43 .map_err(|e: hickory_resolver::ResolveError| {
44 FetchError::DnsResolutionFailed(e.to_string())
45 })?;
46
47 let ips: Vec<IpAddr> = response.iter().collect();
48
49 if ips.is_empty() {
50 return Err(FetchError::DnsResolutionFailed(format!(
51 "no addresses found for {host}"
52 )));
53 }
54
55 if self.deny_private_ips {
56 for &ip in &ips {
57 if is_private_ip(ip) {
58 return Err(FetchError::PrivateIpBlocked {
59 host: host.to_string(),
60 resolved_ip: ip,
61 });
62 }
63 }
64 }
65
66 Ok(ips
67 .into_iter()
68 .map(|ip| SocketAddr::new(ip, port))
69 .collect())
70 }
71}