use crossbeam::channel::{self, Receiver, Sender};
use crossbeam::queue::ArrayQueue;
use std::{
io,
net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket},
sync::Arc,
thread::spawn,
};
use tracing::{error, warn};
use crate::packet::{Packet, Request};
use bon::Builder;
#[derive(Builder)]
pub struct MoldUDP64 {
multicast_addr: SocketAddrV4,
interface_addr: Ipv4Addr,
rerequest_server_addrs: Vec<SocketAddr>,
expected_session_ident: Option<String>,
expected_seq_num: Option<u64>,
#[builder(default = 100)]
max_rerequest_retries: u64,
}
impl MoldUDP64 {
#[must_use]
pub fn start(&self) -> io::Result<Receiver<Datagram>> {
let mcast_socket = UdpSocket::bind(SocketAddrV4::new(
Ipv4Addr::UNSPECIFIED,
self.multicast_addr.port(),
))?;
mcast_socket.join_multicast_v4(self.multicast_addr.ip(), &self.interface_addr)?;
let rereq_socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))?;
self.start_with_sockets(mcast_socket, rereq_socket, &self.rerequest_server_addrs)
}
pub(crate) fn start_with_sockets(
&self,
downstream: UdpSocket,
rereq: UdpSocket,
servers: &[SocketAddr],
) -> io::Result<Receiver<Datagram>> {
let pool: Pool = Arc::new(ArrayQueue::new(POOL_SIZE));
for _ in 0..POOL_SIZE {
let _ = pool.push(vec![0u8; BUF_SIZE].into_boxed_slice());
}
let (data_tx, data_rx) = channel::bounded::<Datagram>(POOL_SIZE);
let (req_tx, req_rx) = channel::bounded::<Request>(POOL_SIZE);
{
let pool = Arc::clone(&pool);
let data_tx = data_tx.clone();
let mut session = self.expected_session_ident.clone();
let mut seq = self.expected_seq_num;
let req_tx = req_tx.clone();
spawn(move || {
multicast_recv_loop(downstream, pool, data_tx, req_tx, &mut session, &mut seq);
});
}
let rereq_socket = Arc::new(rereq);
for &server_addr in servers {
let socket = Arc::clone(&rereq_socket);
let req_rx = req_rx.clone();
let req_tx = req_tx.clone();
let max_rerequest_retries = self.max_rerequest_retries;
spawn(move || {
let mut n_failures = 0;
while let Ok(req) = req_rx.recv()
&& n_failures < max_rerequest_retries
{
if let Err(e) = socket.send_to(req.as_bytes(), &server_addr) {
warn!("failed to send re-request to {server_addr}: {e}");
n_failures += 1;
if req_tx.try_send(req).is_err() {
error!("re-request queue full or disconnected");
}
}
}
});
}
drop(req_rx);
{
let pool = Arc::clone(&pool);
let data_tx = data_tx.clone();
let socket = Arc::clone(&rereq_socket);
spawn(move || rerequest_recv_loop(socket, pool, data_tx));
}
Ok(data_rx)
}
}
fn multicast_recv_loop(
socket: UdpSocket,
pool: Pool,
data_tx: Sender<Datagram>,
req_tx: Sender<Request>,
expected_session_ident: &mut Option<String>,
expected_seq_num: &mut Option<u64>,
) {
loop {
let mut buf = pool
.pop()
.unwrap_or_else(|| vec![0u8; BUF_SIZE].into_boxed_slice());
let n = match socket.recv(&mut buf[..]) {
Ok(n) => n,
Err(e) => {
error!("multicast recv error: {e}");
let _ = pool.push(buf);
break;
}
};
if n < Packet::MIN_PACKET_LEN {
error!("incomplete multicast datagram");
let _ = pool.push(buf);
continue;
}
let packet = Packet::new(&buf);
if let (Some(exp_session), Some(exp_seq)) =
(expected_session_ident.as_deref(), *expected_seq_num)
{
let session_matches = exp_session == packet.session_ident();
if session_matches && packet.seq_num() > exp_seq {
let gap = packet.seq_num() - exp_seq;
let msg_count = gap.min(u16::MAX as u64) as u16;
let req = Request::new(exp_session, exp_seq, msg_count);
if req_tx.try_send(req).is_err() {
error!("re-request queue full or disconnected");
}
}
}
*expected_seq_num = Some(packet.seq_num() + packet.msg_count() as u64);
match expected_session_ident {
Some(s) => {
s.clear();
s.push_str(packet.session_ident());
}
None => *expected_session_ident = Some(packet.session_ident().to_owned()),
}
forward(&data_tx, &pool, buf, n, "multicast");
}
}
fn rerequest_recv_loop(socket: Arc<UdpSocket>, pool: Pool, data_tx: Sender<Datagram>) {
loop {
let mut buf = pool
.pop()
.unwrap_or_else(|| vec![0u8; BUF_SIZE].into_boxed_slice());
let n = match socket.recv(&mut buf[..]) {
Ok(n) => n,
Err(e) => {
error!("re-request recv error: {e}");
let _ = pool.push(buf);
break;
}
};
if n < Packet::MIN_PACKET_LEN {
error!("incomplete retransmission datagram");
let _ = pool.push(buf);
continue;
}
forward(&data_tx, &pool, buf, n, "retx");
}
}
#[inline]
fn forward(data_tx: &Sender<Datagram>, pool: &Pool, buf: Buffer, len: usize, src: &'static str) {
let dgram = Datagram {
buf: Some(buf),
len,
pool: Arc::clone(pool),
};
match data_tx.try_send(dgram) {
Err(channel::TrySendError::Full(_)) => warn!("datagram consumer full ({src})"),
Err(channel::TrySendError::Disconnected(_)) => {
warn!("datagram consumer dropped ({src})");
}
Ok(()) => {}
}
}
const BUF_SIZE: usize = 524_288;
const POOL_SIZE: usize = 1024;
type Buffer = Box<[u8]>;
type Pool = Arc<ArrayQueue<Buffer>>;
pub struct Datagram {
buf: Option<Buffer>,
len: usize,
pool: Pool,
}
impl Datagram {
#[inline]
pub fn bytes(&self) -> &[u8] {
&self.buf.as_ref().unwrap()[..self.len]
}
}
impl Drop for Datagram {
fn drop(&mut self) {
if let Some(buf) = self.buf.take() {
let _ = self.pool.push(buf);
}
}
}