use std::{
collections::HashMap,
num::NonZeroU64,
pin::{pin, Pin},
sync::Mutex,
};
use bincode::{Decode, Encode};
use futures::{
channel::mpsc::{self},
Sink, SinkExt, Stream, StreamExt,
};
use super::{
message::{Message, MessageSink, MessageStream},
termination::{CoTerminatingSet, ConnectionTerminationReason},
};
use crate::utils::{BoolUtils, Generator};
#[derive(Encode, Decode, Debug, Eq, PartialEq, Hash, Copy, Clone)]
pub enum ChannelId {
UpstreamOriginated(NonZeroU64),
DownstreamOriginated(NonZeroU64),
}
#[derive(Encode, Decode, Debug)]
pub enum MultiplexedChannelEvent {
Open(ChannelId),
Message(ChannelId, Message),
Close(ChannelId),
}
pub(crate) fn multiplex(
upstream: bool,
mut incoming_channel_iterator: impl Iterator<Item = (MessageSink, MessageStream)> + Send + Sync + 'static,
outgoing_channel_stream: impl Stream<Item = (MessageSink, MessageStream)> + Send + Sync + 'static,
termination: CoTerminatingSet,
) -> (impl Sink<MultiplexedChannelEvent, Error = mpsc::SendError> + Send + Sync, impl Stream<Item = MultiplexedChannelEvent> + Send + Sync) {
let (input_message_sender, mut input_message_receiver) = mpsc::channel::<MultiplexedChannelEvent>(8);
let (mut output_message_sender, output_message_receiver) = mpsc::channel::<MultiplexedChannelEvent>(8);
(
input_message_sender, output_message_receiver.with_generator(async move {
let sinks = Mutex::new(HashMap::new());
let (output_message_stream_sender, output_message_stream_receiver) =
mpsc::unbounded::<Pin<Box<dyn Stream<Item = MultiplexedChannelEvent> + Send + Sync>>>();
let process_output_messages = {
let termination = termination.clone();
async move {
let mut outgoing_events_stream = output_message_stream_receiver.flatten_unordered(None);
while let Some(event) = outgoing_events_stream.next().await {
if let Err(error) = output_message_sender.send(event).await {
termination.terminate(ConnectionTerminationReason::SeriousError(format!("Can't send outgoing message: {:?}", error)));
return;
}
}
termination.terminate(ConnectionTerminationReason::SeriousError("output message stream sender dropped".to_string()));
}
};
let process_input_messages = {
let mut output_message_stream_sender_for_input = output_message_stream_sender.clone();
let sinks = &sinks;
let termination = termination.clone();
async move {
while let Some(event) = input_message_receiver.next().await {
match event {
MultiplexedChannelEvent::Open(channel_id) => {
if upstream != matches!(channel_id, ChannelId::UpstreamOriginated(_)) {
termination.terminate(ConnectionTerminationReason::SeriousError(
"Received channel Open action with wrong channel id direction".to_string(),
));
return;
};
let (new_sink, new_stream) = incoming_channel_iterator.next().unwrap();
if sinks.lock().unwrap().insert(channel_id, Some(new_sink)).is_some() {
termination.terminate(ConnectionTerminationReason::SeriousError("Double channel register".to_string()));
return;
}
if let Err(error) = output_message_stream_sender_for_input
.send(Box::pin(new_stream.map(move |message| MultiplexedChannelEvent::Message(channel_id, message))))
.await
{
termination.terminate(ConnectionTerminationReason::SeriousError(format!("Can't send incoming message: {:?}", error)));
return;
}
}
MultiplexedChannelEvent::Close(channel_id) => {
if upstream != matches!(channel_id, ChannelId::UpstreamOriginated(_)) {
termination.terminate(ConnectionTerminationReason::SeriousError(
"Received channel Close action with wrong channel id direction".to_string(),
));
return;
};
if sinks.lock().unwrap().remove(&channel_id).is_none() {
termination.terminate(ConnectionTerminationReason::SeriousError("Unregister on unknown channel".to_string()));
return;
}
}
MultiplexedChannelEvent::Message(channel_id, message) => {
let sink = { sinks.lock().expect("expect#1").get_mut(&channel_id).map(|opt| opt.take()) };
if let Some(sink) = sink {
let mut sink = sink.expect("expect#2");
sink.send(message).await.ok();
*sinks.lock().unwrap().get_mut(&channel_id).unwrap() = Some(sink);
} else {
termination.terminate(ConnectionTerminationReason::SeriousError("Received message for unknown channel".to_string()));
return;
};
}
};
}
termination.terminate(ConnectionTerminationReason::SeriousError("Input message sender dropped".to_string()));
}
};
let mut outgoing_channel_stream = std::pin::pin!(outgoing_channel_stream);
let process_outgoing_channels = async {
let mut channel_id_sequence = 0u64;
let mut output_message_stream_sender_for_output = output_message_stream_sender.clone();
let termination = termination.clone();
while let Some((new_sink, new_stream)) = outgoing_channel_stream.next().await {
channel_id_sequence += 1;
let channel_id = if upstream {
ChannelId::DownstreamOriginated(channel_id_sequence.try_into().unwrap())
} else {
ChannelId::UpstreamOriginated(channel_id_sequence.try_into().unwrap())
};
sinks.lock().unwrap().insert(channel_id, Some(Box::pin(new_sink))).is_none().assert_true();
if let Err(error) = output_message_stream_sender_for_output
.send(Box::pin(
Box::pin(futures::stream::once(async move { MultiplexedChannelEvent::Open(channel_id) }))
.chain(new_stream.map(move |message| MultiplexedChannelEvent::Message(channel_id, message)))
.chain(futures::stream::once(async move { MultiplexedChannelEvent::Close(channel_id) })),
))
.await
{
termination.terminate(ConnectionTerminationReason::SeriousError(format!("Can't send outgoing message stream: {:?}", error)));
return;
};
}
termination.terminate(ConnectionTerminationReason::SeriousError("outgoing channel stream ended".to_string()));
};
futures::future::select(
futures::future::select(pin!(process_input_messages), pin!(process_output_messages)),
futures::future::select(pin!(process_outgoing_channels), pin!(termination.reason())),
)
.await;
}),
)
}