Skip to main content

snap_tun/
udp_batch.rs

1// Copyright 2026 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Helpers for batched UDP receive and transmit operations used by SNAP tunnel I/O.
16//!
17//! `UdpBatchReceiver` batches receive-side work with a fixed compile-time batch size,
18//! while `UdpBatchSender` batches same-sized datagrams to the same destination so the
19//! underlying socket can take advantage of UDP segmentation offload when available.
20//!
21//! Both helpers are intended to be created once per socket and reused. They keep
22//! batch-sized scratch state alive so repeated calls can reuse packet buffers and
23//! socket state instead of rebuilding that state for every receive or flush cycle.
24
25use std::{collections::VecDeque, io, io::IoSliceMut, net::SocketAddr};
26
27use ana_gotatun::packet::{Packet, PacketBufPool};
28use quinn_udp::{RecvMeta, Transmit, UdpSockRef, UdpSocketState};
29use tokio::{io::Interest, net::UdpSocket};
30
31const MAX_BATCH_SIZE: usize = 64;
32
33/// Errors returned while receiving and processing a UDP batch.
34pub enum RecvBatchError<E> {
35    /// The socket operation itself failed.
36    Io(io::Error),
37    /// The caller-provided packet handler failed.
38    Handler(E),
39}
40
41/// Errors returned while queueing packets for batched transmission.
42#[derive(Debug)]
43pub enum QueuePacketError {
44    /// The sender queue is full and cannot accept another packet right now.
45    Full {
46        /// The unsent packet.
47        packet: Packet,
48        /// The original target address of the unsent packet.
49        target: SocketAddr,
50    },
51    /// The packet is larger than the configured sender scratch budget.
52    PacketTooLarge {
53        /// The oversized packet.
54        packet: Packet,
55        /// The original target address of the oversized packet.
56        target: SocketAddr,
57        /// The packet length in bytes.
58        packet_len: usize,
59        /// The configured maximum packet size.
60        max_packet_size: usize,
61    },
62}
63
64/// UdpBatchReceiver wraps a standard UDP socket and provides batched receive operations.
65///
66/// It receives up to `BATCH_SIZE` UDP datagrams in one socket read cycle and is
67/// intended to be reused for as long as that socket is active. Reusing it keeps
68/// the receive slots checked out from the pool so repeated receive calls can stay
69/// on the fast path.
70///
71/// `BUFFER_SIZE` controls the size of packet buffers drawn from the provided pool.
72/// `BATCH_SIZE * BUFFER_SIZE` bytes of memory will be reserved for the receive buffer.
73pub struct UdpBatchReceiver<const BATCH_SIZE: usize, const BUFFER_SIZE: usize = 4096> {
74    state: UdpSocketState,
75    recv_meta: [RecvMeta; BATCH_SIZE],
76    recv_slots: [Packet; BATCH_SIZE],
77}
78
79impl<const BATCH_SIZE: usize, const BUFFER_SIZE: usize> UdpBatchReceiver<BATCH_SIZE, BUFFER_SIZE> {
80    /// Creates a receiver configured for a fixed compile-time batch size.
81    ///
82    /// The receiver keeps `BATCH_SIZE` packet buffers checked out from `pool` until
83    /// it is dropped, so callers should typically create one receiver per socket and
84    /// reuse it across receive calls.
85    pub fn new(socket: &UdpSocket, pool: &PacketBufPool<BUFFER_SIZE>) -> io::Result<Self> {
86        assert!(
87            BATCH_SIZE > 0,
88            "UdpBatchReceiver BATCH_SIZE must be greater than zero"
89        );
90        assert!(
91            BATCH_SIZE <= MAX_BATCH_SIZE,
92            "UdpBatchReceiver BATCH_SIZE must not exceed MAX_BATCH_SIZE"
93        );
94        let state = UdpSocketState::new(UdpSockRef::from(socket))?;
95        let recv_slots = std::array::from_fn(|_| pool.get());
96        Ok(Self {
97            state,
98            recv_meta: std::array::from_fn(|_| RecvMeta::default()),
99            recv_slots,
100        })
101    }
102
103    /// Receives a batch of packets and invokes `handler` for each decoded datagram.
104    pub async fn recv_batch<E, F>(
105        &mut self,
106        socket: &UdpSocket,
107        pool: &PacketBufPool<BUFFER_SIZE>,
108        mut handler: F,
109    ) -> Result<(), RecvBatchError<E>>
110    where
111        F: FnMut(Packet, SocketAddr) -> Result<(), E>,
112    {
113        let received = loop {
114            socket.readable().await.map_err(RecvBatchError::Io)?;
115            match socket.try_io(Interest::READABLE, || self.try_recv(socket)) {
116                Ok(count) => break count,
117                Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
118                Err(err) => return Err(RecvBatchError::Io(err)),
119            }
120        };
121
122        for index in 0..received {
123            self.handle_received(index, pool, &mut handler)
124                .map_err(RecvBatchError::Handler)?;
125        }
126
127        Ok(())
128    }
129
130    fn handle_received<E, F>(
131        &mut self,
132        index: usize,
133        pool: &PacketBufPool<BUFFER_SIZE>,
134        handler: &mut F,
135    ) -> Result<(), E>
136    where
137        F: FnMut(Packet, SocketAddr) -> Result<(), E>,
138    {
139        // `quinn_udp` can report one large receive buffer together with a stride
140        // when the kernel coalesced multiple datagrams. Split that back into
141        // logical packets here so downstream code keeps its usual one-packet-at-a-time view.
142        let meta = self.recv_meta[index];
143        if meta.len == 0 {
144            return Ok(());
145        }
146        let stride = if meta.stride == 0 {
147            meta.len
148        } else {
149            meta.stride
150        };
151        if stride >= meta.len {
152            // Hand ownership of the filled slot to the caller and immediately put a
153            // fresh buffer back into the slot so the next batch can reuse the same layout.
154            let mut packet = std::mem::replace(&mut self.recv_slots[index], pool.get());
155            packet.truncate(meta.len);
156            handler(packet, meta.addr)?;
157            return Ok(());
158        }
159
160        // Keep the receive slots permanently populated and carve a coalesced buffer
161        // into individually owned segments only when the kernel told us multiple
162        // datagrams were packed into one receive slot.
163        let packet = std::mem::replace(&mut self.recv_slots[index], pool.get());
164        for chunk in packet[..meta.len].chunks(stride) {
165            let mut segment = pool.get();
166            segment[..chunk.len()].copy_from_slice(chunk);
167            segment.truncate(chunk.len());
168            handler(segment, meta.addr)?;
169        }
170        Ok(())
171    }
172
173    fn try_recv(&mut self, socket: &UdpSocket) -> io::Result<usize> {
174        // Keep the receive slots alive across calls and hand them directly to the
175        // socket so a steady-state receive loop does not need to re-acquire buffers
176        // from the pool on every readiness notification.
177        let mut bufs_uninit: [std::mem::MaybeUninit<IoSliceMut<'_>>; BATCH_SIZE] =
178            std::array::from_fn(|_| std::mem::MaybeUninit::uninit());
179        for (index, packet) in self.recv_slots.iter_mut().enumerate() {
180            bufs_uninit[index].write(IoSliceMut::new(packet.as_mut()));
181        }
182        // SAFETY: Every element of `bufs_uninit` was written in the loop above, so
183        // all `BATCH_SIZE` slots are fully initialised. `MaybeUninit<T>` is guaranteed
184        // to have the same size and alignment as `T`, so reinterpreting the
185        // pointer as `*mut IoSliceMut<'_>` is sound. The resulting slice covers
186        // exactly the `BATCH_SIZE` elements that were initialised, and the backing
187        // array lives for the duration of this function.
188        let bufs = unsafe {
189            std::slice::from_raw_parts_mut(
190                bufs_uninit.as_mut_ptr() as *mut IoSliceMut<'_>,
191                BATCH_SIZE,
192            )
193        };
194        self.state
195            .recv(UdpSockRef::from(socket), bufs, &mut self.recv_meta)
196    }
197}
198
199/// Queues up to `BATCH_SIZE` packets for batched UDP transmission.
200///
201/// The sender is intended to be reused for the lifetime of a socket. It keeps a
202/// reusable scratch buffer and a small transmit queue so successive flushes do not
203/// need to rebuild that state from scratch.
204///
205/// `MAX_PACKET_SIZE` determines the capacity reserved for the transmit scratch buffer.
206pub struct UdpBatchSender<const BATCH_SIZE: usize, const MAX_PACKET_SIZE: usize = 4096> {
207    state: UdpSocketState,
208    queued_packets: VecDeque<(SocketAddr, Packet)>,
209    scratch: Vec<u8>,
210}
211
212impl<const BATCH_SIZE: usize, const MAX_PACKET_SIZE: usize>
213    UdpBatchSender<BATCH_SIZE, MAX_PACKET_SIZE>
214{
215    /// Creates a sender configured for a fixed compile-time batch size.
216    ///
217    /// Callers should generally create one sender per socket and reuse it across
218    /// queue/flush cycles so the queue and scratch storage stay hot.
219    pub fn new(socket: &UdpSocket) -> io::Result<Self> {
220        assert!(
221            BATCH_SIZE > 0,
222            "UdpBatchSender BATCH_SIZE must be greater than zero"
223        );
224        assert!(
225            BATCH_SIZE <= MAX_BATCH_SIZE,
226            "UdpBatchSender BATCH_SIZE must not exceed MAX_BATCH_SIZE"
227        );
228        Ok(Self {
229            state: UdpSocketState::new(UdpSockRef::from(socket))?,
230            queued_packets: VecDeque::with_capacity(BATCH_SIZE),
231            scratch: Vec::with_capacity(MAX_PACKET_SIZE * BATCH_SIZE),
232        })
233    }
234
235    /// Returns whether no packets are currently queued for transmission.
236    pub fn is_empty(&self) -> bool {
237        self.queued_packets.is_empty()
238    }
239
240    /// Returns whether the sender queue has reached its configured capacity.
241    pub fn is_full(&self) -> bool {
242        self.queued_packets.len() == BATCH_SIZE
243    }
244
245    /// Queues one packet for transmission to `target`.
246    ///
247    /// Returns an error when the sender queue is full or when `packet` exceeds
248    /// `MAX_PACKET_SIZE`, which would otherwise force the scratch buffer to grow.
249    pub fn try_queue_packet(
250        &mut self,
251        packet: Packet,
252        target: SocketAddr,
253    ) -> Result<(), QueuePacketError> {
254        let packet_len = packet.len();
255        if packet.len() > MAX_PACKET_SIZE {
256            return Err(QueuePacketError::PacketTooLarge {
257                packet,
258                target,
259                packet_len,
260                max_packet_size: MAX_PACKET_SIZE,
261            });
262        }
263        if self.is_full() {
264            return Err(QueuePacketError::Full { packet, target });
265        }
266        self.queued_packets.push_back((target, packet));
267        Ok(())
268    }
269
270    /// Attempts to flush queued packets without waiting for the socket to become writable.
271    pub fn try_flush_best_effort(&mut self, socket: &UdpSocket) -> io::Result<()> {
272        while !self.is_empty() {
273            match socket.try_io(Interest::WRITABLE, || self.try_send_next(socket)) {
274                Ok(sent) => self.drop_prefix(sent),
275                Err(err) if err.kind() == io::ErrorKind::WouldBlock => return Err(err),
276                Err(err) => return Err(err),
277            }
278        }
279        Ok(())
280    }
281
282    /// Flushes queued packets, waiting asynchronously until the socket becomes writable.
283    pub async fn flush(&mut self, socket: &UdpSocket) -> io::Result<()> {
284        while !self.is_empty() {
285            socket.writable().await?;
286            match socket.try_io(Interest::WRITABLE, || self.try_send_next(socket)) {
287                Ok(sent) => self.drop_prefix(sent),
288                Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
289                Err(err) => return Err(err),
290            }
291        }
292        Ok(())
293    }
294
295    fn drop_prefix(&mut self, count: usize) {
296        self.queued_packets.drain(..count);
297    }
298
299    fn try_send_next(&mut self, socket: &UdpSocket) -> io::Result<usize> {
300        self.scratch.clear();
301        let (target, first_packet) = self
302            .queued_packets
303            .front()
304            .expect("try_send_next requires a non-empty queue");
305        let target = *target;
306        let segment_size = first_packet.len();
307        let mut segments = 0;
308        let max_segments = self.state.max_gso_segments().min(BATCH_SIZE);
309
310        // Only coalesce the segments at the front with matching destination and
311        // segment size so queue order stays intact and we can drop exactly the
312        // packets that were handed to the kernel.
313        for (queued_target, packet) in self.queued_packets.iter().take(max_segments) {
314            if *queued_target != target || packet.len() != segment_size {
315                break;
316            }
317            self.scratch.extend_from_slice(&packet[..]);
318            segments += 1;
319        }
320
321        let transmit = Transmit {
322            destination: target,
323            ecn: None,
324            contents: &self.scratch,
325            segment_size: (segments > 1).then_some(segment_size),
326            src_ip: None,
327        };
328        self.state.try_send(UdpSockRef::from(socket), &transmit)?;
329        Ok(segments)
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use std::net::SocketAddr;
336
337    use ana_gotatun::packet::PacketBufPool;
338    use tokio::net::UdpSocket;
339
340    use super::{MAX_BATCH_SIZE, UdpBatchReceiver, UdpBatchSender};
341
342    const TEST_PACKET_SIZE: usize = 128;
343
344    fn packet_pool() -> PacketBufPool<TEST_PACKET_SIZE> {
345        PacketBufPool::new(MAX_BATCH_SIZE)
346    }
347
348    async fn bound_socket() -> UdpSocket {
349        UdpSocket::bind("127.0.0.1:0").await.unwrap()
350    }
351
352    fn packet_from_bytes(
353        pool: &PacketBufPool<TEST_PACKET_SIZE>,
354        bytes: &[u8],
355    ) -> ana_gotatun::packet::Packet {
356        let mut packet = pool.get();
357        packet[..bytes.len()].copy_from_slice(bytes);
358        packet.truncate(bytes.len());
359        packet
360    }
361
362    #[tokio::test]
363    async fn flushes_partially_full_sender_batch() {
364        let sender_socket = bound_socket().await;
365        let receiver_socket = bound_socket().await;
366        let pool = packet_pool();
367        let mut sender =
368            UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&sender_socket).unwrap();
369
370        sender
371            .try_queue_packet(
372                packet_from_bytes(&pool, b"one"),
373                receiver_socket.local_addr().unwrap(),
374            )
375            .unwrap();
376        sender
377            .try_queue_packet(
378                packet_from_bytes(&pool, b"two"),
379                receiver_socket.local_addr().unwrap(),
380            )
381            .unwrap();
382
383        sender.flush(&sender_socket).await.unwrap();
384
385        let mut buf = [0u8; TEST_PACKET_SIZE];
386        let (n1, _) = receiver_socket.recv_from(&mut buf).await.unwrap();
387        let first = buf[..n1].to_vec();
388        let (n2, _) = receiver_socket.recv_from(&mut buf).await.unwrap();
389        let second = buf[..n2].to_vec();
390
391        assert!(sender.is_empty());
392        assert_eq!(vec![first, second], vec![b"one".to_vec(), b"two".to_vec()]);
393    }
394
395    #[tokio::test]
396    async fn flushes_sender_batch_with_mixed_targets() {
397        let sender_socket = bound_socket().await;
398        let first_target = bound_socket().await;
399        let second_target = bound_socket().await;
400        let pool = packet_pool();
401        let mut sender =
402            UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&sender_socket).unwrap();
403
404        sender
405            .try_queue_packet(
406                packet_from_bytes(&pool, b"alpha"),
407                first_target.local_addr().unwrap(),
408            )
409            .unwrap();
410        sender
411            .try_queue_packet(
412                packet_from_bytes(&pool, b"beta"),
413                second_target.local_addr().unwrap(),
414            )
415            .unwrap();
416        sender
417            .try_queue_packet(
418                packet_from_bytes(&pool, b"gamma"),
419                first_target.local_addr().unwrap(),
420            )
421            .unwrap();
422
423        sender.flush(&sender_socket).await.unwrap();
424
425        let mut buf = [0u8; TEST_PACKET_SIZE];
426        let (n_first_a, _) = first_target.recv_from(&mut buf).await.unwrap();
427        let first_a = buf[..n_first_a].to_vec();
428        let (n_second, _) = second_target.recv_from(&mut buf).await.unwrap();
429        let second = buf[..n_second].to_vec();
430        let (n_first_b, _) = first_target.recv_from(&mut buf).await.unwrap();
431        let first_b = buf[..n_first_b].to_vec();
432
433        assert_eq!(first_a, b"alpha".to_vec());
434        assert_eq!(second, b"beta".to_vec());
435        assert_eq!(first_b, b"gamma".to_vec());
436    }
437
438    #[tokio::test]
439    async fn receive_with_stride_smaller_than_length_splits_segments() {
440        let socket = bound_socket().await;
441        let pool = packet_pool();
442        let mut receiver =
443            UdpBatchReceiver::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket, &pool).unwrap();
444        let source = "127.0.0.1:30000".parse::<SocketAddr>().unwrap();
445
446        receiver.recv_meta[0].addr = source;
447        receiver.recv_meta[0].len = 10;
448        receiver.recv_meta[0].stride = 4;
449        receiver.recv_slots[0][..10].copy_from_slice(b"abcdefghij");
450
451        let mut seen = Vec::new();
452        receiver
453            .handle_received(0, &pool, &mut |packet, addr| {
454                seen.push((packet[..].to_vec(), addr));
455                Ok::<(), ()>(())
456            })
457            .unwrap();
458
459        assert_eq!(
460            seen,
461            vec![
462                (b"abcd".to_vec(), source),
463                (b"efgh".to_vec(), source),
464                (b"ij".to_vec(), source),
465            ]
466        );
467    }
468
469    #[tokio::test]
470    async fn receive_with_stride_at_least_length_uses_single_packet() {
471        let socket = bound_socket().await;
472        let pool = packet_pool();
473        let mut receiver =
474            UdpBatchReceiver::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket, &pool).unwrap();
475        let source = "127.0.0.1:30001".parse::<SocketAddr>().unwrap();
476
477        receiver.recv_meta[0].addr = source;
478        receiver.recv_meta[0].len = 5;
479        receiver.recv_meta[0].stride = 5;
480        receiver.recv_slots[0][..5].copy_from_slice(b"hello");
481
482        let mut seen = Vec::new();
483        receiver
484            .handle_received(0, &pool, &mut |packet, addr| {
485                seen.push((packet[..].to_vec(), addr));
486                Ok::<(), ()>(())
487            })
488            .unwrap();
489
490        assert_eq!(seen, vec![(b"hello".to_vec(), source)]);
491    }
492
493    #[test]
494    fn refuses_to_grow_beyond_batch_capacity() {
495        let runtime = tokio::runtime::Runtime::new().unwrap();
496        runtime.block_on(async {
497            let socket = bound_socket().await;
498            let pool = packet_pool();
499            let mut sender =
500                UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket).unwrap();
501
502            for _ in 0..MAX_BATCH_SIZE {
503                sender
504                    .try_queue_packet(packet_from_bytes(&pool, b"x"), socket.local_addr().unwrap())
505                    .unwrap();
506            }
507
508            assert!(
509                sender
510                    .try_queue_packet(
511                        packet_from_bytes(&pool, b"overflow"),
512                        socket.local_addr().unwrap()
513                    )
514                    .is_err()
515            );
516        });
517    }
518}