use crate::messages::Request;
use crate::worker::common::{EventHandlers, SubscriptionTarget, WeakEventHandlers};
use crate::worker::utils;
use crate::worker::utils::{PreparedChannelRead, PreparedChannelWrite};
use crate::worker::{RequestError, SubscriptionHandler};
use atomic_take::AtomicTake;
use hash_hasher::HashedMap;
use log::{debug, error, trace, warn};
use lru::LruCache;
use mediasoup_sys::UvAsyncT;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::sync::{Arc, Weak};
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub(super) enum InternalMessage {
#[serde(skip)]
Debug(String),
#[serde(skip)]
Warn(String),
#[serde(skip)]
Error(String),
#[serde(skip)]
Dump(String),
#[serde(skip)]
Unexpected(Vec<u8>),
}
pub(crate) struct BufferMessagesGuard {
target_id: SubscriptionTarget,
buffered_notifications_for: Arc<Mutex<HashedMap<SubscriptionTarget, Vec<Vec<u8>>>>>,
event_handlers_weak: WeakEventHandlers<Arc<dyn Fn(&[u8]) + Send + Sync + 'static>>,
}
impl Drop for BufferMessagesGuard {
fn drop(&mut self) {
let mut buffered_notifications_for = self.buffered_notifications_for.lock();
if let Some(notifications) = buffered_notifications_for.remove(&self.target_id) {
if let Some(event_handlers) = self.event_handlers_weak.upgrade() {
for notification in notifications {
event_handlers.call_callbacks_with_single_value(&self.target_id, ¬ification);
}
}
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ChannelReceiveMessage {
#[serde(rename_all = "camelCase")]
Notification {
target_id: SubscriptionTarget,
},
ResponseSuccess {
id: u32,
#[allow(dead_code)]
accepted: bool,
data: Option<Value>,
},
ResponseError {
id: u32,
reason: String,
},
Event(InternalMessage),
}
fn deserialize_message(bytes: &[u8]) -> ChannelReceiveMessage {
match bytes[0] {
b'{' => serde_json::from_slice(bytes).unwrap(),
b'D' => ChannelReceiveMessage::Event(InternalMessage::Debug(
String::from_utf8(Vec::from(&bytes[1..])).unwrap(),
)),
b'W' => ChannelReceiveMessage::Event(InternalMessage::Warn(
String::from_utf8(Vec::from(&bytes[1..])).unwrap(),
)),
b'E' => ChannelReceiveMessage::Event(InternalMessage::Error(
String::from_utf8(Vec::from(&bytes[1..])).unwrap(),
)),
b'X' => ChannelReceiveMessage::Event(InternalMessage::Dump(
String::from_utf8(Vec::from(&bytes[1..])).unwrap(),
)),
_ => ChannelReceiveMessage::Event(InternalMessage::Unexpected(Vec::from(bytes))),
}
}
#[derive(Debug, Serialize)]
struct RequestMessage<'a, R: Serialize> {
id: u32,
method: &'static str,
#[serde(flatten)]
request: &'a R,
}
struct ResponseError {
reason: String,
}
type ResponseResult<T> = Result<Option<T>, ResponseError>;
struct RequestDropGuard<'a> {
id: u32,
message: Arc<AtomicTake<Vec<u8>>>,
channel: &'a Channel,
removed: bool,
}
impl<'a> Drop for RequestDropGuard<'a> {
fn drop(&mut self) {
if self.removed {
return;
}
self.message.take();
if let Some(requests_container) = self.channel.inner.requests_container_weak.upgrade() {
requests_container.lock().handlers.remove(&self.id);
}
}
}
impl<'a> RequestDropGuard<'a> {
fn remove(mut self) {
self.removed = true;
}
}
#[derive(Default)]
struct RequestsContainer {
next_id: u32,
handlers: HashedMap<u32, async_oneshot::Sender<ResponseResult<Value>>>,
}
struct OutgoingMessageBuffer {
handle: Option<UvAsyncT>,
messages: VecDeque<Arc<AtomicTake<Vec<u8>>>>,
}
struct Inner {
outgoing_message_buffer: Arc<Mutex<OutgoingMessageBuffer>>,
internal_message_receiver: async_channel::Receiver<InternalMessage>,
requests_container_weak: Weak<Mutex<RequestsContainer>>,
buffered_notifications_for: Arc<Mutex<HashedMap<SubscriptionTarget, Vec<Vec<u8>>>>>,
event_handlers_weak: WeakEventHandlers<Arc<dyn Fn(&[u8]) + Send + Sync + 'static>>,
}
impl Drop for Inner {
fn drop(&mut self) {
self.internal_message_receiver.close();
}
}
#[derive(Clone)]
pub(crate) struct Channel {
inner: Arc<Inner>,
}
impl Channel {
pub(super) fn new() -> (Self, PreparedChannelRead, PreparedChannelWrite) {
let outgoing_message_buffer = Arc::new(Mutex::new(OutgoingMessageBuffer {
handle: None,
messages: VecDeque::with_capacity(10),
}));
let requests_container = Arc::<Mutex<RequestsContainer>>::default();
let requests_container_weak = Arc::downgrade(&requests_container);
let buffered_notifications_for =
Arc::<Mutex<HashedMap<SubscriptionTarget, Vec<Vec<u8>>>>>::default();
let event_handlers = EventHandlers::new();
let event_handlers_weak = event_handlers.downgrade();
let prepared_channel_read = utils::prepare_channel_read_fn({
let outgoing_message_buffer = Arc::clone(&outgoing_message_buffer);
move |handle| {
let mut outgoing_message_buffer = outgoing_message_buffer.lock();
if outgoing_message_buffer.handle.is_none() {
outgoing_message_buffer.handle.replace(handle);
}
while let Some(maybe_message) = outgoing_message_buffer.messages.pop_front() {
if let Some(message) = maybe_message.take() {
return Some(message);
}
}
None
}
});
let (internal_message_sender, internal_message_receiver) = async_channel::unbounded();
let prepared_channel_write = utils::prepare_channel_write_fn({
let buffered_notifications_for = Arc::clone(&buffered_notifications_for);
let mut non_buffered_notifications = LruCache::<SubscriptionTarget, ()>::new(1000);
move |message| {
trace!("received raw message: {}", String::from_utf8_lossy(message));
match deserialize_message(message) {
ChannelReceiveMessage::Notification { target_id } => {
if !non_buffered_notifications.contains(&target_id) {
let mut buffer_notifications_for = buffered_notifications_for.lock();
if let Some(list) = buffer_notifications_for.get_mut(&target_id) {
list.push(Vec::from(message));
return;
}
non_buffered_notifications.put(target_id, ());
}
event_handlers.call_callbacks_with_single_value(&target_id, message);
}
ChannelReceiveMessage::ResponseSuccess { id, data, .. } => {
let sender = requests_container.lock().handlers.remove(&id);
if let Some(mut sender) = sender {
let _ = sender.send(Ok(data));
} else {
warn!(
"received success response does not match any sent request [id:{}]",
id,
);
}
}
ChannelReceiveMessage::ResponseError { id, reason } => {
let sender = requests_container.lock().handlers.remove(&id);
if let Some(mut sender) = sender {
let _ = sender.send(Err(ResponseError { reason }));
} else {
warn!(
"received error response does not match any sent request [id:{}]",
id,
);
}
}
ChannelReceiveMessage::Event(event_message) => {
let _ = internal_message_sender.try_send(event_message);
}
}
}
});
let inner = Arc::new(Inner {
outgoing_message_buffer,
internal_message_receiver,
requests_container_weak,
buffered_notifications_for,
event_handlers_weak,
});
(
Self { inner },
prepared_channel_read,
prepared_channel_write,
)
}
pub(super) fn get_internal_message_receiver(&self) -> async_channel::Receiver<InternalMessage> {
self.inner.internal_message_receiver.clone()
}
pub(crate) fn buffer_messages_for(&self, target_id: SubscriptionTarget) -> BufferMessagesGuard {
let buffered_notifications_for = Arc::clone(&self.inner.buffered_notifications_for);
let event_handlers_weak = self.inner.event_handlers_weak.clone();
buffered_notifications_for
.lock()
.entry(target_id)
.or_default();
BufferMessagesGuard {
target_id,
buffered_notifications_for,
event_handlers_weak,
}
}
pub(crate) async fn request<R>(&self, request: R) -> Result<R::Response, RequestError>
where
R: Request,
{
let method = request.as_method();
let id;
let (result_sender, result_receiver) = async_oneshot::oneshot();
{
let requests_container_lock = self
.inner
.requests_container_weak
.upgrade()
.ok_or(RequestError::ChannelClosed)?;
let mut requests_container = requests_container_lock.lock();
id = requests_container.next_id;
requests_container.next_id = requests_container.next_id.wrapping_add(1);
requests_container.handlers.insert(id, result_sender);
}
debug!("request() [method:{}, id:{}]: {:?}", method, id, request);
let message = Arc::new(AtomicTake::new(
serde_json::to_vec(&RequestMessage {
id,
method,
request: &request,
})
.unwrap(),
));
{
let mut outgoing_message_buffer = self.inner.outgoing_message_buffer.lock();
outgoing_message_buffer
.messages
.push_back(Arc::clone(&message));
if let Some(handle) = &outgoing_message_buffer.handle {
unsafe {
let ret = mediasoup_sys::uv_async_send(*handle);
if ret != 0 {
error!("uv_async_send call failed with code {}", ret);
return Err(RequestError::ChannelClosed);
}
}
}
}
let request_drop_guard = RequestDropGuard {
id,
message,
channel: self,
removed: false,
};
let response_result_fut = result_receiver.await;
request_drop_guard.remove();
match response_result_fut.map_err(|_| RequestError::ChannelClosed {})? {
Ok(data) => {
debug!("request succeeded [method:{}, id:{}]", method, id);
serde_json::from_value(data.unwrap_or_default()).map_err(|error| {
RequestError::FailedToParse {
error: error.to_string(),
}
})
}
Err(ResponseError { reason }) => {
debug!("request failed [method:{}, id:{}]: {}", method, id, reason);
Err(RequestError::Response { reason })
}
}
}
pub(crate) fn subscribe_to_notifications<F>(
&self,
target_id: SubscriptionTarget,
callback: F,
) -> Option<SubscriptionHandler>
where
F: Fn(&[u8]) + Send + Sync + 'static,
{
self.inner
.event_handlers_weak
.upgrade()
.map(|event_handlers| event_handlers.add(target_id, Arc::new(callback)))
}
}