use std::fmt;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use quinn::udp::{RecvMeta, Transmit};
use quinn::{AsyncUdpSocket, UdpPoller};
use virtual_socket::VirtualUdpSocket;
pub struct SharedSocket {
inner: Arc<VirtualUdpSocket>,
}
impl SharedSocket {
#[must_use]
pub fn new(inner: Arc<VirtualUdpSocket>) -> Arc<Self> {
Arc::new(Self { inner })
}
#[must_use]
pub fn virtual_socket(&self) -> &Arc<VirtualUdpSocket> {
&self.inner
}
}
impl fmt::Debug for SharedSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedSocket").field("inner", &self.inner).finish()
}
}
impl AsyncUdpSocket for SharedSocket {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
Box::pin(SharedSocketPoller { socket: self })
}
fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> {
self.inner.try_send_to(transmit.contents, transmit.destination).map(|_n| ())
}
fn poll_recv(
&self,
cx: &mut Context<'_>,
bufs: &mut [io::IoSliceMut<'_>],
meta: &mut [RecvMeta],
) -> Poll<io::Result<usize>> {
let max = bufs.len().min(meta.len());
if max == 0 {
return Poll::Ready(Ok(0));
}
let first = match self.inner.poll_dequeue(cx) {
Poll::Ready(Some(d)) => d,
Poll::Ready(None) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"virtual socket closed",
)));
}
Poll::Pending => return Poll::Pending,
};
let local = self.inner.local_addr().unwrap_or_else(|_| SocketAddr::from(([0u8, 0, 0, 0], 0)));
fill_slot(0, first, bufs, meta, local);
let mut count = 1;
while count < max {
match self.inner.try_dequeue() {
Some(d) => {
fill_slot(count, d, bufs, meta, local);
count += 1;
}
None => break,
}
}
Poll::Ready(Ok(count))
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
}
fn fill_slot(
idx: usize,
datagram: (SocketAddr, bytes::Bytes),
bufs: &mut [io::IoSliceMut<'_>],
meta: &mut [RecvMeta],
local: SocketAddr,
) {
let (peer, payload) = datagram;
let n = payload.len().min(bufs[idx].len());
bufs[idx][..n].copy_from_slice(&payload[..n]);
meta[idx] = RecvMeta { addr: peer, len: n, stride: n, ecn: None, dst_ip: Some(local.ip()) };
}
struct SharedSocketPoller {
socket: Arc<SharedSocket>,
}
impl fmt::Debug for SharedSocketPoller {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedSocketPoller").finish()
}
}
impl UdpPoller for SharedSocketPoller {
fn poll_writable(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.socket.inner.poll_send_ready(cx)
}
}
#[cfg(test)]
mod tests {
use std::future::poll_fn;
use std::net::Ipv4Addr;
use bytes::Bytes;
use quinn::AsyncUdpSocket;
use tokio::net::UdpSocket;
use super::*;
async fn bound() -> Arc<UdpSocket> {
Arc::new(UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.expect("bind"))
}
#[tokio::test]
async fn local_addr_passes_through() {
let phys = bound().await;
let want = phys.local_addr().expect("local addr");
let virt = VirtualUdpSocket::new(phys);
let shared = SharedSocket::new(virt);
assert_eq!(<SharedSocket as AsyncUdpSocket>::local_addr(&shared).unwrap(), want);
}
#[tokio::test]
async fn poll_recv_pending_when_queue_empty() {
let phys = bound().await;
let virt = VirtualUdpSocket::new(phys);
let shared = SharedSocket::new(virt);
let mut storage = [0u8; 64];
let mut bufs = [io::IoSliceMut::new(&mut storage)];
let mut metas = [RecvMeta::default()];
let r = std::future::poll_fn(|cx| {
match <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas) {
Poll::Pending => Poll::Ready(()),
ready @ Poll::Ready(_) => panic!("expected Pending, got {ready:?}"),
}
})
.await;
let () = r;
}
#[tokio::test]
async fn poll_recv_returns_queued_datagram() {
let phys = bound().await;
let virt = VirtualUdpSocket::new(phys);
let peer: SocketAddr = "192.0.2.10:443".parse().unwrap();
virt.enqueue_inbound(peer, Bytes::from_static(b"INIT"));
let shared = SharedSocket::new(virt);
let mut buf = [0u8; 64];
let mut bufs = [io::IoSliceMut::new(&mut buf)];
let mut metas = [RecvMeta::default()];
let n =
poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
.await
.expect("poll_recv ok");
assert_eq!(n, 1);
assert_eq!(metas[0].addr, peer);
assert_eq!(metas[0].len, 4);
assert_eq!(&buf[..4], b"INIT");
}
#[tokio::test]
async fn poll_recv_drains_burst_into_multi_slot_call() {
let phys = bound().await;
let virt = VirtualUdpSocket::new(phys);
let peer1: SocketAddr = "192.0.2.11:443".parse().unwrap();
let peer2: SocketAddr = "192.0.2.12:443".parse().unwrap();
virt.enqueue_inbound(peer1, Bytes::from_static(b"A"));
virt.enqueue_inbound(peer2, Bytes::from_static(b"BB"));
let shared = SharedSocket::new(virt);
let mut b1 = [0u8; 16];
let mut b2 = [0u8; 16];
let mut bufs = [io::IoSliceMut::new(&mut b1), io::IoSliceMut::new(&mut b2)];
let mut metas = [RecvMeta::default(), RecvMeta::default()];
let n =
poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
.await
.expect("poll_recv ok");
assert_eq!(n, 2);
assert_eq!(metas[0].addr, peer1);
assert_eq!(metas[1].addr, peer2);
assert_eq!(&b1[..1], b"A");
assert_eq!(&b2[..2], b"BB");
}
#[tokio::test]
async fn poll_recv_surfaces_close_as_connection_aborted() {
let phys = bound().await;
let virt = VirtualUdpSocket::new(phys);
virt.close();
let shared = SharedSocket::new(virt);
let mut buf = [0u8; 16];
let mut bufs = [io::IoSliceMut::new(&mut buf)];
let mut metas = [RecvMeta::default()];
let r =
poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
.await;
let err = r.expect_err("close => err");
assert_eq!(err.kind(), io::ErrorKind::ConnectionAborted);
}
#[tokio::test]
async fn try_send_proxies_to_physical() {
let phys_src = bound().await;
let phys_dst = bound().await;
let dst_addr = phys_dst.local_addr().unwrap();
let virt = VirtualUdpSocket::new(phys_src);
let shared = SharedSocket::new(virt);
poll_fn(|cx| shared.virtual_socket().poll_send_ready(cx)).await.expect("ready");
<SharedSocket as AsyncUdpSocket>::try_send(
&shared,
&Transmit {
destination: dst_addr,
ecn: None,
contents: b"PING",
segment_size: None,
src_ip: None,
},
)
.expect("try_send");
let mut got = [0u8; 16];
let (n, _) = phys_dst.recv_from(&mut got).await.expect("recv");
assert_eq!(&got[..n], b"PING");
}
}