1use core::fmt::{self, Display};
9use core::net::SocketAddr;
10use core::pin::Pin;
11use core::task::{Context, Poll};
12use core::time::Duration;
13use std::collections::HashSet;
14use std::sync::Arc;
15
16use futures_util::{FutureExt, Stream, StreamExt, pin_mut, stream::FuturesUnordered};
17use tracing::{debug, trace, warn};
18
19use crate::error::NetError;
20use crate::proto::op::{DEFAULT_RETRY_FLOOR, DnsRequest, DnsResponse, Message, SerialMessage};
21#[cfg(feature = "__dnssec")]
22use crate::proto::rr::TSigner;
23use crate::runtime::{DnsUdpSocket, RuntimeProvider, Spawn, Time};
24use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
25use crate::udp::udp_stream::NextRandomUdpSocket;
26use crate::xfer::{DnsExchange, DnsRequestSender, DnsResponseStream};
27
28#[must_use = "futures do nothing unless polled"]
34pub struct UdpClientStream<P> {
35 name_server: SocketAddr,
36 timeout: Duration,
37 is_shutdown: bool,
38 #[cfg(feature = "__dnssec")]
39 signer: Option<TSigner>,
40 bind_addr: Option<SocketAddr>,
41 avoid_local_ports: Arc<HashSet<u16>>,
42 os_port_selection: bool,
43 provider: P,
44 max_retries: u8,
45 retry_interval_floor: Duration,
46}
47
48impl<P: RuntimeProvider> UdpClientStream<P> {
49 pub fn builder(name_server: SocketAddr, provider: P) -> UdpClientStreamBuilder<P> {
51 UdpClientStreamBuilder {
52 name_server,
53 timeout: None,
54 #[cfg(feature = "__dnssec")]
55 signer: None,
56 bind_addr: None,
57 avoid_local_ports: Arc::default(),
58 os_port_selection: false,
59 provider,
60 max_retries: 3,
61 retry_interval_floor: DEFAULT_RETRY_FLOOR,
64 }
65 }
66}
67
68impl<P> Display for UdpClientStream<P> {
69 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
70 write!(formatter, "UDP({})", self.name_server)
71 }
72}
73
74impl<P: RuntimeProvider> DnsRequestSender for UdpClientStream<P> {
75 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
76 if self.is_shutdown {
77 panic!("can not send messages after stream is shutdown")
78 }
79
80 let retry_interval_time = request.options().retry_interval;
81 let request = UdpRequest::new(request, self);
82
83 let max_retries = self.max_retries;
84 let retry_interval = if retry_interval_time < self.retry_interval_floor {
85 self.retry_interval_floor
86 } else {
87 retry_interval_time
88 };
89
90 P::Timer::timeout(
91 self.timeout,
92 retry::<P>(request, retry_interval, max_retries.into()),
93 )
94 .into()
95 }
96
97 fn shutdown(&mut self) {
98 self.is_shutdown = true;
99 }
100
101 fn is_shutdown(&self) -> bool {
102 self.is_shutdown
103 }
104}
105
106impl<P> Stream for UdpClientStream<P> {
108 type Item = Result<(), NetError>;
109
110 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 if self.is_shutdown {
113 Poll::Ready(None)
114 } else {
115 Poll::Ready(Some(Ok(())))
116 }
117 }
118}
119
120struct UdpRequest<P> {
122 avoid_local_ports: Arc<HashSet<u16>>,
123 name_server: SocketAddr,
124 request: DnsRequest,
125 provider: P,
126 #[cfg(feature = "__dnssec")]
127 signer: Option<TSigner>,
128 #[cfg(feature = "__dnssec")]
129 now: u64,
130 bind_addr: Option<SocketAddr>,
131 os_port_selection: bool,
132 case_randomization: bool,
133 recv_buf_size: usize,
134}
135
136impl<P: RuntimeProvider> UdpRequest<P> {
137 fn new(request: DnsRequest, stream: &UdpClientStream<P>) -> Self {
138 Self {
139 avoid_local_ports: stream.avoid_local_ports.clone(),
140 recv_buf_size: MAX_RECEIVE_BUFFER_SIZE.min(request.max_payload() as usize),
141 case_randomization: request.options().case_randomization,
142 name_server: stream.name_server,
143 #[cfg(feature = "__dnssec")]
145 signer: match &stream.signer {
146 Some(signer) if signer.should_sign_message(&request) => stream.signer.clone(),
147 _ => None,
148 },
149 request,
150 provider: stream.provider.clone(),
151 #[cfg(feature = "__dnssec")]
152 now: P::Timer::current_time(),
153 bind_addr: stream.bind_addr,
154 os_port_selection: stream.os_port_selection,
155 }
156 }
157}
158
159impl<P: RuntimeProvider> Request for UdpRequest<P> {
160 async fn send(&self) -> Result<DnsResponse, NetError> {
161 let original_query = self.request.original_query();
162 #[cfg_attr(not(feature = "__dnssec"), expect(unused_mut))]
163 let mut request = self.request.clone();
164
165 #[cfg(feature = "__dnssec")]
166 let mut verifier = None;
167 #[cfg(feature = "__dnssec")]
168 if let Some(signer) = &self.signer {
169 match request.finalize(signer, self.now) {
170 Ok(answer_verifier) => verifier = answer_verifier,
171 Err(e) => {
172 debug!("could not sign message: {}", e);
173 return Err(e.into());
174 }
175 }
176 }
177
178 let request_bytes = match request.to_vec() {
179 Ok(bytes) => bytes,
180 Err(err) => return Err(err.into()),
181 };
182
183 let msg_id = request.id;
184 let msg = SerialMessage::new(request_bytes, self.name_server);
185 let addr = msg.addr();
186 let final_message = match msg.to_message() {
187 Ok(m) => m,
188 Err(e) => return Err(e.into()),
189 };
190 debug!(%final_message, "final message");
191
192 let socket = NextRandomUdpSocket::new(
193 addr,
194 self.bind_addr,
195 self.avoid_local_ports.clone(),
196 self.os_port_selection,
197 self.provider.clone(),
198 )
199 .await?;
200
201 let bytes = msg.bytes();
202 let len_sent: usize = socket.send_to(bytes, addr).await?;
203
204 if bytes.len() != len_sent {
205 return Err(NetError::from(format!(
206 "Not all bytes of message sent, {} of {}",
207 len_sent,
208 bytes.len()
209 )));
210 }
211
212 trace!(
214 recv_buf_size = self.recv_buf_size,
215 "creating UDP receive buffer"
216 );
217 let mut recv_buf = vec![0; self.recv_buf_size];
218
219 for _ in 0..3 {
221 let (len, src) = socket.recv_from(&mut recv_buf).await?;
222
223 let response_bytes = &recv_buf[0..len];
225 let response_buffer = Vec::from(response_bytes);
226
227 let request_target = msg.addr();
229
230 if src.ip().to_canonical() != request_target.ip().to_canonical()
232 || src.port() != request_target.port()
233 {
234 warn!(
235 "ignoring response from {}:{} because it does not match name_server: {}:{}.",
236 src.ip().to_canonical(),
237 src.port(),
238 request_target.ip().to_canonical(),
239 request_target.port(),
240 );
241
242 continue;
244 }
245
246 let mut response = DnsResponse::from_buffer(response_buffer)?;
247
248 if msg_id != response.id {
250 warn!(
252 "expected message id: {} got: {}, dropped",
253 msg_id, response.id
254 );
255
256 continue;
258 }
259
260 let request_message = Message::from_vec(msg.bytes())?;
286 let request_queries = &request_message.queries;
287 let response_queries = &mut response.queries;
288
289 let question_matches = response_queries
290 .iter()
291 .all(|elem| request_queries.contains(elem));
292 if self.case_randomization
293 && question_matches
294 && !response_queries.iter().all(|elem| {
295 request_queries
296 .iter()
297 .any(|req_q| req_q == elem && req_q.name().eq_case(elem.name()))
298 })
299 {
300 warn!(
301 "case of question section did not match: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
302 );
303 return Err(NetError::QueryCaseMismatch);
304 }
305 if !question_matches {
306 warn!(
307 "detected forged question section: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
308 );
309 continue;
310 }
311
312 if self.case_randomization {
314 if let Some(original_query) = original_query {
315 for response_query in response_queries.iter_mut() {
316 if response_query == original_query {
317 *response_query = original_query.clone();
318 }
319 }
320 }
321 }
322
323 debug!("received message id: {}", response.id);
324 #[cfg(feature = "__dnssec")]
325 if let Some(mut verifier) = verifier {
326 return Ok(verifier.verify(response_bytes)?);
327 }
328 return Ok(response);
329 }
330
331 Err(NetError::from("udp receive attempts exceeded"))
332 }
333}
334
335pub struct UdpClientStreamBuilder<P> {
339 name_server: SocketAddr,
340 timeout: Option<Duration>,
341 #[cfg(feature = "__dnssec")]
342 signer: Option<TSigner>,
343 bind_addr: Option<SocketAddr>,
344 avoid_local_ports: Arc<HashSet<u16>>,
345 os_port_selection: bool,
346 provider: P,
347 max_retries: u8,
348 retry_interval_floor: Duration,
349}
350
351impl<P: RuntimeProvider> UdpClientStreamBuilder<P> {
352 pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
354 self.timeout = timeout;
355 self
356 }
357
358 #[cfg(feature = "__dnssec")]
360 pub fn with_signer(self, signer: Option<TSigner>) -> Self {
361 Self {
362 name_server: self.name_server,
363 timeout: self.timeout,
364 signer,
365 bind_addr: self.bind_addr,
366 avoid_local_ports: self.avoid_local_ports,
367 os_port_selection: self.os_port_selection,
368 provider: self.provider,
369 max_retries: self.max_retries,
370 retry_interval_floor: self.retry_interval_floor,
371 }
372 }
373
374 pub fn with_bind_addr(mut self, bind_addr: Option<SocketAddr>) -> Self {
379 self.bind_addr = bind_addr;
380 self
381 }
382
383 pub fn avoid_local_ports(mut self, avoid_local_ports: Arc<HashSet<u16>>) -> Self {
386 self.avoid_local_ports = avoid_local_ports;
387 self
388 }
389
390 pub fn with_os_port_selection(mut self, os_port_selection: bool) -> Self {
392 self.os_port_selection = os_port_selection;
393 self
394 }
395
396 pub fn with_max_retries(mut self, max_retries: u8) -> Self {
398 self.max_retries = max_retries;
399 self
400 }
401
402 pub fn with_retry_interval_floor(mut self, floor: u64) -> Self {
404 self.retry_interval_floor = Duration::from_millis(floor);
405 self
406 }
407
408 pub fn exchange(self) -> DnsExchange<P> {
410 let mut handle = self.provider.create_handle();
411 let stream = self.build();
412 let (exchange, bg) = DnsExchange::from_stream(stream);
413 handle.spawn_bg(bg);
414 exchange
415 }
416
417 pub fn build(self) -> UdpClientStream<P> {
421 UdpClientStream {
422 name_server: self.name_server,
423 timeout: self.timeout.unwrap_or(Duration::from_secs(5)),
424 is_shutdown: false,
425 #[cfg(feature = "__dnssec")]
426 signer: self.signer,
427 bind_addr: self.bind_addr,
428 avoid_local_ports: self.avoid_local_ports.clone(),
429 os_port_selection: self.os_port_selection,
430 provider: self.provider,
431 max_retries: self.max_retries,
432 retry_interval_floor: self.retry_interval_floor,
433 }
434 }
435}
436
437async fn retry<Provider: RuntimeProvider>(
443 request: impl Request,
444 retry_interval_time: Duration,
445 max_tasks: usize,
446) -> Result<DnsResponse, NetError> {
447 let mut futures = FuturesUnordered::new();
448
449 let retry_timer = Provider::Timer::delay_for(retry_interval_time).fuse();
450 pin_mut!(retry_timer);
451
452 futures.push(request.send());
453 let mut tasks = 1;
454
455 loop {
456 futures_util::select! {
457 result = futures.next() => {
458 match result {
459 Some(result) => return result,
460 None => return Err(NetError::from("no tasks successful")),
461 }
462 }
463 _ = &mut retry_timer => {
464 if tasks < max_tasks {
465 tasks += 1;
466 futures.push(request.send());
467 retry_timer.set(Provider::Timer::delay_for(retry_interval_time).fuse());
468 }
469 }
470 }
471 }
472}
473
474trait Request {
475 async fn send(&self) -> Result<DnsResponse, NetError>;
476}
477
478#[cfg(all(test, feature = "tokio"))]
479mod tests {
480 #![allow(clippy::dbg_macro, clippy::print_stdout)]
481
482 use core::{
483 net::{IpAddr, Ipv4Addr, Ipv6Addr},
484 sync::atomic::{AtomicU8, Ordering},
485 };
486 use std::io;
487
488 use test_support::subscribe;
489 use tokio::time::sleep;
490
491 use super::*;
492 use crate::{
493 proto::op::ResponseCode,
494 runtime::{TokioRuntimeProvider, TokioTime},
495 udp::tests::{
496 udp_client_stream_bad_id_test, udp_client_stream_response_limit_test,
497 udp_client_stream_test,
498 },
499 };
500
501 #[tokio::test]
502 async fn test_udp_client_stream_ipv4() {
503 subscribe();
504 udp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
505 }
506
507 #[tokio::test]
508 async fn test_udp_client_stream_ipv4_bad_id() {
509 subscribe();
510 udp_client_stream_bad_id_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new())
511 .await;
512 }
513
514 #[tokio::test]
515 async fn test_udp_client_stream_ipv4_resp_limit() {
516 subscribe();
517 udp_client_stream_response_limit_test(
518 IpAddr::V4(Ipv4Addr::LOCALHOST),
519 TokioRuntimeProvider::new(),
520 )
521 .await;
522 }
523
524 #[tokio::test]
525 async fn test_udp_client_stream_ipv6() {
526 subscribe();
527 udp_client_stream_test(
528 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
529 TokioRuntimeProvider::new(),
530 )
531 .await;
532 }
533
534 #[tokio::test]
535 async fn test_udp_client_stream_ipv6_bad_id() {
536 subscribe();
537 udp_client_stream_bad_id_test(
538 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
539 TokioRuntimeProvider::new(),
540 )
541 .await;
542 }
543
544 #[tokio::test]
545 async fn test_udp_client_stream_ipv6_resp_limit() {
546 subscribe();
547 udp_client_stream_response_limit_test(
548 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
549 TokioRuntimeProvider::new(),
550 )
551 .await;
552 }
553
554 #[tokio::test(start_paused = true)]
555 async fn retry_handler_test() -> Result<(), NetError> {
556 let mut message = Message::query().into_response();
557 message.metadata.response_code = ResponseCode::NoError;
558
559 let ret = retry::<TokioRuntimeProvider>(
560 FixedResponse {
561 response: DnsResponse::from_message(message.clone())?,
562 },
563 Duration::from_millis(200),
564 5,
565 )
566 .await?;
567 assert_eq!(ret.response_code, ResponseCode::NoError);
568
569 let (req, tries) = DelayedResponse::new(
571 DnsResponse::from_message(message.clone()).unwrap(),
572 Duration::from_millis(100),
573 Arc::new(AtomicU8::new(0)),
574 );
575 retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5).await?;
576 assert_eq!(tries.load(Ordering::Relaxed), 1);
577
578 let (req, tries) = DelayedResponse::new(
580 DnsResponse::from_message(message.clone()).unwrap(),
581 Duration::from_millis(1500),
582 Arc::new(AtomicU8::new(0)),
583 );
584 retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5).await?;
585 assert_eq!(tries.load(Ordering::Relaxed), 5);
586
587 let (req, tries) = DelayedResponse::new(
589 DnsResponse::from_message(message.clone()).unwrap(),
590 Duration::from_millis(1000),
591 Arc::new(AtomicU8::new(0)),
592 );
593 let timer_ret = TokioTime::timeout(
594 Duration::from_millis(500),
595 retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5),
596 )
597 .await;
598
599 if let Err(e) = timer_ret {
600 assert_eq!(e.kind(), io::ErrorKind::TimedOut);
601 } else {
602 panic!("timer did not timeout");
603 }
604
605 assert_eq!(tries.load(Ordering::Relaxed), 3);
606
607 Ok(())
608 }
609
610 struct FixedResponse {
611 response: DnsResponse,
612 }
613
614 impl Request for FixedResponse {
615 async fn send(&self) -> Result<DnsResponse, NetError> {
616 Ok(self.response.clone())
617 }
618 }
619
620 struct DelayedResponse {
621 response: DnsResponse,
622 delay: Duration,
623 counter: Arc<AtomicU8>,
624 }
625
626 impl DelayedResponse {
627 fn new(
628 response: DnsResponse,
629 delay: Duration,
630 counter: Arc<AtomicU8>,
631 ) -> (Self, Arc<AtomicU8>) {
632 (
633 Self {
634 response,
635 delay,
636 counter: counter.clone(),
637 },
638 counter,
639 )
640 }
641 }
642
643 impl Request for DelayedResponse {
644 async fn send(&self) -> Result<DnsResponse, NetError> {
645 let _ = self.counter.fetch_add(1, Ordering::Relaxed);
646 sleep(self.delay).await;
647 Ok(self.response.clone())
648 }
649 }
650}