Skip to main content

quinn_shared_socket/
lib.rs

1//! Adapter that exposes a [`virtual_socket::VirtualUdpSocket`] as a
2//! [`quinn::AsyncUdpSocket`], so a [`quinn::Endpoint`] can run on a
3//! UDP socket that is shared with other consumers.
4//!
5//! Typical setup:
6//!
7//! 1. A "router" task owns the physical `tokio::net::UdpSocket` and
8//!    `recv_from`s it.
9//! 2. Each consumer (a `quinn::Endpoint`, an L4 forwarder, a DNS
10//!    parser, ...) gets its own [`virtual_socket::VirtualUdpSocket`]
11//!    backed by the same physical socket.
12//! 3. For inbound traffic, the router applies its own demultiplex
13//!    rule and pushes each datagram onto the matching virtual
14//!    socket's queue with [`virtual_socket::VirtualUdpSocket::enqueue_inbound`].
15//! 4. The QUIC consumer wraps its virtual socket in
16//!    [`SharedSocket::new`] and hands the wrapper to
17//!    [`quinn::Endpoint::new_with_abstract_socket`]. quinn then sees
18//!    the virtual socket as if it were exclusive — it can demux QUIC
19//!    Connection IDs internally without knowing about the shared
20//!    physical layer.
21//!
22//! For outbound, [`SharedSocket`] forwards every transmit through
23//! `virtual_socket`'s `try_send_to`, which in turn writes through to
24//! the physical socket. quinn never owns the physical socket, so
25//! other consumers can keep writing through it concurrently.
26
27use std::fmt;
28use std::io;
29use std::net::SocketAddr;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33
34use quinn::udp::{RecvMeta, Transmit};
35use quinn::{AsyncUdpSocket, UdpPoller};
36use virtual_socket::VirtualUdpSocket;
37
38/// Adapter that satisfies [`quinn::AsyncUdpSocket`] for a
39/// [`virtual_socket::VirtualUdpSocket`].
40///
41/// Construct via [`SharedSocket::new`], then hand the resulting
42/// `Arc<SharedSocket>` to
43/// [`quinn::Endpoint::new_with_abstract_socket`].
44pub struct SharedSocket {
45	inner: Arc<VirtualUdpSocket>,
46}
47
48impl SharedSocket {
49	/// Wrap a [`virtual_socket::VirtualUdpSocket`].
50	#[must_use]
51	pub fn new(inner: Arc<VirtualUdpSocket>) -> Arc<Self> {
52		Arc::new(Self { inner })
53	}
54
55	/// Borrow the underlying virtual socket — useful if the caller
56	/// needs to reach the inbound-enqueue side from the same value.
57	#[must_use]
58	pub fn virtual_socket(&self) -> &Arc<VirtualUdpSocket> {
59		&self.inner
60	}
61}
62
63impl fmt::Debug for SharedSocket {
64	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65		f.debug_struct("SharedSocket").field("inner", &self.inner).finish()
66	}
67}
68
69impl AsyncUdpSocket for SharedSocket {
70	fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
71		Box::pin(SharedSocketPoller { socket: self })
72	}
73
74	fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> {
75		self.inner.try_send_to(transmit.contents, transmit.destination).map(|_n| ())
76	}
77
78	fn poll_recv(
79		&self,
80		cx: &mut Context<'_>,
81		bufs: &mut [io::IoSliceMut<'_>],
82		meta: &mut [RecvMeta],
83	) -> Poll<io::Result<usize>> {
84		let max = bufs.len().min(meta.len());
85		if max == 0 {
86			return Poll::Ready(Ok(0));
87		}
88
89		// First slot: must register a waker if no datagram is queued.
90		// `Ready(None)` (closed + drained) surfaces as ConnectionAborted
91		// so quinn's accept loop tears the endpoint down cleanly.
92		let first = match self.inner.poll_dequeue(cx) {
93			Poll::Ready(Some(d)) => d,
94			Poll::Ready(None) => {
95				return Poll::Ready(Err(io::Error::new(
96					io::ErrorKind::ConnectionAborted,
97					"virtual socket closed",
98				)));
99			}
100			Poll::Pending => return Poll::Pending,
101		};
102
103		let local = self.inner.local_addr().unwrap_or_else(|_| SocketAddr::from(([0u8, 0, 0, 0], 0)));
104
105		fill_slot(0, first, bufs, meta, local);
106		let mut count = 1;
107
108		// Drain remaining buf slots non-blockingly so a burst of
109		// datagrams completes in one wake-up.
110		while count < max {
111			match self.inner.try_dequeue() {
112				Some(d) => {
113					fill_slot(count, d, bufs, meta, local);
114					count += 1;
115				}
116				None => break,
117			}
118		}
119		Poll::Ready(Ok(count))
120	}
121
122	fn local_addr(&self) -> io::Result<SocketAddr> {
123		self.inner.local_addr()
124	}
125}
126
127fn fill_slot(
128	idx: usize,
129	datagram: (SocketAddr, bytes::Bytes),
130	bufs: &mut [io::IoSliceMut<'_>],
131	meta: &mut [RecvMeta],
132	local: SocketAddr,
133) {
134	let (peer, payload) = datagram;
135	let n = payload.len().min(bufs[idx].len());
136	bufs[idx][..n].copy_from_slice(&payload[..n]);
137	meta[idx] = RecvMeta { addr: peer, len: n, stride: n, ecn: None, dst_ip: Some(local.ip()) };
138}
139
140/// Poller for [`SharedSocket`]. quinn calls this to register a waker
141/// for "socket writable"; we forward to the underlying physical
142/// socket via [`virtual_socket::VirtualUdpSocket::poll_send_ready`].
143struct SharedSocketPoller {
144	socket: Arc<SharedSocket>,
145}
146
147impl fmt::Debug for SharedSocketPoller {
148	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149		f.debug_struct("SharedSocketPoller").finish()
150	}
151}
152
153impl UdpPoller for SharedSocketPoller {
154	fn poll_writable(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
155		self.socket.inner.poll_send_ready(cx)
156	}
157}
158
159#[cfg(test)]
160mod tests {
161	use std::future::poll_fn;
162	use std::net::Ipv4Addr;
163
164	use bytes::Bytes;
165	use quinn::AsyncUdpSocket;
166	use tokio::net::UdpSocket;
167
168	use super::*;
169
170	async fn bound() -> Arc<UdpSocket> {
171		Arc::new(UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.expect("bind"))
172	}
173
174	#[tokio::test]
175	async fn local_addr_passes_through() {
176		let phys = bound().await;
177		let want = phys.local_addr().expect("local addr");
178		let virt = VirtualUdpSocket::new(phys);
179		let shared = SharedSocket::new(virt);
180		assert_eq!(<SharedSocket as AsyncUdpSocket>::local_addr(&shared).unwrap(), want);
181	}
182
183	#[tokio::test]
184	async fn poll_recv_pending_when_queue_empty() {
185		let phys = bound().await;
186		let virt = VirtualUdpSocket::new(phys);
187		let shared = SharedSocket::new(virt);
188
189		// Single-poll: no datagrams queued => Pending.
190		let mut storage = [0u8; 64];
191		let mut bufs = [io::IoSliceMut::new(&mut storage)];
192		let mut metas = [RecvMeta::default()];
193		let r = std::future::poll_fn(|cx| {
194			match <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas) {
195				Poll::Pending => Poll::Ready(()),
196				ready @ Poll::Ready(_) => panic!("expected Pending, got {ready:?}"),
197			}
198		})
199		.await;
200		let () = r;
201	}
202
203	#[tokio::test]
204	async fn poll_recv_returns_queued_datagram() {
205		let phys = bound().await;
206		let virt = VirtualUdpSocket::new(phys);
207		let peer: SocketAddr = "192.0.2.10:443".parse().unwrap();
208		virt.enqueue_inbound(peer, Bytes::from_static(b"INIT"));
209		let shared = SharedSocket::new(virt);
210
211		let mut buf = [0u8; 64];
212		let mut bufs = [io::IoSliceMut::new(&mut buf)];
213		let mut metas = [RecvMeta::default()];
214		let n =
215			poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
216				.await
217				.expect("poll_recv ok");
218		assert_eq!(n, 1);
219		assert_eq!(metas[0].addr, peer);
220		assert_eq!(metas[0].len, 4);
221		assert_eq!(&buf[..4], b"INIT");
222	}
223
224	#[tokio::test]
225	async fn poll_recv_drains_burst_into_multi_slot_call() {
226		let phys = bound().await;
227		let virt = VirtualUdpSocket::new(phys);
228		let peer1: SocketAddr = "192.0.2.11:443".parse().unwrap();
229		let peer2: SocketAddr = "192.0.2.12:443".parse().unwrap();
230		virt.enqueue_inbound(peer1, Bytes::from_static(b"A"));
231		virt.enqueue_inbound(peer2, Bytes::from_static(b"BB"));
232		let shared = SharedSocket::new(virt);
233
234		let mut b1 = [0u8; 16];
235		let mut b2 = [0u8; 16];
236		let mut bufs = [io::IoSliceMut::new(&mut b1), io::IoSliceMut::new(&mut b2)];
237		let mut metas = [RecvMeta::default(), RecvMeta::default()];
238		let n =
239			poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
240				.await
241				.expect("poll_recv ok");
242		assert_eq!(n, 2);
243		assert_eq!(metas[0].addr, peer1);
244		assert_eq!(metas[1].addr, peer2);
245		assert_eq!(&b1[..1], b"A");
246		assert_eq!(&b2[..2], b"BB");
247	}
248
249	#[tokio::test]
250	async fn poll_recv_surfaces_close_as_connection_aborted() {
251		let phys = bound().await;
252		let virt = VirtualUdpSocket::new(phys);
253		virt.close();
254		let shared = SharedSocket::new(virt);
255
256		let mut buf = [0u8; 16];
257		let mut bufs = [io::IoSliceMut::new(&mut buf)];
258		let mut metas = [RecvMeta::default()];
259		let r =
260			poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
261				.await;
262		let err = r.expect_err("close => err");
263		assert_eq!(err.kind(), io::ErrorKind::ConnectionAborted);
264	}
265
266	#[tokio::test]
267	async fn try_send_proxies_to_physical() {
268		let phys_src = bound().await;
269		let phys_dst = bound().await;
270		let dst_addr = phys_dst.local_addr().unwrap();
271		let virt = VirtualUdpSocket::new(phys_src);
272		let shared = SharedSocket::new(virt);
273
274		// Wait for OS-side writability before try_send.
275		poll_fn(|cx| shared.virtual_socket().poll_send_ready(cx)).await.expect("ready");
276		<SharedSocket as AsyncUdpSocket>::try_send(
277			&shared,
278			&Transmit {
279				destination: dst_addr,
280				ecn: None,
281				contents: b"PING",
282				segment_size: None,
283				src_ip: None,
284			},
285		)
286		.expect("try_send");
287		let mut got = [0u8; 16];
288		let (n, _) = phys_dst.recv_from(&mut got).await.expect("recv");
289		assert_eq!(&got[..n], b"PING");
290	}
291}