use bytes::Bytes;
use libwebrtc::{self as rtc, data_channel::DataChannel};
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use tokio::sync::{mpsc, watch, Notify};
pub type DataTrackFramePackets = Vec<Bytes>;
pub struct DataChannelSenderOptions {
pub low_buffer_threshold: u64,
pub dc: DataChannel,
pub close_rx: watch::Receiver<bool>,
}
#[derive(Clone)]
pub struct DataTrackSendQueue {
inner: Arc<DataTrackSendQueueInner>,
}
struct DataTrackSendQueueInner {
queue: Mutex<VecDeque<DataTrackFramePackets>>,
notify: Notify,
capacity: usize,
}
impl DataTrackSendQueue {
fn new(capacity: usize) -> Self {
debug_assert!(capacity >= 1);
Self {
inner: Arc::new(DataTrackSendQueueInner {
queue: Mutex::new(VecDeque::with_capacity(capacity)),
notify: Notify::new(),
capacity,
}),
}
}
pub fn send(&self, packets: DataTrackFramePackets) -> Option<DataTrackFramePackets> {
if packets.is_empty() {
return None;
}
let mut queue = self.inner.queue.lock().expect("send queue mutex poisoned");
let dropped = if queue.len() >= self.inner.capacity { queue.pop_front() } else { None };
queue.push_back(packets);
drop(queue);
self.inner.notify.notify_one();
dropped
}
fn try_pop(&self) -> Option<DataTrackFramePackets> {
self.inner.queue.lock().expect("send queue mutex poisoned").pop_front()
}
fn drain(&self) -> VecDeque<DataTrackFramePackets> {
std::mem::take(&mut *self.inner.queue.lock().expect("send queue mutex poisoned"))
}
async fn recv(&self) -> DataTrackFramePackets {
loop {
if let Some(packets) = self.try_pop() {
return packets;
}
let notified = self.inner.notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
if let Some(packets) = self.try_pop() {
return packets;
}
notified.await;
}
}
}
pub struct DataChannelSender {
queue: DataTrackSendQueue,
in_flight: VecDeque<Bytes>,
dc_event_rx: mpsc::UnboundedReceiver<DataChannelEvent>,
dc_event_tx: mpsc::UnboundedSender<DataChannelEvent>,
close_rx: watch::Receiver<bool>,
dc: DataChannel,
low_buffer_threshold: u64,
buffered_amount: u64,
}
impl DataChannelSender {
const QUEUE_CAPACITY: usize = 1;
pub fn new(options: DataChannelSenderOptions) -> (Self, DataTrackSendQueue) {
let queue = DataTrackSendQueue::new(Self::QUEUE_CAPACITY);
let (dc_event_tx, dc_event_rx) = mpsc::unbounded_channel();
let sender = Self {
low_buffer_threshold: options.low_buffer_threshold,
dc: options.dc,
queue: queue.clone(),
in_flight: VecDeque::new(),
dc_event_rx,
dc_event_tx,
close_rx: options.close_rx,
buffered_amount: 0,
};
(sender, queue)
}
pub async fn run(mut self) {
log::debug!("Send task started for data channel '{}'", self.dc.label());
self.register_dc_callbacks();
loop {
tokio::select! {
Some(event) = self.dc_event_rx.recv() => {
let DataChannelEvent::BytesSent(bytes_sent) = event;
self.handle_bytes_sent(bytes_sent);
self.drain_in_flight();
}
packets = self.queue.recv(),
if self.in_flight.is_empty()
&& self.buffered_amount <= self.low_buffer_threshold =>
{
self.in_flight.extend(packets);
self.drain_in_flight();
}
_ = self.close_rx.changed() => break
}
}
let remaining = self.queue.drain();
if !remaining.is_empty() {
let unsent_bytes: usize =
remaining.iter().flat_map(|frame| frame.iter()).map(|p| p.len()).sum();
log::info!("{} byte(s) remain in queue", unsent_bytes);
}
log::debug!("Send task ended for data channel '{}'", self.dc.label());
}
fn drain_in_flight(&mut self) {
while self.buffered_amount <= self.low_buffer_threshold {
let Some(packet) = self.in_flight.pop_front() else { break };
self.dispatch(packet);
}
}
fn dispatch(&mut self, payload: Bytes) {
self.buffered_amount += payload.len() as u64;
_ = self
.dc
.send(&payload, true)
.inspect_err(|err| log::error!("Failed to send data: {}", err));
}
fn handle_bytes_sent(&mut self, bytes_sent: u64) {
if self.buffered_amount < bytes_sent {
log::error!("Unexpected buffer amount");
self.buffered_amount = 0;
return;
}
self.buffered_amount -= bytes_sent;
}
fn register_dc_callbacks(&self) {
self.dc.on_buffered_amount_change(
on_buffered_amount_change(self.dc_event_tx.downgrade()).into(),
);
}
}
#[derive(Debug)]
enum DataChannelEvent {
BytesSent(u64),
}
fn on_buffered_amount_change(
event_tx: mpsc::WeakUnboundedSender<DataChannelEvent>,
) -> rtc::data_channel::OnBufferedAmountChange {
Box::new(move |bytes_sent| {
let Some(event_tx) = event_tx.upgrade() else { return };
_ = event_tx.send(DataChannelEvent::BytesSent(bytes_sent));
})
}
#[cfg(test)]
mod tests {
use super::*;
fn packet(byte: u8, len: usize) -> Bytes {
Bytes::from(vec![byte; len])
}
fn frame(byte: u8, packet_count: usize, packet_len: usize) -> DataTrackFramePackets {
(0..packet_count).map(|_| packet(byte, packet_len)).collect()
}
#[test]
fn send_empty_frame_is_noop() {
let q = DataTrackSendQueue::new(1);
assert!(q.send(Vec::new()).is_none());
assert!(q.try_pop().is_none());
}
#[test]
fn send_keeps_multi_packet_frame_intact() {
let q = DataTrackSendQueue::new(1);
let f = frame(0xAA, 13, 16_000);
assert!(q.send(f.clone()).is_none());
let got = q.try_pop().expect("frame should be queued");
assert_eq!(got.len(), 13, "all packets for the frame must remain together");
assert!(got.iter().all(|p| p.len() == 16_000 && p[0] == 0xAA));
}
#[test]
fn send_drops_oldest_whole_frame_when_full() {
let q = DataTrackSendQueue::new(1);
let older = frame(0x01, 4, 128);
let newer = frame(0x02, 3, 128);
assert!(q.send(older.clone()).is_none());
let evicted = q.send(newer.clone()).expect("older frame should be evicted");
assert_eq!(evicted.len(), older.len());
assert!(evicted.iter().all(|p| p[0] == 0x01));
let got = q.try_pop().expect("newer frame should remain");
assert_eq!(got.len(), newer.len());
assert!(got.iter().all(|p| p[0] == 0x02));
assert!(q.try_pop().is_none());
}
#[tokio::test]
async fn recv_returns_whole_frame() {
let q = DataTrackSendQueue::new(1);
let f = frame(0x33, 5, 64);
let q_send = q.clone();
let f_sent = f.clone();
tokio::spawn(async move {
q_send.send(f_sent);
});
let got = q.recv().await;
assert_eq!(got.len(), f.len());
assert!(got.iter().all(|p| p[0] == 0x33));
}
}