1use crate::core::MtopError;
2use crate::dns::core::{RecordClass, RecordType};
3use crate::dns::message::{Flags, Message, MessageId, Question, ResponseCode};
4use crate::dns::name::Name;
5use crate::net::tcp_connect;
6use crate::pool::{ClientFactory, ClientPool, ClientPoolConfig};
7use crate::timeout::Timeout;
8use async_trait::async_trait;
9use std::fmt;
10use std::io::{self, Cursor, Error};
11use std::net::{IpAddr, Ipv4Addr, SocketAddr};
12use std::pin::Pin;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::task::{Context, Poll};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadBuf};
17use tokio::net::UdpSocket;
18
19const DEFAULT_NAMESERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53);
20const DEFAULT_MESSAGE_BUFFER: usize = 512;
21
22#[derive(Debug, Clone)]
24pub struct DnsClientConfig {
25 pub nameservers: Vec<SocketAddr>,
28
29 pub timeout: Duration,
32
33 pub attempts: u8,
37
38 pub rotate: bool,
41
42 pub pool_max_idle: u64,
45}
46
47impl Default for DnsClientConfig {
48 fn default() -> Self {
49 Self {
51 nameservers: vec![DEFAULT_NAMESERVER],
52 timeout: Duration::from_secs(5),
53 attempts: 2,
54 rotate: false,
55 pool_max_idle: 1,
56 }
57 }
58}
59
60#[async_trait]
65pub trait DnsClient {
66 async fn resolve(
67 &self,
68 id: MessageId,
69 name: Name,
70 rtype: RecordType,
71 rclass: RecordClass,
72 ) -> Result<Message, MtopError>;
73}
74
75#[derive(Debug)]
86pub struct DefaultDnsClient {
87 config: DnsClientConfig,
88 server_idx: AtomicUsize,
89 udp_pool: ClientPool<SocketAddr, UdpConnection>,
90 tcp_pool: ClientPool<SocketAddr, TcpConnection>,
91}
92
93impl DefaultDnsClient {
94 pub fn new<U, T>(config: DnsClientConfig, udp_factory: U, tcp_factory: T) -> Self
97 where
98 U: ClientFactory<SocketAddr, UdpConnection> + Send + Sync + 'static,
99 T: ClientFactory<SocketAddr, TcpConnection> + Send + Sync + 'static,
100 {
101 let udp_config = ClientPoolConfig {
102 name: "dns-udp".to_owned(),
103 max_idle: config.pool_max_idle,
104 };
105
106 let tcp_config = ClientPoolConfig {
107 name: "dns-tcp".to_owned(),
108 max_idle: config.pool_max_idle,
109 };
110
111 Self {
112 config,
113 server_idx: AtomicUsize::new(0),
114 udp_pool: ClientPool::new(udp_config, udp_factory),
115 tcp_pool: ClientPool::new(tcp_config, tcp_factory),
116 }
117 }
118
119 async fn exchange(&self, msg: &Message, server: &SocketAddr) -> Result<Message, MtopError> {
120 let res = async {
121 let mut conn = self.udp_pool.get(server).await?;
122 let res = conn.exchange(msg).await;
123 if res.is_ok() {
124 self.udp_pool.put(conn).await;
125 }
126
127 res
128 }
129 .timeout(self.config.timeout, format!("client.exchange udp://{}", server))
130 .await?;
131
132 if res.flags().is_truncated() {
135 tracing::debug!(message = "UDP response truncated, retrying with TCP", flags = ?res.flags(), server = %server);
136 async {
137 let mut conn = self.tcp_pool.get(server).await?;
138 let res = conn.exchange(msg).await;
139 if res.is_ok() {
140 self.tcp_pool.put(conn).await;
141 }
142
143 res
144 }
145 .timeout(self.config.timeout, format!("client.exchange tcp://{}", server))
146 .await
147 } else {
148 Ok(res)
149 }
150 }
151
152 fn starting_idx(&self) -> usize {
155 if self.config.rotate {
156 self.server_idx.fetch_add(1, Ordering::Relaxed)
157 } else {
158 0
159 }
160 }
161
162 fn nameserver_iterator(&self, idx: usize) -> impl Iterator<Item = &SocketAddr> {
164 self.config
165 .nameservers
166 .iter()
167 .cycle()
168 .skip(idx)
169 .take(self.config.nameservers.len())
170 }
171}
172
173#[async_trait]
174impl DnsClient for DefaultDnsClient {
175 async fn resolve(
176 &self,
177 id: MessageId,
178 name: Name,
179 rtype: RecordType,
180 rclass: RecordClass,
181 ) -> Result<Message, MtopError> {
182 let full = name.to_fqdn();
183 let flags = Flags::default().set_recursion_desired();
184 let question = Question::new(full.clone(), rtype).set_qclass(rclass);
185 let message = Message::new(id, flags).add_question(question);
186 let start = self.starting_idx();
187
188 let mut errors = Vec::new();
189 for attempt in 0..self.config.attempts {
190 for server in self.nameserver_iterator(start) {
191 match self.exchange(&message, server).await {
192 Ok(v) => {
193 let rc = v.flags().get_response_code();
197 if rc == ResponseCode::NoError || rc == ResponseCode::NameError {
198 return Ok(v);
199 }
200
201 tracing::debug!(message = "unsuitable response from nameserver, trying next one", server = %server, attempt = attempt + 1, max_attempts = self.config.attempts, response_code = ?rc);
202 errors.push(rc.to_string());
203 }
204 Err(e) => {
205 tracing::debug!(message = "nameserver failed, trying next one", server = %server, attempt = attempt + 1, max_attempts = self.config.attempts, err = %e);
206 errors.push(e.to_string());
207 }
208 }
209 }
210
211 if attempt + 1 < self.config.attempts {
212 tracing::debug!(
213 message = "all nameservers failed, retrying",
214 attempt = attempt + 1,
215 max_attempts = self.config.attempts
216 );
217 }
218 }
219
220 Err(MtopError::runtime(format!(
221 "no nameservers returned suitable responses for name {} type {} class {}: {}",
222 full,
223 rtype,
224 rclass,
225 errors.join("; ")
226 )))
227 }
228}
229
230pub struct TcpConnection {
235 read: BufReader<Box<dyn AsyncRead + Send + Sync + Unpin>>,
236 write: BufWriter<Box<dyn AsyncWrite + Send + Sync + Unpin>>,
237 buffer: Vec<u8>,
238}
239
240impl TcpConnection {
241 pub fn new<R, W>(read: R, write: W) -> Self
242 where
243 R: AsyncRead + Unpin + Sync + Send + 'static,
244 W: AsyncWrite + Unpin + Sync + Send + 'static,
245 {
246 Self {
247 read: BufReader::new(Box::new(read)),
248 write: BufWriter::new(Box::new(write)),
249 buffer: Vec::with_capacity(DEFAULT_MESSAGE_BUFFER),
250 }
251 }
252
253 pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
254 self.buffer.clear();
257 msg.write_network_bytes(&mut self.buffer)?;
258 self.write.write_u16(self.buffer.len() as u16).await?;
259 self.write.write_all(&self.buffer).await?;
260 self.write.flush().await?;
261
262 let sz = self.read.read_u16().await?;
265 self.buffer.clear();
266 self.buffer.resize(usize::from(sz), 0);
267 self.read.read_exact(&mut self.buffer).await?;
268
269 let mut cur = Cursor::new(&self.buffer);
270 let res = Message::read_network_bytes(&mut cur)?;
271 if res.id() != msg.id() {
272 Err(MtopError::runtime(format!(
273 "unexpected DNS MessageId; expected {}, got {}",
274 msg.id(),
275 res.id()
276 )))
277 } else {
278 Ok(res)
279 }
280 }
281}
282
283impl fmt::Debug for TcpConnection {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 write!(f, "TcpConnection {{ ... }}")
286 }
287}
288
289pub struct UdpConnection {
294 read: Box<dyn AsyncRead + Send + Sync + Unpin>,
295 write: Box<dyn AsyncWrite + Send + Sync + Unpin>,
296 buffer: Vec<u8>,
297 packet_size: usize,
298}
299
300impl UdpConnection {
301 pub fn new<R, W>(read: R, write: W) -> Self
302 where
303 R: AsyncRead + Unpin + Sync + Send + 'static,
304 W: AsyncWrite + Unpin + Sync + Send + 'static,
305 {
306 Self {
307 read: Box::new(read),
308 write: Box::new(write),
309 buffer: Vec::with_capacity(DEFAULT_MESSAGE_BUFFER),
310 packet_size: DEFAULT_MESSAGE_BUFFER,
311 }
312 }
313
314 pub async fn exchange(&mut self, msg: &Message) -> Result<Message, MtopError> {
315 self.buffer.clear();
316 msg.write_network_bytes(&mut self.buffer)?;
317 let n = self.write.write(&self.buffer).await?;
319 if n != self.buffer.len() {
320 return Err(MtopError::runtime(format!(
321 "short write to UDP socket. expected {}, got {}",
322 self.buffer.len(),
323 n
324 )));
325 }
326 self.write.flush().await?;
327
328 self.buffer.clear();
331 self.buffer.resize(self.packet_size, 0);
332
333 loop {
334 let n = self.read.read(&mut self.buffer).await?;
335 let cur = Cursor::new(&self.buffer[0..n]);
336 let res = Message::read_network_bytes(cur)?;
337 if res.id() == msg.id() {
338 return Ok(res);
339 }
340 }
341 }
342}
343
344impl fmt::Debug for UdpConnection {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 write!(f, "UdpConnection {{ ... }}")
347 }
348}
349
350pub(crate) struct SocketAdapter(UdpSocket);
354
355impl SocketAdapter {
356 pub(crate) fn new(sock: UdpSocket) -> Self {
357 Self(sock)
358 }
359}
360
361impl AsyncRead for SocketAdapter {
362 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
363 self.0.poll_recv(cx, buf)
364 }
365}
366
367impl AsyncWrite for SocketAdapter {
368 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
369 self.0.poll_send(cx, buf)
370 }
371
372 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
373 Poll::Ready(Ok(()))
374 }
375
376 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
377 Poll::Ready(Ok(()))
378 }
379}
380
381#[derive(Debug, Clone, Default)]
384pub struct UdpConnectionFactory;
385
386#[async_trait]
387impl ClientFactory<SocketAddr, UdpConnection> for UdpConnectionFactory {
388 async fn make(&self, address: &SocketAddr) -> Result<UdpConnection, MtopError> {
389 let local = if address.is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
390 let sock = UdpSocket::bind(local).await?;
391 sock.connect(address).await?;
392
393 let adapter = SocketAdapter::new(sock);
394 let (read, write) = tokio::io::split(adapter);
395 Ok(UdpConnection::new(read, write))
396 }
397}
398
399#[derive(Debug, Clone, Default)]
402pub struct TcpConnectionFactory;
403
404#[async_trait]
405impl ClientFactory<SocketAddr, TcpConnection> for TcpConnectionFactory {
406 async fn make(&self, address: &SocketAddr) -> Result<TcpConnection, MtopError> {
407 let (read, write) = tcp_connect(address).await?;
408 Ok(TcpConnection::new(read, write))
409 }
410}
411
412#[cfg(test)]
413mod test {
414 use super::{DefaultDnsClient, DnsClient, DnsClientConfig, TcpConnection, UdpConnection};
415 use crate::core::ErrorKind;
416 use crate::dns::core::{RecordClass, RecordType};
417 use crate::dns::message::{Flags, Message, MessageId, Question, Record, ResponseCode};
418 use crate::dns::name::Name;
419 use crate::dns::rdata::{RecordData, RecordDataA};
420 use crate::dns::test::{TestTcpClientFactory, TestTcpSocket, TestUdpClientFactory, TestUdpSocket};
421 use std::collections::HashMap;
422 use std::io::Cursor;
423 use std::net::{Ipv4Addr, SocketAddr};
424 use std::str::FromStr;
425
426 fn new_request(id: MessageId) -> Message {
427 let flags = Flags::default().set_query().set_recursion_desired();
428 let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
429 Message::new(id, flags).add_question(question)
430 }
431
432 fn new_empty_response(id: MessageId) -> Message {
433 let flags = Flags::default()
434 .set_response()
435 .set_recursion_desired()
436 .set_recursion_available();
437 let question = Question::new(Name::from_str("example.com.").unwrap(), RecordType::A);
438 Message::new(id, flags).add_question(question)
439 }
440
441 fn new_response(id: MessageId) -> Message {
442 let response = new_empty_response(id);
443 let answer = Record::new(
444 Name::from_str("example.com.").unwrap(),
445 RecordType::A,
446 RecordClass::INET,
447 300,
448 RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 1))),
449 );
450
451 response.add_answer(answer)
452 }
453
454 #[tokio::test]
455 async fn test_tcp_client_eof_reading_length() {
456 let write = Vec::new();
457 let read = Cursor::new(Vec::new());
458
459 let id = MessageId::from(123);
460 let request = new_request(id);
461
462 let mut client = TcpConnection::new(read, write);
463
464 let res = client.exchange(&request).await;
465 let err = res.unwrap_err();
466 assert_eq!(ErrorKind::IO, err.kind());
467 }
468
469 #[tokio::test]
470 async fn test_tcp_client_eof_reading_message() {
471 let write = Vec::new();
472 let read = Cursor::new(vec![
473 0, 200, ]);
475
476 let id = MessageId::from(123);
477 let request = new_request(id);
478
479 let mut client = TcpConnection::new(read, write);
480
481 let res = client.exchange(&request).await;
482 let err = res.unwrap_err();
483 assert_eq!(ErrorKind::IO, err.kind());
484 }
485
486 #[tokio::test]
487 async fn test_tcp_client_id_mismatch() {
488 let response_id = MessageId::from(456);
489 let response = new_response(response_id);
490
491 let request_id = MessageId::from(123);
492 let request = new_request(request_id);
493
494 let sock = TestTcpSocket::new(vec![response]);
495 let (read, write) = tokio::io::split(sock);
496 let mut client = TcpConnection::new(read, write);
497
498 let result = client.exchange(&request).await;
499 let err = result.unwrap_err();
500 assert_eq!(ErrorKind::Runtime, err.kind());
501 }
502
503 #[tokio::test]
504 async fn test_tcp_client_single_message() {
505 let id = MessageId::from(123);
506 let response = new_response(id);
507 let request = new_request(id);
508
509 let sock = TestTcpSocket::new(vec![response.clone()]);
510 let (read, write) = tokio::io::split(sock);
511 let mut client = TcpConnection::new(read, write);
512
513 let result = client.exchange(&request).await.unwrap();
514 assert_eq!(response, result);
515 }
516
517 #[tokio::test]
518 async fn test_tcp_client_multiple_message() {
519 let id1 = MessageId::from(123);
520 let response1 = new_response(id1);
521 let request1 = new_request(id1);
522 let id2 = MessageId::from(456);
523 let response2 = new_response(id2);
524 let request2 = new_request(id2);
525
526 let sock = TestTcpSocket::new(vec![response2.clone(), response1.clone()]);
527 let (read, write) = tokio::io::split(sock);
528 let mut client = TcpConnection::new(read, write);
529
530 let result1 = client.exchange(&request1).await.unwrap();
531 assert_eq!(response1, result1);
532
533 let result2 = client.exchange(&request2).await.unwrap();
534 assert_eq!(response2, result2);
535 }
536
537 #[tokio::test]
538 async fn test_udp_client_success() {
539 let id = MessageId::from(123);
540 let response = new_response(id);
541 let request = new_request(id);
542
543 let sock = TestUdpSocket::new(vec![response.clone()]);
544 let (read, write) = tokio::io::split(sock);
545 let mut client = UdpConnection::new(read, write);
546
547 let result = client.exchange(&request).await.unwrap();
548 assert_eq!(response, result);
549 }
550
551 #[tokio::test]
552 async fn test_udp_client_one_id_mismatch() {
553 let id1 = MessageId::from(456);
554 let response1 = new_response(id1);
555 let id2 = MessageId::from(123);
556 let response2 = new_response(id2);
557
558 let request = new_request(id2);
562
563 let sock = TestUdpSocket::new(vec![response2.clone(), response1.clone()]);
564 let (read, write) = tokio::io::split(sock);
565 let mut client = UdpConnection::new(read, write);
566
567 let result = client.exchange(&request).await.unwrap();
568 assert_eq!(response2, result);
569 }
570
571 #[tokio::test]
572 async fn test_default_dns_client_resolve_name_error() {
573 let id = MessageId::from(123);
574 let name = Name::from_str("example.com.").unwrap();
575 let server = "127.0.0.1:53".parse().unwrap();
576
577 let udp_response = new_empty_response(id);
578 let flags = udp_response.flags().set_response_code(ResponseCode::NameError);
579 let udp_response = udp_response.set_flags(flags);
580
581 let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
582 udp_mapping.entry(server).or_default().push(udp_response);
583 let udp_factory = TestUdpClientFactory::new(udp_mapping);
584 let tcp_factory = TestTcpClientFactory::new(HashMap::new());
585
586 let cfg = DnsClientConfig::default();
587 let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
588 let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
589
590 assert_eq!(ResponseCode::NameError, result.flags().get_response_code());
591 assert!(result.answers().is_empty());
592 }
593
594 #[tokio::test]
595 async fn test_default_dns_client_resolve_success() {
596 let id = MessageId::from(123);
597 let name = Name::from_str("example.com.").unwrap();
598 let server = "127.0.0.1:53".parse().unwrap();
599
600 let udp_response = new_response(id);
601 let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
602 udp_mapping.entry(server).or_default().push(udp_response.clone());
603 let udp_factory = TestUdpClientFactory::new(udp_mapping);
604 let tcp_factory = TestTcpClientFactory::new(HashMap::new());
605
606 let cfg = DnsClientConfig::default();
607 let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
608 let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
609
610 assert_eq!(udp_response, result);
611 }
612
613 #[tokio::test]
614 async fn test_default_dns_client_resolve_one_error() {
615 let id = MessageId::from(123);
616 let name = Name::from_str("example.com.").unwrap();
617 let server = "127.0.0.1:53".parse().unwrap();
618
619 let udp_response1 = new_empty_response(id);
620 let flags = udp_response1.flags().set_response_code(ResponseCode::ServerFailure);
621 let udp_response1 = udp_response1.set_flags(flags);
622 let udp_response2 = new_response(id);
623
624 let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
625 let entry = udp_mapping.entry(server).or_default();
626 entry.push(udp_response2.clone());
627 entry.push(udp_response1);
628
629 let udp_factory = TestUdpClientFactory::new(udp_mapping);
630 let tcp_factory = TestTcpClientFactory::new(HashMap::new());
631
632 let cfg = DnsClientConfig::default();
633 let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
634 let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
635
636 assert_eq!(udp_response2, result);
637 }
638
639 #[tokio::test]
640 async fn test_default_dns_client_resolve_all_errors() {
641 let id = MessageId::from(123);
642 let name = Name::from_str("example.com.").unwrap();
643 let server = "127.0.0.1:53".parse().unwrap();
644
645 let udp_response1 = new_empty_response(id);
646 let flags = udp_response1.flags().set_response_code(ResponseCode::ServerFailure);
647 let udp_response1 = udp_response1.set_flags(flags);
648
649 let udp_response2 = new_empty_response(id);
650 let flags = udp_response2.flags().set_response_code(ResponseCode::ServerFailure);
651 let udp_response2 = udp_response2.set_flags(flags);
652
653 let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
654 let entry = udp_mapping.entry(server).or_default();
655 entry.push(udp_response2.clone());
656 entry.push(udp_response1);
657
658 let udp_factory = TestUdpClientFactory::new(udp_mapping);
659 let tcp_factory = TestTcpClientFactory::new(HashMap::new());
660
661 let cfg = DnsClientConfig::default();
662 let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
663 let err = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap_err();
664
665 assert_eq!(ErrorKind::Runtime, err.kind());
666 }
667
668 #[tokio::test]
669 async fn test_default_dns_client_resolve_one_bad_server() {
670 let id = MessageId::from(123);
671 let name = Name::from_str("example.com.").unwrap();
672 let server1 = "127.0.0.1:53".parse().unwrap();
673 let server2 = "127.0.0.2:53".parse().unwrap();
674
675 let udp_response1 = new_empty_response(id);
676 let flags = udp_response1.flags().set_response_code(ResponseCode::ServerFailure);
677 let udp_response1 = udp_response1.set_flags(flags);
678 let udp_response2 = new_response(id);
679
680 let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
681 udp_mapping.entry(server1).or_default().push(udp_response1);
682 udp_mapping.entry(server2).or_default().push(udp_response2.clone());
683
684 let udp_factory = TestUdpClientFactory::new(udp_mapping);
685 let tcp_factory = TestTcpClientFactory::new(HashMap::new());
686
687 let cfg = DnsClientConfig {
688 nameservers: vec![server1, server2],
689 ..Default::default()
690 };
691 let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
692 let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
693
694 assert_eq!(udp_response2, result);
695 }
696
697 #[tokio::test]
698 async fn test_default_dns_client_resolve_udp_truncation() {
699 let id = MessageId::from(123);
700 let name = Name::from_str("example.com.").unwrap();
701 let server = "127.0.0.1:53".parse().unwrap();
702
703 let udp_response = new_empty_response(id);
704 let flags = udp_response.flags().set_truncated();
705 let udp_response = udp_response.set_flags(flags);
706 let tcp_response = new_response(id);
707
708 let mut udp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
709 udp_mapping.entry(server).or_default().push(udp_response);
710
711 let mut tcp_mapping: HashMap<SocketAddr, Vec<Message>> = HashMap::new();
712 tcp_mapping.entry(server).or_default().push(tcp_response.clone());
713
714 let udp_factory = TestUdpClientFactory::new(udp_mapping);
715 let tcp_factory = TestTcpClientFactory::new(tcp_mapping);
716
717 let cfg = DnsClientConfig::default();
718 let client = DefaultDnsClient::new(cfg, udp_factory, tcp_factory);
719 let result = client.resolve(id, name, RecordType::A, RecordClass::INET).await.unwrap();
720
721 assert_eq!(tcp_response, result);
722 }
723}