use std::{net::SocketAddr, sync::Arc};
use futures::{FutureExt, select};
use tokio::{io, sync::mpsc};
use crate::packet::{Packet, PacketBufPool};
use crate::task::Task;
use crate::udp::{UdpRecv, UdpSend};
#[derive(Clone)]
pub struct BufferedUdpSend {
_send_task: Arc<Task>,
send_tx_v4: mpsc::Sender<(Packet, SocketAddr)>,
send_tx_v6: mpsc::Sender<(Packet, SocketAddr)>,
}
impl BufferedUdpSend {
pub fn new(capacity: usize, udp_tx: impl UdpSend + 'static) -> Self {
let (send_tx_v4, mut send_rx_v4) = mpsc::channel::<(Packet, SocketAddr)>(capacity);
let (send_tx_v6, mut send_rx_v6) = mpsc::channel::<(Packet, SocketAddr)>(capacity);
let send_task = Task::spawn("buffered UDP send", async move {
let mut buf_v4 = vec![];
let mut buf_v6 = vec![];
let max_packet_count = udp_tx.max_number_of_packets_to_send();
let mut send_many_buf = Default::default();
loop {
let (count, buf) = select! {
n = send_rx_v4.recv_many(&mut buf_v4, max_packet_count).fuse() => (n, &mut buf_v4),
n = send_rx_v6.recv_many(&mut buf_v6, max_packet_count).fuse() => (n, &mut buf_v6),
};
match count {
0 => break,
1 => {
let (packet, addr) =
buf.pop().expect("recv_meny received 1 packet into buf");
let _ = udp_tx
.send_to(packet, addr)
.await
.inspect_err(|e| log::trace!("send_to_err: {e:#}"));
}
2.. => {
if let Err(e) = udp_tx.send_many_to(&mut send_many_buf, buf).await {
log::trace!("send_to_many_err: {e:#}");
if !buf.is_empty() {
log::trace!(
"send_to_many dropping {} packets due to error.",
buf.len()
);
buf.clear(); }
}
}
}
}
});
Self {
_send_task: Arc::new(send_task),
send_tx_v4,
send_tx_v6,
}
}
}
impl UdpSend for BufferedUdpSend {
type SendManyBuf = ();
async fn send_to(&self, packet: Packet, destination: SocketAddr) -> io::Result<()> {
let tx = match destination {
SocketAddr::V4(..) => &self.send_tx_v4,
SocketAddr::V6(..) => &self.send_tx_v6,
};
tx.send((packet, destination))
.await
.expect("receiver task is never stopped while Self exists");
Ok(())
}
fn max_number_of_packets_to_send(&self) -> usize {
debug_assert_eq!(
self.send_tx_v4.max_capacity(),
self.send_tx_v6.max_capacity(),
);
self.send_tx_v4.max_capacity()
}
}
pub struct BufferedUdpReceive {
_recv_task: Arc<Task>,
recv_rx: mpsc::Receiver<(Packet, SocketAddr)>,
}
impl BufferedUdpReceive {
pub fn new<U: UdpRecv + 'static>(
capacity: usize,
mut udp_rx: impl UdpRecv + 'static,
mut recv_pool: PacketBufPool,
) -> Self {
let (recv_tx, recv_rx) = mpsc::channel::<(Packet, SocketAddr)>(capacity);
let recv_task = Task::spawn("buffered UDP receive", async move {
let mut recv_many_buf = Default::default();
let mut packet_bufs = vec![];
loop {
let Ok(()) = udp_rx
.recv_many_from(&mut recv_many_buf, &mut recv_pool, &mut packet_bufs)
.await
else {
return;
};
for (packet_buf, src) in packet_bufs.drain(..) {
match recv_tx.try_send((packet_buf, src)) {
Ok(()) => (),
Err(mpsc::error::TrySendError::Full((packet_buf, addr))) => {
if recv_tx.send((packet_buf, addr)).await.is_err() {
return;
}
}
Err(mpsc::error::TrySendError::Closed(_)) => return,
}
}
}
});
Self {
_recv_task: Arc::new(recv_task),
recv_rx,
}
}
}
impl UdpRecv for BufferedUdpReceive {
type RecvManyBuf = ();
async fn recv_from(&mut self, _pool: &mut PacketBufPool) -> io::Result<(Packet, SocketAddr)> {
let Some((rx_packet, src)) = self.recv_rx.recv().await else {
return Err(io::Error::other("No packet available"));
};
Ok((rx_packet, src))
}
}