use std::pin::Pin;
use futures::{Stream, StreamExt, stream};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
pub use crate::WireframeError;
pub type FrameStream<F, E = ()> =
Pin<Box<dyn Stream<Item = Result<F, WireframeError<E>>> + Send + 'static>>;
pub enum Response<F, E = ()> {
Single(F),
Vec(Vec<F>),
Stream(FrameStream<F, E>),
MultiPacket(mpsc::Receiver<F>),
Empty,
}
impl<F: std::fmt::Debug, E> std::fmt::Debug for Response<F, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Response::Single(frame) => f.debug_tuple("Single").field(frame).finish(),
Response::Vec(v) => f.debug_tuple("Vec").field(v).finish(),
Response::Stream(_) => f.write_str("Stream(..)"),
Response::MultiPacket(_) => f.write_str("MultiPacket(..)"),
Response::Empty => f.write_str("Empty"),
}
}
}
impl<F, E> From<F> for Response<F, E> {
fn from(f: F) -> Self { Response::Single(f) }
}
impl<F, E> From<Vec<F>> for Response<F, E> {
fn from(v: Vec<F>) -> Self { Response::Vec(v) }
}
impl<F: Send + 'static, E: Send + 'static> Response<F, E> {
#[must_use]
pub fn with_channel(capacity: usize) -> (mpsc::Sender<F>, Response<F, E>) {
let (sender, receiver) = mpsc::channel(capacity);
(sender, Response::MultiPacket(receiver))
}
#[must_use]
pub fn into_stream(self) -> FrameStream<F, E> {
match self {
Response::Single(f) => {
stream::once(async move { Ok::<F, WireframeError<E>>(f) }).boxed()
}
Response::Vec(frames) => stream::iter(frames.into_iter().map(Ok)).boxed(),
Response::Stream(s) => s,
Response::MultiPacket(rx) => ReceiverStream::new(rx).map(Ok).boxed(),
Response::Empty => stream::empty().boxed(),
}
}
}
#[cfg(all(test, not(loom)))]
mod tests {
use rstest::{fixture, rstest};
use tokio::sync::mpsc::{self, error::TrySendError};
use super::*;
#[fixture]
fn single_capacity_channel() -> (mpsc::Sender<u8>, Response<u8, ()>) {
Response::with_channel(1)
}
#[rstest]
#[tokio::test]
async fn with_channel_streams_frames_and_respects_capacity(
single_capacity_channel: (mpsc::Sender<u8>, Response<u8, ()>),
) {
let (sender, response) = single_capacity_channel;
let Response::MultiPacket(mut rx) = response else {
panic!("with_channel did not return a MultiPacket response");
};
sender.send(1).await.expect("send first frame");
assert!(matches!(sender.try_send(2), Err(TrySendError::Full(2))));
assert_eq!(rx.recv().await, Some(1));
sender
.send(3)
.await
.expect("send follow-up frame after draining");
drop(sender);
assert_eq!(rx.recv().await, Some(3));
assert_eq!(rx.recv().await, None);
}
#[rstest]
#[tokio::test]
async fn with_channel_sender_errors_when_receiver_dropped(
single_capacity_channel: (mpsc::Sender<u8>, Response<u8, ()>),
) {
let (sender, response) = single_capacity_channel;
drop(response);
assert!(matches!(sender.try_send(7), Err(TrySendError::Closed(7))));
}
#[rstest]
#[tokio::test]
async fn with_channel_receiver_detects_sender_drop(
single_capacity_channel: (mpsc::Sender<u8>, Response<u8, ()>),
) {
let (sender, response) = single_capacity_channel;
let Response::MultiPacket(mut rx) = response else {
panic!("with_channel did not return a MultiPacket response");
};
drop(sender);
assert_eq!(rx.recv().await, None);
}
}