mtop_client/dns/
client.rs

1use crate::core::MtopError;
2use crate::dns::core::{RecordClass, RecordType};
3use crate::dns::message::{Flags, Message, MessageId, Question, ResponseCode};
4use crate::dns::name::Name;
5use crate::net::tcp_connect;
6use crate::pool::{ClientFactory, ClientPool, ClientPoolConfig};
7use crate::timeout::Timeout;
8use async_trait::async_trait;
9use std::fmt;
10use std::io::{self, Cursor, Error};
11use std::net::{IpAddr, Ipv4Addr, SocketAddr};
12use std::pin::Pin;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::task::{Context, Poll};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadBuf};
17use tokio::net::UdpSocket;
18
19const DEFAULT_NAMESERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53);
20const DEFAULT_MESSAGE_BUFFER: usize = 512;
21
22/// Configuration for creating a new `DnsClient` instance.
23#[derive(Debug, Clone)]
24pub struct DnsClientConfig {
25    /// One or more DNS nameservers to use for resolution. These servers will be tried
26    /// in order for each resolution unless `rotate` is set.
27    pub nameservers: Vec<SocketAddr>,
28
29    /// Timeout for each resolution. This timeout is applied to each attempt and so a
30    /// single call to `DnsClient::resolve` may take longer based on the value of `attempts`.
31    pub timeout: Duration,
32
33    /// Number of attempts to make performing a resolution for a single name. Note that
34    /// any response from a DNS server counts as "success". Only timeout or network errors
35    /// will trigger retries.
36    pub attempts: u8,
37
38    /// If true, `nameservers` will be round-robin load balanced for each resolution. If false
39    /// the nameservers are tried in-order for each resolution.
40    pub rotate: bool,
41
42    /// Max number of open sockets or connections to each nameserver. Default is to keep one
43    /// open socket or connection per nameserver. Set to 0 to disable this behavior.
44    pub pool_max_idle: u64,
45}
46
47impl Default for DnsClientConfig {
48    fn default() -> Self {
49        // Default values picked based on `man 5 resolv.conf` when relevant.
50        Self {
51            nameservers: vec![DEFAULT_NAMESERVER],
52            timeout: Duration::from_secs(5),
53            attempts: 2,
54            rotate: false,
55            pool_max_idle: 1,
56        }
57    }
58}
59
60/// Client for performing DNS queries and returning the results.
61///
62/// There is currently only a single non-test implementation because this
63/// trait exists to make testing consumers easier.
64#[async_trait]
65pub trait DnsClient {
66    async fn resolve(
67        &self,
68        id: MessageId,
69        name: Name,
70        rtype: RecordType,
71        rclass: RecordClass,
72    ) -> Result<Message, MtopError>;
73}
74
75/// Implementation of a `DnsClient` that uses UDP with TCP fallback.
76///
77/// Supports nameserver rotation, retries, timeouts, and pooling of client
78/// connections. Names are assumed to already be fully qualified, meaning
79/// that they are not combined with a search domain.
80///
81/// Timeouts are handled by the client itself and so callers should _not_
82/// add a timeout on the `resolve` method. Note that timeouts are per-network
83/// operation. This means that a single call to `resolve` make take longer
84/// than the timeout since failed network operations are retried.
85#[derive(Debug)]
86pub struct DefaultDnsClient {
87    config: DnsClientConfig,
88    server_idx: AtomicUsize,
89    udp_pool: ClientPool<SocketAddr, UdpConnection>,
90    tcp_pool: ClientPool<SocketAddr, TcpConnection>,
91}
92
93impl DefaultDnsClient {
94    /// Create a new DnsClient that will resolve names using UDP or TCP connections
95    /// and behavior based on a resolv.conf configuration file.
96    pub fn new<U, T>(config: DnsClientConfig, udp_factory: U, tcp_factory: T) -> Self
97    where
98        U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
99        T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
100    {
101        let udp_config = ClientPoolConfig {
102            name: "dns-udp".to_owned(),
103            max_idle: config.pool_max_idle,
104        };
105
106        let tcp_config = ClientPoolConfig {
107            name: "dns-tcp".to_owned(),
108            max_idle: config.pool_max_idle,
109        };
110
111        Self {
112            config,
113            server_idx: AtomicUsize::new(0),
114            udp_pool: ClientPool::new(udp_config, udp_factory),
115            tcp_pool: ClientPool::new(tcp_config, tcp_factory),
116        }
117    }
118
119    async fn exchange(&self, msg: &Message, server: &SocketAddr) -> Result<Message, MtopError> {
120        let res = async {
121            let mut conn = self.udp_pool.get(server).await?;
122            let res = conn.exchange(msg).await;
123            if res.is_ok() {
124                self.udp_pool.put(conn).await;
125            }
126
127            res
128        }
129        .timeout(self.config.timeout, format!("client.exchange udp://{}", server))
130        .await?;
131
132        // If the UDP response indicates the message was truncated, we discard
133        // it and repeat the query using TCP.
134        if res.flags().is_truncated() {
135            tracing::debug!(message = "UDP response truncated, retrying with TCP", flags = ?res.flags(), server = %server);
136            async {
137                let mut conn = self.tcp_pool.get(server).await?;
138                let res = conn.exchange(msg).await;
139                if res.is_ok() {
140                    self.tcp_pool.put(conn).await;
141                }
142
143                res
144            }
145            .timeout(self.config.timeout, format!("client.exchange tcp://{}", server))
146            .await
147        } else {
148            Ok(res)
149        }
150    }
151
152    // Get the index of nameserver that should be used for a query based on if the client has
153    // been configured to roundrobin between nameservers or not.
154    fn starting_idx(&self) -> usize {
155        if self.config.rotate {
156            self.server_idx.fetch_add(1, Ordering::Relaxed)
157        } else {
158            0
159        }
160    }
161
162    // Get an iterator that will visit every nameserver once starting from the provided index.
163    fn nameserver_iterator(&self, idx: usize) -> impl Iterator<Item = &SocketAddr> {
164        self.config
165            .nameservers
166            .iter()
167            .cycle()
168            .skip(idx)
169            .take(self.config.nameservers.len())
170    }
171}
172
173#[async_trait]
174impl DnsClient for DefaultDnsClient {
175    async fn resolve(
176        &self,
177        id: MessageId,
178        name: Name,
179        rtype: RecordType,
180        rclass: RecordClass,
181    ) -> Result<Message, MtopError> {
182        let full = name.to_fqdn();
183        let flags = Flags::default().set_recursion_desired();
184        let question = Question::new(full.clone(), rtype).set_qclass(rclass);
185        let message = Message::new(id, flags).add_question(question);
186        let start = self.starting_idx();
187
188        let mut errors = Vec::new();
189        for attempt in 0..self.config.attempts {
190            for server in self.nameserver_iterator(start) {
191                match self.exchange(&message, server).await {
192                    Ok(v) => {
193                        // NoError or a NameError is a conclusive answer. We either have results
194                        // or this is a bad domain. Any other type of response means we have to try
195                        // another server.
196                        let rc = v.flags().get_response_code();
197                        if rc == ResponseCode::NoError || rc == ResponseCode::NameError {
198                            return Ok(v);
199                        }
200
201                        tracing::debug!(message = "unsuitable response from nameserver, trying next one", server = %server, attempt = attempt + 1, max_attempts = self.config.attempts, response_code = ?rc);
202                        errors.push(rc.to_string());
203                    }
204                    Err(e) => {
205                        tracing::debug!(message = "nameserver failed, trying next one", server = %server, attempt = attempt + 1, max_attempts = self.config.attempts, err = %e);
206                        errors.push(e.to_string());
207                    }
208                }
209            }
210
211            if attempt + 1 < self.config.attempts {
212                tracing::debug!(
213                    message = "all nameservers failed, retrying",
214                    attempt = attempt + 1,
215                    max_attempts = self.config.attempts
216                );
217            }
218        }
219
220        Err(MtopError::runtime(format!(
221            "no nameservers returned suitable responses for name {} type {} class {}: {}",
222            full,
223            rtype,
224            rclass,
225            errors.join("; ")
226        )))
227    }
228}
229
230/// Connection for unconditionally sending and receiving DNS messages using TCP streams.
231/// Messages are sent with a two byte prefix that indicates the size of the message.
232/// Responses are expected to have the same prefix. The message ID of responses is
233/// checked to ensure it matches the request ID. If it does not, an error is returned.
234pub struct TcpConnection {
235    read: BufReader<Box<dyn AsyncRead + Send + Sync + Unpin>>,
236    write: BufWriter<Box<dyn AsyncWrite + Send + Sync + Unpin>>,
237    buffer: Vec<u8>,
238}
239
240impl TcpConnection {
241    pub fn new<R, W>(read: R, write: W) -> Self
242    where
243        R: AsyncRead + Unpin + Sync + Send + 'static,
244        W: AsyncWrite + Unpin + Sync + Send + 'static,
245    {
246        Self {
247            read: BufReader::new(Box::new(read)),
248            write: BufWriter::new(Box::new(write)),
249            buffer: Vec::with_capacity(DEFAULT_MESSAGE_BUFFER),
250        }
251    }
252
253    pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
254        // Write the message to a local buffer and then send it, prefixed
255        // with the size of the message.
256        self.buffer.clear();
257        msg.write_network_bytes(&mut self.buffer)?;
258        self.write.write_u16(self.buffer.len() as u16).await?;
259        self.write.write_all(&self.buffer).await?;
260        self.write.flush().await?;
261
262        // Read the prefixed size of the response in big-endian (network)
263        // order and then read exactly that many bytes into our buffer.
264        let sz = self.read.read_u16().await?;
265        self.buffer.clear();
266        self.buffer.resize(usize::from(sz), 0);
267        self.read.read_exact(&mut self.buffer).await?;
268
269        let mut cur = Cursor::new(&self.buffer);
270        let res = Message::read_network_bytes(&mut cur)?;
271        if res.id() != msg.id() {
272            Err(MtopError::runtime(format!(
273                "unexpected DNS MessageId; expected {}, got {}",
274                msg.id(),
275                res.id()
276            )))
277        } else {
278            Ok(res)
279        }
280    }
281}
282
283impl fmt::Debug for TcpConnection {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        write!(f, "TcpConnection {{ ... }}")
286    }
287}
288
289/// Connection for unconditionally sending and receiving DNS messages using UDP packets.
290/// The message ID of responses is checked to ensure it matches the request ID. If it
291/// does not, the response is discarded and the client will wait for another response
292/// until it gets one with a matching ID.
293pub struct UdpConnection {
294    read: Box<dyn AsyncRead + Send + Sync + Unpin>,
295    write: Box<dyn AsyncWrite + Send + Sync + Unpin>,
296    buffer: Vec<u8>,
297    packet_size: usize,
298}
299
300impl UdpConnection {
301    pub fn new<R, W>(read: R, write: W) -> Self
302    where
303        R: AsyncRead + Unpin + Sync + Send + 'static,
304        W: AsyncWrite + Unpin + Sync + Send + 'static,
305    {
306        Self {
307            read: Box::new(read),
308            write: Box::new(write),
309            buffer: Vec::with_capacity(DEFAULT_MESSAGE_BUFFER),
310            packet_size: DEFAULT_MESSAGE_BUFFER,
311        }
312    }
313
314    pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
315        self.buffer.clear();
316        msg.write_network_bytes(&mut self.buffer)?;
317        // We expect this to be a datagram socket so we only do a single write.
318        let n = self.write.write(&self.buffer).await?;
319        if n != self.buffer.len() {
320            return Err(MtopError::runtime(format!(
321                "short write to UDP socket. expected {}, got {}",
322                self.buffer.len(),
323                n
324            )));
325        }
326        self.write.flush().await?;
327
328        // Resize to our packet size since the .read() call will only read up to
329        // the size of the buffer at most.
330        self.buffer.clear();
331        self.buffer.resize(self.packet_size, 0);
332
333        loop {
334            let n = self.read.read(&mut self.buffer).await?;
335            let cur = Cursor::new(&self.buffer[0..n]);
336            let res = Message::read_network_bytes(cur)?;
337            if res.id() == msg.id() {
338                return Ok(res);
339            }
340        }
341    }
342}
343
344impl fmt::Debug for UdpConnection {
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        write!(f, "UdpConnection {{ ... }}")
347    }
348}
349
350/// Adapter for reading and writing to a `UdpSocket` using the `AsyncRead` and `AsyncWrite`
351/// traits. This exists to enable easier testing of `UdpConnection` by allowing alternate
352/// implementations of those traits to be used.
353pub(crate) struct SocketAdapter(UdpSocket);
354
355impl SocketAdapter {
356    pub(crate) fn new(sock: UdpSocket) -> Self {
357        Self(sock)
358    }
359}
360
361impl AsyncRead for SocketAdapter {
362    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
363        self.0.poll_recv(cx, buf)
364    }
365}
366
367impl AsyncWrite for SocketAdapter {
368    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
369        self.0.poll_send(cx, buf)
370    }
371
372    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
373        Poll::Ready(Ok(()))
374    }
375
376    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
377        Poll::Ready(Ok(()))
378    }
379}
380
381/// Implementation of `ClientFactory` for creating concrete `UdpConnection` instances
382/// that use a UDP socket.
383#[derive(Debug, Clone, Default)]
384pub struct UdpConnectionFactory;
385
386#[async_trait]
387impl ClientFactory<SocketAddr, UdpConnection> for UdpConnectionFactory {
388    async fn make(&self, address: &SocketAddr) -> Result<UdpConnection, MtopError> {
389        let local = if address.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
390        let sock = UdpSocket::bind(local).await?;
391        sock.connect(address).await?;
392
393        let adapter = SocketAdapter::new(sock);
394        let (read, write) = tokio::io::split(adapter);
395        Ok(UdpConnection::new(read, write))
396    }
397}
398
399/// Implementation of `ClientFactory` for creating concrete `TcpConnection` instances
400/// that use a TCP socket.
401#[derive(Debug, Clone, Default)]
402pub struct TcpConnectionFactory;
403
404#[async_trait]
405impl ClientFactory<SocketAddr, TcpConnection> for TcpConnectionFactory {
406    async fn make(&self, address: &SocketAddr) -> Result<TcpConnection, MtopError> {
407        let (read, write) = tcp_connect(address).await?;
408        Ok(TcpConnection::new(read, write))
409    }
410}
411
412#[cfg(test)]
413mod test {
414    use super::{DefaultDnsClient, DnsClient, DnsClientConfig, TcpConnection, UdpConnection};
415    use crate::core::ErrorKind;
416    use crate::dns::core::{RecordClass, RecordType};
417    use crate::dns::message::{Flags, Message, MessageId, Question, Record, ResponseCode};
418    use crate::dns::name::Name;
419    use crate::dns::rdata::{RecordData, RecordDataA};
420    use crate::dns::test::{TestTcpClientFactory, TestTcpSocket, TestUdpClientFactory, TestUdpSocket};
421    use std::collections::HashMap;
422    use std::io::Cursor;
423    use std::net::{Ipv4Addr, SocketAddr};
424    use std::str::FromStr;
425
426    fn new_request(id: MessageId) -> Message {
427        let flags = Flags::default().set_query().set_recursion_desired();
428        let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
429        Message::new(id, flags).add_question(question)
430    }
431
432    fn new_empty_response(id: MessageId) -> Message {
433        let flags = Flags::default()
434            .set_response()
435            .set_recursion_desired()
436            .set_recursion_available();
437        let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
438        Message::new(id, flags).add_question(question)
439    }
440
441    fn new_response(id: MessageId) -> Message {
442        let response = new_empty_response(id);
443        let answer = Record::new(
444            Name::from_str("example.com.").unwrap(),
445            RecordType::A,
446            RecordClass::INET,
447            300,
448            RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 1))),
449        );
450
451        response.add_answer(answer)
452    }
453
454    #[tokio::test]
455    async fn test_tcp_client_eof_reading_length() {
456        let write = Vec::new();
457        let read = Cursor::new(Vec::new());
458
459        let id = MessageId::from(123);
460        let request = new_request(id);
461
462        let mut client = TcpConnection::new(read, write);
463
464        let res = client.exchange(&request).await;
465        let err = res.unwrap_err();
466        assert_eq!(ErrorKind::IO, err.kind());
467    }
468
469    #[tokio::test]
470    async fn test_tcp_client_eof_reading_message() {
471        let write = Vec::new();
472        let read = Cursor::new(vec![
473            0, 200, // message length
474        ]);
475
476        let id = MessageId::from(123);
477        let request = new_request(id);
478
479        let mut client = TcpConnection::new(read, write);
480
481        let res = client.exchange(&request).await;
482        let err = res.unwrap_err();
483        assert_eq!(ErrorKind::IO, err.kind());
484    }
485
486    #[tokio::test]
487    async fn test_tcp_client_id_mismatch() {
488        let response_id = MessageId::from(456);
489        let response = new_response(response_id);
490
491        let request_id = MessageId::from(123);
492        let request = new_request(request_id);
493
494        let sock = TestTcpSocket::new(vec![response]);
495        let (read, write) = tokio::io::split(sock);
496        let mut client = TcpConnection::new(read, write);
497
498        let result = client.exchange(&request).await;
499        let err = result.unwrap_err();
500        assert_eq!(ErrorKind::Runtime, err.kind());
501    }
502
503    #[tokio::test]
504    async fn test_tcp_client_single_message() {
505        let id = MessageId::from(123);
506        let response = new_response(id);
507        let request = new_request(id);
508
509        let sock = TestTcpSocket::new(vec![response.clone()]);
510        let (read, write) = tokio::io::split(sock);
511        let mut client = TcpConnection::new(read, write);
512
513        let result = client.exchange(&request).await.unwrap();
514        assert_eq!(response, result);
515    }
516
517    #[tokio::test]
518    async fn test_tcp_client_multiple_message() {
519        let id1 = MessageId::from(123);
520        let response1 = new_response(id1);
521        let request1 = new_request(id1);
522        let id2 = MessageId::from(456);
523        let response2 = new_response(id2);
524        let request2 = new_request(id2);
525
526        let sock = TestTcpSocket::new(vec![response2.clone(), response1.clone()]);
527        let (read, write) = tokio::io::split(sock);
528        let mut client = TcpConnection::new(read, write);
529
530        let result1 = client.exchange(&request1).await.unwrap();
531        assert_eq!(response1, result1);
532
533        let result2 = client.exchange(&request2).await.unwrap();
534        assert_eq!(response2, result2);
535    }
536
537    #[tokio::test]
538    async fn test_udp_client_success() {
539        let id = MessageId::from(123);
540        let response = new_response(id);
541        let request = new_request(id);
542
543        let sock = TestUdpSocket::new(vec![response.clone()]);
544        let (read, write) = tokio::io::split(sock);
545        let mut client = UdpConnection::new(read, write);
546
547        let result = client.exchange(&request).await.unwrap();
548        assert_eq!(response, result);
549    }
550
551    #[tokio::test]
552    async fn test_udp_client_one_id_mismatch() {
553        let id1 = MessageId::from(456);
554        let response1 = new_response(id1);
555        let id2 = MessageId::from(123);
556        let response2 = new_response(id2);
557
558        // Note that the request has the ID of the second response because
559        // we are testing the that first response is discarded due to the ID
560        // not matching.
561        let request = new_request(id2);
562
563        let sock = TestUdpSocket::new(vec![response2.clone(), response1.clone()]);
564        let (read, write) = tokio::io::split(sock);
565        let mut client = UdpConnection::new(read, write);
566
567        let result = client.exchange(&request).await.unwrap();
568        assert_eq!(response2, result);
569    }
570
571    #[tokio::test]
572    async fn test_default_dns_client_resolve_name_error() {
573        let id = MessageId::from(123);
574        let name = Name::from_str("example.com.").unwrap();
575        let server = "127.0.0.1:53".parse().unwrap();
576
577        let udp_response = new_empty_response(id);
578        let flags = udp_response.flags().set_response_code(ResponseCode::NameError);
579        let udp_response = udp_response.set_flags(flags);
580
581        let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
582        udp_mapping.entry(server).or_default().push(udp_response);
583        let udp_factory = TestUdpClientFactory::new(udp_mapping);
584        let tcp_factory = TestTcpClientFactory::new(HashMap::new());
585
586        let cfg = DnsClientConfig::default();
587        let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
588        let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
589
590        assert_eq!(ResponseCode::NameError, result.flags().get_response_code());
591        assert!(result.answers().is_empty());
592    }
593
594    #[tokio::test]
595    async fn test_default_dns_client_resolve_success() {
596        let id = MessageId::from(123);
597        let name = Name::from_str("example.com.").unwrap();
598        let server = "127.0.0.1:53".parse().unwrap();
599
600        let udp_response = new_response(id);
601        let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
602        udp_mapping.entry(server).or_default().push(udp_response.clone());
603        let udp_factory = TestUdpClientFactory::new(udp_mapping);
604        let tcp_factory = TestTcpClientFactory::new(HashMap::new());
605
606        let cfg = DnsClientConfig::default();
607        let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
608        let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
609
610        assert_eq!(udp_response, result);
611    }
612
613    #[tokio::test]
614    async fn test_default_dns_client_resolve_one_error() {
615        let id = MessageId::from(123);
616        let name = Name::from_str("example.com.").unwrap();
617        let server = "127.0.0.1:53".parse().unwrap();
618
619        let udp_response1 = new_empty_response(id);
620        let flags = udp_response1.flags().set_response_code(ResponseCode::ServerFailure);
621        let udp_response1 = udp_response1.set_flags(flags);
622        let udp_response2 = new_response(id);
623
624        let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
625        let entry = udp_mapping.entry(server).or_default();
626        entry.push(udp_response2.clone());
627        entry.push(udp_response1);
628
629        let udp_factory = TestUdpClientFactory::new(udp_mapping);
630        let tcp_factory = TestTcpClientFactory::new(HashMap::new());
631
632        let cfg = DnsClientConfig::default();
633        let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
634        let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
635
636        assert_eq!(udp_response2, result);
637    }
638
639    #[tokio::test]
640    async fn test_default_dns_client_resolve_all_errors() {
641        let id = MessageId::from(123);
642        let name = Name::from_str("example.com.").unwrap();
643        let server = "127.0.0.1:53".parse().unwrap();
644
645        let udp_response1 = new_empty_response(id);
646        let flags = udp_response1.flags().set_response_code(ResponseCode::ServerFailure);
647        let udp_response1 = udp_response1.set_flags(flags);
648
649        let udp_response2 = new_empty_response(id);
650        let flags = udp_response2.flags().set_response_code(ResponseCode::ServerFailure);
651        let udp_response2 = udp_response2.set_flags(flags);
652
653        let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
654        let entry = udp_mapping.entry(server).or_default();
655        entry.push(udp_response2.clone());
656        entry.push(udp_response1);
657
658        let udp_factory = TestUdpClientFactory::new(udp_mapping);
659        let tcp_factory = TestTcpClientFactory::new(HashMap::new());
660
661        let cfg = DnsClientConfig::default();
662        let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
663        let err = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap_err();
664
665        assert_eq!(ErrorKind::Runtime, err.kind());
666    }
667
668    #[tokio::test]
669    async fn test_default_dns_client_resolve_one_bad_server() {
670        let id = MessageId::from(123);
671        let name = Name::from_str("example.com.").unwrap();
672        let server1 = "127.0.0.1:53".parse().unwrap();
673        let server2 = "127.0.0.2:53".parse().unwrap();
674
675        let udp_response1 = new_empty_response(id);
676        let flags = udp_response1.flags().set_response_code(ResponseCode::ServerFailure);
677        let udp_response1 = udp_response1.set_flags(flags);
678        let udp_response2 = new_response(id);
679
680        let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
681        udp_mapping.entry(server1).or_default().push(udp_response1);
682        udp_mapping.entry(server2).or_default().push(udp_response2.clone());
683
684        let udp_factory = TestUdpClientFactory::new(udp_mapping);
685        let tcp_factory = TestTcpClientFactory::new(HashMap::new());
686
687        let cfg = DnsClientConfig {
688            nameservers: vec![server1, server2],
689            ..Default::default()
690        };
691        let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
692        let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
693
694        assert_eq!(udp_response2, result);
695    }
696
697    #[tokio::test]
698    async fn test_default_dns_client_resolve_udp_truncation() {
699        let id = MessageId::from(123);
700        let name = Name::from_str("example.com.").unwrap();
701        let server = "127.0.0.1:53".parse().unwrap();
702
703        let udp_response = new_empty_response(id);
704        let flags = udp_response.flags().set_truncated();
705        let udp_response = udp_response.set_flags(flags);
706        let tcp_response = new_response(id);
707
708        let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
709        udp_mapping.entry(server).or_default().push(udp_response);
710
711        let mut tcp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
712        tcp_mapping.entry(server).or_default().push(tcp_response.clone());
713
714        let udp_factory = TestUdpClientFactory::new(udp_mapping);
715        let tcp_factory = TestTcpClientFactory::new(tcp_mapping);
716
717        let cfg = DnsClientConfig::default();
718        let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
719        let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
720
721        assert_eq!(tcp_response, result);
722    }
723}