use crossbeam::channel::{self, Receiver, Sender};
use crossbeam::queue::ArrayQueue;
use std::{
io,
net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket},
sync::Arc,
thread::spawn,
};
use tracing::{error, warn};
use crate::packet::Packet;
use bon::Builder;
#[derive(Builder)]
pub struct MoldUDP64 {
multicast_addr: SocketAddrV4,
interface_addr: Ipv4Addr,
rerequest_server_addrs: Vec<SocketAddr>,
expected_session_ident: Option<String>,
expected_seq_num: Option<u64>,
#[builder(default = 100)]
max_rerequest_retries: u8,
}
impl MoldUDP64 {
#[must_use]
pub fn start(&self) -> io::Result<(Receiver<Datagram>, Sender<RetransmissionRequest>)> {
let mcast_socket = UdpSocket::bind(SocketAddrV4::new(
Ipv4Addr::UNSPECIFIED,
self.multicast_addr.port(),
))?;
mcast_socket.join_multicast_v4(self.multicast_addr.ip(), &self.interface_addr)?;
let rereq_socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))?;
self.start_with_sockets(mcast_socket, rereq_socket, &self.rerequest_server_addrs)
}
pub fn start_with_sockets(
&self,
downstream: UdpSocket,
rereq: UdpSocket,
servers: &[SocketAddr],
) -> io::Result<(Receiver<Datagram>, Sender<RetransmissionRequest>)> {
let pool: Pool = Arc::new(ArrayQueue::new(POOL_SIZE));
for _ in 0..POOL_SIZE {
let _ = pool.push(vec![0u8; BUF_SIZE].into_boxed_slice());
}
let (data_tx, data_rx) = channel::bounded::<Datagram>(POOL_SIZE);
let (req_tx, req_rx) = channel::bounded::<RetransmissionRequest>(POOL_SIZE);
{
let pool = Arc::clone(&pool);
let data_tx = data_tx.clone();
let mut session = self.expected_session_ident.clone();
let mut seq = self.expected_seq_num;
let req_tx = req_tx.clone();
spawn(move || {
multicast_recv_loop(downstream, pool, data_tx, req_tx, &mut session, &mut seq);
});
}
let rereq_socket = Arc::new(rereq);
for &server_addr in servers {
let socket = Arc::clone(&rereq_socket);
let req_rx = req_rx.clone();
let req_tx = req_tx.clone();
let max_rerequest_retries = self.max_rerequest_retries;
spawn(move || {
let mut buf = [0u8; 20];
while let Ok(RetransmissionRequest { req, attempts }) = req_rx.recv()
&& attempts < max_rerequest_retries
{
req.serialize_into(&mut buf);
if let Err(e) = socket.send_to(buf.as_slice(), &server_addr) {
warn!("failed to send re-request to {server_addr}: {e}");
if req_tx
.try_send(RetransmissionRequest {
attempts: attempts + 1,
req: req,
})
.is_err()
{
error!("re-request queue full or disconnected");
}
}
}
});
}
drop(req_rx);
{
let pool = Arc::clone(&pool);
let data_tx = data_tx.clone();
let socket = Arc::clone(&rereq_socket);
spawn(move || rerequest_recv_loop(socket, pool, data_tx));
}
Ok((data_rx, req_tx))
}
}
fn multicast_recv_loop(
socket: UdpSocket,
pool: Pool,
data_tx: Sender<Datagram>,
req_tx: Sender<RetransmissionRequest>,
expected_session_ident: &mut Option<String>,
expected_seq_num: &mut Option<u64>,
) {
loop {
let mut buf = pool
.pop()
.unwrap_or_else(|| vec![0u8; BUF_SIZE].into_boxed_slice());
let n = match socket.recv(&mut buf[..]) {
Ok(n) => n,
Err(e) => {
error!("multicast recv error: {e}");
let _ = pool.push(buf);
break;
}
};
if n < Packet::MIN_PACKET_LEN {
error!("incomplete multicast datagram");
let _ = pool.push(buf);
continue;
}
let packet = Packet::new(&buf);
if let (Some(exp_session), Some(exp_seq)) =
(expected_session_ident.as_deref(), *expected_seq_num)
{
let session_matches = exp_session == packet.session_ident();
if session_matches && packet.seq_num() > exp_seq {
let gap = packet.seq_num() - exp_seq;
let msg_count = gap.min(u16::MAX as u64) as u16;
let req = RetransmissionPacket {
session: *packet.session_ident_raw(),
seq_num: exp_seq,
msg_count: msg_count,
};
if req_tx.try_send(RetransmissionRequest::new(req)).is_err() {
error!("re-request queue full or disconnected");
}
}
}
*expected_seq_num = Some(packet.seq_num() + packet.msg_count() as u64);
match expected_session_ident {
Some(s) => {
s.clear();
s.push_str(packet.session_ident());
}
None => *expected_session_ident = Some(packet.session_ident().to_owned()),
}
forward(&data_tx, &pool, buf, n, "multicast");
}
}
fn rerequest_recv_loop(socket: Arc<UdpSocket>, pool: Pool, data_tx: Sender<Datagram>) {
loop {
let mut buf = pool
.pop()
.unwrap_or_else(|| vec![0u8; BUF_SIZE].into_boxed_slice());
let n = match socket.recv(&mut buf[..]) {
Ok(n) => n,
Err(e) => {
error!("re-request recv error: {e}");
let _ = pool.push(buf);
break;
}
};
if n < Packet::MIN_PACKET_LEN {
error!("incomplete retransmission datagram");
let _ = pool.push(buf);
continue;
}
forward(&data_tx, &pool, buf, n, "retx");
}
}
#[inline]
fn forward(data_tx: &Sender<Datagram>, pool: &Pool, buf: Buffer, len: usize, src: &'static str) {
let dgram = Datagram {
buf: Some(buf),
len,
pool: Arc::clone(pool),
};
match data_tx.try_send(dgram) {
Err(channel::TrySendError::Full(_)) => warn!("datagram consumer full ({src})"),
Err(channel::TrySendError::Disconnected(_)) => {
warn!("datagram consumer dropped ({src})");
}
Ok(()) => {}
}
}
const BUF_SIZE: usize = 524_288;
const POOL_SIZE: usize = 1024;
type Buffer = Box<[u8]>;
type Pool = Arc<ArrayQueue<Buffer>>;
pub struct Datagram {
buf: Option<Buffer>,
len: usize,
pool: Pool,
}
impl Datagram {
#[inline]
pub fn bytes(&self) -> &[u8] {
&self.buf.as_ref().unwrap()[..self.len]
}
}
impl Drop for Datagram {
fn drop(&mut self) {
if let Some(buf) = self.buf.take() {
let _ = self.pool.push(buf);
}
}
}
pub struct RetransmissionRequest {
req: RetransmissionPacket,
attempts: u8,
}
impl RetransmissionRequest {
pub fn new(packet: RetransmissionPacket) -> Self {
RetransmissionRequest {
req: packet,
attempts: 0,
}
}
}
pub struct RetransmissionPacket {
session: [u8; 10],
seq_num: u64,
msg_count: u16,
}
impl RetransmissionPacket {
pub fn new(session: [u8; 10], seq_num: u64, msg_count: u16) -> RetransmissionPacket {
RetransmissionPacket {
session,
seq_num,
msg_count,
}
}
#[inline]
fn serialize_into(&self, buf: &mut [u8; 20]) {
buf[Self::SESSION_OFFSET..Self::SESSION_LENGTH].copy_from_slice(&self.session);
let end = Self::SEQ_OFFSET + Self::SEQ_LENGTH;
buf[Self::SEQ_OFFSET..end].copy_from_slice(&self.seq_num.to_be_bytes());
let end = Self::MSG_COUNT_OFFSET + Self::MSG_COUNT_LENGTH;
buf[Self::MSG_COUNT_OFFSET..end].copy_from_slice(&self.msg_count.to_be_bytes());
}
const SESSION_OFFSET: usize = 0;
const SESSION_LENGTH: usize = 10;
const SEQ_OFFSET: usize = 10;
const SEQ_LENGTH: usize = 8;
const MSG_COUNT_OFFSET: usize = 18;
const MSG_COUNT_LENGTH: usize = 2;
}