use std::convert::Infallible;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use futures::channel::oneshot::channel;
use futures::channel::oneshot::Canceled;
use futures::channel::oneshot::Sender;
use futures::select_biased;
use futures::sink::SinkExt as _;
use futures::stream::FusedStream;
use futures::task::Context;
use futures::task::Poll;
use futures::Future;
use futures::FutureExt as _;
use futures::Sink;
use futures::Stream;
use futures::StreamExt as _;
#[derive(Debug)]
pub enum Classification<U, C> {
UserMessage(U),
ControlMessage(C),
}
pub trait Message {
type UserMessage;
type ControlMessage;
fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage>;
fn is_error(user_message: &Self::UserMessage) -> bool;
}
type SharedState<M> = Arc<Mutex<Option<Sender<Option<Result<M, ()>>>>>>;
#[derive(Debug)]
pub struct MessageStream<S, M>
where
M: Message,
{
stream: S,
shared: SharedState<M::ControlMessage>,
}
impl<S, M> MessageStream<S, M>
where
M: Message,
{
fn inform_subscription(
shared: &SharedState<M::ControlMessage>,
message: Option<Result<M::ControlMessage, ()>>,
) {
let sender = shared
.lock()
.map_err(|err| err.into_inner())
.unwrap_or_else(|err| err)
.take();
if let Some(sender) = sender {
let _ = sender.send(message);
}
}
}
impl<S, M> Stream for MessageStream<S, M>
where
S: Stream<Item = M> + Unpin,
M: Message,
{
type Item = M::UserMessage;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let shared = self.shared.clone();
let this = self.get_mut();
loop {
match this.stream.poll_next_unpin(ctx) {
Poll::Pending => {
break Poll::Pending
},
Poll::Ready(None) => {
Self::inform_subscription(&shared, None);
break Poll::Ready(None)
},
Poll::Ready(Some(message)) => {
match message.classify() {
Classification::UserMessage(user_message) => {
if M::is_error(&user_message) {
Self::inform_subscription(&shared, Some(Err(())));
}
break Poll::Ready(Some(user_message))
},
Classification::ControlMessage(control_message) => {
Self::inform_subscription(&shared, Some(Ok(control_message)));
},
}
},
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
Stream::size_hint(&self.stream)
}
}
impl<S, M> FusedStream for MessageStream<S, M>
where
S: FusedStream<Item = M> + Unpin,
M: Message,
{
#[inline]
fn is_terminated(&self) -> bool {
self.stream.is_terminated()
}
}
#[derive(Debug)]
pub struct Subscription<S, M, I>
where
M: Message,
{
sink: S,
shared: SharedState<M::ControlMessage>,
_phantom: PhantomData<I>,
}
impl<S, M, I> Subscription<S, M, I>
where
S: Sink<I> + Unpin,
M: Message,
{
async fn with_channel<'slf, F, G, E>(
&'slf mut self,
f: F,
) -> Result<Option<Result<M::ControlMessage, ()>>, E>
where
F: FnOnce(&'slf mut S) -> G,
G: Future<Output = Result<(), E>>,
{
let (sender, receiver) = channel();
let _prev = self
.shared
.lock()
.map_err(|err| err.into_inner())
.unwrap_or_else(|err| err)
.replace(sender);
debug_assert!(_prev.is_none());
if let Err(err) = f(&mut self.sink).await {
let _prev = self
.shared
.lock()
.map_err(|err| err.into_inner())
.unwrap_or_else(|err| err)
.take();
debug_assert!(_prev.is_some());
return Err(err)
}
let result = receiver.await;
debug_assert!(self
.shared
.lock()
.map_err(|err| err.into_inner())
.unwrap_or_else(|err| err)
.is_none());
Ok(Result::<_, Canceled>::unwrap_or(result, None))
}
pub async fn send(&mut self, item: I) -> Result<Option<Result<M::ControlMessage, ()>>, S::Error> {
self
.with_channel(|sink| async move { sink.send(item).await })
.await
}
pub async fn read(&mut self) -> Option<Result<M::ControlMessage, ()>> {
let result = self.with_channel(|_sink| async { Ok(()) }).await;
Result::<_, Infallible>::unwrap(result)
}
}
pub fn subscribe<M, I, St, Si>(
stream: St,
control_channel: Si,
) -> (MessageStream<St, M>, Subscription<Si, M, I>)
where
M: Message,
St: Stream<Item = M>,
Si: Sink<I>,
{
let shared = Arc::new(Mutex::new(None));
let subscription = Subscription {
sink: control_channel,
shared: shared.clone(),
_phantom: PhantomData,
};
let message_stream = MessageStream { stream, shared };
(message_stream, subscription)
}
pub async fn drive<M, F, S>(future: F, stream: &mut S) -> Result<F::Output, M::UserMessage>
where
M: Message,
F: Future + Unpin,
S: FusedStream<Item = M::UserMessage> + Unpin,
{
let mut future = future.fuse();
'l: loop {
select_biased! {
output = future => break 'l Ok(output),
user_message = stream.next() => {
if let Some(user_message) = user_message {
if M::is_error(&user_message) {
break 'l Err(user_message)
}
}
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::channel::mpsc::channel;
use futures::stream::iter;
use test_log::test;
#[derive(Debug)]
enum MockMessage<T> {
Value(T),
Close(u8),
}
impl<T> Message for MockMessage<T> {
type UserMessage = T;
type ControlMessage = u8;
fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage> {
match self {
MockMessage::Value(x) => Classification::UserMessage(x),
MockMessage::Close(x) => Classification::ControlMessage(x),
}
}
#[inline]
fn is_error(_user_message: &Self::UserMessage) -> bool {
false
}
}
#[test(tokio::test)]
async fn send_recv() {
let mut it = iter([
MockMessage::Value(1u64),
MockMessage::Value(2u64),
MockMessage::Value(3u64),
MockMessage::Close(200),
MockMessage::Close(201),
MockMessage::Value(4u64),
])
.map(Ok);
let (mut send, recv) = channel::<MockMessage<u64>>(16);
let () = send.send_all(&mut it).await.unwrap();
let (mut message_stream, mut subscription) = subscribe(recv, send);
let close = subscription.send(MockMessage::Close(42)).boxed_local();
let message = drive::<MockMessage<u64>, _, _>(close, &mut message_stream)
.await
.unwrap()
.unwrap();
assert_eq!(message, Some(Ok(200)));
}
#[test(tokio::test)]
async fn read() {
let mut it = iter([
MockMessage::Value(1u64),
MockMessage::Value(2u64),
MockMessage::Value(3u64),
MockMessage::Close(200),
MockMessage::Close(201),
MockMessage::Value(4u64),
])
.map(Ok);
let (mut send, recv) = channel::<MockMessage<u64>>(16);
let () = send.send_all(&mut it).await.unwrap();
let (mut message_stream, mut subscription) = subscribe(recv, send);
let close = subscription.read().boxed_local();
let message = drive::<MockMessage<u64>, _, _>(close, &mut message_stream)
.await
.unwrap();
assert_eq!(message, Some(Ok(200)));
}
#[test(tokio::test)]
async fn stream_drop() {
let (send, recv) = channel::<MockMessage<u64>>(1);
let (message_stream, mut subscription) = subscribe(recv, send);
drop(message_stream);
let result = subscription.send(MockMessage::Close(42)).await;
assert!(result.is_err());
let result = subscription.send(MockMessage::Close(41)).await;
assert!(result.is_err());
}
#[test(tokio::test)]
async fn control_channel_closed() {
let (mut send, recv) = channel::<MockMessage<u64>>(1);
send.close_channel();
let (_message_stream, mut subscription) = subscribe(recv, send);
let result = subscription.send(MockMessage::Close(42)).await;
assert!(result.is_err());
let result = subscription.send(MockMessage::Close(41)).await;
assert!(result.is_err());
}
#[test(tokio::test)]
async fn stream_processing_with_dropped_subscription() {
let mut it = iter([
MockMessage::Value(1u64),
MockMessage::Close(200),
MockMessage::Value(4u64),
])
.map(Ok);
let (mut send, recv) = channel::<MockMessage<u64>>(4);
let () = send.send_all(&mut it).await.unwrap();
let (message_stream, subscription) = subscribe(recv, send);
drop(subscription);
let vec = message_stream.collect::<Vec<_>>().await;
assert_eq!(vec, vec![1u64, 4u64]);
}
impl<T> Message for Result<Result<MockMessage<T>, String>, u64> {
type UserMessage = Result<Result<T, String>, u64>;
type ControlMessage = u8;
fn classify(self) -> Classification<Self::UserMessage, Self::ControlMessage> {
match self {
Ok(Ok(MockMessage::Value(x))) => Classification::UserMessage(Ok(Ok(x))),
Ok(Ok(MockMessage::Close(x))) => Classification::ControlMessage(x),
Ok(Err(err)) => Classification::UserMessage(Ok(Err(err))),
Err(err) => Classification::UserMessage(Err(err)),
}
}
fn is_error(user_message: &Self::UserMessage) -> bool {
user_message
.as_ref()
.map(|inner| inner.is_err())
.unwrap_or(false)
}
}
#[test(tokio::test)]
async fn send_recv_with_errors() {
let mut it = iter([
Ok(Ok(MockMessage::Value(1u64))),
Ok(Ok(MockMessage::Value(2u64))),
Ok(Ok(MockMessage::Value(3u64))),
Ok(Ok(MockMessage::Close(200))),
Ok(Ok(MockMessage::Close(201))),
Ok(Ok(MockMessage::Value(4u64))),
])
.map(Ok);
let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
let () = send.send_all(&mut it).await.unwrap();
let (mut message_stream, mut subscription) = subscribe(recv, send);
let close = subscription
.send(Ok(Ok(MockMessage::Close(42))))
.boxed_local();
let message =
drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
.await
.unwrap()
.unwrap();
assert_eq!(message, Some(Ok(200)));
}
#[test(tokio::test)]
async fn inner_error() {
let mut it = iter([
Ok(Ok(MockMessage::Value(1u64))),
Ok(Err("error".to_string())),
Ok(Ok(MockMessage::Close(200))),
])
.map(Ok);
let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
let () = send.send_all(&mut it).await.unwrap();
let (mut message_stream, mut subscription) = subscribe(recv, send);
let close = subscription
.send(Ok(Ok(MockMessage::Close(42))))
.boxed_local();
let message =
drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
.await
.unwrap_err()
.unwrap();
assert_eq!(message, Err("error".to_string()));
}
#[test(tokio::test)]
async fn outer_error() {
let mut it = iter([
Ok(Ok(MockMessage::Value(1u64))),
Err(1337),
Ok(Ok(MockMessage::Close(200))),
])
.map(Ok);
let (mut send, recv) = channel::<Result<Result<MockMessage<u64>, String>, u64>>(16);
let () = send.send_all(&mut it).await.unwrap();
let (mut message_stream, mut subscription) = subscribe(recv, send);
let close = subscription
.send(Ok(Ok(MockMessage::Close(42))))
.boxed_local();
let message =
drive::<Result<Result<MockMessage<u64>, String>, u64>, _, _>(close, &mut message_stream)
.await
.unwrap()
.unwrap();
assert_eq!(message, Some(Ok(200)));
}
}