use std::{
sync::{Arc, Mutex},
task::Waker,
};
use futures_util::{Stream, StreamExt};
#[derive(Clone)]
struct CopiedMultiStreamState<S>
where
S: Stream,
{
cache: Box<[Option<S::Item>]>,
wakers: Box<[Option<Waker>]>,
stream: Option<S>,
}
#[must_use = "streams do nothing unless polled"]
#[derive(Clone)]
pub struct CopiedMultiStream<S>
where
S: Stream,
{
state: Arc<Mutex<CopiedMultiStreamState<S>>>,
pos: usize,
}
pub fn copied_multi_stream<S>(stream: S, i: usize) -> Vec<CopiedMultiStream<S>>
where
S: Stream,
{
let state = Arc::new(Mutex::new(CopiedMultiStreamState {
stream: Some(stream),
cache: (0..i).map(|_| None).collect(),
wakers: (0..i).map(|_| None).collect(),
}));
(0..i)
.map(|pos| CopiedMultiStream {
pos,
state: state.clone(),
})
.collect()
}
impl<S> Stream for CopiedMultiStream<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
type Item = S::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut state = self.state.lock().unwrap();
if let Some(v) = state.cache[self.pos].take() {
std::task::Poll::Ready(Some(v))
} else if state.cache.iter().any(Option::is_some) {
state.wakers[self.pos] = Some(cx.waker().clone());
std::task::Poll::Pending
} else if let Some(ref mut stream) = state.stream {
match stream.poll_next_unpin(cx) {
std::task::Poll::Ready(Some(v)) => {
state.cache.iter_mut().for_each(|c| *c = Some(v.clone()));
state.wakers.iter_mut().for_each(|waker| {
if let Some(waker) = waker.take() {
waker.wake_by_ref()
}
});
std::task::Poll::Ready(state.cache[self.pos].take())
}
std::task::Poll::Ready(None) => {
state.stream = None;
state.wakers.iter_mut().for_each(|waker| {
if let Some(waker) = waker.take() {
waker.wake_by_ref()
}
});
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => {
state.wakers[self.pos] = Some(cx.waker().clone());
std::task::Poll::Pending
}
}
} else {
std::task::Poll::Ready(None)
}
}
}
#[cfg(test)]
mod tests {
use std::pin::pin;
use futures_util::stream::{self, BoxStream};
use ntest_timeout::timeout;
use crate::StreamUtils;
use super::*;
#[tokio::test]
async fn test_stream() {
let size = 3;
let stream = stream::iter(0..3);
let res = stream.copied_multi_stream(size);
assert_eq!(res.len(), size);
let res = stream::select_all(res);
let res: Vec<usize> = res.collect().await;
assert_eq!(res, vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
}
#[tokio::test]
async fn test_box_stream() {
let size = 3;
let stream: BoxStream<usize> = Box::pin(stream::iter(0..3));
let res = stream.copied_multi_stream(size);
assert_eq!(res.len(), size);
let res = stream::select_all(res);
let res: Vec<usize> = res.collect().await;
assert_eq!(res, vec![0, 0, 0, 1, 1, 1, 2, 2, 2]);
}
#[tokio::test]
async fn test_empty_stream() {
let size = 3;
let stream = Box::pin(stream::iter(0..0));
let res = stream.copied_multi_stream(size);
assert_eq!(res.len(), size);
let res = stream::select_all(res);
let res: Vec<usize> = res.collect().await;
let exp: Vec<usize> = Vec::new();
assert_eq!(res, exp);
}
#[tokio::test]
async fn test_zero_streams() {
let size = 0;
let stream = stream::iter(0..3);
let res = stream.copied_multi_stream(size);
assert_eq!(res.len(), size);
let res = stream::select_all(res);
let res: Vec<usize> = res.collect().await;
let exp: Vec<usize> = Vec::new();
assert_eq!(res, exp);
}
#[tokio::test]
async fn test_future_stream() {
let size = 3;
let stream = stream::unfold(0, |state| async move {
if state <= 2 {
let next_state = state + 1;
let yielded = state * 2;
Some((yielded, next_state))
} else {
None
}
});
let stream = pin!(stream);
let res = stream.copied_multi_stream(size);
assert_eq!(res.len(), size);
let res = stream::select_all(res);
let res: Vec<usize> = res.collect().await;
assert_eq!(res, vec![0, 0, 0, 2, 2, 2, 4, 4, 4]);
}
#[tokio::test]
#[timeout(200)]
async fn test_async_pull() {
let size = 5;
let stream = stream::iter(0..3);
let res = stream.copied_multi_stream(size);
let res: Vec<_> = res
.into_iter()
.map(|stream| tokio::task::spawn(async move { stream.collect::<Vec<usize>>().await }))
.collect();
for r in res {
let r = r.await.unwrap();
assert_eq!(r, vec![0, 1, 2]);
}
}
}