Skip to main content

hickory_net/xfer/
dns_multiplexer.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! `DnsMultiplexer` and associated types implement the state machines for sending DNS messages while using the underlying streams.
9
10use core::{
11    pin::Pin,
12    task::{Context, Poll},
13    time::Duration,
14};
15use std::collections::{HashMap, hash_map::Entry};
16
17use futures_channel::mpsc;
18use futures_util::{
19    FutureExt,
20    future::BoxFuture,
21    stream::{Stream, StreamExt},
22};
23use rand::RngExt;
24use tracing::debug;
25
26use super::{
27    BufDnsStreamHandle, DnsClientStream, DnsRequestSender, DnsResponseStream, ignore_send,
28};
29use crate::proto::op::{DnsRequest, DnsResponse, SerialMessage};
30#[cfg(feature = "__dnssec")]
31use crate::proto::rr::{TSigVerifier, TSigner};
32use crate::{DnsStreamHandle, error::NetError, runtime::Time};
33
34struct ActiveRequest {
35    // the completion is the channel for a response to the original request
36    completion: mpsc::Sender<Result<DnsResponse, NetError>>,
37    request_id: u16,
38    timeout: BoxFuture<'static, ()>,
39    #[cfg(feature = "__dnssec")]
40    verifier: Option<TSigVerifier>,
41}
42
43impl ActiveRequest {
44    fn new(
45        completion: mpsc::Sender<Result<DnsResponse, NetError>>,
46        request_id: u16,
47        timeout: BoxFuture<'static, ()>,
48        #[cfg(feature = "__dnssec")] verifier: Option<TSigVerifier>,
49    ) -> Self {
50        Self {
51            completion,
52            request_id,
53            // request,
54            timeout,
55            #[cfg(feature = "__dnssec")]
56            verifier,
57        }
58    }
59
60    /// polls the timeout and converts the error
61    fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
62        self.timeout.poll_unpin(cx)
63    }
64
65    /// Returns true of the other side canceled the request
66    fn is_canceled(&self) -> bool {
67        self.completion.is_closed()
68    }
69
70    /// the request id of the message that was sent
71    fn request_id(&self) -> u16 {
72        self.request_id
73    }
74
75    /// Sends an error
76    fn complete_with_error(mut self, error: NetError) {
77        ignore_send(self.completion.try_send(Err(error)));
78    }
79}
80
81/// A DNS Client implemented over futures-rs.
82///
83/// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol
84///  implementations. This should be used for underlying protocols that do not natively support
85///  multiplexed sessions.
86#[must_use = "futures do nothing unless polled"]
87pub struct DnsMultiplexer<S> {
88    stream: S,
89    timeout_duration: Duration,
90    stream_handle: BufDnsStreamHandle,
91    active_requests: HashMap<u16, ActiveRequest>,
92    max_active_requests: usize,
93    #[cfg(feature = "__dnssec")]
94    signer: Option<TSigner>,
95    is_shutdown: bool,
96}
97
98impl<S: DnsClientStream> DnsMultiplexer<S> {
99    /// Spawns a new DnsMultiplexer Stream.
100    ///
101    /// This uses a default timeout of 5 seconds for all requests unless changed with
102    /// [`Self::with_timeout()`]. At most 32 in-flight requests are allowed unless
103    /// changed with [`Self::with_max_active_requests()`].
104    ///
105    /// # Arguments
106    ///
107    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
108    ///   (see TcpClientStream or UdpClientStream)
109    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
110    pub fn new(stream: S, stream_handle: BufDnsStreamHandle) -> Self {
111        Self {
112            stream,
113            timeout_duration: Duration::from_secs(5),
114            stream_handle,
115            active_requests: HashMap::default(),
116            max_active_requests: 32,
117            #[cfg(feature = "__dnssec")]
118            signer: None,
119            is_shutdown: false,
120        }
121    }
122
123    /// Change the default timeout of the DnsMultiplexer stream.
124    pub fn with_timeout(mut self, timeout: Duration) -> Self {
125        self.timeout_duration = timeout;
126        self
127    }
128
129    /// Set the maximum number of active (in-flight) requests.
130    ///
131    /// This limits how many DNS queries can be simultaneously pending on this
132    /// multiplexed connection. When the limit is reached, new requests will
133    /// return [`NetError::Busy`].
134    pub fn with_max_active_requests(mut self, max: usize) -> Self {
135        self.max_active_requests = max;
136        self
137    }
138
139    /// Specify an optional signer to TSIG authenticate requests.
140    #[cfg(feature = "__dnssec")]
141    pub fn with_signer(mut self, signer: TSigner) -> Self {
142        self.signer = Some(signer);
143        self
144    }
145
146    /// loop over active_requests and remove cancelled requests
147    ///  this should free up space if we already had 4096 active requests
148    fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
149        let mut canceled = HashMap::<u16, NetError>::new();
150        for (&id, active_req) in &mut self.active_requests {
151            if active_req.is_canceled() {
152                canceled.insert(id, NetError::from("requestor canceled"));
153            }
154
155            // check for timeouts...
156            match active_req.poll_timeout(cx) {
157                Poll::Ready(()) => {
158                    debug!("request timed out: {}", id);
159                    canceled.insert(id, NetError::Timeout);
160                }
161                Poll::Pending => (),
162            }
163        }
164
165        // drop all the canceled requests
166        for (id, error) in canceled {
167            if let Some(active_request) = self.active_requests.remove(&id) {
168                // complete the request, it's failed...
169                active_request.complete_with_error(error);
170            }
171        }
172    }
173
174    /// creates random query_id, validates against all active queries
175    fn next_random_query_id(&self) -> Result<u16, NetError> {
176        let mut rand = rand::rng();
177
178        for _ in 0..100 {
179            let id: u16 = rand.random(); // the range is [0 ... u16::max]
180
181            if !self.active_requests.contains_key(&id) {
182                return Ok(id);
183            }
184        }
185
186        Err(NetError::from(
187            "id space exhausted, consider filing an issue",
188        ))
189    }
190
191    /// Closes all outstanding completes with a closed stream error
192    fn stream_closed_close_all(&mut self, error: NetError) {
193        debug!(%error, addr = %self.stream.name_server_addr());
194        for (_, active_request) in self.active_requests.drain() {
195            // complete the request, it's failed...
196            active_request.complete_with_error(error.clone());
197        }
198    }
199}
200
201impl<S: DnsClientStream> DnsRequestSender for DnsMultiplexer<S> {
202    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
203        if self.is_shutdown {
204            panic!("can not send messages after stream is shutdown")
205        }
206
207        if self.active_requests.len() >= self.max_active_requests {
208            return NetError::Busy.into();
209        }
210
211        let query_id = match self.next_random_query_id() {
212            Ok(id) => id,
213            Err(e) => return e.into(),
214        };
215
216        let (mut request, _) = request.into_parts();
217        request.metadata.id = query_id;
218
219        #[cfg(feature = "__dnssec")]
220        let mut verifier = None;
221        #[cfg(feature = "__dnssec")]
222        if let Some(signer) = &self.signer {
223            if signer.should_sign_message(&request) {
224                match request.finalize(signer, S::Time::current_time()) {
225                    Ok(answer_verifier) => verifier = answer_verifier,
226                    Err(e) => {
227                        debug!("could not sign message: {}", e);
228                        return NetError::from(e).into();
229                    }
230                }
231            }
232        }
233
234        // store a Timeout for this message before sending
235        let timeout = S::Time::delay_for(self.timeout_duration);
236
237        let (complete, receiver) = mpsc::channel(QUERY_RESPONSE_BUFFER_SIZE);
238
239        // send the message
240        let active_request = ActiveRequest::new(
241            complete,
242            request.id,
243            timeout,
244            #[cfg(feature = "__dnssec")]
245            verifier,
246        );
247
248        match request.to_vec() {
249            Ok(buffer) => {
250                debug!(id = %active_request.request_id(), "sending message");
251                let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
252
253                debug!(
254                    "final message: {}",
255                    serial_message
256                        .to_message()
257                        .expect("bizarre we just made this message")
258                );
259
260                // add to the map -after- the client send b/c we don't want to put it in the map if
261                //  we ended up returning an error from the send.
262                match self.stream_handle.send(serial_message) {
263                    Ok(()) => self
264                        .active_requests
265                        .insert(active_request.request_id(), active_request),
266                    Err(err) => return err.into(),
267                };
268            }
269            Err(error) => {
270                debug!(
271                    id = %active_request.request_id(),
272                    %error,
273                    "error message"
274                );
275                // complete with the error, don't add to the map of active requests
276                return NetError::from(error).into();
277            }
278        }
279
280        receiver.into()
281    }
282
283    fn shutdown(&mut self) {
284        self.is_shutdown = true;
285    }
286
287    fn is_shutdown(&self) -> bool {
288        self.is_shutdown
289    }
290}
291
292impl<S: DnsClientStream> Stream for DnsMultiplexer<S> {
293    type Item = Result<(), NetError>;
294
295    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296        // Always drop the cancelled queries first
297        self.drop_cancelled(cx);
298
299        if self.is_shutdown && self.active_requests.is_empty() {
300            debug!("stream is done: {}", self.stream.name_server_addr());
301            return Poll::Ready(None);
302        }
303
304        // Collect all inbound requests, max 100 at a time for QoS
305        //   by having a max we will guarantee that the client can't be DOSed in this loop
306        // TODO: make the QoS configurable
307        let mut messages_received = 0;
308        for i in 0..QOS_MAX_RECEIVE_MSGS {
309            match self.stream.poll_next_unpin(cx) {
310                Poll::Ready(Some(Ok(buffer))) => {
311                    messages_received = i;
312
313                    //   deserialize or log decode_error
314                    match DnsResponse::from_buffer(buffer.into_parts().0) {
315                        Ok(response) => match self.active_requests.entry(response.id) {
316                            Entry::Occupied(mut request_entry) => {
317                                // send the response, complete the request...
318                                let active_request = request_entry.get_mut();
319                                #[cfg(feature = "__dnssec")]
320                                if let Some(verifier) = &mut active_request.verifier {
321                                    ignore_send(
322                                        active_request.completion.try_send(
323                                            verifier
324                                                .verify(response.as_buffer())
325                                                .map_err(NetError::from),
326                                        ),
327                                    );
328                                } else {
329                                    ignore_send(active_request.completion.try_send(Ok(response)));
330                                }
331                                #[cfg(not(feature = "__dnssec"))]
332                                ignore_send(active_request.completion.try_send(Ok(response)));
333                            }
334                            Entry::Vacant(..) => debug!("unexpected request_id: {}", response.id),
335                        },
336                        // TODO: return src address for diagnostics
337                        Err(error) => debug!(%error, "error decoding message"),
338                    }
339                }
340                Poll::Ready(err) => {
341                    let err = match err {
342                        Some(Err(e)) => e,
343                        None => NetError::from("stream closed"),
344                        _ => unreachable!(),
345                    };
346
347                    self.stream_closed_close_all(err);
348                    self.is_shutdown = true;
349                    return Poll::Ready(None);
350                }
351                Poll::Pending => break,
352            }
353        }
354
355        // If still active, then if the qos (for _ in 0..100 loop) limit
356        // was hit then "yield". This'll make sure that the future is
357        // woken up immediately on the next turn of the event loop.
358        if messages_received == QOS_MAX_RECEIVE_MSGS {
359            // FIXME: this was a task::current().notify(); is this right?
360            cx.waker().wake_by_ref();
361        }
362
363        // Finally, return not ready to keep the 'driver task' alive.
364        Poll::Pending
365    }
366}
367
368const QOS_MAX_RECEIVE_MSGS: usize = 100; // max number of messages to receive from the UDP socket
369
370/// Buffer size for per-query response channels.
371///
372/// Each outgoing DNS query gets its own channel to receive responses. Standard
373/// DNS queries receive exactly one response so a small buffer is sufficient.
374const QUERY_RESPONSE_BUFFER_SIZE: usize = 8;
375
376#[cfg(test)]
377mod test {
378    use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
379
380    use futures_util::{
381        future::{self, BoxFuture},
382        ready,
383        stream::TryStreamExt,
384    };
385    use test_support::subscribe;
386
387    use super::*;
388    use crate::proto::op::{DnsRequestOptions, Message, Query};
389    use crate::proto::rr::rdata::{NS, SOA};
390    use crate::proto::rr::{DNSClass, Name, RData, Record, RecordType};
391    use crate::proto::serialize::binary::BinEncodable;
392    use crate::xfer::{DnsClientStream, StreamReceiver};
393
394    struct MockClientStream {
395        messages: Vec<Message>,
396        addr: SocketAddr,
397        id: Option<u16>,
398        receiver: Option<StreamReceiver>,
399    }
400
401    impl MockClientStream {
402        fn new(
403            mut messages: Vec<Message>,
404            addr: SocketAddr,
405        ) -> BoxFuture<'static, Result<Self, NetError>> {
406            messages.reverse(); // so we can pop() and get messages in order
407            Box::pin(future::ok(Self {
408                messages,
409                addr,
410                id: None,
411                receiver: None,
412            }))
413        }
414    }
415
416    impl Stream for MockClientStream {
417        type Item = Result<SerialMessage, NetError>;
418
419        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
420            let id = if let Some(id) = self.id {
421                id
422            } else {
423                let serial = ready!(
424                    self.receiver
425                        .as_mut()
426                        .expect("should only be polled after receiver has been set")
427                        .poll_next_unpin(cx)
428                );
429                let message = serial.unwrap().to_message().unwrap();
430                self.id = Some(message.id);
431                message.id
432            };
433
434            if let Some(mut message) = self.messages.pop() {
435                message.metadata.id = id;
436                Poll::Ready(Some(Ok(SerialMessage::new(
437                    message.to_bytes().unwrap(),
438                    self.addr,
439                ))))
440            } else {
441                Poll::Pending
442            }
443        }
444    }
445
446    impl DnsClientStream for MockClientStream {
447        type Time = crate::runtime::TokioTime;
448
449        fn name_server_addr(&self) -> SocketAddr {
450            self.addr
451        }
452    }
453
454    async fn get_mocked_multiplexer(
455        mock_response: Vec<Message>,
456    ) -> DnsMultiplexer<MockClientStream> {
457        let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
458        let mock_response = MockClientStream::new(mock_response, addr).await.unwrap();
459        let (handler, receiver) = BufDnsStreamHandle::new(addr);
460        let mut multiplexer =
461            DnsMultiplexer::new(mock_response, handler).with_timeout(Duration::from_millis(100));
462
463        multiplexer.stream.receiver = Some(receiver); // so it can get the correct request id
464
465        multiplexer
466    }
467
468    fn a_query_answer() -> (DnsRequest, Vec<Message>) {
469        let name = Name::from_ascii("www.example.com.").unwrap();
470
471        let mut request = Message::query();
472        request.metadata.recursion_desired = true;
473        request.add_query({
474            let mut q = Query::query(name.clone(), RecordType::A);
475            q.set_query_class(DNSClass::IN);
476            q
477        });
478
479        let mut response = request.clone().into_response();
480        response.add_answer(Record::from_rdata(
481            name,
482            86400,
483            RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
484        ));
485        (
486            DnsRequest::new(request, DnsRequestOptions::default()),
487            vec![response],
488        )
489    }
490
491    fn axfr_query() -> Message {
492        let name = Name::from_ascii("example.com.").unwrap();
493
494        let mut msg = Message::query();
495        msg.metadata.recursion_desired = true;
496        msg.add_query({
497            let mut query = Query::query(name, RecordType::AXFR);
498            query.set_query_class(DNSClass::IN);
499            query
500        });
501        msg
502    }
503
504    fn axfr_response() -> Vec<Record> {
505        let origin = Name::from_ascii("example.com.").unwrap();
506        let soa = Record::from_rdata(
507            origin.clone(),
508            3600,
509            RData::SOA(SOA::new(
510                Name::parse("sns.dns.icann.org.", None).unwrap(),
511                Name::parse("noc.dns.icann.org.", None).unwrap(),
512                2015082403,
513                7200,
514                3600,
515                1209600,
516                3600,
517            )),
518        );
519
520        vec![
521            soa.clone(),
522            Record::from_rdata(
523                origin.clone(),
524                86400,
525                RData::NS(NS(Name::parse("a.iana-servers.net.", None).unwrap())),
526            ),
527            Record::from_rdata(
528                origin.clone(),
529                86400,
530                RData::NS(NS(Name::parse("b.iana-servers.net.", None).unwrap())),
531            ),
532            Record::from_rdata(
533                origin.clone(),
534                86400,
535                RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
536            ),
537            Record::from_rdata(
538                origin,
539                86400,
540                RData::AAAA(
541                    Ipv6Addr::new(
542                        0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c,
543                    )
544                    .into(),
545                ),
546            ),
547            soa,
548        ]
549    }
550
551    fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
552        let msg = axfr_query();
553
554        let mut response = msg.clone().into_response();
555        response.insert_answers(axfr_response());
556        (
557            DnsRequest::new(msg, DnsRequestOptions::default()),
558            vec![response],
559        )
560    }
561
562    fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
563        let base = axfr_query();
564
565        let query = base.clone();
566        let mut rr = axfr_response();
567        let rr2 = rr.split_off(3);
568        let mut msg1 = base.clone().into_response();
569        msg1.insert_answers(rr);
570        let mut msg2 = base.into_response();
571        msg2.insert_answers(rr2);
572        (
573            DnsRequest::new(query, DnsRequestOptions::default()),
574            vec![msg1, msg2],
575        )
576    }
577
578    #[tokio::test]
579    async fn test_multiplexer_a() {
580        subscribe();
581        let (query, answer) = a_query_answer();
582        let mut multiplexer = get_mocked_multiplexer(answer).await;
583        let response = multiplexer.send_message(query);
584        let response = tokio::select! {
585            _ = multiplexer.next() => {
586                // polling multiplexer to make it run
587                panic!("should never end")
588            },
589            r = response.try_collect::<Vec<_>>() => r.unwrap(),
590        };
591        assert_eq!(response.len(), 1);
592    }
593
594    #[tokio::test]
595    async fn test_multiplexer_axfr() {
596        subscribe();
597        let (query, answer) = axfr_query_answer();
598        let mut multiplexer = get_mocked_multiplexer(answer).await;
599        let response = multiplexer.send_message(query);
600        let response = tokio::select! {
601            _ = multiplexer.next() => {
602                // polling multiplexer to make it run
603                panic!("should never end")
604            },
605            r = response.try_collect::<Vec<_>>() => r.unwrap(),
606        };
607        assert_eq!(response.len(), 1);
608        assert_eq!(response[0].answers.len(), axfr_response().len());
609    }
610
611    #[tokio::test]
612    async fn test_multiplexer_axfr_multi() {
613        subscribe();
614        let (query, answer) = axfr_query_answer_multi();
615        let mut multiplexer = get_mocked_multiplexer(answer).await;
616        let response = multiplexer.send_message(query);
617        let response = tokio::select! {
618            _ = multiplexer.next() => {
619                // polling multiplexer to make it run
620                panic!("should never end")
621            },
622            r = response.try_collect::<Vec<_>>() => r.unwrap(),
623        };
624        assert_eq!(response.len(), 2);
625        assert_eq!(
626            response.iter().map(|m| m.answers.len()).sum::<usize>(),
627            axfr_response().len()
628        );
629    }
630}