1use std::{collections::VecDeque, io, io::IoSliceMut, net::SocketAddr};
26
27use ana_gotatun::packet::{Packet, PacketBufPool};
28use quinn_udp::{RecvMeta, Transmit, UdpSockRef, UdpSocketState};
29use tokio::{io::Interest, net::UdpSocket};
30
31const MAX_BATCH_SIZE: usize = 64;
32
33pub enum RecvBatchError<E> {
35 Io(io::Error),
37 Handler(E),
39}
40
41#[derive(Debug)]
43pub enum QueuePacketError {
44 Full {
46 packet: Packet,
48 target: SocketAddr,
50 },
51 PacketTooLarge {
53 packet: Packet,
55 target: SocketAddr,
57 packet_len: usize,
59 max_packet_size: usize,
61 },
62}
63
64pub struct UdpBatchReceiver<const BATCH_SIZE: usize, const BUFFER_SIZE: usize = 4096> {
74 state: UdpSocketState,
75 recv_meta: [RecvMeta; BATCH_SIZE],
76 recv_slots: [Packet; BATCH_SIZE],
77}
78
79impl<const BATCH_SIZE: usize, const BUFFER_SIZE: usize> UdpBatchReceiver<BATCH_SIZE, BUFFER_SIZE> {
80 pub fn new(socket: &UdpSocket, pool: &PacketBufPool<BUFFER_SIZE>) -> io::Result<Self> {
86 assert!(
87 BATCH_SIZE > 0,
88 "UdpBatchReceiver BATCH_SIZE must be greater than zero"
89 );
90 assert!(
91 BATCH_SIZE <= MAX_BATCH_SIZE,
92 "UdpBatchReceiver BATCH_SIZE must not exceed MAX_BATCH_SIZE"
93 );
94 let state = UdpSocketState::new(UdpSockRef::from(socket))?;
95 let recv_slots = std::array::from_fn(|_| pool.get());
96 Ok(Self {
97 state,
98 recv_meta: std::array::from_fn(|_| RecvMeta::default()),
99 recv_slots,
100 })
101 }
102
103 pub async fn recv_batch<E, F>(
105 &mut self,
106 socket: &UdpSocket,
107 pool: &PacketBufPool<BUFFER_SIZE>,
108 mut handler: F,
109 ) -> Result<(), RecvBatchError<E>>
110 where
111 F: FnMut(Packet, SocketAddr) -> Result<(), E>,
112 {
113 let received = loop {
114 socket.readable().await.map_err(RecvBatchError::Io)?;
115 match socket.try_io(Interest::READABLE, || self.try_recv(socket)) {
116 Ok(count) => break count,
117 Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
118 Err(err) => return Err(RecvBatchError::Io(err)),
119 }
120 };
121
122 for index in 0..received {
123 self.handle_received(index, pool, &mut handler)
124 .map_err(RecvBatchError::Handler)?;
125 }
126
127 Ok(())
128 }
129
130 fn handle_received<E, F>(
131 &mut self,
132 index: usize,
133 pool: &PacketBufPool<BUFFER_SIZE>,
134 handler: &mut F,
135 ) -> Result<(), E>
136 where
137 F: FnMut(Packet, SocketAddr) -> Result<(), E>,
138 {
139 let meta = self.recv_meta[index];
143 if meta.len == 0 {
144 return Ok(());
145 }
146 let stride = if meta.stride == 0 {
147 meta.len
148 } else {
149 meta.stride
150 };
151 if stride >= meta.len {
152 let mut packet = std::mem::replace(&mut self.recv_slots[index], pool.get());
155 packet.truncate(meta.len);
156 handler(packet, meta.addr)?;
157 return Ok(());
158 }
159
160 let packet = std::mem::replace(&mut self.recv_slots[index], pool.get());
164 for chunk in packet[..meta.len].chunks(stride) {
165 let mut segment = pool.get();
166 segment[..chunk.len()].copy_from_slice(chunk);
167 segment.truncate(chunk.len());
168 handler(segment, meta.addr)?;
169 }
170 Ok(())
171 }
172
173 fn try_recv(&mut self, socket: &UdpSocket) -> io::Result<usize> {
174 let mut bufs_uninit: [std::mem::MaybeUninit<IoSliceMut<'_>>; BATCH_SIZE] =
178 std::array::from_fn(|_| std::mem::MaybeUninit::uninit());
179 for (index, packet) in self.recv_slots.iter_mut().enumerate() {
180 bufs_uninit[index].write(IoSliceMut::new(packet.as_mut()));
181 }
182 let bufs = unsafe {
189 std::slice::from_raw_parts_mut(
190 bufs_uninit.as_mut_ptr() as *mut IoSliceMut<'_>,
191 BATCH_SIZE,
192 )
193 };
194 self.state
195 .recv(UdpSockRef::from(socket), bufs, &mut self.recv_meta)
196 }
197}
198
199pub struct UdpBatchSender<const BATCH_SIZE: usize, const MAX_PACKET_SIZE: usize = 4096> {
207 state: UdpSocketState,
208 queued_packets: VecDeque<(SocketAddr, Packet)>,
209 scratch: Vec<u8>,
210}
211
212impl<const BATCH_SIZE: usize, const MAX_PACKET_SIZE: usize>
213 UdpBatchSender<BATCH_SIZE, MAX_PACKET_SIZE>
214{
215 pub fn new(socket: &UdpSocket) -> io::Result<Self> {
220 assert!(
221 BATCH_SIZE > 0,
222 "UdpBatchSender BATCH_SIZE must be greater than zero"
223 );
224 assert!(
225 BATCH_SIZE <= MAX_BATCH_SIZE,
226 "UdpBatchSender BATCH_SIZE must not exceed MAX_BATCH_SIZE"
227 );
228 Ok(Self {
229 state: UdpSocketState::new(UdpSockRef::from(socket))?,
230 queued_packets: VecDeque::with_capacity(BATCH_SIZE),
231 scratch: Vec::with_capacity(MAX_PACKET_SIZE * BATCH_SIZE),
232 })
233 }
234
235 pub fn is_empty(&self) -> bool {
237 self.queued_packets.is_empty()
238 }
239
240 pub fn is_full(&self) -> bool {
242 self.queued_packets.len() == BATCH_SIZE
243 }
244
245 pub fn try_queue_packet(
250 &mut self,
251 packet: Packet,
252 target: SocketAddr,
253 ) -> Result<(), QueuePacketError> {
254 let packet_len = packet.len();
255 if packet.len() > MAX_PACKET_SIZE {
256 return Err(QueuePacketError::PacketTooLarge {
257 packet,
258 target,
259 packet_len,
260 max_packet_size: MAX_PACKET_SIZE,
261 });
262 }
263 if self.is_full() {
264 return Err(QueuePacketError::Full { packet, target });
265 }
266 self.queued_packets.push_back((target, packet));
267 Ok(())
268 }
269
270 pub fn try_flush_best_effort(&mut self, socket: &UdpSocket) -> io::Result<()> {
272 while !self.is_empty() {
273 match socket.try_io(Interest::WRITABLE, || self.try_send_next(socket)) {
274 Ok(sent) => self.drop_prefix(sent),
275 Err(err) if err.kind() == io::ErrorKind::WouldBlock => return Err(err),
276 Err(err) => return Err(err),
277 }
278 }
279 Ok(())
280 }
281
282 pub async fn flush(&mut self, socket: &UdpSocket) -> io::Result<()> {
284 while !self.is_empty() {
285 socket.writable().await?;
286 match socket.try_io(Interest::WRITABLE, || self.try_send_next(socket)) {
287 Ok(sent) => self.drop_prefix(sent),
288 Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
289 Err(err) => return Err(err),
290 }
291 }
292 Ok(())
293 }
294
295 fn drop_prefix(&mut self, count: usize) {
296 self.queued_packets.drain(..count);
297 }
298
299 fn try_send_next(&mut self, socket: &UdpSocket) -> io::Result<usize> {
300 self.scratch.clear();
301 let (target, first_packet) = self
302 .queued_packets
303 .front()
304 .expect("try_send_next requires a non-empty queue");
305 let target = *target;
306 let segment_size = first_packet.len();
307 let mut segments = 0;
308 let max_segments = self.state.max_gso_segments().min(BATCH_SIZE);
309
310 for (queued_target, packet) in self.queued_packets.iter().take(max_segments) {
314 if *queued_target != target || packet.len() != segment_size {
315 break;
316 }
317 self.scratch.extend_from_slice(&packet[..]);
318 segments += 1;
319 }
320
321 let transmit = Transmit {
322 destination: target,
323 ecn: None,
324 contents: &self.scratch,
325 segment_size: (segments > 1).then_some(segment_size),
326 src_ip: None,
327 };
328 self.state.try_send(UdpSockRef::from(socket), &transmit)?;
329 Ok(segments)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use std::net::SocketAddr;
336
337 use ana_gotatun::packet::PacketBufPool;
338 use tokio::net::UdpSocket;
339
340 use super::{MAX_BATCH_SIZE, UdpBatchReceiver, UdpBatchSender};
341
342 const TEST_PACKET_SIZE: usize = 128;
343
344 fn packet_pool() -> PacketBufPool<TEST_PACKET_SIZE> {
345 PacketBufPool::new(MAX_BATCH_SIZE)
346 }
347
348 async fn bound_socket() -> UdpSocket {
349 UdpSocket::bind("127.0.0.1:0").await.unwrap()
350 }
351
352 fn packet_from_bytes(
353 pool: &PacketBufPool<TEST_PACKET_SIZE>,
354 bytes: &[u8],
355 ) -> ana_gotatun::packet::Packet {
356 let mut packet = pool.get();
357 packet[..bytes.len()].copy_from_slice(bytes);
358 packet.truncate(bytes.len());
359 packet
360 }
361
362 #[tokio::test]
363 async fn flushes_partially_full_sender_batch() {
364 let sender_socket = bound_socket().await;
365 let receiver_socket = bound_socket().await;
366 let pool = packet_pool();
367 let mut sender =
368 UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&sender_socket).unwrap();
369
370 sender
371 .try_queue_packet(
372 packet_from_bytes(&pool, b"one"),
373 receiver_socket.local_addr().unwrap(),
374 )
375 .unwrap();
376 sender
377 .try_queue_packet(
378 packet_from_bytes(&pool, b"two"),
379 receiver_socket.local_addr().unwrap(),
380 )
381 .unwrap();
382
383 sender.flush(&sender_socket).await.unwrap();
384
385 let mut buf = [0u8; TEST_PACKET_SIZE];
386 let (n1, _) = receiver_socket.recv_from(&mut buf).await.unwrap();
387 let first = buf[..n1].to_vec();
388 let (n2, _) = receiver_socket.recv_from(&mut buf).await.unwrap();
389 let second = buf[..n2].to_vec();
390
391 assert!(sender.is_empty());
392 assert_eq!(vec![first, second], vec![b"one".to_vec(), b"two".to_vec()]);
393 }
394
395 #[tokio::test]
396 async fn flushes_sender_batch_with_mixed_targets() {
397 let sender_socket = bound_socket().await;
398 let first_target = bound_socket().await;
399 let second_target = bound_socket().await;
400 let pool = packet_pool();
401 let mut sender =
402 UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&sender_socket).unwrap();
403
404 sender
405 .try_queue_packet(
406 packet_from_bytes(&pool, b"alpha"),
407 first_target.local_addr().unwrap(),
408 )
409 .unwrap();
410 sender
411 .try_queue_packet(
412 packet_from_bytes(&pool, b"beta"),
413 second_target.local_addr().unwrap(),
414 )
415 .unwrap();
416 sender
417 .try_queue_packet(
418 packet_from_bytes(&pool, b"gamma"),
419 first_target.local_addr().unwrap(),
420 )
421 .unwrap();
422
423 sender.flush(&sender_socket).await.unwrap();
424
425 let mut buf = [0u8; TEST_PACKET_SIZE];
426 let (n_first_a, _) = first_target.recv_from(&mut buf).await.unwrap();
427 let first_a = buf[..n_first_a].to_vec();
428 let (n_second, _) = second_target.recv_from(&mut buf).await.unwrap();
429 let second = buf[..n_second].to_vec();
430 let (n_first_b, _) = first_target.recv_from(&mut buf).await.unwrap();
431 let first_b = buf[..n_first_b].to_vec();
432
433 assert_eq!(first_a, b"alpha".to_vec());
434 assert_eq!(second, b"beta".to_vec());
435 assert_eq!(first_b, b"gamma".to_vec());
436 }
437
438 #[tokio::test]
439 async fn receive_with_stride_smaller_than_length_splits_segments() {
440 let socket = bound_socket().await;
441 let pool = packet_pool();
442 let mut receiver =
443 UdpBatchReceiver::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket, &pool).unwrap();
444 let source = "127.0.0.1:30000".parse::<SocketAddr>().unwrap();
445
446 receiver.recv_meta[0].addr = source;
447 receiver.recv_meta[0].len = 10;
448 receiver.recv_meta[0].stride = 4;
449 receiver.recv_slots[0][..10].copy_from_slice(b"abcdefghij");
450
451 let mut seen = Vec::new();
452 receiver
453 .handle_received(0, &pool, &mut |packet, addr| {
454 seen.push((packet[..].to_vec(), addr));
455 Ok::<(), ()>(())
456 })
457 .unwrap();
458
459 assert_eq!(
460 seen,
461 vec![
462 (b"abcd".to_vec(), source),
463 (b"efgh".to_vec(), source),
464 (b"ij".to_vec(), source),
465 ]
466 );
467 }
468
469 #[tokio::test]
470 async fn receive_with_stride_at_least_length_uses_single_packet() {
471 let socket = bound_socket().await;
472 let pool = packet_pool();
473 let mut receiver =
474 UdpBatchReceiver::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket, &pool).unwrap();
475 let source = "127.0.0.1:30001".parse::<SocketAddr>().unwrap();
476
477 receiver.recv_meta[0].addr = source;
478 receiver.recv_meta[0].len = 5;
479 receiver.recv_meta[0].stride = 5;
480 receiver.recv_slots[0][..5].copy_from_slice(b"hello");
481
482 let mut seen = Vec::new();
483 receiver
484 .handle_received(0, &pool, &mut |packet, addr| {
485 seen.push((packet[..].to_vec(), addr));
486 Ok::<(), ()>(())
487 })
488 .unwrap();
489
490 assert_eq!(seen, vec![(b"hello".to_vec(), source)]);
491 }
492
493 #[test]
494 fn refuses_to_grow_beyond_batch_capacity() {
495 let runtime = tokio::runtime::Runtime::new().unwrap();
496 runtime.block_on(async {
497 let socket = bound_socket().await;
498 let pool = packet_pool();
499 let mut sender =
500 UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket).unwrap();
501
502 for _ in 0..MAX_BATCH_SIZE {
503 sender
504 .try_queue_packet(packet_from_bytes(&pool, b"x"), socket.local_addr().unwrap())
505 .unwrap();
506 }
507
508 assert!(
509 sender
510 .try_queue_packet(
511 packet_from_bytes(&pool, b"overflow"),
512 socket.local_addr().unwrap()
513 )
514 .is_err()
515 );
516 });
517 }
518}