Skip to main content

hickory_net/xfer/
mod.rs

1//! DNS high level transit implementations.
2//!
3//! Primarily there are two types in this module of interest, the `DnsMultiplexer` type and the `DnsHandle` type. `DnsMultiplexer` can be thought of as the state machine responsible for sending and receiving DNS messages. `DnsHandle` is the type given to API users of the `hickory-proto` library to send messages into the `DnsMultiplexer` for delivery. Finally there is the `DnsRequest` type. This allows for customizations, through `DnsRequestOptions`, to the delivery of messages via a `DnsMultiplexer`.
4//!
5//! TODO: this module needs some serious refactoring and normalization.
6
7use core::fmt::Display;
8use core::fmt::{self, Debug};
9use core::future::Future;
10use core::marker::PhantomData;
11use core::net::SocketAddr;
12use core::pin::Pin;
13use core::task::{Context, Poll};
14use core::time::Duration;
15use std::io;
16
17use futures_channel::mpsc;
18use futures_channel::oneshot;
19use futures_util::future::BoxFuture;
20use futures_util::ready;
21use futures_util::stream::{Fuse, Peekable};
22use futures_util::stream::{Stream, StreamExt};
23#[cfg(feature = "serde")]
24use serde::{Deserialize, Serialize};
25use tracing::{debug, warn};
26
27use crate::error::NetError;
28use crate::proto::ProtoError;
29use crate::proto::op::{DnsRequest, DnsResponse, SerialMessage};
30use crate::runtime::{RuntimeProvider, Time};
31
32mod dns_exchange;
33pub use dns_exchange::{DnsExchange, DnsExchangeBackground, DnsExchangeSend};
34
35pub mod dns_handle;
36pub use dns_handle::{DnsHandle, DnsStreamHandle};
37
38pub mod dns_multiplexer;
39pub use dns_multiplexer::DnsMultiplexer;
40
41pub mod retry_dns_handle;
42pub use retry_dns_handle::RetryDnsHandle;
43
44/// A stream returning DNS responses
45pub struct DnsResponseStream {
46    inner: DnsResponseStreamInner,
47    done: bool,
48}
49
50impl DnsResponseStream {
51    fn new(inner: DnsResponseStreamInner) -> Self {
52        Self { inner, done: false }
53    }
54}
55
56impl Stream for DnsResponseStream {
57    type Item = Result<DnsResponse, NetError>;
58
59    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60        use DnsResponseStreamInner::*;
61
62        // if the standard futures are done, don't poll again
63        if self.done {
64            return Poll::Ready(None);
65        }
66
67        // split mutable refs to Self
68        let Self { inner, done } = self.get_mut();
69
70        let result = match inner {
71            Timeout(fut) => {
72                let x = match ready!(fut.as_mut().poll(cx)) {
73                    Ok(x) => x,
74                    Err(e) => Err(e.into()),
75                };
76                *done = true;
77                x
78            }
79            Receiver(fut) => match ready!(Pin::new(fut).poll_next(cx)) {
80                Some(Ok(x)) => Ok(x),
81                Some(Err(e)) => Err(e),
82                None => return Poll::Ready(None),
83            },
84            Error(err) => {
85                *done = true;
86                Err(err.take().expect("cannot poll after complete"))
87            }
88            Boxed(fut) => {
89                let x = ready!(fut.as_mut().poll(cx));
90                *done = true;
91                x
92            }
93        };
94
95        match result {
96            Err(NetError::Timeout) => Poll::Ready(None),
97            r => Poll::Ready(Some(r)),
98        }
99    }
100}
101
102impl From<TimeoutFuture> for DnsResponseStream {
103    fn from(f: TimeoutFuture) -> Self {
104        Self::new(DnsResponseStreamInner::Timeout(f))
105    }
106}
107
108impl From<mpsc::Receiver<Result<DnsResponse, NetError>>> for DnsResponseStream {
109    fn from(receiver: mpsc::Receiver<Result<DnsResponse, NetError>>) -> Self {
110        Self::new(DnsResponseStreamInner::Receiver(receiver))
111    }
112}
113
114impl From<NetError> for DnsResponseStream {
115    fn from(e: NetError) -> Self {
116        Self::new(DnsResponseStreamInner::Error(Some(e)))
117    }
118}
119
120impl<F> From<Pin<Box<F>>> for DnsResponseStream
121where
122    F: Future<Output = Result<DnsResponse, NetError>> + Send + 'static,
123{
124    fn from(f: Pin<Box<F>>) -> Self {
125        Self::new(DnsResponseStreamInner::Boxed(f))
126    }
127}
128
129enum DnsResponseStreamInner {
130    Timeout(TimeoutFuture),
131    Receiver(mpsc::Receiver<Result<DnsResponse, NetError>>),
132    Error(Option<NetError>),
133    Boxed(BoxFuture<'static, Result<DnsResponse, NetError>>),
134}
135
136type TimeoutFuture = BoxFuture<'static, Result<Result<DnsResponse, NetError>, io::Error>>;
137
138/// Ignores the result of a send operation and logs and ignores errors
139fn ignore_send<M, T>(result: Result<M, mpsc::TrySendError<T>>) {
140    if let Err(error) = result {
141        if error.is_disconnected() {
142            debug!("ignoring send error on disconnected stream");
143            return;
144        }
145
146        warn!("error notifying wait, possible future leak: {:?}", error);
147    }
148}
149
150/// A non-multiplexed stream of Serialized DNS messages
151pub trait DnsClientStream:
152    Stream<Item = Result<SerialMessage, NetError>> + Unpin + Send + 'static
153{
154    /// Time implementation for this impl
155    type Time: Time;
156
157    /// The remote name server address
158    fn name_server_addr(&self) -> SocketAddr;
159}
160
161/// Receiver handle for peekable fused SerialMessage channel
162pub type StreamReceiver = Peekable<Fuse<mpsc::Receiver<SerialMessage>>>;
163
164/// A buffered channel for outbound DNS messages on a connection.
165///
166/// Used to queue messages for sending over a DNS connection. On the client/resolver side,
167/// this buffers outbound queries to nameservers. On the server side, this buffers outbound
168/// responses to clients.
169#[derive(Clone)]
170pub struct BufDnsStreamHandle {
171    remote_addr: SocketAddr,
172    sender: mpsc::Sender<SerialMessage>,
173}
174
175impl BufDnsStreamHandle {
176    /// Constructs a new buffered stream handle with the default buffer size (32).
177    ///
178    /// # Arguments
179    ///
180    /// * `remote_addr` - address of the DNS peer (nameserver for clients, client for servers)
181    pub fn new(remote_addr: SocketAddr) -> (Self, StreamReceiver) {
182        Self::with_buffer_size(remote_addr, DEFAULT_STREAM_BUFFER_SIZE)
183    }
184
185    /// Constructs a new buffered stream handle with an explicit buffer size.
186    ///
187    /// Use this when you need a larger buffer to handle high message rates without
188    /// dropping messages due to backpressure.
189    ///
190    /// # Arguments
191    ///
192    /// * `remote_addr` - address of the DNS peer (nameserver for clients, client for servers)
193    /// * `buffer_size` - maximum number of messages that can be queued for sending
194    pub fn with_buffer_size(remote_addr: SocketAddr, buffer_size: usize) -> (Self, StreamReceiver) {
195        let (sender, receiver) = mpsc::channel(buffer_size);
196        let receiver = receiver.fuse().peekable();
197
198        let this = Self {
199            remote_addr,
200            sender,
201        };
202
203        (this, receiver)
204    }
205
206    /// Associates a different remote address for any responses.
207    ///
208    /// This is mainly useful in server use cases where the incoming address is only known after receiving a packet.
209    pub fn with_remote_addr(&self, remote_addr: SocketAddr) -> Self {
210        Self {
211            remote_addr,
212            sender: self.sender.clone(),
213        }
214    }
215}
216
217impl DnsStreamHandle for BufDnsStreamHandle {
218    fn send(&mut self, buffer: SerialMessage) -> Result<(), NetError> {
219        let sender: &mut _ = &mut self.sender;
220        sender
221            .try_send(SerialMessage::new(buffer.into_parts().0, self.remote_addr))
222            .map_err(|e| NetError::from(format!("mpsc::SendError {e}")))
223    }
224}
225
226/// Types that implement this are capable of sending a serialized DNS message on a stream
227///
228/// The underlying Stream implementation should yield `Some(())` whenever it is ready to send a message,
229///   NotReady, if it is not ready to send a message, and `Err` or `None` in the case that the stream is
230///   done, and should be shutdown.
231pub trait DnsRequestSender: Stream<Item = Result<(), NetError>> + Send + Unpin + 'static {
232    /// Send a message, and return a stream of response
233    ///
234    /// # Return
235    ///
236    /// A stream which will resolve to SerialMessage responses
237    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream;
238
239    /// Allows the upstream user to inform the underling stream that it should shutdown.
240    ///
241    /// After this is called, the next time `poll` is called on the stream it would be correct to return `Poll::Ready(Ok(()))`. This is not required though, if there are say outstanding requests that are not yet complete, then it would be correct to first wait for those results.
242    fn shutdown(&mut self);
243
244    /// Returns true if the stream has been shutdown with `shutdown`
245    fn is_shutdown(&self) -> bool;
246}
247
248/// Used for associating a name_server to a DnsRequestStreamHandle
249#[derive(Clone)]
250
251pub struct BufDnsRequestStreamHandle<P> {
252    sender: mpsc::Sender<OneshotDnsRequest>,
253    _phantom: PhantomData<P>,
254}
255
256impl<P: RuntimeProvider> DnsHandle for BufDnsRequestStreamHandle<P> {
257    type Response = DnsResponseReceiver;
258    type Runtime = P;
259
260    fn send(&self, request: DnsRequest) -> Self::Response {
261        debug!(
262            "enqueueing message:{}:{:?}",
263            request.op_code, request.queries
264        );
265
266        let (request, oneshot) = OneshotDnsRequest::oneshot(request);
267        let mut sender = self.sender.clone();
268        let try_send = sender.try_send(request).map_err(|_| {
269            debug!("unable to enqueue message");
270            NetError::Busy
271        });
272
273        match try_send {
274            Ok(val) => val,
275            Err(err) => return DnsResponseReceiver::Err(Some(err)),
276        }
277
278        DnsResponseReceiver::Receiver(oneshot)
279    }
280}
281
282// TODO: this future should return the origin message in the response on errors
283/// A OneshotDnsRequest creates a channel for a response to message
284pub struct OneshotDnsRequest {
285    dns_request: DnsRequest,
286    sender_for_response: oneshot::Sender<DnsResponseStream>,
287}
288
289impl OneshotDnsRequest {
290    fn oneshot(dns_request: DnsRequest) -> (Self, oneshot::Receiver<DnsResponseStream>) {
291        let (sender_for_response, receiver) = oneshot::channel();
292
293        (
294            Self {
295                dns_request,
296                sender_for_response,
297            },
298            receiver,
299        )
300    }
301
302    fn into_parts(self) -> (DnsRequest, OneshotDnsResponse) {
303        (
304            self.dns_request,
305            OneshotDnsResponse(self.sender_for_response),
306        )
307    }
308}
309
310struct OneshotDnsResponse(oneshot::Sender<DnsResponseStream>);
311
312impl OneshotDnsResponse {
313    fn send_response(self, serial_response: DnsResponseStream) -> Result<(), DnsResponseStream> {
314        self.0.send(serial_response)
315    }
316}
317
318/// A Stream that wraps a [`oneshot::Receiver<Stream>`] and resolves to items in the inner Stream
319pub enum DnsResponseReceiver {
320    /// The receiver
321    Receiver(oneshot::Receiver<DnsResponseStream>),
322    /// The stream once received
323    Received(DnsResponseStream),
324    /// Error during the send operation
325    Err(Option<NetError>),
326}
327
328impl Stream for DnsResponseReceiver {
329    type Item = Result<DnsResponse, NetError>;
330
331    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
332        loop {
333            *self = match &mut *self {
334                Self::Receiver(receiver) => {
335                    let receiver = Pin::new(receiver);
336                    let future = ready!(
337                        receiver
338                            .poll(cx)
339                            .map_err(|_| NetError::from("receiver was canceled"))
340                    )?;
341                    Self::Received(future)
342                }
343                Self::Received(stream) => {
344                    return stream.poll_next_unpin(cx);
345                }
346                Self::Err(err) => return Poll::Ready(err.take().map(Err)),
347            };
348        }
349    }
350}
351
352/// Helper trait to convert a Stream of dns response into a Future
353pub trait FirstAnswer<T, E: From<ProtoError>>: Stream<Item = Result<T, E>> + Unpin + Sized {
354    /// Convert a Stream of dns response into a Future yielding the first answer,
355    /// discarding others if any.
356    fn first_answer(self) -> FirstAnswerFuture<Self> {
357        FirstAnswerFuture { stream: Some(self) }
358    }
359}
360
361impl<E, S, T> FirstAnswer<T, E> for S
362where
363    S: Stream<Item = Result<T, E>> + Unpin + Sized,
364    E: From<ProtoError>,
365{
366}
367
368/// See [FirstAnswer::first_answer]
369#[derive(Debug)]
370#[must_use = "futures do nothing unless you `.await` or poll them"]
371pub struct FirstAnswerFuture<S> {
372    stream: Option<S>,
373}
374
375impl<S: Stream<Item = Result<T, NetError>> + Unpin, T> Future for FirstAnswerFuture<S>
376where
377    S: Stream<Item = Result<T, NetError>> + Unpin + Sized,
378{
379    type Output = S::Item;
380
381    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
382        let s = self
383            .stream
384            .as_mut()
385            .expect("polling FirstAnswerFuture twice");
386        let item = match ready!(s.poll_next_unpin(cx)) {
387            Some(r) => r,
388            None => Err(NetError::Timeout),
389        };
390        self.stream.take();
391        Poll::Ready(item)
392    }
393}
394
395/// The protocol on which a NameServer should be communicated with
396#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
397#[cfg_attr(
398    feature = "serde",
399    derive(Serialize, Deserialize),
400    serde(rename_all = "lowercase")
401)]
402#[non_exhaustive]
403pub enum Protocol {
404    /// UDP is the traditional DNS port, this is generally the correct choice
405    Udp,
406    /// TCP can be used for large queries, but not all NameServers support it
407    Tcp,
408    /// Tls for DNS over TLS
409    #[cfg(feature = "__tls")]
410    Tls,
411    /// Https for DNS over HTTPS
412    #[cfg(feature = "__https")]
413    Https,
414    /// QUIC for DNS over QUIC
415    #[cfg(feature = "__quic")]
416    Quic,
417    /// HTTP/3 for DNS over HTTP/3
418    #[cfg(feature = "__h3")]
419    H3,
420}
421
422impl Protocol {
423    /// Returns true if this is a datagram oriented protocol, e.g. UDP
424    pub fn is_datagram(self) -> bool {
425        matches!(self, Self::Udp)
426    }
427
428    /// Returns true if this is a stream oriented protocol, e.g. TCP
429    pub fn is_stream(self) -> bool {
430        !self.is_datagram()
431    }
432
433    /// Is this an encrypted protocol, i.e. TLS or HTTPS
434    pub fn is_encrypted(self) -> bool {
435        match self {
436            Self::Udp => false,
437            Self::Tcp => false,
438            #[cfg(feature = "__tls")]
439            Self::Tls => true,
440            #[cfg(feature = "__https")]
441            Self::Https => true,
442            #[cfg(feature = "__quic")]
443            Self::Quic => true,
444            #[cfg(feature = "__h3")]
445            Self::H3 => true,
446        }
447    }
448}
449
450impl Display for Protocol {
451    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452        f.write_str(match self {
453            Self::Udp => "udp",
454            Self::Tcp => "tcp",
455            #[cfg(feature = "__tls")]
456            Self::Tls => "tls",
457            #[cfg(feature = "__https")]
458            Self::Https => "https",
459            #[cfg(feature = "__quic")]
460            Self::Quic => "quic",
461            #[cfg(feature = "__h3")]
462            Self::H3 => "h3",
463        })
464    }
465}
466
467#[allow(unused)] // May be unused depending on features
468pub(crate) const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
469
470/// Default buffer size for outbound DNS message streams.
471///
472/// This controls the maximum number of DNS messages that can be queued for sending on a
473/// single connection. The buffer is used in two contexts:
474///
475/// - **Server side**: buffers outbound responses to clients
476/// - **Client/resolver side**: buffers outbound queries to nameservers
477///
478/// Under high load, a larger buffer prevents messages from being dropped due to backpressure.
479const DEFAULT_STREAM_BUFFER_SIZE: usize = 32;