hakuban 0.8.5

Data-object sharing library
Documentation
use std::{
	collections::HashMap,
	num::NonZeroU64,
	pin::{pin, Pin},
	sync::Mutex,
};

use bincode::{Decode, Encode};
use futures::{
	channel::mpsc::{self},
	Sink, SinkExt, Stream, StreamExt,
};

use super::{
	message::{Message, MessageSink, MessageStream},
	termination::{CoTerminatingSet, ConnectionTerminationReason},
};
use crate::utils::{BoolUtils, Generator};

#[derive(Encode, Decode, Debug, Eq, PartialEq, Hash, Copy, Clone)]
pub enum ChannelId {
	UpstreamOriginated(NonZeroU64),
	DownstreamOriginated(NonZeroU64),
}

#[derive(Encode, Decode, Debug)]
pub enum MultiplexedChannelEvent {
	Open(ChannelId),
	Message(ChannelId, Message),
	Close(ChannelId),
}

//TODO: bring back the TryStream version, it was much prettier
pub(crate) fn multiplex(
	upstream: bool,
	mut incoming_channel_iterator: impl Iterator<Item = (MessageSink, MessageStream)> + Send + Sync + 'static,
	outgoing_channel_stream: impl Stream<Item = (MessageSink, MessageStream)> + Send + Sync + 'static,
	termination: CoTerminatingSet,
) -> (impl Sink<MultiplexedChannelEvent, Error = mpsc::SendError> + Send + Sync, impl Stream<Item = MultiplexedChannelEvent> + Send + Sync) {
	let (input_message_sender, mut input_message_receiver) = mpsc::channel::<MultiplexedChannelEvent>(8);
	let (mut output_message_sender, output_message_receiver) = mpsc::channel::<MultiplexedChannelEvent>(8);
	(
		input_message_sender, //.sink_map_err(|error| ConnectionTerminationReason::SeriousError(format!("mpsc error {:?}", error)) ),
		output_message_receiver.with_generator(async move {
			let sinks = Mutex::new(HashMap::new());

			let (output_message_stream_sender, output_message_stream_receiver) =
				mpsc::unbounded::<Pin<Box<dyn Stream<Item = MultiplexedChannelEvent> + Send + Sync>>>();

			let process_output_messages = {
				let termination = termination.clone();
				async move {
					let mut outgoing_events_stream = output_message_stream_receiver.flatten_unordered(None);
					while let Some(event) = outgoing_events_stream.next().await {
						if let Err(error) = output_message_sender.send(event).await {
							termination.terminate(ConnectionTerminationReason::SeriousError(format!("Can't send outgoing message: {:?}", error)));
							return;
						}
					}
					termination.terminate(ConnectionTerminationReason::SeriousError("output message stream sender dropped".to_string()));
				}
			};

			let process_input_messages = {
				let mut output_message_stream_sender_for_input = output_message_stream_sender.clone();
				let sinks = &sinks;
				let termination = termination.clone();
				async move {
					while let Some(event) = input_message_receiver.next().await {
						match event {
							MultiplexedChannelEvent::Open(channel_id) => {
								if upstream != matches!(channel_id, ChannelId::UpstreamOriginated(_)) {
									termination.terminate(ConnectionTerminationReason::SeriousError(
										"Received channel Open action with wrong channel id direction".to_string(),
									));
									return;
								};
								let (new_sink, new_stream) = incoming_channel_iterator.next().unwrap();
								if sinks.lock().unwrap().insert(channel_id, Some(new_sink)).is_some() {
									termination.terminate(ConnectionTerminationReason::SeriousError("Double channel register".to_string()));
									return;
								}
								if let Err(error) = output_message_stream_sender_for_input
									.send(Box::pin(new_stream.map(move |message| MultiplexedChannelEvent::Message(channel_id, message))))
									.await
								{
									termination.terminate(ConnectionTerminationReason::SeriousError(format!("Can't send incoming message: {:?}", error)));
									return;
								}
							}
							MultiplexedChannelEvent::Close(channel_id) => {
								if upstream != matches!(channel_id, ChannelId::UpstreamOriginated(_)) {
									termination.terminate(ConnectionTerminationReason::SeriousError(
										"Received channel Close action with wrong channel id direction".to_string(),
									));
									return;
								};
								if sinks.lock().unwrap().remove(&channel_id).is_none() {
									termination.terminate(ConnectionTerminationReason::SeriousError("Unregister on unknown channel".to_string()));
									return;
								}
							}
							MultiplexedChannelEvent::Message(channel_id, message) => {
								let sink = { sinks.lock().expect("expect#1").get_mut(&channel_id).map(|opt| opt.take()) };
								if let Some(sink) = sink {
									let mut sink = sink.expect("expect#2");
									//Err here means the channel-servicing lambda has already exited (channel got closed from this side), so we just ignore this late message
									sink.send(message).await.ok();
									*sinks.lock().unwrap().get_mut(&channel_id).unwrap() = Some(sink);
								} else {
									termination.terminate(ConnectionTerminationReason::SeriousError("Received message for unknown channel".to_string()));
									return;
								};
							}
						};
					}
					termination.terminate(ConnectionTerminationReason::SeriousError("Input message sender dropped".to_string()));
				}
			};

			let mut outgoing_channel_stream = std::pin::pin!(outgoing_channel_stream);
			let process_outgoing_channels = async {
				let mut channel_id_sequence = 0u64;
				let mut output_message_stream_sender_for_output = output_message_stream_sender.clone();
				let termination = termination.clone();
				//let mut outgoing_channel_stream_ref = std::pin::pin!(outgoing_channel_stream);
				while let Some((new_sink, new_stream)) = outgoing_channel_stream.next().await {
					channel_id_sequence += 1;
					let channel_id = if upstream {
						ChannelId::DownstreamOriginated(channel_id_sequence.try_into().unwrap())
					} else {
						ChannelId::UpstreamOriginated(channel_id_sequence.try_into().unwrap())
					};
					sinks.lock().unwrap().insert(channel_id, Some(Box::pin(new_sink))).is_none().assert_true();
					if let Err(error) = output_message_stream_sender_for_output
						.send(Box::pin(
							Box::pin(futures::stream::once(async move { MultiplexedChannelEvent::Open(channel_id) }))
								.chain(new_stream.map(move |message| MultiplexedChannelEvent::Message(channel_id, message)))
								.chain(futures::stream::once(async move { MultiplexedChannelEvent::Close(channel_id) })),
						))
						.await
					{
						termination.terminate(ConnectionTerminationReason::SeriousError(format!("Can't send outgoing message stream: {:?}", error)));
						return;
					};
				}
				termination.terminate(ConnectionTerminationReason::SeriousError("outgoing channel stream ended".to_string()));
			};

			futures::future::select(
				futures::future::select(pin!(process_input_messages), pin!(process_output_messages)),
				futures::future::select(pin!(process_outgoing_channels), pin!(termination.reason())),
			)
			.await;
		}),
	)
}