dnsclientx/
resolve.rs

1#[cfg(feature = "smol-async")]
2use crate::smol_async::{query_raw_tcp, query_raw_udp};
3#[cfg(feature = "std-async")]
4use crate::std_async::{query_raw_tcp, query_raw_udp};
5#[cfg(feature = "sync")]
6use crate::sync::{query_raw_tcp, query_raw_udp};
7#[cfg(feature = "tokio-async")]
8use crate::tokio_async::{query_raw_tcp, query_raw_udp};
9use crate::{err::as_io_error, reverse::reverse_dns_query, tcp::tcp_query};
10use dnssector::constants::{Class, Type};
11use dnssector::*;
12use std::{
13    io::{self, Error, ErrorKind},
14    net::{IpAddr, SocketAddr},
15    time::Duration,
16};
17
18pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
19
20/// A DNS Client.
21/// A simple DNS Client.
22///
23/// # Example
24/// ```
25/// # use std::net::SocketAddr;
26/// # use std::str::FromStr;
27/// use dnsclientx::DNSClient;
28///
29/// let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
30/// let dns_client = DNSClient::new(nameservers);
31/// let ips = dns_client.query_a("one.one.one.one").unwrap();
32///
33/// let expected = "1.1.1.1".parse().unwrap();
34/// assert!(ips.contains(&expected));
35/// ```
36#[derive(Clone, Debug)]
37pub struct DNSClient {
38    upstream_server_timeout: Duration,
39    upstream_servers: Vec<SocketAddr>,
40    local_v4_addr: SocketAddr,
41    local_v6_addr: SocketAddr,
42}
43
44impl DNSClient {
45    /// Create a new DNSClient.
46    /// # Example
47    /// ```
48    /// # use std::net::SocketAddr;
49    /// # use std::str::FromStr;
50    /// use dnsclientx::DNSClient;
51    ///
52    /// let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
53    /// let dns_client = DNSClient::new(nameservers);
54    /// ```
55    pub fn new(upstream_servers: Vec<SocketAddr>) -> Self {
56        DNSClient {
57            upstream_server_timeout: DEFAULT_TIMEOUT,
58            upstream_servers,
59            local_v4_addr: ([0; 4], 0).into(),
60            local_v6_addr: ([0; 16], 0).into(),
61        }
62    }
63
64    /// Set the timeout used for DNS requests.
65    /// # Example
66    /// ```
67    /// # use std::net::SocketAddr;
68    /// # use std::str::FromStr;
69    /// # use dnsclientx::DNSClient;
70    /// # use std::time::Duration;
71    /// # let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
72    /// let mut dns_client = DNSClient::new(nameservers);
73    /// dns_client.set_timeout(Duration::from_secs(2));
74    /// ```
75    pub fn set_timeout(&mut self, timeout: Duration) {
76        self.upstream_server_timeout = timeout
77    }
78
79    /// Set the local IPV4 socket address for use with UDP queries.
80    /// # Example
81    /// ```
82    /// # use std::net::SocketAddr;
83    /// # use std::str::FromStr;
84    /// # use dnsclientx::DNSClient;
85    /// # let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
86    /// let mut dns_client = DNSClient::new(nameservers);
87    /// dns_client.set_local_v4_addr(([192, 168, 1, 28], 1234));
88    /// ```
89    pub fn set_local_v4_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
90        self.local_v4_addr = addr.into()
91    }
92
93    /// Set the local IPV6 socket address for use with UDP queries.
94    /// # Example
95    /// ```
96    /// # use std::net::SocketAddr;
97    /// # use std::str::FromStr;
98    /// # use dnsclientx::DNSClient;
99    /// # let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
100    /// let mut dns_client = DNSClient::new(nameservers);
101    /// dns_client.set_local_v6_addr(([0; 16], 1234));
102    /// ```
103    pub fn set_local_v6_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
104        self.local_v6_addr = addr.into()
105    }
106
107    /// Get the IPV4 address of the given domain name.
108    ///
109    /// Returns an empty Vec if no addresses were found.
110    /// # Example
111    /// ```
112    /// # use std::net::SocketAddr;
113    /// # use std::str::FromStr;
114    /// use dnsclientx::DNSClient;
115    ///
116    /// let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
117    /// let dns_client = DNSClient::new(nameservers);
118    /// let ips = dns_client.query_a("one.one.one.one").unwrap();
119    ///
120    /// let expected = "1.1.1.1".parse().unwrap();
121    /// assert!(ips.contains(&expected));
122    /// ```
123    #[maybe_async::maybe_async]
124    pub async fn query_a(&self, name: &str) -> io::Result<Vec<IpAddr>> {
125        let name = encode_name(name)?;
126        let query = dnssector::r#gen::query(name.as_bytes(), Type::A, Class::IN)
127            .map_err(as_io_error(ErrorKind::InvalidInput))?;
128        let response = self.query(query).await?;
129        extract_ips(response)
130    }
131
132    /// Get the IPV6 address of the given domain name.
133    ///
134    /// Returns an empty Vec if no addresses were found.
135    /// # Example
136    /// ```
137    /// # use std::net::SocketAddr;
138    /// # use std::str::FromStr;
139    /// use dnsclientx::DNSClient;
140    ///
141    /// let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
142    /// let dns_client = DNSClient::new(nameservers);
143    /// let ips = dns_client.query_aaaa("one.one.one.one").unwrap();
144    ///
145    /// let expected = "2606:4700:4700::1001".parse().unwrap();
146    /// assert!(ips.contains(&expected));
147    /// ```
148    #[maybe_async::maybe_async]
149    pub async fn query_aaaa(&self, name: &str) -> io::Result<Vec<IpAddr>> {
150        let name = encode_name(name)?;
151        let query = dnssector::r#gen::query(name.as_bytes(), Type::AAAA, Class::IN)
152            .map_err(as_io_error(ErrorKind::InvalidInput))?;
153        let response = self.query(query).await?;
154        extract_ips(response)
155    }
156
157    /// Do a reverse lookup on the given IPV4 or IPV6 address.
158    ///
159    /// # Examples
160    /// ```
161    /// # use std::net::SocketAddr;
162    /// # use std::str::FromStr;
163    /// # use dnsclientx::DNSClient;
164    /// # let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
165    /// let dns_client = DNSClient::new(nameservers);
166    /// let ip  = "1.1.1.1".parse().unwrap();
167    /// let name = dns_client.query_ptr(ip).unwrap();
168    ///
169    /// assert!(name == "one.one.one.one");
170    /// ```
171    /// Returns an error if no name exists.
172    /// ```
173    /// # use std::net::SocketAddr;
174    /// # use std::str::FromStr;
175    /// # use std::matches;
176    /// # use dnsclientx::DNSClient;
177    /// # let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
178    /// let dns_client = DNSClient::new(nameservers);
179    /// // An IP address that has no name
180    /// let ip  = "1.2.3.4".parse().unwrap();
181    /// let res = dns_client.query_ptr(ip);
182    ///
183    /// assert!(matches!(res, Err(e) if e.kind() == std::io::ErrorKind::NotFound));
184    /// ```
185    #[maybe_async::maybe_async]
186    pub async fn query_ptr(&self, ip: IpAddr) -> io::Result<String> {
187        let in_addr = reverse_dns_query(ip);
188        let query = dnssector::r#gen::query(&in_addr, Type::PTR, Class::IN)
189            .map_err(as_io_error(ErrorKind::InvalidInput))?;
190        let response = self.query(query).await?;
191        extract_names(response).map(|mut v| v.remove(0))
192    }
193
194    /// Get the name servers for the given domain.
195    ///
196    /// # Examples
197    /// ```
198    /// # use std::net::SocketAddr;
199    /// # use std::str::FromStr;
200    /// # use dnsclientx::DNSClient;
201    /// # let nameservers = vec![SocketAddr::from_str("1.0.0.1:53").unwrap()];
202    /// let dns_client = DNSClient::new(nameservers);
203    /// let ns = dns_client.query_ns("one.one.one").unwrap();
204    /// ```
205    #[maybe_async::maybe_async]
206    pub async fn query_ns(&self, domain: &str) -> io::Result<Vec<String>> {
207        let query = dnssector::r#gen::query(domain.as_bytes(), Type::NS, Class::IN)
208            .map_err(as_io_error(ErrorKind::InvalidInput))?;
209        let response = self.query(query).await?;
210        extract_names(response).or_else(|e| {
211            if e.kind() == ErrorKind::NotFound {
212                Ok(Vec::new())
213            } else {
214                Err(e)
215            }
216        })
217    }
218
219    #[maybe_async::maybe_async]
220    async fn query(&self, packet: ParsedPacket) -> io::Result<ParsedPacket> {
221        let is_compressed = matches!(
222            packet.qtype_qclass(),
223            Some((rr_type, _class)) if rr_type == Type::NS as u16
224        );
225        let raw_packet = packet.into_packet();
226        for i in 0..self.upstream_servers.len() {
227            let response = self
228                .query_upstream(&self.upstream_servers[i], &raw_packet, is_compressed)
229                .await;
230            if response.is_ok() || i >= self.upstream_servers.len() - 1 {
231                return response;
232            }
233        }
234        unreachable!("query must be ok or err");
235    }
236
237    #[maybe_async::maybe_async]
238    async fn query_upstream(
239        &self,
240        upstream: &SocketAddr,
241        packet: &[u8],
242        is_compressed_response: bool,
243    ) -> io::Result<ParsedPacket> {
244        let local_addr = match upstream {
245            SocketAddr::V4(_) => &self.local_v4_addr,
246            SocketAddr::V6(_) => &self.local_v6_addr,
247        };
248        let raw_response =
249            query_raw_udp(local_addr, upstream, packet, self.upstream_server_timeout).await?;
250        let response = parse_response(raw_response, is_compressed_response)?;
251        if response.flags() & DNS_FLAG_TC != DNS_FLAG_TC {
252            return Ok(response);
253        }
254        // If this point is reached -- upgrade to TCP
255        let tcp_packet = tcp_query(packet);
256        let raw_response =
257            query_raw_tcp(upstream, &tcp_packet, self.upstream_server_timeout).await?;
258        parse_response(raw_response, is_compressed_response)
259    }
260}
261
262fn parse_response(raw: Vec<u8>, is_compressed: bool) -> io::Result<ParsedPacket> {
263    let mut raw_response = raw;
264    if is_compressed {
265        raw_response =
266            Compress::uncompress(&raw_response).map_err(as_io_error(ErrorKind::InvalidData))?;
267    }
268    DNSSector::new(raw_response)
269        .map_err(as_io_error(ErrorKind::InvalidData))?
270        .parse()
271        .map_err(as_io_error(ErrorKind::InvalidData))
272}
273
274fn extract_ips(mut packet: ParsedPacket) -> io::Result<Vec<IpAddr>> {
275    use std::result::Result as StdResult;
276
277    let mut ips = Vec::new();
278    let mut response = packet.into_iter_answer();
279    while let Some(i) = response {
280        ips.push(i.rr_ip());
281        response = i.next();
282    }
283    let (ips, errors): (Vec<_>, Vec<_>) = ips.into_iter().partition(StdResult::is_ok);
284    if ips.is_empty() {
285        if let Some(Err(e)) = errors.into_iter().next() {
286            return Err(Error::new(ErrorKind::InvalidData, e));
287        }
288    }
289    let ips: Vec<_> = ips.into_iter().map(StdResult::unwrap).collect();
290    Ok(ips)
291}
292
293fn extract_names(mut packet: ParsedPacket) -> io::Result<Vec<String>> {
294    let mut response = packet.into_iter_answer();
295    let mut ret = Vec::new();
296    while let Some(i) = response {
297        let raw_name = &i.rdata_slice()[DNS_RR_HEADER_SIZE..];
298        let name = parse_tlv_name(raw_name);
299        ret.push(name);
300        response = i.next();
301    }
302    if ret.is_empty() {
303        return Err(ErrorKind::NotFound.into());
304    }
305    ret.iter().map(|i| decode_name(i)).collect()
306}
307
308fn parse_tlv_name(raw: &[u8]) -> Vec<u8> {
309    let mut result = Vec::with_capacity(raw.len());
310    let mut i = 0;
311    let mut remaining = 0;
312    while i < raw.len() && raw[i] != 0 {
313        if remaining == 0 {
314            remaining = raw[i];
315            if i > 0 {
316                result.push(b'.')
317            }
318        } else {
319            result.push(raw[i]);
320            remaining -= 1;
321        }
322        i += 1;
323    }
324    result
325}
326
327fn encode_name(name: &str) -> io::Result<String> {
328    let parts: io::Result<Vec<String>> = name
329        .split('.')
330        .map(|part| {
331            if part.is_ascii() {
332                Ok(part.to_string())
333            } else {
334                unic_idna_punycode::encode_str(part)
335                    .map(|s| "xn--".to_string() + &s)
336                    .ok_or_else(|| ErrorKind::InvalidInput.into())
337            }
338        })
339        .collect();
340    let parts = parts?;
341    let ret = parts.join(".");
342    Ok(ret)
343}
344
345fn decode_name(name: &[u8]) -> io::Result<String> {
346    let parts: io::Result<Vec<String>> = name
347        .split(|ch| *ch == b'.')
348        .map(|part| {
349            if let Some(code) = part.strip_prefix(b"xn--") {
350                String::from_utf8(code.to_vec())
351                    .map_err(as_io_error(ErrorKind::InvalidData))
352                    .and_then(|code| {
353                        unic_idna_punycode::decode_to_string(&code)
354                            .ok_or_else(|| ErrorKind::InvalidData.into())
355                    })
356            } else {
357                String::from_utf8(part.to_vec()).map_err(as_io_error(ErrorKind::InvalidData))
358            }
359        })
360        .collect();
361    let parts = parts?;
362    let ret = parts.join(".");
363    Ok(ret)
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    #[cfg(not(feature = "sync"))]
370    use std::future::Future;
371    use std::{
372        net::{Ipv4Addr, Ipv6Addr},
373        str::FromStr,
374    };
375
376    const EXAMPLE_FQDN: &str = "one.one.one.one";
377    const EXAMPLE_DOMAIN: &str = "one.one.one";
378    const EXAMPLE_DOMAIN_NS: &str = "ns.cloudflare.com";
379    const EXAMPLE_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
380    const EXAMPLE_IPV6: IpAddr =
381        IpAddr::V6(Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111));
382    const EXAMPLE_IDN: &str = "日本.icom.museum";
383    const EXAMPLE_IDN_PUNYCODE: &str = "xn--wgv71a.icom.museum";
384    const EXAMPLE_IDN_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(81, 201, 190, 55));
385
386    #[cfg(feature = "std-async")]
387    fn block_on<F: Future>(future: F) -> F::Output {
388        use async_std::task;
389        task::block_on(future)
390    }
391
392    #[cfg(feature = "smol-async")]
393    fn block_on<F: Future>(future: F) -> F::Output {
394        smol::block_on(future)
395    }
396
397    #[cfg(feature = "tokio-async")]
398    fn block_on<F: Future>(future: F) -> F::Output {
399        use tokio::runtime;
400        let rt = runtime::Builder::new_current_thread()
401            .enable_time()
402            .enable_io()
403            .build()
404            .unwrap();
405        rt.block_on(future)
406    }
407
408    #[cfg(not(feature = "sync"))]
409    macro_rules! block_on {
410        ($b:expr) => {
411            block_on(async move { $b.await })
412        };
413    }
414
415    #[cfg(feature = "sync")]
416    macro_rules! block_on {
417        ($b:expr) => {
418            $b
419        };
420    }
421
422    fn dns_servers() -> Vec<SocketAddr> {
423        vec![
424            SocketAddr::from_str("1.0.0.1:53").unwrap(),
425            SocketAddr::from_str("1.1.1.1:53").unwrap(),
426        ]
427    }
428
429    fn slow_dns_servers() -> Vec<SocketAddr> {
430        vec![
431            // Yerevan
432            SocketAddr::from_str("109.75.41.201:53").unwrap(),
433            // NTT Japan
434            SocketAddr::from_str("124.99.9.4:53").unwrap(),
435        ]
436    }
437
438    #[test]
439    fn query_a() {
440        let dns_client = DNSClient::new(dns_servers());
441        let r = block_on!(dns_client.query_a(EXAMPLE_FQDN)).unwrap();
442        let expected = EXAMPLE_IPV4;
443        assert!(r.contains(&expected), "Expected {} got {:?}", expected, r);
444    }
445
446    #[test]
447    fn query_timeout() {
448        let mut dns_client = DNSClient::new(slow_dns_servers());
449        dns_client.set_timeout(Duration::from_millis(1));
450        let r = block_on!(dns_client.query_a(EXAMPLE_FQDN));
451        assert!(
452            matches!(&r, Err(e) if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock),
453            "Expected timout got {:?}",
454            r,
455        );
456    }
457
458    #[test]
459    fn query_utf8() {
460        let dns_client = DNSClient::new(dns_servers());
461        let jp_res = block_on!(dns_client.query_a(EXAMPLE_IDN)).unwrap();
462        let expected = EXAMPLE_IDN_IP;
463        assert!(
464            jp_res.contains(&expected),
465            "Expected {} for {} got {:?}",
466            expected,
467            EXAMPLE_IDN,
468            jp_res
469        );
470    }
471
472    #[test]
473    fn query_aaaa() {
474        let dns_client = DNSClient::new(dns_servers());
475        let r = block_on!(dns_client.query_aaaa(EXAMPLE_FQDN)).unwrap();
476        let expected = EXAMPLE_IPV6;
477        assert!(r.contains(&expected), "Expected {} got {:?}", expected, r);
478    }
479
480    #[test]
481    fn query_ptr_ipv4() {
482        let dns_client = DNSClient::new(dns_servers());
483        let r = block_on!(dns_client.query_ptr(EXAMPLE_IPV4)).unwrap();
484        let expected = EXAMPLE_FQDN;
485        assert!(r == expected, "Expected {} got {:?}", expected, r);
486    }
487
488    #[test]
489    fn query_ptr_ipv6() {
490        let dns_client = DNSClient::new(dns_servers());
491        let r = block_on!(dns_client.query_ptr(EXAMPLE_IPV6)).unwrap();
492        let expected = EXAMPLE_FQDN;
493        assert!(r == expected, "Expected {} got {:?}", expected, r);
494    }
495
496    #[test]
497    fn query_ptr_utf8() {
498        // Are there any real examples of this?
499        // For now, just test the puny decoder.
500        let r = decode_name(EXAMPLE_IDN_PUNYCODE.as_bytes()).unwrap();
501        let expected = EXAMPLE_IDN;
502        assert!(r == expected, "Expected {} got {:?}", expected, r);
503    }
504
505    #[test]
506    fn query_ns() {
507        let dns_client = DNSClient::new(dns_servers());
508        let r = block_on!(dns_client.query_ns(EXAMPLE_DOMAIN)).unwrap();
509        assert!(
510            r.iter().any(|n| n.ends_with(EXAMPLE_DOMAIN_NS)),
511            "Expected {} got {:?}",
512            EXAMPLE_DOMAIN_NS,
513            r
514        );
515    }
516}