use bytes::Bytes;
use libwebrtc::{self as rtc, data_channel::DataChannel};
use std::collections::VecDeque;
use tokio::sync::{mpsc, watch};
pub struct DataChannelSenderOptions {
pub low_buffer_threshold: u64,
pub dc: DataChannel,
pub close_rx: watch::Receiver<bool>,
}
pub struct DataChannelSender {
send_rx: mpsc::Receiver<Bytes>,
dc_event_rx: mpsc::UnboundedReceiver<DataChannelEvent>,
dc_event_tx: mpsc::UnboundedSender<DataChannelEvent>,
close_rx: watch::Receiver<bool>,
dc: DataChannel,
low_buffer_threshold: u64,
buffered_amount: u64,
send_queue: VecDeque<Bytes>,
}
impl DataChannelSender {
const CHANNEL_BUFFER_SIZE: usize = 128;
pub fn new(options: DataChannelSenderOptions) -> (Self, mpsc::Sender<Bytes>) {
let (send_tx, send_rx) = mpsc::channel(Self::CHANNEL_BUFFER_SIZE);
let (dc_event_tx, dc_event_rx) = mpsc::unbounded_channel();
let sender = Self {
low_buffer_threshold: options.low_buffer_threshold,
dc: options.dc,
send_rx,
dc_event_rx,
dc_event_tx,
close_rx: options.close_rx,
buffered_amount: 0,
send_queue: VecDeque::default(),
};
(sender, send_tx)
}
pub async fn run(mut self) {
log::debug!("Send task started for data channel '{}'", self.dc.label());
self.register_dc_callbacks();
loop {
tokio::select! {
Some(event) = self.dc_event_rx.recv() => {
let DataChannelEvent::BytesSent(bytes_sent) = event;
self.handle_bytes_sent(bytes_sent);
}
Some(payload) = self.send_rx.recv() => {
self.handle_enqueue_for_send(payload)
}
_ = self.close_rx.changed() => break
}
}
if !self.send_queue.is_empty() {
let unsent_bytes: usize =
self.send_queue.into_iter().map(|payload| payload.len()).sum();
log::info!("{} byte(s) remain in queue", unsent_bytes);
}
log::debug!("Send task ended for data channel '{}'", self.dc.label());
}
fn send_until_threshold(&mut self) {
while self.buffered_amount <= self.low_buffer_threshold {
let Some(payload) = self.send_queue.pop_front() else {
break;
};
self.buffered_amount += payload.len() as u64;
_ = self
.dc
.send(&payload, true)
.inspect_err(|err| log::error!("Failed to send data: {}", err));
}
}
fn handle_enqueue_for_send(&mut self, payload: Bytes) {
self.send_queue.push_back(payload);
self.send_until_threshold();
}
fn handle_bytes_sent(&mut self, bytes_sent: u64) {
if self.buffered_amount < bytes_sent {
log::error!("Unexpected buffer amount");
self.buffered_amount = 0;
return;
}
self.buffered_amount -= bytes_sent;
self.send_until_threshold();
}
fn register_dc_callbacks(&self) {
self.dc.on_buffered_amount_change(
on_buffered_amount_change(self.dc_event_tx.downgrade()).into(),
);
}
}
#[derive(Debug)]
enum DataChannelEvent {
BytesSent(u64),
}
fn on_buffered_amount_change(
event_tx: mpsc::WeakUnboundedSender<DataChannelEvent>,
) -> rtc::data_channel::OnBufferedAmountChange {
Box::new(move |bytes_sent| {
let Some(event_tx) = event_tx.upgrade() else { return };
_ = event_tx.send(DataChannelEvent::BytesSent(bytes_sent));
})
}