dnsclient/
sync.rs

1use std::io;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::time::Duration;
4
5use dnssector::constants::{Class, Type};
6use dnssector::*;
7use rand::{seq::SliceRandom, Rng};
8
9use crate::backend::sync::SyncBackend;
10use crate::upstream_server::UpstreamServer;
11
12#[derive(Clone, Debug)]
13pub struct DNSClient {
14    backend: SyncBackend,
15    upstream_servers: Vec<UpstreamServer>,
16    local_v4_addr: SocketAddr,
17    local_v6_addr: SocketAddr,
18    force_tcp: bool,
19}
20
21impl DNSClient {
22    pub fn new(upstream_servers: Vec<UpstreamServer>) -> Self {
23        DNSClient {
24            backend: SyncBackend::new(Duration::new(6, 0)),
25            upstream_servers,
26            local_v4_addr: ([0; 4], 0).into(),
27            local_v6_addr: ([0; 16], 0).into(),
28            force_tcp: false,
29        }
30    }
31
32    #[cfg(unix)]
33    pub fn new_with_system_resolvers() -> Result<Self, io::Error> {
34        Ok(DNSClient::new(crate::system::default_resolvers()?))
35    }
36
37    pub fn set_timeout(&mut self, timeout: Duration) {
38        self.backend.upstream_server_timeout = timeout
39    }
40
41    pub fn set_local_v4_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
42        self.local_v4_addr = addr.into()
43    }
44
45    pub fn set_local_v6_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
46        self.local_v6_addr = addr.into()
47    }
48
49    pub fn force_tcp(&mut self, force_tcp: bool) {
50        self.force_tcp = force_tcp;
51    }
52
53    fn send_query_to_upstream_server(
54        &self,
55        upstream_server: &UpstreamServer,
56        query_tid: u16,
57        query_question: &Option<(Vec<u8>, u16, u16)>,
58        query: &[u8],
59    ) -> Result<ParsedPacket, io::Error> {
60        let local_addr = match upstream_server.addr {
61            SocketAddr::V4(_) => &self.local_v4_addr,
62            SocketAddr::V6(_) => &self.local_v6_addr,
63        };
64        let response = if self.force_tcp {
65            self.backend
66                .dns_exchange_tcp(local_addr, upstream_server, query)?
67        } else {
68            self.backend
69                .dns_exchange_udp(local_addr, upstream_server, query)?
70        };
71        let mut parsed_response = DNSSector::new(response)
72            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
73            .parse()
74            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
75        if !self.force_tcp && parsed_response.flags() & DNS_FLAG_TC == DNS_FLAG_TC {
76            parsed_response = {
77                let response = self
78                    .backend
79                    .dns_exchange_tcp(local_addr, upstream_server, query)?;
80                DNSSector::new(response)
81                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
82                    .parse()
83                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
84            };
85        }
86        if parsed_response.tid() != query_tid || &parsed_response.question() != query_question {
87            return Err(io::Error::new(
88                io::ErrorKind::PermissionDenied,
89                "Unexpected response",
90            ));
91        }
92        Ok(parsed_response)
93    }
94
95    fn query_from_parsed_query(
96        &self,
97        mut parsed_query: ParsedPacket,
98    ) -> Result<ParsedPacket, io::Error> {
99        let query_tid = parsed_query.tid();
100        let query_question = parsed_query.question();
101        if query_question.is_none() || parsed_query.flags() & DNS_FLAG_QR != 0 {
102            return Err(io::Error::new(
103                io::ErrorKind::InvalidInput,
104                "No DNS question",
105            ));
106        }
107        let valid_query = parsed_query.into_packet();
108        for upstream_server in &self.upstream_servers {
109            if let Ok(parsed_response) = self.send_query_to_upstream_server(
110                upstream_server,
111                query_tid,
112                &query_question,
113                &valid_query,
114            ) {
115                return Ok(parsed_response);
116            }
117        }
118        Err(io::Error::new(
119            io::ErrorKind::InvalidInput,
120            "No response received from any servers",
121        ))
122    }
123
124    /// Send a raw query to the DNS server and return the response.
125    pub fn query_raw(&self, query: &[u8], tid_masking: bool) -> Result<Vec<u8>, io::Error> {
126        let mut parsed_query = DNSSector::new(query.to_vec())
127            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
128            .parse()
129            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
130        let mut tid = 0;
131        if tid_masking {
132            tid = parsed_query.tid();
133            let mut rnd = rand::rng();
134            let masked_tid: u16 = rnd.random();
135            parsed_query.set_tid(masked_tid);
136        }
137        let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
138        if tid_masking {
139            parsed_response.set_tid(tid);
140        }
141        let response = parsed_response.into_packet();
142        Ok(response)
143    }
144
145    /// Return IPv4 addresses.
146    pub fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, io::Error> {
147        let parsed_query = dnssector::gen::query(
148            name.as_bytes(),
149            Type::from_string("A").unwrap(),
150            Class::from_string("IN").unwrap(),
151        )
152        .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
153        let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
154        let mut ips = vec![];
155        {
156            let mut it = parsed_response.into_iter_answer();
157            while let Some(item) = it {
158                if let Ok(IpAddr::V4(addr)) = item.rr_ip() {
159                    ips.push(addr);
160                }
161                it = item.next();
162            }
163        }
164        ips.shuffle(&mut rand::rng());
165        Ok(ips)
166    }
167
168    /// Return IPv6 addresses.
169    pub fn query_aaaa(&self, name: &str) -> Result<Vec<Ipv6Addr>, io::Error> {
170        let parsed_query = dnssector::gen::query(
171            name.as_bytes(),
172            Type::from_string("AAAA").unwrap(),
173            Class::from_string("IN").unwrap(),
174        )
175        .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
176        let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
177        let mut ips = vec![];
178        {
179            let mut it = parsed_response.into_iter_answer();
180            while let Some(item) = it {
181                if let Ok(IpAddr::V6(addr)) = item.rr_ip() {
182                    ips.push(addr);
183                }
184                it = item.next();
185            }
186        }
187        ips.shuffle(&mut rand::rng());
188        Ok(ips)
189    }
190
191    /// Return both IPv4 and IPv6 addresses.
192    pub fn query_addrs(&self, name: &str) -> Result<Vec<IpAddr>, io::Error> {
193        let ipv4_ips = self.query_a(name)?;
194        let ipv6_ips = self.query_aaaa(name)?;
195        let mut ips: Vec<_> = ipv4_ips
196            .into_iter()
197            .map(IpAddr::from)
198            .chain(ipv6_ips.into_iter().map(IpAddr::from))
199            .collect();
200        ips.shuffle(&mut rand::rng());
201        Ok(ips)
202    }
203
204    /// Return TXT records.
205    pub fn query_txt(&self, name: &str) -> Result<Vec<Vec<u8>>, io::Error> {
206        let rr_class = Class::from_string("IN").unwrap();
207        let rr_type = Type::from_string("TXT").unwrap();
208        let parsed_query = dnssector::gen::query(name.as_bytes(), rr_type, rr_class)
209            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
210        let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
211        let mut txts: Vec<Vec<u8>> = vec![];
212
213        let mut it = parsed_response.into_iter_answer();
214        while let Some(item) = it {
215            if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
216                it = item.next();
217                continue;
218            }
219            if let Ok(RawRRData::Data(data)) = item.rr_rd() {
220                let mut txt = vec![];
221                let mut it = data.iter();
222                while let Some(&len) = it.next() {
223                    for _ in 0..len {
224                        txt.push(*it.next().ok_or_else(|| {
225                            io::Error::new(io::ErrorKind::InvalidInput, "Invalid text record")
226                        })?)
227                    }
228                }
229                txts.push(txt);
230            }
231            it = item.next();
232        }
233        Ok(txts)
234    }
235
236    /// Reverse IP lookup.
237    pub fn query_ptr(&self, ip: &IpAddr) -> Result<Vec<String>, io::Error> {
238        let rr_class = Class::from_string("IN").unwrap();
239        let rr_type = Type::from_string("PTR").unwrap();
240        let rev_name = match ip {
241            IpAddr::V4(ip) => {
242                let mut octets = ip.octets();
243                octets.reverse();
244                format!(
245                    "{}.{}.{}.{}.in-addr.arpa",
246                    octets[0], octets[1], octets[2], octets[3]
247                )
248            }
249            IpAddr::V6(ip) => {
250                let mut octets = ip.octets();
251                octets.reverse();
252                let rev = octets
253                    .iter()
254                    .map(|x| x.to_string())
255                    .collect::<Vec<_>>()
256                    .join(".");
257                format!("{}.ip6.arpa", rev)
258            }
259        };
260        let parsed_query = dnssector::gen::query(rev_name.as_bytes(), rr_type, rr_class)
261            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
262        let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
263        let mut names: Vec<String> = vec![];
264
265        let mut it = parsed_response.into_iter_answer();
266        while let Some(item) = it {
267            if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
268                it = item.next();
269                continue;
270            }
271            if let Ok(RawRRData::Data(data)) = item.rr_rd() {
272                let mut name = vec![];
273                let mut it = data.iter();
274                while let Some(&len) = it.next() {
275                    if len != 0 && !name.is_empty() {
276                        name.push(b'.');
277                    }
278                    for _ in 0..len {
279                        name.push(*it.next().ok_or_else(|| {
280                            io::Error::new(io::ErrorKind::InvalidInput, "Invalid text record")
281                        })?)
282                    }
283                }
284                if name.is_empty() {
285                    name.push(b'.');
286                }
287                if let Ok(name) = String::from_utf8(name) {
288                    match ip {
289                        IpAddr::V4(ip) => {
290                            if self.query_a(&name)?.contains(ip) {
291                                names.push(name)
292                            }
293                        }
294                        IpAddr::V6(ip) => {
295                            if self.query_aaaa(&name)?.contains(ip) {
296                                names.push(name)
297                            }
298                        }
299                    };
300                }
301            }
302            it = item.next();
303        }
304        Ok(names)
305    }
306
307    /// Return the raw record data for the given query type.
308    pub fn query_rrs_data(
309        &self,
310        name: &str,
311        query_class: &str,
312        query_type: &str,
313    ) -> Result<Vec<Vec<u8>>, io::Error> {
314        let rr_class = Class::from_string(query_class)
315            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
316        let rr_type = Type::from_string(query_type)
317            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
318        let parsed_query = dnssector::gen::query(name.as_bytes(), rr_type, rr_class)
319            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
320        let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
321        let mut raw_rrs = vec![];
322
323        let mut it = parsed_response.into_iter_answer();
324        while let Some(item) = it {
325            if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
326                it = item.next();
327                continue;
328            }
329            if let Ok(RawRRData::Data(data)) = item.rr_rd() {
330                raw_rrs.push(data.to_vec());
331            }
332            it = item.next();
333        }
334        Ok(raw_rrs)
335    }
336}
337
338#[test]
339fn test_query_a() {
340    use std::str::FromStr;
341
342    let upstream_servers = crate::system::default_resolvers().unwrap_or_else(|_| {
343        vec![
344            UpstreamServer::new(SocketAddr::from_str("1.0.0.1:53").unwrap()),
345            UpstreamServer::new(SocketAddr::from_str("1.1.1.1:53").unwrap()),
346        ]
347    });
348    let dns_client = DNSClient::new(upstream_servers);
349    let r = dns_client.query_a("one.one.one.one").unwrap();
350    assert!(r.contains(&Ipv4Addr::new(1, 1, 1, 1)));
351}
352
353#[test]
354fn test_query_ptr() {
355    use std::str::FromStr;
356
357    let upstream_servers = crate::system::default_resolvers().unwrap_or_else(|_| {
358        vec![
359            UpstreamServer::new(SocketAddr::from_str("1.0.0.1:53").unwrap()),
360            UpstreamServer::new(SocketAddr::from_str("1.1.1.1:53").unwrap()),
361        ]
362    });
363    let dns_client = DNSClient::new(upstream_servers);
364    let r = dns_client
365        .query_ptr(&IpAddr::from_str("1.1.1.1").unwrap())
366        .unwrap();
367    assert_eq!(r[0], "one.one.one.one");
368}