1#[cfg(feature = "smol-async")]
2use crate::smol_async::{query_raw_tcp, query_raw_udp};
3#[cfg(feature = "std-async")]
4use crate::std_async::{query_raw_tcp, query_raw_udp};
5#[cfg(feature = "sync")]
6use crate::sync::{query_raw_tcp, query_raw_udp};
7#[cfg(feature = "tokio-async")]
8use crate::tokio_async::{query_raw_tcp, query_raw_udp};
9use crate::{err::as_io_error, reverse::reverse_dns_query, tcp::tcp_query};
10use dnssector::constants::{Class, Type};
11use dnssector::*;
12use std::{
13 io::{self, Error, ErrorKind},
14 net::{IpAddr, SocketAddr},
15 time::Duration,
16};
17
18pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
19
20#[derive(Clone, Debug)]
37pub struct DNSClient {
38 upstream_server_timeout: Duration,
39 upstream_servers: Vec<SocketAddr>,
40 local_v4_addr: SocketAddr,
41 local_v6_addr: SocketAddr,
42}
43
44impl DNSClient {
45 pub fn new(upstream_servers: Vec<SocketAddr>) -> Self {
56 DNSClient {
57 upstream_server_timeout: DEFAULT_TIMEOUT,
58 upstream_servers,
59 local_v4_addr: ([0; 4], 0).into(),
60 local_v6_addr: ([0; 16], 0).into(),
61 }
62 }
63
64 pub fn set_timeout(&mut self, timeout: Duration) {
76 self.upstream_server_timeout = timeout
77 }
78
79 pub fn set_local_v4_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
90 self.local_v4_addr = addr.into()
91 }
92
93 pub fn set_local_v6_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
104 self.local_v6_addr = addr.into()
105 }
106
107 #[maybe_async::maybe_async]
124 pub async fn query_a(&self, name: &str) -> io::Result<Vec<IpAddr>> {
125 let name = encode_name(name)?;
126 let query = dnssector::r#gen::query(name.as_bytes(), Type::A, Class::IN)
127 .map_err(as_io_error(ErrorKind::InvalidInput))?;
128 let response = self.query(query).await?;
129 extract_ips(response)
130 }
131
132 #[maybe_async::maybe_async]
149 pub async fn query_aaaa(&self, name: &str) -> io::Result<Vec<IpAddr>> {
150 let name = encode_name(name)?;
151 let query = dnssector::r#gen::query(name.as_bytes(), Type::AAAA, Class::IN)
152 .map_err(as_io_error(ErrorKind::InvalidInput))?;
153 let response = self.query(query).await?;
154 extract_ips(response)
155 }
156
157 #[maybe_async::maybe_async]
186 pub async fn query_ptr(&self, ip: IpAddr) -> io::Result<String> {
187 let in_addr = reverse_dns_query(ip);
188 let query = dnssector::r#gen::query(&in_addr, Type::PTR, Class::IN)
189 .map_err(as_io_error(ErrorKind::InvalidInput))?;
190 let response = self.query(query).await?;
191 extract_names(response).map(|mut v| v.remove(0))
192 }
193
194 #[maybe_async::maybe_async]
206 pub async fn query_ns(&self, domain: &str) -> io::Result<Vec<String>> {
207 let query = dnssector::r#gen::query(domain.as_bytes(), Type::NS, Class::IN)
208 .map_err(as_io_error(ErrorKind::InvalidInput))?;
209 let response = self.query(query).await?;
210 extract_names(response).or_else(|e| {
211 if e.kind() == ErrorKind::NotFound {
212 Ok(Vec::new())
213 } else {
214 Err(e)
215 }
216 })
217 }
218
219 #[maybe_async::maybe_async]
220 async fn query(&self, packet: ParsedPacket) -> io::Result<ParsedPacket> {
221 let is_compressed = matches!(
222 packet.qtype_qclass(),
223 Some((rr_type, _class)) if rr_type == Type::NS as u16
224 );
225 let raw_packet = packet.into_packet();
226 for i in 0..self.upstream_servers.len() {
227 let response = self
228 .query_upstream(&self.upstream_servers[i], &raw_packet, is_compressed)
229 .await;
230 if response.is_ok() || i >= self.upstream_servers.len() - 1 {
231 return response;
232 }
233 }
234 unreachable!("query must be ok or err");
235 }
236
237 #[maybe_async::maybe_async]
238 async fn query_upstream(
239 &self,
240 upstream: &SocketAddr,
241 packet: &[u8],
242 is_compressed_response: bool,
243 ) -> io::Result<ParsedPacket> {
244 let local_addr = match upstream {
245 SocketAddr::V4(_) => &self.local_v4_addr,
246 SocketAddr::V6(_) => &self.local_v6_addr,
247 };
248 let raw_response =
249 query_raw_udp(local_addr, upstream, packet, self.upstream_server_timeout).await?;
250 let response = parse_response(raw_response, is_compressed_response)?;
251 if response.flags() & DNS_FLAG_TC != DNS_FLAG_TC {
252 return Ok(response);
253 }
254 let tcp_packet = tcp_query(packet);
256 let raw_response =
257 query_raw_tcp(upstream, &tcp_packet, self.upstream_server_timeout).await?;
258 parse_response(raw_response, is_compressed_response)
259 }
260}
261
262fn parse_response(raw: Vec<u8>, is_compressed: bool) -> io::Result<ParsedPacket> {
263 let mut raw_response = raw;
264 if is_compressed {
265 raw_response =
266 Compress::uncompress(&raw_response).map_err(as_io_error(ErrorKind::InvalidData))?;
267 }
268 DNSSector::new(raw_response)
269 .map_err(as_io_error(ErrorKind::InvalidData))?
270 .parse()
271 .map_err(as_io_error(ErrorKind::InvalidData))
272}
273
274fn extract_ips(mut packet: ParsedPacket) -> io::Result<Vec<IpAddr>> {
275 use std::result::Result as StdResult;
276
277 let mut ips = Vec::new();
278 let mut response = packet.into_iter_answer();
279 while let Some(i) = response {
280 ips.push(i.rr_ip());
281 response = i.next();
282 }
283 let (ips, errors): (Vec<_>, Vec<_>) = ips.into_iter().partition(StdResult::is_ok);
284 if ips.is_empty() {
285 if let Some(Err(e)) = errors.into_iter().next() {
286 return Err(Error::new(ErrorKind::InvalidData, e));
287 }
288 }
289 let ips: Vec<_> = ips.into_iter().map(StdResult::unwrap).collect();
290 Ok(ips)
291}
292
293fn extract_names(mut packet: ParsedPacket) -> io::Result<Vec<String>> {
294 let mut response = packet.into_iter_answer();
295 let mut ret = Vec::new();
296 while let Some(i) = response {
297 let raw_name = &i.rdata_slice()[DNS_RR_HEADER_SIZE..];
298 let name = parse_tlv_name(raw_name);
299 ret.push(name);
300 response = i.next();
301 }
302 if ret.is_empty() {
303 return Err(ErrorKind::NotFound.into());
304 }
305 ret.iter().map(|i| decode_name(i)).collect()
306}
307
308fn parse_tlv_name(raw: &[u8]) -> Vec<u8> {
309 let mut result = Vec::with_capacity(raw.len());
310 let mut i = 0;
311 let mut remaining = 0;
312 while i < raw.len() && raw[i] != 0 {
313 if remaining == 0 {
314 remaining = raw[i];
315 if i > 0 {
316 result.push(b'.')
317 }
318 } else {
319 result.push(raw[i]);
320 remaining -= 1;
321 }
322 i += 1;
323 }
324 result
325}
326
327fn encode_name(name: &str) -> io::Result<String> {
328 let parts: io::Result<Vec<String>> = name
329 .split('.')
330 .map(|part| {
331 if part.is_ascii() {
332 Ok(part.to_string())
333 } else {
334 unic_idna_punycode::encode_str(part)
335 .map(|s| "xn--".to_string() + &s)
336 .ok_or_else(|| ErrorKind::InvalidInput.into())
337 }
338 })
339 .collect();
340 let parts = parts?;
341 let ret = parts.join(".");
342 Ok(ret)
343}
344
345fn decode_name(name: &[u8]) -> io::Result<String> {
346 let parts: io::Result<Vec<String>> = name
347 .split(|ch| *ch == b'.')
348 .map(|part| {
349 if let Some(code) = part.strip_prefix(b"xn--") {
350 String::from_utf8(code.to_vec())
351 .map_err(as_io_error(ErrorKind::InvalidData))
352 .and_then(|code| {
353 unic_idna_punycode::decode_to_string(&code)
354 .ok_or_else(|| ErrorKind::InvalidData.into())
355 })
356 } else {
357 String::from_utf8(part.to_vec()).map_err(as_io_error(ErrorKind::InvalidData))
358 }
359 })
360 .collect();
361 let parts = parts?;
362 let ret = parts.join(".");
363 Ok(ret)
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 #[cfg(not(feature = "sync"))]
370 use std::future::Future;
371 use std::{
372 net::{Ipv4Addr, Ipv6Addr},
373 str::FromStr,
374 };
375
376 const EXAMPLE_FQDN: &str = "one.one.one.one";
377 const EXAMPLE_DOMAIN: &str = "one.one.one";
378 const EXAMPLE_DOMAIN_NS: &str = "ns.cloudflare.com";
379 const EXAMPLE_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
380 const EXAMPLE_IPV6: IpAddr =
381 IpAddr::V6(Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111));
382 const EXAMPLE_IDN: &str = "日本.icom.museum";
383 const EXAMPLE_IDN_PUNYCODE: &str = "xn--wgv71a.icom.museum";
384 const EXAMPLE_IDN_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(81, 201, 190, 55));
385
386 #[cfg(feature = "std-async")]
387 fn block_on<F: Future>(future: F) -> F::Output {
388 use async_std::task;
389 task::block_on(future)
390 }
391
392 #[cfg(feature = "smol-async")]
393 fn block_on<F: Future>(future: F) -> F::Output {
394 smol::block_on(future)
395 }
396
397 #[cfg(feature = "tokio-async")]
398 fn block_on<F: Future>(future: F) -> F::Output {
399 use tokio::runtime;
400 let rt = runtime::Builder::new_current_thread()
401 .enable_time()
402 .enable_io()
403 .build()
404 .unwrap();
405 rt.block_on(future)
406 }
407
408 #[cfg(not(feature = "sync"))]
409 macro_rules! block_on {
410 ($b:expr) => {
411 block_on(async move { $b.await })
412 };
413 }
414
415 #[cfg(feature = "sync")]
416 macro_rules! block_on {
417 ($b:expr) => {
418 $b
419 };
420 }
421
422 fn dns_servers() -> Vec<SocketAddr> {
423 vec![
424 SocketAddr::from_str("1.0.0.1:53").unwrap(),
425 SocketAddr::from_str("1.1.1.1:53").unwrap(),
426 ]
427 }
428
429 fn slow_dns_servers() -> Vec<SocketAddr> {
430 vec![
431 SocketAddr::from_str("109.75.41.201:53").unwrap(),
433 SocketAddr::from_str("124.99.9.4:53").unwrap(),
435 ]
436 }
437
438 #[test]
439 fn query_a() {
440 let dns_client = DNSClient::new(dns_servers());
441 let r = block_on!(dns_client.query_a(EXAMPLE_FQDN)).unwrap();
442 let expected = EXAMPLE_IPV4;
443 assert!(r.contains(&expected), "Expected {} got {:?}", expected, r);
444 }
445
446 #[test]
447 fn query_timeout() {
448 let mut dns_client = DNSClient::new(slow_dns_servers());
449 dns_client.set_timeout(Duration::from_millis(1));
450 let r = block_on!(dns_client.query_a(EXAMPLE_FQDN));
451 assert!(
452 matches!(&r, Err(e) if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock),
453 "Expected timout got {:?}",
454 r,
455 );
456 }
457
458 #[test]
459 fn query_utf8() {
460 let dns_client = DNSClient::new(dns_servers());
461 let jp_res = block_on!(dns_client.query_a(EXAMPLE_IDN)).unwrap();
462 let expected = EXAMPLE_IDN_IP;
463 assert!(
464 jp_res.contains(&expected),
465 "Expected {} for {} got {:?}",
466 expected,
467 EXAMPLE_IDN,
468 jp_res
469 );
470 }
471
472 #[test]
473 fn query_aaaa() {
474 let dns_client = DNSClient::new(dns_servers());
475 let r = block_on!(dns_client.query_aaaa(EXAMPLE_FQDN)).unwrap();
476 let expected = EXAMPLE_IPV6;
477 assert!(r.contains(&expected), "Expected {} got {:?}", expected, r);
478 }
479
480 #[test]
481 fn query_ptr_ipv4() {
482 let dns_client = DNSClient::new(dns_servers());
483 let r = block_on!(dns_client.query_ptr(EXAMPLE_IPV4)).unwrap();
484 let expected = EXAMPLE_FQDN;
485 assert!(r == expected, "Expected {} got {:?}", expected, r);
486 }
487
488 #[test]
489 fn query_ptr_ipv6() {
490 let dns_client = DNSClient::new(dns_servers());
491 let r = block_on!(dns_client.query_ptr(EXAMPLE_IPV6)).unwrap();
492 let expected = EXAMPLE_FQDN;
493 assert!(r == expected, "Expected {} got {:?}", expected, r);
494 }
495
496 #[test]
497 fn query_ptr_utf8() {
498 let r = decode_name(EXAMPLE_IDN_PUNYCODE.as_bytes()).unwrap();
501 let expected = EXAMPLE_IDN;
502 assert!(r == expected, "Expected {} got {:?}", expected, r);
503 }
504
505 #[test]
506 fn query_ns() {
507 let dns_client = DNSClient::new(dns_servers());
508 let r = block_on!(dns_client.query_ns(EXAMPLE_DOMAIN)).unwrap();
509 assert!(
510 r.iter().any(|n| n.ends_with(EXAMPLE_DOMAIN_NS)),
511 "Expected {} got {:?}",
512 EXAMPLE_DOMAIN_NS,
513 r
514 );
515 }
516}