1use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use std::collections::HashSet;
12use std::io;
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use futures_util::{
17 future::{BoxFuture, Future},
18 ready,
19 stream::Stream,
20};
21use tracing::{debug, trace, warn};
22
23use crate::error::NetError;
24use crate::proto::op::SerialMessage;
25use crate::runtime::{DnsUdpSocket, RuntimeProvider};
26use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
27use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
28
29#[async_trait]
31pub trait UdpSocket: DnsUdpSocket {
32 async fn connect(addr: SocketAddr) -> io::Result<Self>;
34
35 async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
37
38 async fn bind(addr: SocketAddr) -> io::Result<Self>;
40}
41
42#[must_use = "futures do nothing unless polled"]
44pub struct UdpStream<P: RuntimeProvider> {
45 socket: P::Udp,
46 outbound_messages: StreamReceiver,
47}
48
49impl<P: RuntimeProvider> UdpStream<P> {
50 pub fn new(
78 remote_addr: SocketAddr,
79 bind_addr: Option<SocketAddr>,
80 avoid_local_ports: Option<Arc<HashSet<u16>>>,
81 os_port_selection: bool,
82 provider: P,
83 ) -> (
84 BoxFuture<'static, Result<Self, NetError>>,
85 BufDnsStreamHandle,
86 ) {
87 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
88
89 let next_socket = NextRandomUdpSocket::new(
91 remote_addr,
92 bind_addr,
93 avoid_local_ports.unwrap_or_default(),
94 os_port_selection,
95 provider,
96 );
97
98 let stream = Box::pin(async {
101 Ok(Self {
102 socket: next_socket.await?,
103 outbound_messages,
104 })
105 });
106
107 (stream, message_sender)
108 }
109}
110
111impl<P: RuntimeProvider> UdpStream<P> {
112 pub fn with_bound(socket: P::Udp, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
127 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
128 let stream = Self {
129 socket,
130 outbound_messages,
131 };
132
133 (stream, message_sender)
134 }
135
136 #[cfg(all(feature = "tokio", feature = "mdns"))]
137 pub(crate) fn from_parts(socket: P::Udp, outbound_messages: StreamReceiver) -> Self {
138 Self {
139 socket,
140 outbound_messages,
141 }
142 }
143}
144
145impl<P: RuntimeProvider> UdpStream<P> {
146 fn pollable_split(&mut self) -> (&mut P::Udp, &mut StreamReceiver) {
147 (&mut self.socket, &mut self.outbound_messages)
148 }
149}
150
151impl<P: RuntimeProvider> Stream for UdpStream<P> {
152 type Item = Result<SerialMessage, io::Error>;
153
154 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
155 let (socket, outbound_messages) = self.pollable_split();
156 let socket = Pin::new(socket);
157 let mut outbound_messages = Pin::new(outbound_messages);
158
159 while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
162 let addr = message.addr();
164
165 if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
170 warn!(
172 "error sending message to {} on udp_socket, dropping response: {}",
173 addr, e
174 );
175 }
176
177 assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
179 }
180
181 let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
186 let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
187
188 let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
189 Poll::Ready(Some(Ok(serial_message)))
190 }
191}
192
193#[must_use = "futures do nothing unless polled"]
194pub(crate) struct NextRandomUdpSocket<P: RuntimeProvider> {
195 name_server: SocketAddr,
196 bind_address: SocketAddr,
197 provider: P,
198 attempted: usize,
200 #[allow(clippy::type_complexity)]
201 future: Option<Pin<Box<dyn Send + Future<Output = Result<P::Udp, NetError>>>>>,
202 avoid_local_ports: Arc<HashSet<u16>>,
203 os_port_selection: bool,
204}
205
206impl<P: RuntimeProvider> NextRandomUdpSocket<P> {
207 pub(crate) fn new(
212 name_server: SocketAddr,
213 bind_addr: Option<SocketAddr>,
214 avoid_local_ports: Arc<HashSet<u16>>,
215 os_port_selection: bool,
216 provider: P,
217 ) -> Self {
218 let bind_address = match bind_addr {
219 Some(ba) => ba,
220 None => match name_server {
221 SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
222 SocketAddr::V6(..) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
223 },
224 };
225
226 Self {
227 name_server,
228 bind_address,
229 provider,
230 attempted: 0,
231 future: None,
232 avoid_local_ports,
233 os_port_selection,
234 }
235 }
236}
237
238impl<P: RuntimeProvider> Future for NextRandomUdpSocket<P> {
239 type Output = Result<P::Udp, NetError>;
240
241 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244 let this = self.get_mut();
245 loop {
246 this.future = match this.future.take() {
247 Some(mut future) => match future.as_mut().poll(cx) {
248 Poll::Ready(Ok(socket)) => {
249 debug!("created socket successfully");
250 return Poll::Ready(Ok(socket));
251 }
252 Poll::Ready(Err(NetError::Io(io)))
253 if matches!(
254 io.kind(),
255 io::ErrorKind::PermissionDenied | io::ErrorKind::AddrInUse
256 ) && this.attempted < ATTEMPT_RANDOM + 1 =>
257 {
258 debug!("unable to bind port, attempt: {}: {io}", this.attempted);
259 this.attempted += 1;
260 None
261 }
262 Poll::Ready(Err(err)) => {
263 debug!("failed to bind port: {err}");
264 return Poll::Ready(Err(err));
265 }
266 Poll::Pending => {
267 debug!("unable to bind port, attempt: {}", this.attempted);
268 this.future = Some(future);
269 return Poll::Pending;
270 }
271 },
272 None => {
273 let mut bind_addr = this.bind_address;
274
275 if !this.os_port_selection && bind_addr.port() == 0 {
276 while this.attempted < ATTEMPT_RANDOM {
277 let port = rand::random_range(1024..=u16::MAX);
283 if this.avoid_local_ports.contains(&port) {
284 this.attempted += 1;
290 continue;
291 } else {
292 bind_addr = SocketAddr::new(bind_addr.ip(), port);
293 break;
294 }
295 }
296 }
297
298 trace!(port = bind_addr.port(), "binding UDP socket");
299 let future = this.provider.bind_udp(bind_addr, this.name_server);
300 Some(Box::pin(async move { Ok(future.await?) }))
301 }
302 }
303 }
304 }
305}
306
307const ATTEMPT_RANDOM: usize = 10;
308
309#[cfg(feature = "tokio")]
310#[async_trait]
311impl UdpSocket for tokio::net::UdpSocket {
312 async fn connect(addr: SocketAddr) -> io::Result<Self> {
316 let bind_addr: SocketAddr = match addr {
317 SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
318 SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
319 };
320
321 Self::connect_with_bind(addr, bind_addr).await
322 }
323
324 async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
326 let socket = Self::bind(bind_addr).await?;
327
328 Ok(socket)
332 }
333
334 async fn bind(addr: SocketAddr) -> io::Result<Self> {
335 Self::bind(addr).await
336 }
337}
338
339#[cfg(feature = "tokio")]
340#[async_trait]
341impl DnsUdpSocket for tokio::net::UdpSocket {
342 type Time = crate::runtime::TokioTime;
343
344 fn poll_recv_from(
345 &self,
346 cx: &mut Context<'_>,
347 buf: &mut [u8],
348 ) -> Poll<io::Result<(usize, SocketAddr)>> {
349 let mut buf = tokio::io::ReadBuf::new(buf);
350 let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
351 let len = buf.filled().len();
352
353 Poll::Ready(Ok((len, addr)))
354 }
355
356 fn poll_send_to(
357 &self,
358 cx: &mut Context<'_>,
359 buf: &[u8],
360 target: SocketAddr,
361 ) -> Poll<io::Result<usize>> {
362 Self::poll_send_to(self, cx, buf, target)
363 }
364}
365
366#[cfg(test)]
367#[cfg(feature = "tokio")]
368mod tests {
369 use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
370
371 use test_support::subscribe;
372
373 use crate::{
374 runtime::TokioRuntimeProvider,
375 udp::tests::{next_random_socket_test, udp_stream_test},
376 };
377
378 #[tokio::test]
379 async fn test_next_random_socket() {
380 subscribe();
381 let provider = TokioRuntimeProvider::new();
382 next_random_socket_test(provider).await;
383 }
384
385 #[tokio::test]
386 async fn test_udp_stream_ipv4() {
387 subscribe();
388 let provider = TokioRuntimeProvider::new();
389 udp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
390 }
391
392 #[tokio::test]
393 async fn test_udp_stream_ipv6() {
394 subscribe();
395 let provider = TokioRuntimeProvider::new();
396 udp_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
397 }
398}