Skip to main content

hickory_net/udp/
udp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use core::fmt::{self, Display};
9use core::net::SocketAddr;
10use core::pin::Pin;
11use core::task::{Context, Poll};
12use core::time::Duration;
13use std::collections::HashSet;
14use std::sync::Arc;
15
16use futures_util::{FutureExt, Stream, StreamExt, pin_mut, stream::FuturesUnordered};
17use tracing::{debug, trace, warn};
18
19use crate::error::NetError;
20use crate::proto::op::{DEFAULT_RETRY_FLOOR, DnsRequest, DnsResponse, Message, SerialMessage};
21#[cfg(feature = "__dnssec")]
22use crate::proto::rr::TSigner;
23use crate::runtime::{DnsUdpSocket, RuntimeProvider, Spawn, Time};
24use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
25use crate::udp::udp_stream::NextRandomUdpSocket;
26use crate::xfer::{DnsExchange, DnsRequestSender, DnsResponseStream};
27
28/// A UDP client stream of DNS binary packets.
29///
30/// It is expected that the resolver wrapper will be responsible for creating and managing a new UDP
31/// client stream such that each request would have a random port. This is to avoid potential cache
32/// poisoning due to UDP spoofing attacks.
33#[must_use = "futures do nothing unless polled"]
34pub struct UdpClientStream<P> {
35    name_server: SocketAddr,
36    timeout: Duration,
37    is_shutdown: bool,
38    #[cfg(feature = "__dnssec")]
39    signer: Option<TSigner>,
40    bind_addr: Option<SocketAddr>,
41    avoid_local_ports: Arc<HashSet<u16>>,
42    os_port_selection: bool,
43    provider: P,
44    max_retries: u8,
45    retry_interval_floor: Duration,
46}
47
48impl<P: RuntimeProvider> UdpClientStream<P> {
49    /// Construct a new [`UdpClientStream`] via a [`UdpClientStreamBuilder`].
50    pub fn builder(name_server: SocketAddr, provider: P) -> UdpClientStreamBuilder<P> {
51        UdpClientStreamBuilder {
52            name_server,
53            timeout: None,
54            #[cfg(feature = "__dnssec")]
55            signer: None,
56            bind_addr: None,
57            avoid_local_ports: Arc::default(),
58            os_port_selection: false,
59            provider,
60            max_retries: 3,
61            // This is the default value to use for the retry interval floor, which acts as a lower
62            // bound on the retry interval.
63            retry_interval_floor: DEFAULT_RETRY_FLOOR,
64        }
65    }
66}
67
68impl<P> Display for UdpClientStream<P> {
69    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
70        write!(formatter, "UDP({})", self.name_server)
71    }
72}
73
74impl<P: RuntimeProvider> DnsRequestSender for UdpClientStream<P> {
75    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
76        if self.is_shutdown {
77            panic!("can not send messages after stream is shutdown")
78        }
79
80        let retry_interval_time = request.options().retry_interval;
81        let request = UdpRequest::new(request, self);
82
83        let max_retries = self.max_retries;
84        let retry_interval = if retry_interval_time < self.retry_interval_floor {
85            self.retry_interval_floor
86        } else {
87            retry_interval_time
88        };
89
90        P::Timer::timeout(
91            self.timeout,
92            retry::<P>(request, retry_interval, max_retries.into()),
93        )
94        .into()
95    }
96
97    fn shutdown(&mut self) {
98        self.is_shutdown = true;
99    }
100
101    fn is_shutdown(&self) -> bool {
102        self.is_shutdown
103    }
104}
105
106// TODO: is this impl necessary? there's nothing being driven here...
107impl<P> Stream for UdpClientStream<P> {
108    type Item = Result<(), NetError>;
109
110    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111        // Technically the Stream doesn't actually do anything.
112        if self.is_shutdown {
113            Poll::Ready(None)
114        } else {
115            Poll::Ready(Some(Ok(())))
116        }
117    }
118}
119
120/// Request context for data send_udp_message needs via the retry handler closure
121struct UdpRequest<P> {
122    avoid_local_ports: Arc<HashSet<u16>>,
123    name_server: SocketAddr,
124    request: DnsRequest,
125    provider: P,
126    #[cfg(feature = "__dnssec")]
127    signer: Option<TSigner>,
128    #[cfg(feature = "__dnssec")]
129    now: u64,
130    bind_addr: Option<SocketAddr>,
131    os_port_selection: bool,
132    case_randomization: bool,
133    recv_buf_size: usize,
134}
135
136impl<P: RuntimeProvider> UdpRequest<P> {
137    fn new(request: DnsRequest, stream: &UdpClientStream<P>) -> Self {
138        Self {
139            avoid_local_ports: stream.avoid_local_ports.clone(),
140            recv_buf_size: MAX_RECEIVE_BUFFER_SIZE.min(request.max_payload() as usize),
141            case_randomization: request.options().case_randomization,
142            name_server: stream.name_server,
143            // Only smuggle in the signer if we are going to use it.
144            #[cfg(feature = "__dnssec")]
145            signer: match &stream.signer {
146                Some(signer) if signer.should_sign_message(&request) => stream.signer.clone(),
147                _ => None,
148            },
149            request,
150            provider: stream.provider.clone(),
151            #[cfg(feature = "__dnssec")]
152            now: P::Timer::current_time(),
153            bind_addr: stream.bind_addr,
154            os_port_selection: stream.os_port_selection,
155        }
156    }
157}
158
159impl<P: RuntimeProvider> Request for UdpRequest<P> {
160    async fn send(&self) -> Result<DnsResponse, NetError> {
161        let original_query = self.request.original_query();
162        #[cfg_attr(not(feature = "__dnssec"), expect(unused_mut))]
163        let mut request = self.request.clone();
164
165        #[cfg(feature = "__dnssec")]
166        let mut verifier = None;
167        #[cfg(feature = "__dnssec")]
168        if let Some(signer) = &self.signer {
169            match request.finalize(signer, self.now) {
170                Ok(answer_verifier) => verifier = answer_verifier,
171                Err(e) => {
172                    debug!("could not sign message: {}", e);
173                    return Err(e.into());
174                }
175            }
176        }
177
178        let request_bytes = match request.to_vec() {
179            Ok(bytes) => bytes,
180            Err(err) => return Err(err.into()),
181        };
182
183        let msg_id = request.id;
184        let msg = SerialMessage::new(request_bytes, self.name_server);
185        let addr = msg.addr();
186        let final_message = match msg.to_message() {
187            Ok(m) => m,
188            Err(e) => return Err(e.into()),
189        };
190        debug!(%final_message, "final message");
191
192        let socket = NextRandomUdpSocket::new(
193            addr,
194            self.bind_addr,
195            self.avoid_local_ports.clone(),
196            self.os_port_selection,
197            self.provider.clone(),
198        )
199        .await?;
200
201        let bytes = msg.bytes();
202        let len_sent: usize = socket.send_to(bytes, addr).await?;
203
204        if bytes.len() != len_sent {
205            return Err(NetError::from(format!(
206                "Not all bytes of message sent, {} of {}",
207                len_sent,
208                bytes.len()
209            )));
210        }
211
212        // Create the receive buffer.
213        trace!(
214            recv_buf_size = self.recv_buf_size,
215            "creating UDP receive buffer"
216        );
217        let mut recv_buf = vec![0; self.recv_buf_size];
218
219        // Try to process up to 3 responses
220        for _ in 0..3 {
221            let (len, src) = socket.recv_from(&mut recv_buf).await?;
222
223            // Copy the slice of read bytes.
224            let response_bytes = &recv_buf[0..len];
225            let response_buffer = Vec::from(response_bytes);
226
227            // compare expected src to received packet
228            let request_target = msg.addr();
229
230            // Comparing the IP and Port directly as internal information about the link is stored with the IpAddr, see https://github.com/hickory-dns/hickory-dns/issues/2081
231            if src.ip().to_canonical() != request_target.ip().to_canonical()
232                || src.port() != request_target.port()
233            {
234                warn!(
235                    "ignoring response from {}:{} because it does not match name_server: {}:{}.",
236                    src.ip().to_canonical(),
237                    src.port(),
238                    request_target.ip().to_canonical(),
239                    request_target.port(),
240                );
241
242                // await an answer from the correct NameServer
243                continue;
244            }
245
246            let mut response = DnsResponse::from_buffer(response_buffer)?;
247
248            // Validate the message id in the response matches the value chosen for the query.
249            if msg_id != response.id {
250                // on wrong id, attempted poison?
251                warn!(
252                    "expected message id: {} got: {}, dropped",
253                    msg_id, response.id
254                );
255
256                // await an answer with the correct message id
257                continue;
258            }
259
260            // Validate the returned query name.
261            //
262            // This currently checks that each response query name was present in the original query, but not that
263            // every original question is present.
264            //
265            // References:
266            //
267            // RFC 1035 7.3:
268            //
269            // The next step is to match the response to a current resolver request.
270            // The recommended strategy is to do a preliminary matching using the ID
271            // field in the domain header, and then to verify that the question section
272            // corresponds to the information currently desired.
273            //
274            // RFC 1035 7.4:
275            //
276            // In general, we expect a resolver to cache all data which it receives in
277            // responses since it may be useful in answering future client requests.
278            // However, there are several types of data which should not be cached:
279            //
280            // ...
281            //
282            //  - RR data in responses of dubious reliability.  When a resolver
283            // receives unsolicited responses or RR data other than that
284            // requested, it should discard it without caching it.
285            let request_message = Message::from_vec(msg.bytes())?;
286            let request_queries = &request_message.queries;
287            let response_queries = &mut response.queries;
288
289            let question_matches = response_queries
290                .iter()
291                .all(|elem| request_queries.contains(elem));
292            if self.case_randomization
293                && question_matches
294                && !response_queries.iter().all(|elem| {
295                    request_queries
296                        .iter()
297                        .any(|req_q| req_q == elem && req_q.name().eq_case(elem.name()))
298                })
299            {
300                warn!(
301                    "case of question section did not match: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
302                );
303                return Err(NetError::QueryCaseMismatch);
304            }
305            if !question_matches {
306                warn!(
307                    "detected forged question section: we expected '{request_queries:?}', but received '{response_queries:?}' from server {src}"
308                );
309                continue;
310            }
311
312            // overwrite the query with the original query if case randomization may have been used
313            if self.case_randomization {
314                if let Some(original_query) = original_query {
315                    for response_query in response_queries.iter_mut() {
316                        if response_query == original_query {
317                            *response_query = original_query.clone();
318                        }
319                    }
320                }
321            }
322
323            debug!("received message id: {}", response.id);
324            #[cfg(feature = "__dnssec")]
325            if let Some(mut verifier) = verifier {
326                return Ok(verifier.verify(response_bytes)?);
327            }
328            return Ok(response);
329        }
330
331        Err(NetError::from("udp receive attempts exceeded"))
332    }
333}
334
335/// A builder to create a UDP client stream.
336///
337/// This is created by [`UdpClientStream::builder`].
338pub struct UdpClientStreamBuilder<P> {
339    name_server: SocketAddr,
340    timeout: Option<Duration>,
341    #[cfg(feature = "__dnssec")]
342    signer: Option<TSigner>,
343    bind_addr: Option<SocketAddr>,
344    avoid_local_ports: Arc<HashSet<u16>>,
345    os_port_selection: bool,
346    provider: P,
347    max_retries: u8,
348    retry_interval_floor: Duration,
349}
350
351impl<P: RuntimeProvider> UdpClientStreamBuilder<P> {
352    /// Sets the connection timeout.
353    pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
354        self.timeout = timeout;
355        self
356    }
357
358    /// Sets the message finalizer to be applied to queries.
359    #[cfg(feature = "__dnssec")]
360    pub fn with_signer(self, signer: Option<TSigner>) -> Self {
361        Self {
362            name_server: self.name_server,
363            timeout: self.timeout,
364            signer,
365            bind_addr: self.bind_addr,
366            avoid_local_ports: self.avoid_local_ports,
367            os_port_selection: self.os_port_selection,
368            provider: self.provider,
369            max_retries: self.max_retries,
370            retry_interval_floor: self.retry_interval_floor,
371        }
372    }
373
374    /// Sets the local socket address to connect from.
375    ///
376    /// If the port number is 0, a random port number will be chosen to defend against spoofing
377    /// attacks. If the port number is nonzero, it will be used instead.
378    pub fn with_bind_addr(mut self, bind_addr: Option<SocketAddr>) -> Self {
379        self.bind_addr = bind_addr;
380        self
381    }
382
383    /// Configures a list of local UDP ports that should not be used when making outgoing
384    /// connections.
385    pub fn avoid_local_ports(mut self, avoid_local_ports: Arc<HashSet<u16>>) -> Self {
386        self.avoid_local_ports = avoid_local_ports;
387        self
388    }
389
390    /// Configures that OS should provide the ephemeral port, not the Hickory DNS
391    pub fn with_os_port_selection(mut self, os_port_selection: bool) -> Self {
392        self.os_port_selection = os_port_selection;
393        self
394    }
395
396    /// Sets the maximum number of retries for a single request
397    pub fn with_max_retries(mut self, max_retries: u8) -> Self {
398        self.max_retries = max_retries;
399        self
400    }
401
402    /// Sets the retry interval floor
403    pub fn with_retry_interval_floor(mut self, floor: u64) -> Self {
404        self.retry_interval_floor = Duration::from_millis(floor);
405        self
406    }
407
408    /// Wrap a [`DnsExchange`] around the built [`UdpClientStream`]
409    pub fn exchange(self) -> DnsExchange<P> {
410        let mut handle = self.provider.create_handle();
411        let stream = self.build();
412        let (exchange, bg) = DnsExchange::from_stream(stream);
413        handle.spawn_bg(bg);
414        exchange
415    }
416
417    /// Construct a new UDP client stream.
418    ///
419    /// Returns a future that outputs the client stream.
420    pub fn build(self) -> UdpClientStream<P> {
421        UdpClientStream {
422            name_server: self.name_server,
423            timeout: self.timeout.unwrap_or(Duration::from_secs(5)),
424            is_shutdown: false,
425            #[cfg(feature = "__dnssec")]
426            signer: self.signer,
427            bind_addr: self.bind_addr,
428            avoid_local_ports: self.avoid_local_ports.clone(),
429            os_port_selection: self.os_port_selection,
430            provider: self.provider,
431            max_retries: self.max_retries,
432            retry_interval_floor: self.retry_interval_floor,
433        }
434    }
435}
436
437/// This implements a retry handler for tasks that might not complete successfully (e.g.,
438/// DNS requests made via UDP.) It starts a task future immediately, then every
439/// retry_interval_time period up to a maximum of max_tasks. It will immediately return
440/// the first task that completes successfully, or an error if no tasks succeed.
441/// It does not implement an overall timeout to bound the work.
442async fn retry<Provider: RuntimeProvider>(
443    request: impl Request,
444    retry_interval_time: Duration,
445    max_tasks: usize,
446) -> Result<DnsResponse, NetError> {
447    let mut futures = FuturesUnordered::new();
448
449    let retry_timer = Provider::Timer::delay_for(retry_interval_time).fuse();
450    pin_mut!(retry_timer);
451
452    futures.push(request.send());
453    let mut tasks = 1;
454
455    loop {
456        futures_util::select! {
457            result = futures.next() => {
458                match result {
459                    Some(result) => return result,
460                    None => return Err(NetError::from("no tasks successful")),
461                }
462            }
463            _ = &mut retry_timer => {
464                if tasks < max_tasks {
465                    tasks += 1;
466                    futures.push(request.send());
467                    retry_timer.set(Provider::Timer::delay_for(retry_interval_time).fuse());
468                }
469            }
470        }
471    }
472}
473
474trait Request {
475    async fn send(&self) -> Result<DnsResponse, NetError>;
476}
477
478#[cfg(all(test, feature = "tokio"))]
479mod tests {
480    #![allow(clippy::dbg_macro, clippy::print_stdout)]
481
482    use core::{
483        net::{IpAddr, Ipv4Addr, Ipv6Addr},
484        sync::atomic::{AtomicU8, Ordering},
485    };
486    use std::io;
487
488    use test_support::subscribe;
489    use tokio::time::sleep;
490
491    use super::*;
492    use crate::{
493        proto::op::ResponseCode,
494        runtime::{TokioRuntimeProvider, TokioTime},
495        udp::tests::{
496            udp_client_stream_bad_id_test, udp_client_stream_response_limit_test,
497            udp_client_stream_test,
498        },
499    };
500
501    #[tokio::test]
502    async fn test_udp_client_stream_ipv4() {
503        subscribe();
504        udp_client_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new()).await;
505    }
506
507    #[tokio::test]
508    async fn test_udp_client_stream_ipv4_bad_id() {
509        subscribe();
510        udp_client_stream_bad_id_test(IpAddr::V4(Ipv4Addr::LOCALHOST), TokioRuntimeProvider::new())
511            .await;
512    }
513
514    #[tokio::test]
515    async fn test_udp_client_stream_ipv4_resp_limit() {
516        subscribe();
517        udp_client_stream_response_limit_test(
518            IpAddr::V4(Ipv4Addr::LOCALHOST),
519            TokioRuntimeProvider::new(),
520        )
521        .await;
522    }
523
524    #[tokio::test]
525    async fn test_udp_client_stream_ipv6() {
526        subscribe();
527        udp_client_stream_test(
528            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
529            TokioRuntimeProvider::new(),
530        )
531        .await;
532    }
533
534    #[tokio::test]
535    async fn test_udp_client_stream_ipv6_bad_id() {
536        subscribe();
537        udp_client_stream_bad_id_test(
538            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
539            TokioRuntimeProvider::new(),
540        )
541        .await;
542    }
543
544    #[tokio::test]
545    async fn test_udp_client_stream_ipv6_resp_limit() {
546        subscribe();
547        udp_client_stream_response_limit_test(
548            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
549            TokioRuntimeProvider::new(),
550        )
551        .await;
552    }
553
554    #[tokio::test(start_paused = true)]
555    async fn retry_handler_test() -> Result<(), NetError> {
556        let mut message = Message::query().into_response();
557        message.metadata.response_code = ResponseCode::NoError;
558
559        let ret = retry::<TokioRuntimeProvider>(
560            FixedResponse {
561                response: DnsResponse::from_message(message.clone())?,
562            },
563            Duration::from_millis(200),
564            5,
565        )
566        .await?;
567        assert_eq!(ret.response_code, ResponseCode::NoError);
568
569        // test: retry timer doesn't fire extra tasks before the retry interval
570        let (req, tries) = DelayedResponse::new(
571            DnsResponse::from_message(message.clone()).unwrap(),
572            Duration::from_millis(100),
573            Arc::new(AtomicU8::new(0)),
574        );
575        retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5).await?;
576        assert_eq!(tries.load(Ordering::Relaxed), 1);
577
578        // test: retry timer does fire extra tasks after the retry interval
579        let (req, tries) = DelayedResponse::new(
580            DnsResponse::from_message(message.clone()).unwrap(),
581            Duration::from_millis(1500),
582            Arc::new(AtomicU8::new(0)),
583        );
584        retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5).await?;
585        assert_eq!(tries.load(Ordering::Relaxed), 5);
586
587        // test: retry timer tasks when nested under a Time::timer
588        let (req, tries) = DelayedResponse::new(
589            DnsResponse::from_message(message.clone()).unwrap(),
590            Duration::from_millis(1000),
591            Arc::new(AtomicU8::new(0)),
592        );
593        let timer_ret = TokioTime::timeout(
594            Duration::from_millis(500),
595            retry::<TokioRuntimeProvider>(req, Duration::from_millis(200), 5),
596        )
597        .await;
598
599        if let Err(e) = timer_ret {
600            assert_eq!(e.kind(), io::ErrorKind::TimedOut);
601        } else {
602            panic!("timer did not timeout");
603        }
604
605        assert_eq!(tries.load(Ordering::Relaxed), 3);
606
607        Ok(())
608    }
609
610    struct FixedResponse {
611        response: DnsResponse,
612    }
613
614    impl Request for FixedResponse {
615        async fn send(&self) -> Result<DnsResponse, NetError> {
616            Ok(self.response.clone())
617        }
618    }
619
620    struct DelayedResponse {
621        response: DnsResponse,
622        delay: Duration,
623        counter: Arc<AtomicU8>,
624    }
625
626    impl DelayedResponse {
627        fn new(
628            response: DnsResponse,
629            delay: Duration,
630            counter: Arc<AtomicU8>,
631        ) -> (Self, Arc<AtomicU8>) {
632            (
633                Self {
634                    response,
635                    delay,
636                    counter: counter.clone(),
637                },
638                counter,
639            )
640        }
641    }
642
643    impl Request for DelayedResponse {
644        async fn send(&self) -> Result<DnsResponse, NetError> {
645            let _ = self.counter.fetch_add(1, Ordering::Relaxed);
646            sleep(self.delay).await;
647            Ok(self.response.clone())
648        }
649    }
650}