use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use core::pin::Pin;
use core::task::{Context, Poll};
use std::collections::HashSet;
use std::io;
use std::sync::Arc;
use async_trait::async_trait;
use futures_util::{
future::{BoxFuture, Future},
ready,
stream::Stream,
};
use tracing::{debug, trace, warn};
use crate::error::NetError;
use crate::proto::op::SerialMessage;
use crate::runtime::{DnsUdpSocket, RuntimeProvider};
use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
#[async_trait]
pub trait UdpSocket: DnsUdpSocket {
async fn connect(addr: SocketAddr) -> io::Result<Self>;
async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
async fn bind(addr: SocketAddr) -> io::Result<Self>;
}
#[must_use = "futures do nothing unless polled"]
pub struct UdpStream<P: RuntimeProvider> {
socket: P::Udp,
outbound_messages: StreamReceiver,
}
impl<P: RuntimeProvider> UdpStream<P> {
pub fn new(
remote_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
avoid_local_ports: Option<Arc<HashSet<u16>>>,
os_port_selection: bool,
provider: P,
) -> (
BoxFuture<'static, Result<Self, NetError>>,
BufDnsStreamHandle,
) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
let next_socket = NextRandomUdpSocket::new(
remote_addr,
bind_addr,
avoid_local_ports.unwrap_or_default(),
os_port_selection,
provider,
);
let stream = Box::pin(async {
Ok(Self {
socket: next_socket.await?,
outbound_messages,
})
});
(stream, message_sender)
}
}
impl<P: RuntimeProvider> UdpStream<P> {
pub fn with_bound(socket: P::Udp, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
let stream = Self {
socket,
outbound_messages,
};
(stream, message_sender)
}
#[cfg(all(feature = "tokio", feature = "mdns"))]
pub(crate) fn from_parts(socket: P::Udp, outbound_messages: StreamReceiver) -> Self {
Self {
socket,
outbound_messages,
}
}
}
impl<P: RuntimeProvider> UdpStream<P> {
fn pollable_split(&mut self) -> (&mut P::Udp, &mut StreamReceiver) {
(&mut self.socket, &mut self.outbound_messages)
}
}
impl<P: RuntimeProvider> Stream for UdpStream<P> {
type Item = Result<SerialMessage, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (socket, outbound_messages) = self.pollable_split();
let socket = Pin::new(socket);
let mut outbound_messages = Pin::new(outbound_messages);
while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
let addr = message.addr();
if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
warn!(
"error sending message to {} on udp_socket, dropping response: {}",
addr, e
);
}
assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
}
let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
Poll::Ready(Some(Ok(serial_message)))
}
}
#[must_use = "futures do nothing unless polled"]
pub(crate) struct NextRandomUdpSocket<P: RuntimeProvider> {
name_server: SocketAddr,
bind_address: SocketAddr,
provider: P,
attempted: usize,
#[allow(clippy::type_complexity)]
future: Option<Pin<Box<dyn Send + Future<Output = Result<P::Udp, NetError>>>>>,
avoid_local_ports: Arc<HashSet<u16>>,
os_port_selection: bool,
}
impl<P: RuntimeProvider> NextRandomUdpSocket<P> {
pub(crate) fn new(
name_server: SocketAddr,
bind_addr: Option<SocketAddr>,
avoid_local_ports: Arc<HashSet<u16>>,
os_port_selection: bool,
provider: P,
) -> Self {
let bind_address = match bind_addr {
Some(ba) => ba,
None => match name_server {
SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(..) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
},
};
Self {
name_server,
bind_address,
provider,
attempted: 0,
future: None,
avoid_local_ports,
os_port_selection,
}
}
}
impl<P: RuntimeProvider> Future for NextRandomUdpSocket<P> {
type Output = Result<P::Udp, NetError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
this.future = match this.future.take() {
Some(mut future) => match future.as_mut().poll(cx) {
Poll::Ready(Ok(socket)) => {
debug!("created socket successfully");
return Poll::Ready(Ok(socket));
}
Poll::Ready(Err(NetError::Io(io)))
if matches!(
io.kind(),
io::ErrorKind::PermissionDenied | io::ErrorKind::AddrInUse
) && this.attempted < ATTEMPT_RANDOM + 1 =>
{
debug!("unable to bind port, attempt: {}: {io}", this.attempted);
this.attempted += 1;
None
}
Poll::Ready(Err(err)) => {
debug!("failed to bind port: {err}");
return Poll::Ready(Err(err));
}
Poll::Pending => {
debug!("unable to bind port, attempt: {}", this.attempted);
this.future = Some(future);
return Poll::Pending;
}
},
None => {
let mut bind_addr = this.bind_address;
if !this.os_port_selection && bind_addr.port() == 0 {
while this.attempted < ATTEMPT_RANDOM {
let port = rand::random_range(1024..=u16::MAX);
if this.avoid_local_ports.contains(&port) {
this.attempted += 1;
continue;
} else {
bind_addr = SocketAddr::new(bind_addr.ip(), port);
break;
}
}
}
trace!(port = bind_addr.port(), "binding UDP socket");
let future = this.provider.bind_udp(bind_addr, this.name_server);
Some(Box::pin(async move { Ok(future.await?) }))
}
}
}
}
}
const ATTEMPT_RANDOM: usize = 10;
#[cfg(feature = "tokio")]
#[async_trait]
impl UdpSocket for tokio::net::UdpSocket {
async fn connect(addr: SocketAddr) -> io::Result<Self> {
let bind_addr: SocketAddr = match addr {
SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
};
Self::connect_with_bind(addr, bind_addr).await
}
async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
let socket = Self::bind(bind_addr).await?;
Ok(socket)
}
async fn bind(addr: SocketAddr) -> io::Result<Self> {
Self::bind(addr).await
}
}
#[cfg(feature = "tokio")]
#[async_trait]
impl DnsUdpSocket for tokio::net::UdpSocket {
type Time = crate::runtime::TokioTime;
fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
let mut buf = tokio::io::ReadBuf::new(buf);
let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
let len = buf.filled().len();
Poll::Ready(Ok((len, addr)))
}
fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: SocketAddr,
) -> Poll<io::Result<usize>> {
Self::poll_send_to(self, cx, buf, target)
}
}
#[cfg(test)]
#[cfg(feature = "tokio")]
mod tests {
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use test_support::subscribe;
use crate::{
runtime::TokioRuntimeProvider,
udp::tests::{next_random_socket_test, udp_stream_test},
};
#[tokio::test]
async fn test_next_random_socket() {
subscribe();
let provider = TokioRuntimeProvider::new();
next_random_socket_test(provider).await;
}
#[tokio::test]
async fn test_udp_stream_ipv4() {
subscribe();
let provider = TokioRuntimeProvider::new();
udp_stream_test(IpAddr::V4(Ipv4Addr::LOCALHOST), provider).await;
}
#[tokio::test]
async fn test_udp_stream_ipv6() {
subscribe();
let provider = TokioRuntimeProvider::new();
udp_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), provider).await;
}
}