#![forbid(unsafe_code)]
#![warn(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
trivial_casts,
trivial_numeric_casts,
unreachable_pub,
unsafe_code,
unstable_features,
unused_import_braces,
unused_qualifications,
rust_2018_idioms
)]
mod channel;
mod error;
mod stream_dropper;
use crate::stream_dropper::*;
use channel::Channel;
pub use error::MultiplexerError;
use async_mutex::Mutex;
use dashmap::DashMap;
use futures_util::future::FutureExt;
use futures_util::sink::{Sink, SinkExt};
use futures_util::stream::{FuturesUnordered, StreamExt, TryStream};
use parking_lot::RwLock;
use sharded_slab::Slab;
use std::sync::Arc;
pub type StreamId = usize;
pub type ChannelId = usize;
struct ChannelChange<St> {
pub(crate) next_channel_id: ChannelId,
pub(crate) stream_id: StreamId,
pub(crate) stream: St,
}
type AddStreamToChannelTx<St> = async_channel::Sender<StreamDropper<St>>;
#[derive(Debug)]
pub struct Multiplexer<St, Si>
where
St: 'static,
{
stream_controls: Arc<DashMap<StreamId, StreamDropControl>>,
sinks: Arc<RwLock<Slab<Arc<Mutex<Si>>>>>,
channels: Arc<DashMap<ChannelId, (AddStreamToChannelTx<St>, Mutex<Channel<St>>)>>,
channel_change_rx: async_channel::Receiver<ChannelChange<St>>,
channel_change_tx: async_channel::Sender<ChannelChange<St>>,
}
impl<St, Si> Clone for Multiplexer<St, Si> {
fn clone(&self) -> Self {
let stream_controls = Arc::clone(&self.stream_controls);
let sinks = Arc::clone(&self.sinks);
let channels = Arc::clone(&self.channels);
let channel_change_rx = self.channel_change_rx.clone();
let channel_change_tx = self.channel_change_tx.clone();
Self {
stream_controls,
sinks,
channels,
channel_change_rx,
channel_change_tx,
}
}
}
impl<St, Si> Multiplexer<St, Si>
where
St: TryStream + Unpin,
St::Ok: Clone,
Si: Sink<St::Ok> + Unpin,
Si::Error: std::fmt::Debug,
{
pub fn new() -> Self {
let (channel_change_tx, channel_change_rx) = async_channel::unbounded();
Self {
channel_change_rx,
channel_change_tx,
channels: Arc::new(DashMap::new()),
sinks: Arc::new(RwLock::new(Slab::new())),
stream_controls: Arc::new(DashMap::new()),
}
}
pub fn add_channel(
&mut self,
channel_id: ChannelId,
) -> Result<(), MultiplexerError<Si::Error>> {
if self.has_channel(channel_id) {
return Err(MultiplexerError::DuplicateChannel(channel_id));
}
let (add_tx, channel) = Channel::new(channel_id);
let channel = Mutex::new(channel);
self.channels.insert(channel_id, (add_tx, channel));
Ok(())
}
#[inline]
pub fn has_channel(&self, channel: ChannelId) -> bool {
self.channels.contains_key(&channel)
}
pub fn remove_channel(&mut self, channel: ChannelId) -> bool {
self.channels.remove(&channel)
}
pub async fn send(
&self,
stream_ids: impl IntoIterator<Item = StreamId>,
data: St::Ok,
) -> Vec<Result<StreamId, MultiplexerError<Si::Error>>> {
let futures: FuturesUnordered<_> = stream_ids
.into_iter()
.map(|stream_id| {
let data = data.clone();
async move {
let sink = {
let slab_read_guard = self.sinks.read(); let sink = slab_read_guard.get(stream_id); sink.map(|s| Arc::clone(&s))
};
match sink {
Some(sink) => {
sink.lock()
.await
.send(data)
.await
.map(|()| stream_id)
.map_err(|err| MultiplexerError::SendError(stream_id, err))
}
None => {
Err(MultiplexerError::UnknownStream(stream_id))
}
}
}
})
.collect();
futures.collect().await
}
pub fn add_stream_pair(
&mut self,
sink: Si,
stream: St,
channel_id: ChannelId,
) -> Result<StreamId, MultiplexerError<Si::Error>> {
let channel = self
.channels
.get(&channel_id)
.ok_or_else(|| MultiplexerError::UnknownChannel(channel_id))?;
let stream_id = self
.sinks
.write()
.insert(Arc::new(Mutex::new(sink)))
.ok_or_else(move || MultiplexerError::ChannelFull(channel_id))?;
let (stream_control, stream_dropper) =
StreamDropControl::wrap(stream_id, stream, self.channel_change_tx.clone());
self.stream_controls.insert(stream_id, stream_control);
channel
.0
.try_send(stream_dropper)
.map_err(|_err| MultiplexerError::ChannelAdd(stream_id, channel_id))?;
Ok(stream_id)
}
pub fn change_stream_channel(
&self,
stream_id: StreamId,
channel_id: ChannelId,
) -> Result<(), MultiplexerError<Si::Error>> {
if !self.has_channel(channel_id) {
return Err(MultiplexerError::UnknownChannel(channel_id));
}
let control = self
.stream_controls
.get(&stream_id)
.ok_or_else(|| MultiplexerError::UnknownStream(stream_id))?;
control.change_channel(channel_id);
Ok(())
}
pub fn remove_stream(&mut self, stream_id: StreamId) -> bool {
log::trace!("Attempting to remove stream {}", stream_id);
let res = self.sinks.write().take(stream_id).is_some();
log::trace!("Stream {} exists", stream_id);
if let Some(control) = self.stream_controls.get(&stream_id) {
log::trace!("Stream {} is getting dropped()", stream_id);
control.drop_stream();
}
res
}
pub async fn recv(
&mut self,
channel_id: ChannelId,
) -> Result<(StreamId, Option<St::Item>), MultiplexerError<Si::Error>> {
log::debug!("recv({})", channel_id);
match self.channels.get(&channel_id) {
Some(channel_guard) => {
log::debug!("recv({}) before loop {{}}", channel_id);
let mut channel = channel_guard.value().1.lock().await;
loop {
log::debug!("recv({}) awaiting the select", channel_id);
futures_util::select! {
channel_next = channel.next().fuse() => {
return Ok(channel_next);
}
add_res = self.channel_change_rx.next().fuse() => {
log::debug!("recv({}) channel_change has message.", channel_id);
if let Some(add_res) = add_res {
self.process_add_channel(add_res)?;
}
}
}
}
}
None => Err(MultiplexerError::UnknownChannel(channel_id)),
}
}
fn process_add_channel(
&mut self,
channel_change: ChannelChange<St>,
) -> Result<(), MultiplexerError<Si::Error>> {
let ChannelChange {
next_channel_id,
stream_id,
stream,
} = channel_change;
let channel = self
.channels
.get(&next_channel_id)
.ok_or_else(|| MultiplexerError::ChannelAdd(stream_id, next_channel_id))?;
if self.sinks.read().contains(stream_id) {
let (stream_control, stream_dropper) =
StreamDropControl::wrap(stream_id, stream, self.channel_change_tx.clone());
self.stream_controls.insert(stream_id, stream_control);
channel
.0
.try_send(stream_dropper)
.map_err(|_err| MultiplexerError::ChannelAdd(stream_id, next_channel_id))?;
}
Ok(())
}
}