#[cfg(test)]
mod tests;
use crate::data_producer::{DataProducer, DataProducerId, WeakDataProducer};
use crate::data_structures::{AppData, WebRtcMessage};
use crate::messages::{
DataConsumerCloseRequest, DataConsumerDumpRequest, DataConsumerGetBufferedAmountRequest,
DataConsumerGetStatsRequest, DataConsumerSendRequest,
DataConsumerSetBufferedAmountLowThresholdRequest,
};
use crate::sctp_parameters::SctpStreamParameters;
use crate::transport::Transport;
use crate::uuid_based_wrapper_type;
use crate::worker::{Channel, PayloadChannel, RequestError, SubscriptionHandler};
use async_executor::Executor;
use event_listener_primitives::{Bag, BagOnce, HandlerId};
use log::{debug, error};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::fmt;
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};
uuid_based_wrapper_type!(
DataConsumerId
);
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DataConsumerOptions {
pub(super) data_producer_id: DataProducerId,
pub(super) ordered: Option<bool>,
pub(super) max_packet_life_time: Option<u16>,
pub(super) max_retransmits: Option<u16>,
pub app_data: AppData,
}
impl DataConsumerOptions {
#[must_use]
pub fn new_sctp(data_producer_id: DataProducerId) -> Self {
Self {
data_producer_id,
ordered: None,
max_packet_life_time: None,
max_retransmits: None,
app_data: AppData::default(),
}
}
#[must_use]
pub fn new_direct(data_producer_id: DataProducerId) -> Self {
Self {
data_producer_id,
ordered: Some(true),
max_packet_life_time: None,
max_retransmits: None,
app_data: AppData::default(),
}
}
#[must_use]
pub fn new_sctp_ordered(data_producer_id: DataProducerId) -> Self {
Self {
data_producer_id,
ordered: None,
max_packet_life_time: None,
max_retransmits: None,
app_data: AppData::default(),
}
}
#[must_use]
pub fn new_sctp_unordered_with_life_time(
data_producer_id: DataProducerId,
max_packet_life_time: u16,
) -> Self {
Self {
data_producer_id,
ordered: None,
max_packet_life_time: Some(max_packet_life_time),
max_retransmits: None,
app_data: AppData::default(),
}
}
#[must_use]
pub fn new_sctp_unordered_with_retransmits(
data_producer_id: DataProducerId,
max_retransmits: u16,
) -> Self {
Self {
data_producer_id,
ordered: None,
max_packet_life_time: None,
max_retransmits: Some(max_retransmits),
app_data: AppData::default(),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
#[doc(hidden)]
#[non_exhaustive]
pub struct DataConsumerDump {
pub id: DataConsumerId,
pub data_producer_id: DataProducerId,
pub r#type: DataConsumerType,
pub label: String,
pub protocol: String,
pub sctp_stream_parameters: Option<SctpStreamParameters>,
pub buffered_amount_low_threshold: u32,
}
#[derive(Debug, Clone, PartialOrd, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
#[allow(missing_docs)]
pub struct DataConsumerStat {
pub timestamp: u64,
pub label: String,
pub protocol: String,
pub messages_sent: usize,
pub bytes_sent: usize,
pub buffered_amount: u32,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum DataConsumerType {
Sctp,
Direct,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "event", rename_all = "lowercase", content = "data")]
enum Notification {
DataProducerClose,
SctpSendBufferFull,
#[serde(rename_all = "camelCase")]
BufferedAmountLow {
buffered_amount: u32,
},
}
#[derive(Debug, Deserialize)]
#[serde(tag = "event", rename_all = "lowercase", content = "data")]
enum PayloadNotification {
Message { ppid: u32 },
}
#[derive(Default)]
#[allow(clippy::type_complexity)]
struct Handlers {
message: Bag<Arc<dyn Fn(&WebRtcMessage<'_>) + Send + Sync>>,
sctp_send_buffer_full: Bag<Arc<dyn Fn() + Send + Sync>>,
buffered_amount_low: Bag<Arc<dyn Fn(u32) + Send + Sync>>,
data_producer_close: BagOnce<Box<dyn FnOnce() + Send>>,
transport_close: BagOnce<Box<dyn FnOnce() + Send>>,
close: BagOnce<Box<dyn FnOnce() + Send>>,
}
struct Inner {
id: DataConsumerId,
r#type: DataConsumerType,
sctp_stream_parameters: Option<SctpStreamParameters>,
label: String,
protocol: String,
data_producer_id: DataProducerId,
direct: bool,
executor: Arc<Executor<'static>>,
channel: Channel,
payload_channel: PayloadChannel,
handlers: Arc<Handlers>,
app_data: AppData,
transport: Arc<dyn Transport>,
weak_data_producer: WeakDataProducer,
closed: Arc<AtomicBool>,
_subscription_handlers: Mutex<Vec<Option<SubscriptionHandler>>>,
_on_transport_close_handler: Mutex<HandlerId>,
}
impl Drop for Inner {
fn drop(&mut self) {
debug!("drop()");
self.close(true);
}
}
impl Inner {
fn close(&self, close_request: bool) {
if !self.closed.swap(true, Ordering::SeqCst) {
debug!("close()");
self.handlers.close.call_simple();
if close_request {
let channel = self.channel.clone();
let transport_id = self.transport.id();
let request = DataConsumerCloseRequest {
data_consumer_id: self.id,
};
let weak_data_producer = self.weak_data_producer.clone();
self.executor
.spawn(async move {
if weak_data_producer.upgrade().is_some() {
if let Err(error) = channel.request(transport_id, request).await {
error!("consumer closing failed on drop: {}", error);
}
}
})
.detach();
}
}
}
}
#[derive(Clone)]
#[must_use = "Data consumer will be closed on drop, make sure to keep it around for as long as needed"]
pub struct RegularDataConsumer {
inner: Arc<Inner>,
}
impl fmt::Debug for RegularDataConsumer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RegularDataConsumer")
.field("id", &self.inner.id)
.field("type", &self.inner.r#type)
.field("sctp_stream_parameters", &self.inner.sctp_stream_parameters)
.field("label", &self.inner.label)
.field("protocol", &self.inner.protocol)
.field("data_producer_id", &self.inner.data_producer_id)
.field("transport", &self.inner.transport)
.field("closed", &self.inner.closed)
.finish()
}
}
impl From<RegularDataConsumer> for DataConsumer {
fn from(producer: RegularDataConsumer) -> Self {
DataConsumer::Regular(producer)
}
}
#[derive(Clone)]
#[must_use = "Data consumer will be closed on drop, make sure to keep it around for as long as needed"]
pub struct DirectDataConsumer {
inner: Arc<Inner>,
}
impl fmt::Debug for DirectDataConsumer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DirectDataConsumer")
.field("id", &self.inner.id)
.field("type", &self.inner.r#type)
.field("sctp_stream_parameters", &self.inner.sctp_stream_parameters)
.field("label", &self.inner.label)
.field("protocol", &self.inner.protocol)
.field("data_producer_id", &self.inner.data_producer_id)
.field("transport", &self.inner.transport)
.field("closed", &self.inner.closed)
.finish()
}
}
impl From<DirectDataConsumer> for DataConsumer {
fn from(producer: DirectDataConsumer) -> Self {
DataConsumer::Direct(producer)
}
}
#[derive(Clone)]
#[non_exhaustive]
#[must_use = "Data consumer will be closed on drop, make sure to keep it around for as long as needed"]
pub enum DataConsumer {
Regular(RegularDataConsumer),
Direct(DirectDataConsumer),
}
impl fmt::Debug for DataConsumer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self {
DataConsumer::Regular(producer) => f.debug_tuple("Regular").field(&producer).finish(),
DataConsumer::Direct(producer) => f.debug_tuple("Direct").field(&producer).finish(),
}
}
}
impl DataConsumer {
#[allow(clippy::too_many_arguments)]
pub(super) fn new(
id: DataConsumerId,
r#type: DataConsumerType,
sctp_stream_parameters: Option<SctpStreamParameters>,
label: String,
protocol: String,
data_producer: DataProducer,
executor: Arc<Executor<'static>>,
channel: Channel,
payload_channel: PayloadChannel,
app_data: AppData,
transport: Arc<dyn Transport>,
direct: bool,
) -> Self {
debug!("new()");
let handlers = Arc::<Handlers>::default();
let closed = Arc::new(AtomicBool::new(false));
let inner_weak = Arc::<Mutex<Option<Weak<Inner>>>>::default();
let subscription_handler = {
let handlers = Arc::clone(&handlers);
let closed = Arc::clone(&closed);
let inner_weak = Arc::clone(&inner_weak);
channel.subscribe_to_notifications(id.into(), move |notification| {
match serde_json::from_slice::<Notification>(notification) {
Ok(notification) => match notification {
Notification::DataProducerClose => {
if !closed.load(Ordering::SeqCst) {
handlers.data_producer_close.call_simple();
let maybe_inner =
inner_weak.lock().as_ref().and_then(Weak::upgrade);
if let Some(inner) = maybe_inner {
inner
.executor
.clone()
.spawn(async move {
inner.close(false);
})
.detach();
}
}
}
Notification::SctpSendBufferFull => {
handlers.sctp_send_buffer_full.call_simple();
}
Notification::BufferedAmountLow { buffered_amount } => {
handlers.buffered_amount_low.call(|callback| {
callback(buffered_amount);
});
}
},
Err(error) => {
error!("Failed to parse notification: {}", error);
}
}
})
};
let payload_subscription_handler = {
let handlers = Arc::clone(&handlers);
payload_channel.subscribe_to_notifications(id.into(), move |message, payload| {
match serde_json::from_slice::<PayloadNotification>(message) {
Ok(notification) => match notification {
PayloadNotification::Message { ppid } => {
match WebRtcMessage::new(ppid, Cow::from(payload)) {
Ok(message) => {
handlers.message.call(|callback| {
callback(&message);
});
}
Err(ppid) => {
error!("Bad ppid {}", ppid);
}
}
}
},
Err(error) => {
error!("Failed to parse payload notification: {}", error);
}
}
})
};
let on_transport_close_handler = transport.on_close({
let inner_weak = Arc::clone(&inner_weak);
Box::new(move || {
let maybe_inner = inner_weak.lock().as_ref().and_then(Weak::upgrade);
if let Some(inner) = maybe_inner {
inner.handlers.transport_close.call_simple();
inner.close(false);
}
})
});
let inner = Arc::new(Inner {
id,
r#type,
sctp_stream_parameters,
label,
protocol,
data_producer_id: data_producer.id(),
direct,
executor,
channel,
payload_channel,
handlers,
app_data,
transport,
weak_data_producer: data_producer.downgrade(),
closed,
_subscription_handlers: Mutex::new(vec![
subscription_handler,
payload_subscription_handler,
]),
_on_transport_close_handler: Mutex::new(on_transport_close_handler),
});
inner_weak.lock().replace(Arc::downgrade(&inner));
if direct {
Self::Direct(DirectDataConsumer { inner })
} else {
Self::Regular(RegularDataConsumer { inner })
}
}
#[must_use]
pub fn id(&self) -> DataConsumerId {
self.inner().id
}
#[must_use]
pub fn data_producer_id(&self) -> DataProducerId {
self.inner().data_producer_id
}
pub fn transport(&self) -> &Arc<dyn Transport> {
&self.inner().transport
}
#[must_use]
pub fn r#type(&self) -> DataConsumerType {
self.inner().r#type
}
#[must_use]
pub fn sctp_stream_parameters(&self) -> Option<SctpStreamParameters> {
self.inner().sctp_stream_parameters
}
#[must_use]
pub fn label(&self) -> &String {
&self.inner().label
}
#[must_use]
pub fn protocol(&self) -> &String {
&self.inner().protocol
}
#[must_use]
pub fn app_data(&self) -> &AppData {
&self.inner().app_data
}
#[must_use]
pub fn closed(&self) -> bool {
self.inner().closed.load(Ordering::SeqCst)
}
#[doc(hidden)]
pub async fn dump(&self) -> Result<DataConsumerDump, RequestError> {
debug!("dump()");
self.inner()
.channel
.request(self.id(), DataConsumerDumpRequest {})
.await
}
pub async fn get_stats(&self) -> Result<Vec<DataConsumerStat>, RequestError> {
debug!("get_stats()");
self.inner()
.channel
.request(self.id(), DataConsumerGetStatsRequest {})
.await
}
pub async fn get_buffered_amount(&self) -> Result<u32, RequestError> {
debug!("get_buffered_amount()");
let response = self
.inner()
.channel
.request(self.id(), DataConsumerGetBufferedAmountRequest {})
.await?;
Ok(response.buffered_amount)
}
pub async fn set_buffered_amount_low_threshold(
&self,
threshold: u32,
) -> Result<(), RequestError> {
debug!(
"set_buffered_amount_low_threshold() [threshold:{}]",
threshold
);
self.inner()
.channel
.request(
self.id(),
DataConsumerSetBufferedAmountLowThresholdRequest { threshold },
)
.await
}
pub fn on_message<F: Fn(&WebRtcMessage<'_>) + Send + Sync + 'static>(
&self,
callback: F,
) -> HandlerId {
self.inner().handlers.message.add(Arc::new(callback))
}
pub fn on_sctp_send_buffer_full<F: Fn() + Send + Sync + 'static>(
&self,
callback: F,
) -> HandlerId {
self.inner()
.handlers
.sctp_send_buffer_full
.add(Arc::new(callback))
}
pub fn on_buffered_amount_low<F: Fn(u32) + Send + Sync + 'static>(
&self,
callback: F,
) -> HandlerId {
self.inner()
.handlers
.buffered_amount_low
.add(Arc::new(callback))
}
pub fn on_data_producer_close<F: FnOnce() + Send + 'static>(&self, callback: F) -> HandlerId {
self.inner()
.handlers
.data_producer_close
.add(Box::new(callback))
}
pub fn on_transport_close<F: FnOnce() + Send + 'static>(&self, callback: F) -> HandlerId {
self.inner()
.handlers
.transport_close
.add(Box::new(callback))
}
pub fn on_close<F: FnOnce() + Send + 'static>(&self, callback: F) -> HandlerId {
let handler_id = self.inner().handlers.close.add(Box::new(callback));
if self.inner().closed.load(Ordering::Relaxed) {
self.inner().handlers.close.call_simple();
}
handler_id
}
#[must_use]
pub fn downgrade(&self) -> WeakDataConsumer {
WeakDataConsumer {
inner: Arc::downgrade(self.inner()),
}
}
fn inner(&self) -> &Arc<Inner> {
match self {
DataConsumer::Regular(data_consumer) => &data_consumer.inner,
DataConsumer::Direct(data_consumer) => &data_consumer.inner,
}
}
}
impl DirectDataConsumer {
pub async fn send(&self, message: WebRtcMessage<'_>) -> Result<(), RequestError> {
let (ppid, payload) = message.into_ppid_and_payload();
self.inner
.payload_channel
.request(
self.inner.id,
DataConsumerSendRequest { ppid },
payload.into_owned(),
)
.await
}
}
#[derive(Clone)]
pub struct WeakDataConsumer {
inner: Weak<Inner>,
}
impl fmt::Debug for WeakDataConsumer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeakDataConsumer").finish()
}
}
impl WeakDataConsumer {
#[must_use]
pub fn upgrade(&self) -> Option<DataConsumer> {
let inner = self.inner.upgrade()?;
let data_consumer = if inner.direct {
DataConsumer::Direct(DirectDataConsumer { inner })
} else {
DataConsumer::Regular(RegularDataConsumer { inner })
};
Some(data_consumer)
}
}