1use crate::client::{DnsClient, HyperDnsClient};
2use crate::error::{DnsError, QueryError};
3use crate::status::RCode;
4use crate::{Dns, DnsAnswer, DnsHttpsServer, DnsResponse};
5use hyper::Uri;
6use idna;
7use log::error;
8use std::time::Duration;
9use tokio::time::timeout;
10
11impl Default for Dns<HyperDnsClient> {
12 fn default() -> Dns<HyperDnsClient> {
13 Dns {
14 client: HyperDnsClient::default(),
15 servers: vec![
16 DnsHttpsServer::Google(Duration::from_secs(3)),
17 DnsHttpsServer::Cloudflare1_1_1_1(Duration::from_secs(10)),
18 ],
19 }
20 }
21}
22
23impl<C: DnsClient> Dns<C> {
24 pub fn with_servers(servers: &[DnsHttpsServer]) -> Result<Dns<C>, DnsError> {
29 if servers.is_empty() {
30 return Err(DnsError::NoServers);
31 }
32 Ok(Dns {
33 client: C::default(),
34 servers: servers.to_vec(),
35 })
36 }
37
38 pub async fn resolve_mx_and_sort(&self, domain: &str) -> Result<Vec<DnsAnswer>, DnsError> {
41 match self.client_request(domain, &RTYPE_mx).await {
42 Err(e) => Err(DnsError::Query(e)),
43 Ok(res) => match num::FromPrimitive::from_u32(res.Status) {
44 Some(RCode::NoError) => {
45 let mut mxs = res
46 .Answer
47 .unwrap_or_else(|| vec![])
48 .iter()
49 .filter_map(|a| {
50 if a.r#type == RTYPE_mx.0 {
52 let mut parts = a.data.split_ascii_whitespace();
54 if let Some(part_1) = parts.next() {
55 if let Ok(priority) = part_1.parse::<u32>() {
57 if let Some(mx) = parts.next() {
58 let mut m = a.clone();
60 m.data = mx.to_string();
61 return Some((m, priority));
62 }
63 }
64 }
65 }
66 None
67 })
68 .collect::<Vec<_>>();
69 mxs.sort_unstable_by_key(|x| x.1);
71 Ok(mxs.into_iter().map(|x| x.0).collect())
72 }
73 Some(code) => Err(DnsError::Status(code)),
74 None => Err(DnsError::Status(RCode::Unknown)),
75 },
76 }
77 }
78
79 async fn request_and_process(
82 &self,
83 name: &str,
84 rtype: &Rtype,
85 ) -> Result<Vec<DnsAnswer>, DnsError> {
86 match self.client_request(name, rtype).await {
87 Err(e) => Err(DnsError::Query(e)),
88 Ok(res) => match num::FromPrimitive::from_u32(res.Status) {
89 Some(RCode::NoError) => Ok(res
90 .Answer
91 .unwrap_or_else(|| vec![])
92 .into_iter()
93 .filter(|a| a.r#type == rtype.0 || rtype.0 == 0)
96 .collect::<Vec<_>>()),
97 Some(code) => Err(DnsError::Status(code)),
98 None => Err(DnsError::Status(RCode::Unknown)),
99 },
100 }
101 }
102
103 async fn client_request(&self, name: &str, rtype: &Rtype) -> Result<DnsResponse, QueryError> {
106 let name = match idna::domain_to_ascii(name) {
108 Ok(name) => name,
109 Err(e) => return Err(QueryError::InvalidName(format!("{:?}", e))),
110 };
111 let mut error = QueryError::Unknown;
112 for server in self.servers.iter() {
113 let url = format!("{}?name={}&type={}", server.uri(), name, rtype.1);
114 let endpoint = match url.parse::<Uri>() {
115 Err(e) => return Err(QueryError::InvalidEndpoint(e.to_string())),
116 Ok(endpoint) => endpoint,
117 };
118
119 error = match timeout(server.timeout(), self.client.get(endpoint)).await {
120 Ok(Err(e)) => QueryError::Connection(e.to_string()),
121 Ok(Ok(res)) => {
122 match res.status().as_u16() {
123 200 => match hyper::body::to_bytes(res).await {
124 Err(e) => QueryError::ReadResponse(e.to_string()),
125 Ok(body) => match serde_json::from_slice::<DnsResponse>(&body) {
126 Err(e) => QueryError::ParseResponse(e.to_string()),
127 Ok(res) => {
128 return Ok(res);
129 }
130 },
131 },
132 400 => return Err(QueryError::BadRequest400),
133 413 => return Err(QueryError::PayloadTooLarge413),
134 414 => return Err(QueryError::UriTooLong414),
135 415 => return Err(QueryError::UnsupportedMediaType415),
136 501 => return Err(QueryError::NotImplemented501),
137 429 => QueryError::TooManyRequests429,
140 500 => QueryError::InternalServerError500,
141 502 => QueryError::BadGateway502,
142 504 => QueryError::ResolverTimeout504,
143 _ => QueryError::Unknown,
144 }
145 }
146 Err(_) => QueryError::Connection(format!(
147 "connection timeout after {:?}",
148 server.timeout()
149 )),
150 };
151 error!("request error on URL {}: {}", url, error);
152 }
153 Err(error)
154 }
155}
156
157struct Rtype(pub u32, pub &'static str);
158
159macro_rules! rtypes {
160 (
161 $(
162 $(#[$docs:meta])*
163 ($konst:ident, $num:expr);
164 )+
165 ) => {
166 paste::item! {
167 impl<C: DnsClient> Dns<C> {
168 $(
169 $(#[$docs])*
170 pub async fn [<resolve_ $konst>](&self, name: &str) -> Result<Vec<DnsAnswer>, DnsError> {
171 self.request_and_process(name, &[<RTYPE_ $konst>]).await
172 }
173 )+
174
175 pub async fn resolve_str_type(&self, name: &str, rtype: &str) -> Result<Vec<DnsAnswer>, DnsError> {
176 match rtype.to_ascii_lowercase().as_ref() {
177 $(
178 stringify!($konst) => self.[<resolve_ $konst>](name).await,
179 )+
180 _ => Err(DnsError::InvalidRecordType),
181 }
182 }
183
184 pub fn rtype_to_name(&self, rtype: u32) -> String {
186 let name = match rtype {
187 $(
188 $num => stringify!($konst),
189 )+
190 _ => "unknown",
191 };
192 name.to_ascii_uppercase()
193 }
194 }
195 $(
196 #[allow(non_upper_case_globals)]
197 const [<RTYPE_ $konst>]: Rtype = Rtype($num, stringify!($konst));
198 )+
199 }
200 }
201}
202
203rtypes! {
206 (a, 1);
208 (aaaa, 28);
210 (any, 0);
212 (caa, 257);
214 (cds, 59);
216 (cert, 37);
218 (cname, 5);
220 (dname, 39);
222 (dnskey, 48);
224 (ds, 43);
226 (hinfo, 13);
228 (ipseckey, 45);
230 (mx, 15);
232 (naptr, 35);
234 (ns, 2);
236 (nsec, 47);
238 (nsec3, 50);
240 (nsec3param, 51);
242 (ptr, 12);
244 (rp, 17);
246 (rrsig, 46);
248 (soa, 6);
250 (spf, 99);
252 (srv, 33);
254 (sshfp, 44);
256 (tlsa, 52);
258 (txt, 16);
260 (wks, 11);
262}
263
264#[cfg(test)]
265pub mod tests {
266 use async_trait::async_trait;
267 use hyper::StatusCode;
268 use hyper::{error::Result as HyperResult, Body, Response, Uri};
269 use std::sync::{
270 atomic::{AtomicUsize, Ordering},
271 Arc,
272 };
273 struct MockDnsClient {
274 response: Vec<(String, StatusCode)>,
275 counter: Arc<AtomicUsize>,
276 }
277
278 impl MockDnsClient {
279 fn new(response: &[(String, StatusCode)]) -> MockDnsClient {
280 MockDnsClient {
281 response: response.to_vec(),
282 counter: Arc::new(AtomicUsize::new(0)),
283 }
284 }
285 }
286
287 #[async_trait]
288 impl DnsClient for MockDnsClient {
289 async fn get(&self, _uri: Uri) -> HyperResult<Response<Body>> {
290 let counter = Arc::clone(&self.counter);
291 let index = counter.fetch_add(1, Ordering::SeqCst);
292 let chunks: Vec<Result<_, ::std::io::Error>> = vec![Ok(self.response[index].0.clone())];
294 let stream = futures_util::stream::iter(chunks);
295 let body = Body::wrap_stream(stream);
296 let mut response = Response::new(body);
297 *response.status_mut() = self.response[index].1;
298 Ok(response)
299 }
300 }
301
302 impl Default for MockDnsClient {
303 fn default() -> MockDnsClient {
304 MockDnsClient {
305 response: vec![],
306 counter: Arc::new(AtomicUsize::new(0)),
307 }
308 }
309 }
310
311 use super::*;
312
313 #[tokio::test]
314 async fn test_a() {
315 let response = String::from(
316 r#"
317 {
318 "Status": 0,
319 "TC": false,
320 "RD": true,
321 "RA": true,
322 "AD": false,
323 "CD": false,
324 "Question": [
325 {
326 "name": "www.sendgrid.com.",
327 "type": 1
328 }
329 ],
330 "Answer": [
331 {
332 "name": "www.sendgrid.com.",
333 "type": 5,
334 "TTL": 988,
335 "data": "sendgrid.com."
336 },
337 {
338 "name": "sendgrid.com.",
339 "type": 1,
340 "TTL": 89,
341 "data": "169.45.113.198"
342 },
343 {
344 "name": "sendgrid.com.",
345 "type": 1,
346 "TTL": 89,
347 "data": "167.89.118.63"
348 },
349 {
350 "name": "sendgrid.com.",
351 "type": 1,
352 "TTL": 89,
353 "data": "169.45.89.183"
354 },
355 {
356 "name": "sendgrid.com.",
357 "type": 1,
358 "TTL": 89,
359 "data": "167.89.118.65"
360 }
361 ],
362 "Comment": "Response from 2600:1801:13::1."
363 }"#,
364 );
365 let d = Dns {
366 client: MockDnsClient::new(&[(response, StatusCode::OK)]),
367 servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
368 };
369 let r = d.resolve_a("sendgrid.com").await.unwrap();
370 assert_eq!(r.len(), 4);
371 assert_eq!(r[0].name, "sendgrid.com.");
372 assert_eq!(r[0].data, "169.45.113.198");
373 assert_eq!(r[0].r#type, 1);
374 assert_eq!(r[0].TTL, 89);
375 assert_eq!(r[1].name, "sendgrid.com.");
376 assert_eq!(r[1].data, "167.89.118.63");
377 assert_eq!(r[1].r#type, 1);
378 assert_eq!(r[1].TTL, 89);
379 assert_eq!(r[2].name, "sendgrid.com.");
380 assert_eq!(r[2].data, "169.45.89.183");
381 assert_eq!(r[2].r#type, 1);
382 assert_eq!(r[2].TTL, 89);
383 assert_eq!(r[3].name, "sendgrid.com.");
384 assert_eq!(r[3].data, "167.89.118.65");
385 assert_eq!(r[3].r#type, 1);
386 assert_eq!(r[3].TTL, 89);
387 }
388
389 #[tokio::test]
390 async fn test_mx() {
391 let response = String::from(
392 r#"
393 {
394 "Status": 0,
395 "TC": false,
396 "RD": true,
397 "RA": true,
398 "AD": false,
399 "CD": false,
400 "Question": [
401 {
402 "name": "gmail.com.",
403 "type": 15
404 }
405 ],
406 "Answer": [
407 {
408 "name": "gmail.com.",
409 "type": 15,
410 "TTL": 3599,
411 "data": "30 alt3.gmail-smtp-in.l.google.com."
412 },
413 {
414 "name": "gmail.com.",
415 "type": 15,
416 "TTL": 3599,
417 "data": "5 gmail-smtp-in.l.google.com."
418 },
419 {
420 "name": "gmail.com.",
421 "type": 15,
422 "TTL": 3599,
423 "data": "40 alt4.gmail-smtp-in.l.google.com."
424 },
425 {
426 "name": "gmail.com.",
427 "type": 15,
428 "TTL": 3599,
429 "data": "10 alt1.gmail-smtp-in.l.google.com."
430 },
431 {
432 "name": "gmail.com.",
433 "type": 15,
434 "TTL": 3599,
435 "data": "20 alt2.gmail-smtp-in.l.google.com."
436 }
437 ],
438 "Comment": "Response from 2001:4860:4802:32::a."
439}"#,
440 );
441 let d = Dns {
442 client: MockDnsClient::new(&[(response.clone(), StatusCode::OK)]),
443 servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
444 };
445 let r = d.resolve_mx_and_sort("gmail.com").await.unwrap();
446 assert_eq!(r.len(), 5);
447 assert_eq!(r[0].name, "gmail.com.");
448 assert_eq!(r[0].data, "gmail-smtp-in.l.google.com.");
449 assert_eq!(r[0].r#type, 15);
450 assert_eq!(r[0].TTL, 3599);
451 assert_eq!(r[1].name, "gmail.com.");
452 assert_eq!(r[1].data, "alt1.gmail-smtp-in.l.google.com.");
453 assert_eq!(r[1].r#type, 15);
454 assert_eq!(r[1].TTL, 3599);
455 assert_eq!(r[2].name, "gmail.com.");
456 assert_eq!(r[2].data, "alt2.gmail-smtp-in.l.google.com.");
457 assert_eq!(r[2].r#type, 15);
458 assert_eq!(r[2].TTL, 3599);
459 assert_eq!(r[3].name, "gmail.com.");
460 assert_eq!(r[3].data, "alt3.gmail-smtp-in.l.google.com.");
461 assert_eq!(r[3].r#type, 15);
462 assert_eq!(r[3].TTL, 3599);
463 assert_eq!(r[4].name, "gmail.com.");
464 assert_eq!(r[4].data, "alt4.gmail-smtp-in.l.google.com.");
465 assert_eq!(r[4].r#type, 15);
466 assert_eq!(r[4].TTL, 3599);
467
468 let d = Dns {
469 client: MockDnsClient::new(&[(response, StatusCode::OK)]),
470 servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
471 };
472 let r = d.resolve_mx("gmail.com").await.unwrap();
473 assert_eq!(r.len(), 5);
474 assert_eq!(r[0].name, "gmail.com.");
475 assert_eq!(r[0].data, "30 alt3.gmail-smtp-in.l.google.com.");
476 assert_eq!(r[0].r#type, 15);
477 assert_eq!(r[0].TTL, 3599);
478 assert_eq!(r[1].name, "gmail.com.");
479 assert_eq!(r[1].data, "5 gmail-smtp-in.l.google.com.");
480 assert_eq!(r[1].r#type, 15);
481 assert_eq!(r[1].TTL, 3599);
482 assert_eq!(r[2].name, "gmail.com.");
483 assert_eq!(r[2].data, "40 alt4.gmail-smtp-in.l.google.com.");
484 assert_eq!(r[2].r#type, 15);
485 assert_eq!(r[2].TTL, 3599);
486 assert_eq!(r[3].name, "gmail.com.");
487 assert_eq!(r[3].data, "10 alt1.gmail-smtp-in.l.google.com.");
488 assert_eq!(r[3].r#type, 15);
489 assert_eq!(r[3].TTL, 3599);
490 assert_eq!(r[4].name, "gmail.com.");
491 assert_eq!(r[4].data, "20 alt2.gmail-smtp-in.l.google.com.");
492 assert_eq!(r[4].r#type, 15);
493 assert_eq!(r[4].TTL, 3599);
494 }
495
496 #[tokio::test]
497 async fn test_txt() {
498 let response = String::from(
499 r#"
500 {
501 "Status": 0,
502 "TC": false,
503 "RD": true,
504 "RA": true,
505 "AD": false,
506 "CD": false,
507 "Question": [
508 {
509 "name": "google.com.",
510 "type": 16
511 }
512 ],
513 "Answer": [
514 {
515 "name": "google.com.",
516 "type": 16,
517 "TTL": 3599,
518 "data": "\"facebook-domain-verification=22rm551cu4k0ab0bxsw536tlds4h95\""
519 },
520 {
521 "name": "google.com.",
522 "type": 16,
523 "TTL": 3599,
524 "data": "\"globalsign-smime-dv=CDYX+XFHUw2wml6/Gb8+59BsH31KzUr6c1l2BPvqKX8=\""
525 },
526 {
527 "name": "google.com.",
528 "type": 16,
529 "TTL": 299,
530 "data": "\"docusign=05958488-4752-4ef2-95eb-aa7ba8a3bd0e\""
531 },
532 {
533 "name": "google.com.",
534 "type": 16,
535 "TTL": 299,
536 "data": "\"docusign=1b0a6754-49b1-4db5-8540-d2c12664b289\""
537 },
538 {
539 "name": "google.com.",
540 "type": 16,
541 "TTL": 3599,
542 "data": "\"v=spf1 include:_spf.google.com ~all\""
543 }
544 ],
545 "Comment": "Response from 216.239.36.10."
546}"#,
547 );
548 let d = Dns {
549 client: MockDnsClient::new(&[(response, StatusCode::OK)]),
550 servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
551 };
552 let r = d.resolve_txt("google.com").await.unwrap();
553 assert_eq!(r.len(), 5);
554 assert_eq!(r[0].name, "google.com.");
555 assert_eq!(
556 r[0].data,
557 "\"facebook-domain-verification=22rm551cu4k0ab0bxsw536tlds4h95\""
558 );
559 assert_eq!(r[0].r#type, 16);
560 assert_eq!(r[0].TTL, 3599);
561 assert_eq!(r[1].name, "google.com.");
562 assert_eq!(
563 r[1].data,
564 "\"globalsign-smime-dv=CDYX+XFHUw2wml6/Gb8+59BsH31KzUr6c1l2BPvqKX8=\""
565 );
566 assert_eq!(r[1].r#type, 16);
567 assert_eq!(r[1].TTL, 3599);
568 assert_eq!(r[2].name, "google.com.");
569 assert_eq!(
570 r[2].data,
571 "\"docusign=05958488-4752-4ef2-95eb-aa7ba8a3bd0e\""
572 );
573 assert_eq!(r[2].r#type, 16);
574 assert_eq!(r[2].TTL, 299);
575 assert_eq!(r[3].name, "google.com.");
576 assert_eq!(
577 r[3].data,
578 "\"docusign=1b0a6754-49b1-4db5-8540-d2c12664b289\""
579 );
580 assert_eq!(r[3].r#type, 16);
581 assert_eq!(r[3].TTL, 299);
582 assert_eq!(r[4].name, "google.com.");
583 assert_eq!(r[4].data, "\"v=spf1 include:_spf.google.com ~all\"");
584 assert_eq!(r[4].r#type, 16);
585 assert_eq!(r[4].TTL, 3599);
586 }
587
588 #[tokio::test]
589 async fn test_retries() {
590 let response = String::from(
591 r#"
592{
593 "Status": 0,
594 "TC": false,
595 "RD": true,
596 "RA": true,
597 "AD": false,
598 "CD": false,
599 "Question": [
600 {
601 "name": "www.google.com.",
602 "type": 1
603 }
604 ],
605 "Answer": [
606 {
607 "name": "www.google.com.",
608 "type": 1,
609 "TTL": 163,
610 "data": "172.217.11.164"
611 }
612 ]
613}"#,
614 );
615 let d = Dns {
617 client: MockDnsClient::new(&[
618 ("".to_owned(), StatusCode::INTERNAL_SERVER_ERROR),
619 (response.clone(), StatusCode::OK),
620 ]),
621 servers: vec![
622 DnsHttpsServer::Google(Duration::from_secs(5)),
623 DnsHttpsServer::Cloudflare1_1_1_1(Duration::from_secs(5)),
624 ],
625 };
626 let r = d.resolve_a("www.google.com").await.unwrap();
627 assert_eq!(r.len(), 1);
628 assert_eq!(r[0].name, "www.google.com.");
629 assert_eq!(r[0].data, "172.217.11.164");
630 assert_eq!(r[0].r#type, 1);
631 assert_eq!(r[0].TTL, 163);
632
633 let d = Dns {
635 client: MockDnsClient::new(&[
636 ("".to_owned(), StatusCode::BAD_REQUEST),
637 (response.clone(), StatusCode::OK),
638 ]),
639 servers: vec![
640 DnsHttpsServer::Google(Duration::from_secs(5)),
641 DnsHttpsServer::Cloudflare1_1_1_1(Duration::from_secs(5)),
642 ],
643 };
644 let r = d.resolve_a("www.google.com").await;
645 assert!(r.is_err());
646
647 let d = Dns {
649 client: MockDnsClient::new(&[
650 ("".to_owned(), StatusCode::INTERNAL_SERVER_ERROR),
651 (response.clone(), StatusCode::OK),
652 ]),
653 servers: vec![DnsHttpsServer::Google(Duration::from_secs(5))],
654 };
655 let r = d.resolve_a("www.google.com").await;
656 assert!(r.is_err());
657 }
658}