doh_dns/
dns.rs

1use crate::client::{DnsClient, HyperDnsClient};
2use crate::error::{DnsError, QueryError};
3use crate::status::RCode;
4use crate::{Dns, DnsAnswer, DnsHttpsServer, DnsResponse};
5use hyper::Uri;
6use idna;
7use log::error;
8use std::time::Duration;
9use tokio::time::timeout;
10
11impl Default for Dns<HyperDnsClient> {
12    fn default() -> Dns<HyperDnsClient> {
13        Dns {
14            client: HyperDnsClient::default(),
15            servers: vec![
16                DnsHttpsServer::Google(Duration::from_secs(3)),
17                DnsHttpsServer::Cloudflare1_1_1_1(Duration::from_secs(10)),
18            ],
19        }
20    }
21}
22
23impl<C: DnsClient> Dns<C> {
24    /// Creates an instance with the given servers along with their respective timeouts
25    /// (in seconds). These servers are tried in the given order. If a request fails on
26    /// the first one, each subsequent server is tried. Only on certain failures a new
27    /// request is retried such as a connection failure or certain server return codes.
28    pub fn with_servers(servers: &[DnsHttpsServer]) -> Result<Dns<C>, DnsError> {
29        if servers.is_empty() {
30            return Err(DnsError::NoServers);
31        }
32        Ok(Dns {
33            client: C::default(),
34            servers: servers.to_vec(),
35        })
36    }
37
38    /// Returns MX records in order of priority for the given name. It removes the priorities
39    /// from the data.
40    pub async fn resolve_mx_and_sort(&self, domain: &str) -> Result<Vec<DnsAnswer>, DnsError> {
41        match self.client_request(domain, &RTYPE_mx).await {
42            Err(e) => Err(DnsError::Query(e)),
43            Ok(res) => match num::FromPrimitive::from_u32(res.Status) {
44                Some(RCode::NoError) => {
45                    let mut mxs = res
46                        .Answer
47                        .unwrap_or_else(|| vec![])
48                        .iter()
49                        .filter_map(|a| {
50                            // Get only MX records.
51                            if a.r#type == RTYPE_mx.0 {
52                                // Get only the records that have a priority.
53                                let mut parts = a.data.split_ascii_whitespace();
54                                if let Some(part_1) = parts.next() {
55                                    // Convert priority to an integer.
56                                    if let Ok(priority) = part_1.parse::<u32>() {
57                                        if let Some(mx) = parts.next() {
58                                            // Change data from "priority name" -> "name".
59                                            let mut m = a.clone();
60                                            m.data = mx.to_string();
61                                            return Some((m, priority));
62                                        }
63                                    }
64                                }
65                            }
66                            None
67                        })
68                        .collect::<Vec<_>>();
69                    // Order MX records by priority.
70                    mxs.sort_unstable_by_key(|x| x.1);
71                    Ok(mxs.into_iter().map(|x| x.0).collect())
72                }
73                Some(code) => Err(DnsError::Status(code)),
74                None => Err(DnsError::Status(RCode::Unknown)),
75            },
76        }
77    }
78
79    // Generates the DNS over HTTPS request on the given name for rtype. It filters out
80    // results that are not of the given rtype with the exception of `ANY`.
81    async fn request_and_process(
82        &self,
83        name: &str,
84        rtype: &Rtype,
85    ) -> Result<Vec<DnsAnswer>, DnsError> {
86        match self.client_request(name, rtype).await {
87            Err(e) => Err(DnsError::Query(e)),
88            Ok(res) => match num::FromPrimitive::from_u32(res.Status) {
89                Some(RCode::NoError) => Ok(res
90                    .Answer
91                    .unwrap_or_else(|| vec![])
92                    .into_iter()
93                    // Get only the record types requested. There is only exception and that is
94                    // the ANY record which has a value of 0.
95                    .filter(|a| a.r#type == rtype.0 || rtype.0 == 0)
96                    .collect::<Vec<_>>()),
97                Some(code) => Err(DnsError::Status(code)),
98                None => Err(DnsError::Status(RCode::Unknown)),
99            },
100        }
101    }
102
103    // Creates the HTTPS request to the server. In certain occasions, it retries to a new server
104    // if one is available.
105    async fn client_request(&self, name: &str, rtype: &Rtype) -> Result<DnsResponse, QueryError> {
106        // Name has to be puny encoded.
107        let name = match idna::domain_to_ascii(name) {
108            Ok(name) => name,
109            Err(e) => return Err(QueryError::InvalidName(format!("{:?}", e))),
110        };
111        let mut error = QueryError::Unknown;
112        for server in self.servers.iter() {
113            let url = format!("{}?name={}&type={}", server.uri(), name, rtype.1);
114            let endpoint = match url.parse::<Uri>() {
115                Err(e) => return Err(QueryError::InvalidEndpoint(e.to_string())),
116                Ok(endpoint) => endpoint,
117            };
118
119            error = match timeout(server.timeout(), self.client.get(endpoint)).await {
120                Ok(Err(e)) => QueryError::Connection(e.to_string()),
121                Ok(Ok(res)) => {
122                    match res.status().as_u16() {
123                        200 => match hyper::body::to_bytes(res).await {
124                            Err(e) => QueryError::ReadResponse(e.to_string()),
125                            Ok(body) => match serde_json::from_slice::<DnsResponse>(&body) {
126                                Err(e) => QueryError::ParseResponse(e.to_string()),
127                                Ok(res) => {
128                                    return Ok(res);
129                                }
130                            },
131                        },
132                        400 => return Err(QueryError::BadRequest400),
133                        413 => return Err(QueryError::PayloadTooLarge413),
134                        414 => return Err(QueryError::UriTooLong414),
135                        415 => return Err(QueryError::UnsupportedMediaType415),
136                        501 => return Err(QueryError::NotImplemented501),
137                        // If the following errors occur, the request will be retried on
138                        // the next server if one is available.
139                        429 => QueryError::TooManyRequests429,
140                        500 => QueryError::InternalServerError500,
141                        502 => QueryError::BadGateway502,
142                        504 => QueryError::ResolverTimeout504,
143                        _ => QueryError::Unknown,
144                    }
145                }
146                Err(_) => QueryError::Connection(format!(
147                    "connection timeout after {:?}",
148                    server.timeout()
149                )),
150            };
151            error!("request error on URL {}: {}", url, error);
152        }
153        Err(error)
154    }
155}
156
157struct Rtype(pub u32, pub &'static str);
158
159macro_rules! rtypes {
160    (
161        $(
162            $(#[$docs:meta])*
163            ($konst:ident, $num:expr);
164        )+
165    ) => {
166        paste::item! {
167            impl<C: DnsClient> Dns<C> {
168                $(
169                    $(#[$docs])*
170                    pub async fn [<resolve_ $konst>](&self, name: &str) -> Result<Vec<DnsAnswer>, DnsError> {
171                        self.request_and_process(name, &[<RTYPE_ $konst>]).await
172                    }
173                )+
174
175                pub async fn resolve_str_type(&self, name: &str, rtype: &str) -> Result<Vec<DnsAnswer>, DnsError> {
176                    match rtype.to_ascii_lowercase().as_ref() {
177                        $(
178                        stringify!($konst) => self.[<resolve_ $konst>](name).await,
179                        )+
180                        _ => Err(DnsError::InvalidRecordType),
181                    }
182                }
183
184                /// Converts the given record type to a string representation.
185                pub fn rtype_to_name(&self, rtype: u32) -> String {
186                    let name = match rtype {
187                        $(
188                        $num => stringify!($konst),
189                        )+
190                        _ => "unknown",
191                    };
192                    name.to_ascii_uppercase()
193                }
194            }
195        $(
196            #[allow(non_upper_case_globals)]
197            const [<RTYPE_ $konst>]: Rtype = Rtype($num, stringify!($konst));
198        )+
199        }
200    }
201}
202
203// The following types were obtained from the following address:
204// https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4
205rtypes! {
206    /// Queries a host address for the given name.
207    (a, 1);
208    /// Queries an IP6 Address for the given name.
209    (aaaa, 28);
210    /// Queries all record types for a given name.
211    (any, 0);
212    /// Queries a certification authority restriction record for the given name.
213    (caa, 257);
214    /// Queries a child DS record for the given name.
215    (cds, 59);
216    /// Queries a CERT record for the given name.
217    (cert, 37);
218    /// Queries the canonical name for an alias for the given name.
219    (cname, 5);
220    /// Queries a DNAME record for the given name.
221    (dname, 39);
222    /// Queries a DNSKEY record for the given name.
223    (dnskey, 48);
224    /// Queries a delegation signer record for the given name.
225    (ds, 43);
226    /// Queries a host information record for the given name.
227    (hinfo, 13);
228    /// Queries a IPSECKEY record for the given name.
229    (ipseckey, 45);
230    /// Queries a mail exchange record for the given name.
231    (mx, 15);
232    /// Queries a naming authority pointer record for the given name.
233    (naptr, 35);
234    /// Queries an authoritative name server record for the given name.
235    (ns, 2);
236    /// Queries a NSEC record for the given name.
237    (nsec, 47);
238    /// Queries a NSEC3 record for the given name.
239    (nsec3, 50);
240    /// Queries a NSEC3PARAM record for the given name.
241    (nsec3param, 51);
242    /// Queries a domain name pointer record for the given name.
243    (ptr, 12);
244    /// Queries a responsible person record for the given name.
245    (rp, 17);
246    /// Queries a RRSIG record for the given name.
247    (rrsig, 46);
248    /// Queries the start of a zone of authority record for the given name.
249    (soa, 6);
250    /// Queries an SPF record for the given name. See RFC7208.
251    (spf, 99);
252    /// Queries a server selection record for the given name.
253    (srv, 33);
254    /// Queries an SSH key fingerprint record for the given name.
255    (sshfp, 44);
256    /// Queries a TLSA record for the given name.
257    (tlsa, 52);
258    /// Queries a text strings record for the given name.
259    (txt, 16);
260    /// Queries a well known service description record for the given name.
261    (wks, 11);
262}
263
264#[cfg(test)]
265pub mod tests {
266    use async_trait::async_trait;
267    use hyper::StatusCode;
268    use hyper::{error::Result as HyperResult, Body, Response, Uri};
269    use std::sync::{
270        atomic::{AtomicUsize, Ordering},
271        Arc,
272    };
273    struct MockDnsClient {
274        response: Vec<(String, StatusCode)>,
275        counter: Arc<AtomicUsize>,
276    }
277
278    impl MockDnsClient {
279        fn new(response: &[(String, StatusCode)]) -> MockDnsClient {
280            MockDnsClient {
281                response: response.to_vec(),
282                counter: Arc::new(AtomicUsize::new(0)),
283            }
284        }
285    }
286
287    #[async_trait]
288    impl DnsClient for MockDnsClient {
289        async fn get(&self, _uri: Uri) -> HyperResult<Response<Body>> {
290            let counter = Arc::clone(&self.counter);
291            let index = counter.fetch_add(1, Ordering::SeqCst);
292            // If more calls than results are given, an out of bounds error should be obtained.
293            let chunks: Vec<Result<_, ::std::io::Error>> = vec![Ok(self.response[index].0.clone())];
294            let stream = futures_util::stream::iter(chunks);
295            let body = Body::wrap_stream(stream);
296            let mut response = Response::new(body);
297            *response.status_mut() = self.response[index].1;
298            Ok(response)
299        }
300    }
301
302    impl Default for MockDnsClient {
303        fn default() -> MockDnsClient {
304            MockDnsClient {
305                response: vec![],
306                counter: Arc::new(AtomicUsize::new(0)),
307            }
308        }
309    }
310
311    use super::*;
312
313    #[tokio::test]
314    async fn test_a() {
315        let response = String::from(
316            r#"
317        {
318  "Status": 0,
319  "TC": false,
320  "RD": true,
321  "RA": true,
322  "AD": false,
323  "CD": false,
324  "Question": [
325    {
326      "name": "www.sendgrid.com.",
327      "type": 1
328    }
329  ],
330  "Answer": [
331    {
332      "name": "www.sendgrid.com.",
333      "type": 5,
334      "TTL": 988,
335      "data": "sendgrid.com."
336    },
337    {
338      "name": "sendgrid.com.",
339      "type": 1,
340      "TTL": 89,
341      "data": "169.45.113.198"
342    },
343    {
344      "name": "sendgrid.com.",
345      "type": 1,
346      "TTL": 89,
347      "data": "167.89.118.63"
348    },
349    {
350      "name": "sendgrid.com.",
351      "type": 1,
352      "TTL": 89,
353      "data": "169.45.89.183"
354    },
355    {
356      "name": "sendgrid.com.",
357      "type": 1,
358      "TTL": 89,
359      "data": "167.89.118.65"
360    }
361  ],
362  "Comment": "Response from 2600:1801:13::1."
363    }"#,
364        );
365        let d = Dns {
366            client: MockDnsClient::new(&[(response, StatusCode::OK)]),
367            servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
368        };
369        let r = d.resolve_a("sendgrid.com").await.unwrap();
370        assert_eq!(r.len(), 4);
371        assert_eq!(r[0].name, "sendgrid.com.");
372        assert_eq!(r[0].data, "169.45.113.198");
373        assert_eq!(r[0].r#type, 1);
374        assert_eq!(r[0].TTL, 89);
375        assert_eq!(r[1].name, "sendgrid.com.");
376        assert_eq!(r[1].data, "167.89.118.63");
377        assert_eq!(r[1].r#type, 1);
378        assert_eq!(r[1].TTL, 89);
379        assert_eq!(r[2].name, "sendgrid.com.");
380        assert_eq!(r[2].data, "169.45.89.183");
381        assert_eq!(r[2].r#type, 1);
382        assert_eq!(r[2].TTL, 89);
383        assert_eq!(r[3].name, "sendgrid.com.");
384        assert_eq!(r[3].data, "167.89.118.65");
385        assert_eq!(r[3].r#type, 1);
386        assert_eq!(r[3].TTL, 89);
387    }
388
389    #[tokio::test]
390    async fn test_mx() {
391        let response = String::from(
392            r#"
393        {
394  "Status": 0,
395  "TC": false,
396  "RD": true,
397  "RA": true,
398  "AD": false,
399  "CD": false,
400  "Question": [
401    {
402      "name": "gmail.com.",
403      "type": 15
404    }
405  ],
406  "Answer": [
407    {
408      "name": "gmail.com.",
409      "type": 15,
410      "TTL": 3599,
411      "data": "30 alt3.gmail-smtp-in.l.google.com."
412    },
413    {
414      "name": "gmail.com.",
415      "type": 15,
416      "TTL": 3599,
417      "data": "5 gmail-smtp-in.l.google.com."
418    },
419    {
420      "name": "gmail.com.",
421      "type": 15,
422      "TTL": 3599,
423      "data": "40 alt4.gmail-smtp-in.l.google.com."
424    },
425    {
426      "name": "gmail.com.",
427      "type": 15,
428      "TTL": 3599,
429      "data": "10 alt1.gmail-smtp-in.l.google.com."
430    },
431    {
432      "name": "gmail.com.",
433      "type": 15,
434      "TTL": 3599,
435      "data": "20 alt2.gmail-smtp-in.l.google.com."
436    }
437  ],
438  "Comment": "Response from 2001:4860:4802:32::a."
439}"#,
440        );
441        let d = Dns {
442            client: MockDnsClient::new(&[(response.clone(), StatusCode::OK)]),
443            servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
444        };
445        let r = d.resolve_mx_and_sort("gmail.com").await.unwrap();
446        assert_eq!(r.len(), 5);
447        assert_eq!(r[0].name, "gmail.com.");
448        assert_eq!(r[0].data, "gmail-smtp-in.l.google.com.");
449        assert_eq!(r[0].r#type, 15);
450        assert_eq!(r[0].TTL, 3599);
451        assert_eq!(r[1].name, "gmail.com.");
452        assert_eq!(r[1].data, "alt1.gmail-smtp-in.l.google.com.");
453        assert_eq!(r[1].r#type, 15);
454        assert_eq!(r[1].TTL, 3599);
455        assert_eq!(r[2].name, "gmail.com.");
456        assert_eq!(r[2].data, "alt2.gmail-smtp-in.l.google.com.");
457        assert_eq!(r[2].r#type, 15);
458        assert_eq!(r[2].TTL, 3599);
459        assert_eq!(r[3].name, "gmail.com.");
460        assert_eq!(r[3].data, "alt3.gmail-smtp-in.l.google.com.");
461        assert_eq!(r[3].r#type, 15);
462        assert_eq!(r[3].TTL, 3599);
463        assert_eq!(r[4].name, "gmail.com.");
464        assert_eq!(r[4].data, "alt4.gmail-smtp-in.l.google.com.");
465        assert_eq!(r[4].r#type, 15);
466        assert_eq!(r[4].TTL, 3599);
467
468        let d = Dns {
469            client: MockDnsClient::new(&[(response, StatusCode::OK)]),
470            servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
471        };
472        let r = d.resolve_mx("gmail.com").await.unwrap();
473        assert_eq!(r.len(), 5);
474        assert_eq!(r[0].name, "gmail.com.");
475        assert_eq!(r[0].data, "30 alt3.gmail-smtp-in.l.google.com.");
476        assert_eq!(r[0].r#type, 15);
477        assert_eq!(r[0].TTL, 3599);
478        assert_eq!(r[1].name, "gmail.com.");
479        assert_eq!(r[1].data, "5 gmail-smtp-in.l.google.com.");
480        assert_eq!(r[1].r#type, 15);
481        assert_eq!(r[1].TTL, 3599);
482        assert_eq!(r[2].name, "gmail.com.");
483        assert_eq!(r[2].data, "40 alt4.gmail-smtp-in.l.google.com.");
484        assert_eq!(r[2].r#type, 15);
485        assert_eq!(r[2].TTL, 3599);
486        assert_eq!(r[3].name, "gmail.com.");
487        assert_eq!(r[3].data, "10 alt1.gmail-smtp-in.l.google.com.");
488        assert_eq!(r[3].r#type, 15);
489        assert_eq!(r[3].TTL, 3599);
490        assert_eq!(r[4].name, "gmail.com.");
491        assert_eq!(r[4].data, "20 alt2.gmail-smtp-in.l.google.com.");
492        assert_eq!(r[4].r#type, 15);
493        assert_eq!(r[4].TTL, 3599);
494    }
495
496    #[tokio::test]
497    async fn test_txt() {
498        let response = String::from(
499            r#"
500        {
501  "Status": 0,
502  "TC": false,
503  "RD": true,
504  "RA": true,
505  "AD": false,
506  "CD": false,
507  "Question": [
508    {
509      "name": "google.com.",
510      "type": 16
511    }
512  ],
513  "Answer": [
514    {
515      "name": "google.com.",
516      "type": 16,
517      "TTL": 3599,
518      "data": "\"facebook-domain-verification=22rm551cu4k0ab0bxsw536tlds4h95\""
519    },
520    {
521      "name": "google.com.",
522      "type": 16,
523      "TTL": 3599,
524      "data": "\"globalsign-smime-dv=CDYX+XFHUw2wml6/Gb8+59BsH31KzUr6c1l2BPvqKX8=\""
525    },
526    {
527      "name": "google.com.",
528      "type": 16,
529      "TTL": 299,
530      "data": "\"docusign=05958488-4752-4ef2-95eb-aa7ba8a3bd0e\""
531    },
532    {
533      "name": "google.com.",
534      "type": 16,
535      "TTL": 299,
536      "data": "\"docusign=1b0a6754-49b1-4db5-8540-d2c12664b289\""
537    },
538    {
539      "name": "google.com.",
540      "type": 16,
541      "TTL": 3599,
542      "data": "\"v=spf1 include:_spf.google.com ~all\""
543    }
544  ],
545  "Comment": "Response from 216.239.36.10."
546}"#,
547        );
548        let d = Dns {
549            client: MockDnsClient::new(&[(response, StatusCode::OK)]),
550            servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
551        };
552        let r = d.resolve_txt("google.com").await.unwrap();
553        assert_eq!(r.len(), 5);
554        assert_eq!(r[0].name, "google.com.");
555        assert_eq!(
556            r[0].data,
557            "\"facebook-domain-verification=22rm551cu4k0ab0bxsw536tlds4h95\""
558        );
559        assert_eq!(r[0].r#type, 16);
560        assert_eq!(r[0].TTL, 3599);
561        assert_eq!(r[1].name, "google.com.");
562        assert_eq!(
563            r[1].data,
564            "\"globalsign-smime-dv=CDYX+XFHUw2wml6/Gb8+59BsH31KzUr6c1l2BPvqKX8=\""
565        );
566        assert_eq!(r[1].r#type, 16);
567        assert_eq!(r[1].TTL, 3599);
568        assert_eq!(r[2].name, "google.com.");
569        assert_eq!(
570            r[2].data,
571            "\"docusign=05958488-4752-4ef2-95eb-aa7ba8a3bd0e\""
572        );
573        assert_eq!(r[2].r#type, 16);
574        assert_eq!(r[2].TTL, 299);
575        assert_eq!(r[3].name, "google.com.");
576        assert_eq!(
577            r[3].data,
578            "\"docusign=1b0a6754-49b1-4db5-8540-d2c12664b289\""
579        );
580        assert_eq!(r[3].r#type, 16);
581        assert_eq!(r[3].TTL, 299);
582        assert_eq!(r[4].name, "google.com.");
583        assert_eq!(r[4].data, "\"v=spf1 include:_spf.google.com ~all\"");
584        assert_eq!(r[4].r#type, 16);
585        assert_eq!(r[4].TTL, 3599);
586    }
587
588    #[tokio::test]
589    async fn test_retries() {
590        let response = String::from(
591            r#"
592{
593  "Status": 0,
594  "TC": false,
595  "RD": true,
596  "RA": true,
597  "AD": false,
598  "CD": false,
599  "Question": [
600    {
601      "name": "www.google.com.",
602      "type": 1
603    }
604  ],
605  "Answer": [
606    {
607      "name": "www.google.com.",
608      "type": 1,
609      "TTL": 163,
610      "data": "172.217.11.164"
611    }
612  ]
613}"#,
614        );
615        // Retry if more than server is given.
616        let d = Dns {
617            client: MockDnsClient::new(&[
618                ("".to_owned(), StatusCode::INTERNAL_SERVER_ERROR),
619                (response.clone(), StatusCode::OK),
620            ]),
621            servers: vec![
622                DnsHttpsServer::Google(Duration::from_secs(5)),
623                DnsHttpsServer::Cloudflare1_1_1_1(Duration::from_secs(5)),
624            ],
625        };
626        let r = d.resolve_a("www.google.com").await.unwrap();
627        assert_eq!(r.len(), 1);
628        assert_eq!(r[0].name, "www.google.com.");
629        assert_eq!(r[0].data, "172.217.11.164");
630        assert_eq!(r[0].r#type, 1);
631        assert_eq!(r[0].TTL, 163);
632
633        // Not all errors should be retried.
634        let d = Dns {
635            client: MockDnsClient::new(&[
636                ("".to_owned(), StatusCode::BAD_REQUEST),
637                (response.clone(), StatusCode::OK),
638            ]),
639            servers: vec![
640                DnsHttpsServer::Google(Duration::from_secs(5)),
641                DnsHttpsServer::Cloudflare1_1_1_1(Duration::from_secs(5)),
642            ],
643        };
644        let r = d.resolve_a("www.google.com").await;
645        assert!(r.is_err());
646
647        // If only one server is given, an error should be received.
648        let d = Dns {
649            client: MockDnsClient::new(&[
650                ("".to_owned(), StatusCode::INTERNAL_SERVER_ERROR),
651                (response.clone(), StatusCode::OK),
652            ]),
653            servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
654        };
655        let r = d.resolve_a("www.google.com").await;
656        assert!(r.is_err());
657    }
658}