Skip to main content

hickory_net/udp/
udp_stream.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use std::collections::HashSet;
12use std::io;
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use futures_util::{
17    future::{BoxFuture, Future},
18    ready,
19    stream::Stream,
20};
21use tracing::{debug, trace, warn};
22
23use crate::error::NetError;
24use crate::proto::op::SerialMessage;
25use crate::runtime::{DnsUdpSocket, RuntimeProvider};
26use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
27use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
28
29/// Trait for UdpSocket
30#[async_trait]
31pub trait UdpSocket: DnsUdpSocket {
32    /// setups up a "client" udp connection that will only receive packets from the associated address
33    async fn connect(addr: SocketAddr) -> io::Result<Self>;
34
35    /// same as connect, but binds to the specified local address for sending address
36    async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
37
38    /// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything)
39    async fn bind(addr: SocketAddr) -> io::Result<Self>;
40}
41
42/// A UDP stream of DNS binary packets
43#[must_use = "futures do nothing unless polled"]
44pub struct UdpStream<P: RuntimeProvider> {
45    socket: P::Udp,
46    outbound_messages: StreamReceiver,
47}
48
49impl<P: RuntimeProvider> UdpStream<P> {
50    /// This method is intended for client connections, see [`Self::with_bound`] for a method better
51    ///  for straight listening. It is expected that the resolver wrapper will be responsible for
52    ///  creating and managing new UdpStreams such that each new client would have a random port
53    ///  (reduce chance of cache poisoning). This will return a randomly assigned local port, unless
54    ///  a nonzero port number is specified in `bind_addr`.
55    ///
56    /// # Arguments
57    ///
58    /// * `remote_addr` - socket address for the remote connection (used to determine IPv4 or IPv6)
59    /// * `bind_addr` - optional local socket address to connect from (if a nonzero port number is
60    ///   specified, it will be used instead of randomly selecting a port)
61    /// * `os_port_selection` - Boolean parameter to specify whether to use the operating system's
62    ///   standard UDP port selection logic instead of Hickory's logic to
63    ///   securely select a random source port. We do not recommend using
64    ///   this option unless absolutely necessary, as the operating system
65    ///   may select ephemeral ports from a smaller range than Hickory, which
66    ///   can make response poisoning attacks easier to conduct. Some
67    ///   operating systems (notably, Windows) might display a user-prompt to
68    ///   allow a Hickory-specified port to be used, and setting this option
69    ///   will prevent those prompts from being displayed. If os_port_selection
70    ///   is true, avoid_local_udp_ports will be ignored.
71    /// * `provider` - async runtime provider, for I/O and timers
72    ///
73    /// # Return
74    ///
75    /// A tuple of a Future of a Stream which will handle sending and receiving messages, and a
76    ///  handle which can be used to send messages into the stream.
77    pub fn new(
78        remote_addr: SocketAddr,
79        bind_addr: Option<SocketAddr>,
80        avoid_local_ports: Option<Arc<HashSet<u16>>>,
81        os_port_selection: bool,
82        provider: P,
83    ) -> (
84        BoxFuture<'static, Result<Self, NetError>>,
85        BufDnsStreamHandle,
86    ) {
87        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
88
89        // constructs a future for getting the next randomly bound port to a UdpSocket
90        let next_socket = NextRandomUdpSocket::new(
91            remote_addr,
92            bind_addr,
93            avoid_local_ports.unwrap_or_default(),
94            os_port_selection,
95            provider,
96        );
97
98        // This set of futures collapses the next udp socket into a stream which can be used for
99        //  sending and receiving udp packets.
100        let stream = Box::pin(async {
101            Ok(Self {
102                socket: next_socket.await?,
103                outbound_messages,
104            })
105        });
106
107        (stream, message_sender)
108    }
109}
110
111impl<P: RuntimeProvider> UdpStream<P> {
112    /// Initialize the Stream with an already bound socket. Generally this should be only used for
113    ///  server listening sockets. See [`Self::new`] for a client oriented socket. Specifically,
114    ///  this requires there is already a bound socket, whereas `new` makes sure to randomize ports
115    ///  for additional cache poison prevention.
116    ///
117    /// # Arguments
118    ///
119    /// * `socket` - an already bound UDP socket
120    /// * `remote_addr` - remote side of this connection
121    ///
122    /// # Return
123    ///
124    /// A tuple of a Stream which will handle sending and receiving messages, and a handle which can
125    ///  be used to send messages into the stream.
126    pub fn with_bound(socket: P::Udp, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
127        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
128        let stream = Self {
129            socket,
130            outbound_messages,
131        };
132
133        (stream, message_sender)
134    }
135
136    #[cfg(all(feature = "tokio", feature = "mdns"))]
137    pub(crate) fn from_parts(socket: P::Udp, outbound_messages: StreamReceiver) -> Self {
138        Self {
139            socket,
140            outbound_messages,
141        }
142    }
143}
144
145impl<P: RuntimeProvider> UdpStream<P> {
146    fn pollable_split(&mut self) -> (&mut P::Udp, &mut StreamReceiver) {
147        (&mut self.socket, &mut self.outbound_messages)
148    }
149}
150
151impl<P: RuntimeProvider> Stream for UdpStream<P> {
152    type Item = Result<SerialMessage, io::Error>;
153
154    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
155        let (socket, outbound_messages) = self.pollable_split();
156        let socket = Pin::new(socket);
157        let mut outbound_messages = Pin::new(outbound_messages);
158
159        // this will not accept incoming data while there is data to send
160        //  makes this self throttling.
161        while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
162            // first try to send
163            let addr = message.addr();
164
165            // this will return if not ready,
166            //   meaning that sending will be preferred over receiving...
167
168            // TODO: shouldn't this return the error to send to the sender?
169            if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
170                // Drop the UDP packet and continue
171                warn!(
172                    "error sending message to {} on udp_socket, dropping response: {}",
173                    addr, e
174                );
175            }
176
177            // message sent, need to pop the message
178            assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
179        }
180
181        // For QoS, this will only accept one message and output that
182        // receive all inbound messages
183
184        // TODO: this should match edns settings
185        let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
186        let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
187
188        let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
189        Poll::Ready(Some(Ok(serial_message)))
190    }
191}
192
193#[must_use = "futures do nothing unless polled"]
194pub(crate) struct NextRandomUdpSocket<P: RuntimeProvider> {
195    name_server: SocketAddr,
196    bind_address: SocketAddr,
197    provider: P,
198    /// Number of unsuccessful attempts to pick a port.
199    attempted: usize,
200    #[allow(clippy::type_complexity)]
201    future: Option<Pin<Box<dyn Send + Future<Output = Result<P::Udp, NetError>>>>>,
202    avoid_local_ports: Arc<HashSet<u16>>,
203    os_port_selection: bool,
204}
205
206impl<P: RuntimeProvider> NextRandomUdpSocket<P> {
207    /// Creates a future for randomly binding to a local socket address for client connections,
208    /// if no port is specified.
209    ///
210    /// If a port is specified in the bind address it is used.
211    pub(crate) fn new(
212        name_server: SocketAddr,
213        bind_addr: Option<SocketAddr>,
214        avoid_local_ports: Arc<HashSet<u16>>,
215        os_port_selection: bool,
216        provider: P,
217    ) -> Self {
218        let bind_address = match bind_addr {
219            Some(ba) => ba,
220            None => match name_server {
221                SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
222                SocketAddr::V6(..) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
223            },
224        };
225
226        Self {
227            name_server,
228            bind_address,
229            provider,
230            attempted: 0,
231            future: None,
232            avoid_local_ports,
233            os_port_selection,
234        }
235    }
236}
237
238impl<P: RuntimeProvider> Future for NextRandomUdpSocket<P> {
239    type Output = Result<P::Udp, NetError>;
240
241    /// polls until there is an available next random UDP port,
242    /// if no port has been specified in bind_addr.
243    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244        let this = self.get_mut();
245        loop {
246            this.future = match this.future.take() {
247                Some(mut future) => match future.as_mut().poll(cx) {
248                    Poll::Ready(Ok(socket)) => {
249                        debug!("created socket successfully");
250                        return Poll::Ready(Ok(socket));
251                    }
252                    Poll::Ready(Err(NetError::Io(io)))
253                        if matches!(
254                            io.kind(),
255                            io::ErrorKind::PermissionDenied | io::ErrorKind::AddrInUse
256                        ) && this.attempted < ATTEMPT_RANDOM + 1 =>
257                    {
258                        debug!("unable to bind port, attempt: {}: {io}", this.attempted);
259                        this.attempted += 1;
260                        None
261                    }
262                    Poll::Ready(Err(err)) => {
263                        debug!("failed to bind port: {err}");
264                        return Poll::Ready(Err(err));
265                    }
266                    Poll::Pending => {
267                        debug!("unable to bind port, attempt: {}", this.attempted);
268                        this.future = Some(future);
269                        return Poll::Pending;
270                    }
271                },
272                None => {
273                    let mut bind_addr = this.bind_address;
274
275                    if !this.os_port_selection && bind_addr.port() == 0 {
276                        while this.attempted < ATTEMPT_RANDOM {
277                            // Per RFC 6056 Section 3.2:
278                            //
279                            // As mentioned in Section 2.1, the dynamic ports consist of the range
280                            // 49152-65535.  However, ephemeral port selection algorithms should use
281                            // the whole range 1024-65535.
282                            let port = rand::random_range(1024..=u16::MAX);
283                            if this.avoid_local_ports.contains(&port) {
284                                // Count this against the total number of attempts to pick a port.
285                                // RFC 6056 Section 3.3.2 notes that this algorithm should find a
286                                // suitable port in one or two attempts with high probability in
287                                // common scenarios. If `avoid_local_ports` is pathologically large,
288                                // then incrementing the counter here will prevent an infinite loop.
289                                this.attempted += 1;
290                                continue;
291                            } else {
292                                bind_addr = SocketAddr::new(bind_addr.ip(), port);
293                                break;
294                            }
295                        }
296                    }
297
298                    trace!(port = bind_addr.port(), "binding UDP socket");
299                    let future = this.provider.bind_udp(bind_addr, this.name_server);
300                    Some(Box::pin(async move { Ok(future.await?) }))
301                }
302            }
303        }
304    }
305}
306
307const ATTEMPT_RANDOM: usize = 10;
308
309#[cfg(feature = "tokio")]
310#[async_trait]
311impl UdpSocket for tokio::net::UdpSocket {
312    /// sets up up a "client" udp connection that will only receive packets from the associated address
313    ///
314    /// if the addr is ipv4 then it will bind local addr to 0.0.0.0:0, ipv6 \[::\]0
315    async fn connect(addr: SocketAddr) -> io::Result<Self> {
316        let bind_addr: SocketAddr = match addr {
317            SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
318            SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
319        };
320
321        Self::connect_with_bind(addr, bind_addr).await
322    }
323
324    /// same as connect, but binds to the specified local address for sending address
325    async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
326        let socket = Self::bind(bind_addr).await?;
327
328        // TODO: research connect more, it appears to break UDP receiving tests, etc...
329        // socket.connect(addr).await?;
330
331        Ok(socket)
332    }
333
334    async fn bind(addr: SocketAddr) -> io::Result<Self> {
335        Self::bind(addr).await
336    }
337}
338
339#[cfg(feature = "tokio")]
340#[async_trait]
341impl DnsUdpSocket for tokio::net::UdpSocket {
342    type Time = crate::runtime::TokioTime;
343
344    fn poll_recv_from(
345        &self,
346        cx: &mut Context<'_>,
347        buf: &mut [u8],
348    ) -> Poll<io::Result<(usize, SocketAddr)>> {
349        let mut buf = tokio::io::ReadBuf::new(buf);
350        let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
351        let len = buf.filled().len();
352
353        Poll::Ready(Ok((len, addr)))
354    }
355
356    fn poll_send_to(
357        &self,
358        cx: &mut Context<'_>,
359        buf: &[u8],
360        target: SocketAddr,
361    ) -> Poll<io::Result<usize>> {
362        Self::poll_send_to(self, cx, buf, target)
363    }
364}
365
366#[cfg(test)]
367#[cfg(feature = "tokio")]
368mod tests {
369    use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
370
371    use test_support::subscribe;
372
373    use crate::{
374        runtime::TokioRuntimeProvider,
375        udp::tests::{next_random_socket_test, udp_stream_test},
376    };
377
378    #[tokio::test]
379    async fn test_next_random_socket() {
380        subscribe();
381        let provider = TokioRuntimeProvider::new();
382        next_random_socket_test(provider).await;
383    }
384
385    #[tokio::test]
386    async fn test_udp_stream_ipv4() {
387        subscribe();
388        let provider = TokioRuntimeProvider::new();
389        udp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
390    }
391
392    #[tokio::test]
393    async fn test_udp_stream_ipv6() {
394        subscribe();
395        let provider = TokioRuntimeProvider::new();
396        udp_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
397    }
398}