use std::{
fmt::Debug,
future::Future,
io,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use anyhow::{bail, Context as _};
use netwatch::UdpSocket;
use quinn::AsyncUdpSocket;
use quinn_udp::{Transmit, UdpSockRef};
use tokio::io::Interest;
use tracing::{debug, trace};
#[derive(Clone, Debug)]
pub struct UdpConn {
io: Arc<UdpSocket>,
inner: Arc<quinn_udp::UdpSocketState>,
}
impl UdpConn {
pub(super) fn as_socket(&self) -> Arc<UdpSocket> {
self.io.clone()
}
pub(super) fn bind(addr: SocketAddr) -> anyhow::Result<Self> {
let sock = bind(addr)?;
let state = quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&sock))?;
Ok(Self {
io: Arc::new(sock),
inner: Arc::new(state),
})
}
pub fn port(&self) -> u16 {
self.local_addr().map(|p| p.port()).unwrap_or_default()
}
#[allow(clippy::unused_async)]
pub async fn close(&self) -> Result<(), io::Error> {
Ok(())
}
}
impl AsyncUdpSocket for UdpConn {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn quinn::UdpPoller>> {
let sock = self.io.clone();
Box::pin(IoPoller {
next_waiter: move || {
let sock = sock.clone();
async move { sock.writable().await }
},
waiter: None,
})
}
fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> {
self.io.try_io(Interest::WRITABLE, || {
let sock_ref = UdpSockRef::from(&self.io);
self.inner.send(sock_ref, transmit)
})
}
fn poll_recv(
&self,
cx: &mut Context,
bufs: &mut [io::IoSliceMut<'_>],
meta: &mut [quinn_udp::RecvMeta],
) -> Poll<io::Result<usize>> {
loop {
ready!(self.io.poll_recv_ready(cx))?;
if let Ok(res) = self.io.try_io(Interest::READABLE, || {
self.inner.recv(Arc::as_ref(&self.io).into(), bufs, meta)
}) {
for meta in meta.iter().take(res) {
trace!(
src = %meta.addr,
len = meta.len,
count = meta.len / meta.stride,
dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(),
"UDP recv"
);
}
return Poll::Ready(Ok(res));
}
}
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.io.local_addr()
}
fn may_fragment(&self) -> bool {
self.inner.may_fragment()
}
fn max_transmit_segments(&self) -> usize {
self.inner.max_gso_segments()
}
fn max_receive_segments(&self) -> usize {
self.inner.gro_segments()
}
}
fn bind(mut addr: SocketAddr) -> anyhow::Result<UdpSocket> {
debug!(%addr, "binding");
let mut ports = Vec::new();
if addr.port() != 0 {
ports.push(addr.port());
}
ports.push(0);
ports.dedup();
debug!(?ports, "candidate ports");
for port in &ports {
addr.set_port(*port);
match UdpSocket::bind_full(addr) {
Ok(pconn) => {
let local_addr = pconn.local_addr().context("UDP socket not bound")?;
debug!(%addr, %local_addr, "successfully bound");
return Ok(pconn);
}
Err(err) => {
debug!(%addr, "failed to bind: {err:#}");
continue;
}
}
}
bail!("failed to bind any ports on {:?} (tried {:?})", addr, ports);
}
#[derive(derive_more::Debug)]
#[pin_project::pin_project]
struct IoPoller<F, Fut>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = io::Result<()>> + Send + Sync + 'static,
{
#[debug("next_waiter")]
next_waiter: F,
#[debug("waiter")]
#[pin]
waiter: Option<Fut>,
}
impl<F, Fut> quinn::UdpPoller for IoPoller<F, Fut>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = io::Result<()>> + Send + Sync + 'static,
{
fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let mut this = self.project();
if this.waiter.is_none() {
this.waiter.set(Some((this.next_waiter)()));
}
let result = this
.waiter
.as_mut()
.as_pin_mut()
.expect("just set")
.poll(cx);
if result.is_ready() {
this.waiter.set(None);
}
result
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use netwatch::IpFamily;
use tokio::sync::mpsc;
use tracing::{info_span, Instrument};
use super::*;
use crate::{key, tls};
const ALPN: &[u8] = b"n0/test/1";
fn wrap_socket(conn: impl AsyncUdpSocket) -> Result<(quinn::Endpoint, key::SecretKey)> {
let key = key::SecretKey::generate();
let quic_server_config = tls::make_server_config(&key, vec![ALPN.to_vec()], false)?;
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_server_config));
let mut quic_ep = quinn::Endpoint::new_with_abstract_socket(
quinn::EndpointConfig::default(),
Some(server_config),
Arc::new(conn),
Arc::new(quinn::TokioRuntime),
)?;
let quic_client_config = tls::make_client_config(&key, None, vec![ALPN.to_vec()], false)?;
let client_config = quinn::ClientConfig::new(Arc::new(quic_client_config));
quic_ep.set_default_client_config(client_config);
Ok((quic_ep, key))
}
#[tokio::test]
async fn test_rebinding_conn_send_recv_ipv4() -> Result<()> {
let _guard = iroh_test::logging::setup();
rebinding_conn_send_recv(IpFamily::V4).await
}
#[tokio::test]
async fn test_rebinding_conn_send_recv_ipv6() -> Result<()> {
let _guard = iroh_test::logging::setup();
if !crate::netcheck::os_has_ipv6() {
return Ok(());
}
rebinding_conn_send_recv(IpFamily::V6).await
}
async fn rebinding_conn_send_recv(network: IpFamily) -> Result<()> {
let m1 = UdpConn::bind(SocketAddr::new(network.unspecified_addr(), 0))?;
let (m1, _m1_key) = wrap_socket(m1)?;
let m2 = UdpConn::bind(SocketAddr::new(network.unspecified_addr(), 0))?;
let (m2, _m2_key) = wrap_socket(m2)?;
let m1_addr = SocketAddr::new(network.local_addr(), m1.local_addr()?.port());
let (m1_send, mut m1_recv) = mpsc::channel(8);
let m1_task = tokio::task::spawn(
async move {
if let Some(conn) = m1.accept().await.and_then(|inc| inc.accept().ok()) {
let conn = conn.await?;
let (mut send_bi, mut recv_bi) = conn.accept_bi().await?;
let val = recv_bi.read_to_end(usize::MAX).await?;
m1_send.send(val).await?;
send_bi.finish()?;
send_bi.stopped().await?;
}
Ok::<_, anyhow::Error>(())
}
.instrument(info_span!("m1_task")),
);
let conn = m2.connect(m1_addr, "localhost")?.await?;
let (mut send_bi, mut recv_bi) = conn.open_bi().await?;
send_bi.write_all(b"hello").await?;
send_bi.finish()?;
let _ = recv_bi.read_to_end(usize::MAX).await?;
conn.close(0u32.into(), b"done");
m2.wait_idle().await;
drop(send_bi);
let val = m1_recv.recv().await.unwrap();
assert_eq!(val, b"hello");
m1_task.await??;
Ok(())
}
}