quinn_shared_socket/
lib.rs1use std::fmt;
28use std::io;
29use std::net::SocketAddr;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33
34use quinn::udp::{RecvMeta, Transmit};
35use quinn::{AsyncUdpSocket, UdpPoller};
36use virtual_socket::VirtualUdpSocket;
37
38pub struct SharedSocket {
45 inner: Arc<VirtualUdpSocket>,
46}
47
48impl SharedSocket {
49 #[must_use]
51 pub fn new(inner: Arc<VirtualUdpSocket>) -> Arc<Self> {
52 Arc::new(Self { inner })
53 }
54
55 #[must_use]
58 pub fn virtual_socket(&self) -> &Arc<VirtualUdpSocket> {
59 &self.inner
60 }
61}
62
63impl fmt::Debug for SharedSocket {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("SharedSocket").field("inner", &self.inner).finish()
66 }
67}
68
69impl AsyncUdpSocket for SharedSocket {
70 fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>> {
71 Box::pin(SharedSocketPoller { socket: self })
72 }
73
74 fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> {
75 self.inner.try_send_to(transmit.contents, transmit.destination).map(|_n| ())
76 }
77
78 fn poll_recv(
79 &self,
80 cx: &mut Context<'_>,
81 bufs: &mut [io::IoSliceMut<'_>],
82 meta: &mut [RecvMeta],
83 ) -> Poll<io::Result<usize>> {
84 let max = bufs.len().min(meta.len());
85 if max == 0 {
86 return Poll::Ready(Ok(0));
87 }
88
89 let first = match self.inner.poll_dequeue(cx) {
93 Poll::Ready(Some(d)) => d,
94 Poll::Ready(None) => {
95 return Poll::Ready(Err(io::Error::new(
96 io::ErrorKind::ConnectionAborted,
97 "virtual socket closed",
98 )));
99 }
100 Poll::Pending => return Poll::Pending,
101 };
102
103 let local = self.inner.local_addr().unwrap_or_else(|_| SocketAddr::from(([0u8, 0, 0, 0], 0)));
104
105 fill_slot(0, first, bufs, meta, local);
106 let mut count = 1;
107
108 while count < max {
111 match self.inner.try_dequeue() {
112 Some(d) => {
113 fill_slot(count, d, bufs, meta, local);
114 count += 1;
115 }
116 None => break,
117 }
118 }
119 Poll::Ready(Ok(count))
120 }
121
122 fn local_addr(&self) -> io::Result<SocketAddr> {
123 self.inner.local_addr()
124 }
125}
126
127fn fill_slot(
128 idx: usize,
129 datagram: (SocketAddr, bytes::Bytes),
130 bufs: &mut [io::IoSliceMut<'_>],
131 meta: &mut [RecvMeta],
132 local: SocketAddr,
133) {
134 let (peer, payload) = datagram;
135 let n = payload.len().min(bufs[idx].len());
136 bufs[idx][..n].copy_from_slice(&payload[..n]);
137 meta[idx] = RecvMeta { addr: peer, len: n, stride: n, ecn: None, dst_ip: Some(local.ip()) };
138}
139
140struct SharedSocketPoller {
144 socket: Arc<SharedSocket>,
145}
146
147impl fmt::Debug for SharedSocketPoller {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 f.debug_struct("SharedSocketPoller").finish()
150 }
151}
152
153impl UdpPoller for SharedSocketPoller {
154 fn poll_writable(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
155 self.socket.inner.poll_send_ready(cx)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::future::poll_fn;
162 use std::net::Ipv4Addr;
163
164 use bytes::Bytes;
165 use quinn::AsyncUdpSocket;
166 use tokio::net::UdpSocket;
167
168 use super::*;
169
170 async fn bound() -> Arc<UdpSocket> {
171 Arc::new(UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).await.expect("bind"))
172 }
173
174 #[tokio::test]
175 async fn local_addr_passes_through() {
176 let phys = bound().await;
177 let want = phys.local_addr().expect("local addr");
178 let virt = VirtualUdpSocket::new(phys);
179 let shared = SharedSocket::new(virt);
180 assert_eq!(<SharedSocket as AsyncUdpSocket>::local_addr(&shared).unwrap(), want);
181 }
182
183 #[tokio::test]
184 async fn poll_recv_pending_when_queue_empty() {
185 let phys = bound().await;
186 let virt = VirtualUdpSocket::new(phys);
187 let shared = SharedSocket::new(virt);
188
189 let mut storage = [0u8; 64];
191 let mut bufs = [io::IoSliceMut::new(&mut storage)];
192 let mut metas = [RecvMeta::default()];
193 let r = std::future::poll_fn(|cx| {
194 match <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas) {
195 Poll::Pending => Poll::Ready(()),
196 ready @ Poll::Ready(_) => panic!("expected Pending, got {ready:?}"),
197 }
198 })
199 .await;
200 let () = r;
201 }
202
203 #[tokio::test]
204 async fn poll_recv_returns_queued_datagram() {
205 let phys = bound().await;
206 let virt = VirtualUdpSocket::new(phys);
207 let peer: SocketAddr = "192.0.2.10:443".parse().unwrap();
208 virt.enqueue_inbound(peer, Bytes::from_static(b"INIT"));
209 let shared = SharedSocket::new(virt);
210
211 let mut buf = [0u8; 64];
212 let mut bufs = [io::IoSliceMut::new(&mut buf)];
213 let mut metas = [RecvMeta::default()];
214 let n =
215 poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
216 .await
217 .expect("poll_recv ok");
218 assert_eq!(n, 1);
219 assert_eq!(metas[0].addr, peer);
220 assert_eq!(metas[0].len, 4);
221 assert_eq!(&buf[..4], b"INIT");
222 }
223
224 #[tokio::test]
225 async fn poll_recv_drains_burst_into_multi_slot_call() {
226 let phys = bound().await;
227 let virt = VirtualUdpSocket::new(phys);
228 let peer1: SocketAddr = "192.0.2.11:443".parse().unwrap();
229 let peer2: SocketAddr = "192.0.2.12:443".parse().unwrap();
230 virt.enqueue_inbound(peer1, Bytes::from_static(b"A"));
231 virt.enqueue_inbound(peer2, Bytes::from_static(b"BB"));
232 let shared = SharedSocket::new(virt);
233
234 let mut b1 = [0u8; 16];
235 let mut b2 = [0u8; 16];
236 let mut bufs = [io::IoSliceMut::new(&mut b1), io::IoSliceMut::new(&mut b2)];
237 let mut metas = [RecvMeta::default(), RecvMeta::default()];
238 let n =
239 poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
240 .await
241 .expect("poll_recv ok");
242 assert_eq!(n, 2);
243 assert_eq!(metas[0].addr, peer1);
244 assert_eq!(metas[1].addr, peer2);
245 assert_eq!(&b1[..1], b"A");
246 assert_eq!(&b2[..2], b"BB");
247 }
248
249 #[tokio::test]
250 async fn poll_recv_surfaces_close_as_connection_aborted() {
251 let phys = bound().await;
252 let virt = VirtualUdpSocket::new(phys);
253 virt.close();
254 let shared = SharedSocket::new(virt);
255
256 let mut buf = [0u8; 16];
257 let mut bufs = [io::IoSliceMut::new(&mut buf)];
258 let mut metas = [RecvMeta::default()];
259 let r =
260 poll_fn(|cx| <SharedSocket as AsyncUdpSocket>::poll_recv(&shared, cx, &mut bufs, &mut metas))
261 .await;
262 let err = r.expect_err("close => err");
263 assert_eq!(err.kind(), io::ErrorKind::ConnectionAborted);
264 }
265
266 #[tokio::test]
267 async fn try_send_proxies_to_physical() {
268 let phys_src = bound().await;
269 let phys_dst = bound().await;
270 let dst_addr = phys_dst.local_addr().unwrap();
271 let virt = VirtualUdpSocket::new(phys_src);
272 let shared = SharedSocket::new(virt);
273
274 poll_fn(|cx| shared.virtual_socket().poll_send_ready(cx)).await.expect("ready");
276 <SharedSocket as AsyncUdpSocket>::try_send(
277 &shared,
278 &Transmit {
279 destination: dst_addr,
280 ecn: None,
281 contents: b"PING",
282 segment_size: None,
283 src_ip: None,
284 },
285 )
286 .expect("try_send");
287 let mut got = [0u8; 16];
288 let (n, _) = phys_dst.recv_from(&mut got).await.expect("recv");
289 assert_eq!(&got[..n], b"PING");
290 }
291}