use std::{collections::VecDeque, io, io::IoSliceMut, net::SocketAddr};
use ana_gotatun::packet::{Packet, PacketBufPool};
use quinn_udp::{RecvMeta, Transmit, UdpSockRef, UdpSocketState};
use tokio::{io::Interest, net::UdpSocket};
const MAX_BATCH_SIZE: usize = 64;
pub enum RecvBatchError<E> {
Io(io::Error),
Handler(E),
}
#[derive(Debug)]
pub enum QueuePacketError {
Full {
packet: Packet,
target: SocketAddr,
},
PacketTooLarge {
packet: Packet,
target: SocketAddr,
packet_len: usize,
max_packet_size: usize,
},
}
pub struct UdpBatchReceiver<const BATCH_SIZE: usize, const BUFFER_SIZE: usize = 4096> {
state: UdpSocketState,
recv_meta: [RecvMeta; BATCH_SIZE],
recv_slots: [Packet; BATCH_SIZE],
}
impl<const BATCH_SIZE: usize, const BUFFER_SIZE: usize> UdpBatchReceiver<BATCH_SIZE, BUFFER_SIZE> {
pub fn new(socket: &UdpSocket, pool: &PacketBufPool<BUFFER_SIZE>) -> io::Result<Self> {
assert!(
BATCH_SIZE > 0,
"UdpBatchReceiver BATCH_SIZE must be greater than zero"
);
assert!(
BATCH_SIZE <= MAX_BATCH_SIZE,
"UdpBatchReceiver BATCH_SIZE must not exceed MAX_BATCH_SIZE"
);
let state = UdpSocketState::new(UdpSockRef::from(socket))?;
let recv_slots = std::array::from_fn(|_| pool.get());
Ok(Self {
state,
recv_meta: std::array::from_fn(|_| RecvMeta::default()),
recv_slots,
})
}
pub async fn recv_batch<E, F>(
&mut self,
socket: &UdpSocket,
pool: &PacketBufPool<BUFFER_SIZE>,
mut handler: F,
) -> Result<(), RecvBatchError<E>>
where
F: FnMut(Packet, SocketAddr) -> Result<(), E>,
{
let received = loop {
socket.readable().await.map_err(RecvBatchError::Io)?;
match socket.try_io(Interest::READABLE, || self.try_recv(socket)) {
Ok(count) => break count,
Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
Err(err) => return Err(RecvBatchError::Io(err)),
}
};
for index in 0..received {
self.handle_received(index, pool, &mut handler)
.map_err(RecvBatchError::Handler)?;
}
Ok(())
}
fn handle_received<E, F>(
&mut self,
index: usize,
pool: &PacketBufPool<BUFFER_SIZE>,
handler: &mut F,
) -> Result<(), E>
where
F: FnMut(Packet, SocketAddr) -> Result<(), E>,
{
let meta = self.recv_meta[index];
if meta.len == 0 {
return Ok(());
}
let stride = if meta.stride == 0 {
meta.len
} else {
meta.stride
};
if stride >= meta.len {
let mut packet = std::mem::replace(&mut self.recv_slots[index], pool.get());
packet.truncate(meta.len);
handler(packet, meta.addr)?;
return Ok(());
}
let packet = std::mem::replace(&mut self.recv_slots[index], pool.get());
for chunk in packet[..meta.len].chunks(stride) {
let mut segment = pool.get();
segment[..chunk.len()].copy_from_slice(chunk);
segment.truncate(chunk.len());
handler(segment, meta.addr)?;
}
Ok(())
}
fn try_recv(&mut self, socket: &UdpSocket) -> io::Result<usize> {
let mut bufs_uninit: [std::mem::MaybeUninit<IoSliceMut<'_>>; BATCH_SIZE] =
std::array::from_fn(|_| std::mem::MaybeUninit::uninit());
for (index, packet) in self.recv_slots.iter_mut().enumerate() {
bufs_uninit[index].write(IoSliceMut::new(packet.as_mut()));
}
let bufs = unsafe {
std::slice::from_raw_parts_mut(
bufs_uninit.as_mut_ptr() as *mut IoSliceMut<'_>,
BATCH_SIZE,
)
};
self.state
.recv(UdpSockRef::from(socket), bufs, &mut self.recv_meta)
}
}
pub struct UdpBatchSender<const BATCH_SIZE: usize, const MAX_PACKET_SIZE: usize = 4096> {
state: UdpSocketState,
queued_packets: VecDeque<(SocketAddr, Packet)>,
scratch: Vec<u8>,
}
impl<const BATCH_SIZE: usize, const MAX_PACKET_SIZE: usize>
UdpBatchSender<BATCH_SIZE, MAX_PACKET_SIZE>
{
pub fn new(socket: &UdpSocket) -> io::Result<Self> {
assert!(
BATCH_SIZE > 0,
"UdpBatchSender BATCH_SIZE must be greater than zero"
);
assert!(
BATCH_SIZE <= MAX_BATCH_SIZE,
"UdpBatchSender BATCH_SIZE must not exceed MAX_BATCH_SIZE"
);
Ok(Self {
state: UdpSocketState::new(UdpSockRef::from(socket))?,
queued_packets: VecDeque::with_capacity(BATCH_SIZE),
scratch: Vec::with_capacity(MAX_PACKET_SIZE * BATCH_SIZE),
})
}
pub fn is_empty(&self) -> bool {
self.queued_packets.is_empty()
}
pub fn is_full(&self) -> bool {
self.queued_packets.len() == BATCH_SIZE
}
pub fn try_queue_packet(
&mut self,
packet: Packet,
target: SocketAddr,
) -> Result<(), QueuePacketError> {
let packet_len = packet.len();
if packet.len() > MAX_PACKET_SIZE {
return Err(QueuePacketError::PacketTooLarge {
packet,
target,
packet_len,
max_packet_size: MAX_PACKET_SIZE,
});
}
if self.is_full() {
return Err(QueuePacketError::Full { packet, target });
}
self.queued_packets.push_back((target, packet));
Ok(())
}
pub fn try_flush_best_effort(&mut self, socket: &UdpSocket) -> io::Result<()> {
while !self.is_empty() {
match socket.try_io(Interest::WRITABLE, || self.try_send_next(socket)) {
Ok(sent) => self.drop_prefix(sent),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => return Err(err),
Err(err) => return Err(err),
}
}
Ok(())
}
pub async fn flush(&mut self, socket: &UdpSocket) -> io::Result<()> {
while !self.is_empty() {
socket.writable().await?;
match socket.try_io(Interest::WRITABLE, || self.try_send_next(socket)) {
Ok(sent) => self.drop_prefix(sent),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
Err(err) => return Err(err),
}
}
Ok(())
}
fn drop_prefix(&mut self, count: usize) {
self.queued_packets.drain(..count);
}
fn try_send_next(&mut self, socket: &UdpSocket) -> io::Result<usize> {
self.scratch.clear();
let (target, first_packet) = self
.queued_packets
.front()
.expect("try_send_next requires a non-empty queue");
let target = *target;
let segment_size = first_packet.len();
let mut segments = 0;
let max_segments = self.state.max_gso_segments().min(BATCH_SIZE);
for (queued_target, packet) in self.queued_packets.iter().take(max_segments) {
if *queued_target != target || packet.len() != segment_size {
break;
}
self.scratch.extend_from_slice(&packet[..]);
segments += 1;
}
let transmit = Transmit {
destination: target,
ecn: None,
contents: &self.scratch,
segment_size: (segments > 1).then_some(segment_size),
src_ip: None,
};
self.state.try_send(UdpSockRef::from(socket), &transmit)?;
Ok(segments)
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use ana_gotatun::packet::PacketBufPool;
use tokio::net::UdpSocket;
use super::{MAX_BATCH_SIZE, UdpBatchReceiver, UdpBatchSender};
const TEST_PACKET_SIZE: usize = 128;
fn packet_pool() -> PacketBufPool<TEST_PACKET_SIZE> {
PacketBufPool::new(MAX_BATCH_SIZE)
}
async fn bound_socket() -> UdpSocket {
UdpSocket::bind("127.0.0.1:0").await.unwrap()
}
fn packet_from_bytes(
pool: &PacketBufPool<TEST_PACKET_SIZE>,
bytes: &[u8],
) -> ana_gotatun::packet::Packet {
let mut packet = pool.get();
packet[..bytes.len()].copy_from_slice(bytes);
packet.truncate(bytes.len());
packet
}
#[tokio::test]
async fn flushes_partially_full_sender_batch() {
let sender_socket = bound_socket().await;
let receiver_socket = bound_socket().await;
let pool = packet_pool();
let mut sender =
UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&sender_socket).unwrap();
sender
.try_queue_packet(
packet_from_bytes(&pool, b"one"),
receiver_socket.local_addr().unwrap(),
)
.unwrap();
sender
.try_queue_packet(
packet_from_bytes(&pool, b"two"),
receiver_socket.local_addr().unwrap(),
)
.unwrap();
sender.flush(&sender_socket).await.unwrap();
let mut buf = [0u8; TEST_PACKET_SIZE];
let (n1, _) = receiver_socket.recv_from(&mut buf).await.unwrap();
let first = buf[..n1].to_vec();
let (n2, _) = receiver_socket.recv_from(&mut buf).await.unwrap();
let second = buf[..n2].to_vec();
assert!(sender.is_empty());
assert_eq!(vec![first, second], vec![b"one".to_vec(), b"two".to_vec()]);
}
#[tokio::test]
async fn flushes_sender_batch_with_mixed_targets() {
let sender_socket = bound_socket().await;
let first_target = bound_socket().await;
let second_target = bound_socket().await;
let pool = packet_pool();
let mut sender =
UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&sender_socket).unwrap();
sender
.try_queue_packet(
packet_from_bytes(&pool, b"alpha"),
first_target.local_addr().unwrap(),
)
.unwrap();
sender
.try_queue_packet(
packet_from_bytes(&pool, b"beta"),
second_target.local_addr().unwrap(),
)
.unwrap();
sender
.try_queue_packet(
packet_from_bytes(&pool, b"gamma"),
first_target.local_addr().unwrap(),
)
.unwrap();
sender.flush(&sender_socket).await.unwrap();
let mut buf = [0u8; TEST_PACKET_SIZE];
let (n_first_a, _) = first_target.recv_from(&mut buf).await.unwrap();
let first_a = buf[..n_first_a].to_vec();
let (n_second, _) = second_target.recv_from(&mut buf).await.unwrap();
let second = buf[..n_second].to_vec();
let (n_first_b, _) = first_target.recv_from(&mut buf).await.unwrap();
let first_b = buf[..n_first_b].to_vec();
assert_eq!(first_a, b"alpha".to_vec());
assert_eq!(second, b"beta".to_vec());
assert_eq!(first_b, b"gamma".to_vec());
}
#[tokio::test]
async fn receive_with_stride_smaller_than_length_splits_segments() {
let socket = bound_socket().await;
let pool = packet_pool();
let mut receiver =
UdpBatchReceiver::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket, &pool).unwrap();
let source = "127.0.0.1:30000".parse::<SocketAddr>().unwrap();
receiver.recv_meta[0].addr = source;
receiver.recv_meta[0].len = 10;
receiver.recv_meta[0].stride = 4;
receiver.recv_slots[0][..10].copy_from_slice(b"abcdefghij");
let mut seen = Vec::new();
receiver
.handle_received(0, &pool, &mut |packet, addr| {
seen.push((packet[..].to_vec(), addr));
Ok::<(), ()>(())
})
.unwrap();
assert_eq!(
seen,
vec![
(b"abcd".to_vec(), source),
(b"efgh".to_vec(), source),
(b"ij".to_vec(), source),
]
);
}
#[tokio::test]
async fn receive_with_stride_at_least_length_uses_single_packet() {
let socket = bound_socket().await;
let pool = packet_pool();
let mut receiver =
UdpBatchReceiver::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket, &pool).unwrap();
let source = "127.0.0.1:30001".parse::<SocketAddr>().unwrap();
receiver.recv_meta[0].addr = source;
receiver.recv_meta[0].len = 5;
receiver.recv_meta[0].stride = 5;
receiver.recv_slots[0][..5].copy_from_slice(b"hello");
let mut seen = Vec::new();
receiver
.handle_received(0, &pool, &mut |packet, addr| {
seen.push((packet[..].to_vec(), addr));
Ok::<(), ()>(())
})
.unwrap();
assert_eq!(seen, vec![(b"hello".to_vec(), source)]);
}
#[test]
fn refuses_to_grow_beyond_batch_capacity() {
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let socket = bound_socket().await;
let pool = packet_pool();
let mut sender =
UdpBatchSender::<MAX_BATCH_SIZE, TEST_PACKET_SIZE>::new(&socket).unwrap();
for _ in 0..MAX_BATCH_SIZE {
sender
.try_queue_packet(packet_from_bytes(&pool, b"x"), socket.local_addr().unwrap())
.unwrap();
}
assert!(
sender
.try_queue_packet(
packet_from_bytes(&pool, b"overflow"),
socket.local_addr().unwrap()
)
.is_err()
);
});
}
}