use futures::Stream;
use futures::task;
use pin_project_lite::pin_project;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::outer_waker::OuterWaker;
pin_project! {
#[must_use = "streams do nothing unless polled"]
pub struct FlattenSwitch<St>
where
St: Stream,
St::Item: Stream
{
#[pin]
outer: St,
outer_waker: Arc<OuterWaker>,
#[pin]
inner: Option<<St as Stream>::Item>
}
}
impl<St> FlattenSwitch<St>
where
St: Stream,
St::Item: Stream,
{
pub(super) fn new(stream: St) -> Self {
Self {
outer: stream,
outer_waker: Arc::default(),
inner: None,
}
}
}
impl<St> Stream for FlattenSwitch<St>
where
St: Stream,
St::Item: Stream,
{
type Item = <St::Item as Stream>::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut this = self.project();
let outer_ready = this.outer_waker.set_parent_waker(cx.waker().clone());
if outer_ready {
let waker = task::waker(Arc::clone(this.outer_waker));
let mut cx = Context::from_waker(&waker);
while let Poll::Ready(inner) = this.outer.as_mut().poll_next(&mut cx) {
match inner {
Some(inner) => this.inner.set(Some(inner)),
None => {
return Poll::Ready(None);
}
}
}
};
match this.inner.as_mut().as_pin_mut() {
Some(inner) => match inner.poll_next(cx) {
Poll::Ready(value) => match value {
Some(value) => Poll::Ready(Some(value)),
None => {
this.inner.set(None);
Poll::Pending
}
},
Poll::Pending => Poll::Pending,
},
None => Poll::Pending,
}
}
}
impl<S> std::fmt::Debug for FlattenSwitch<S>
where
S: Stream + std::fmt::Debug,
S::Item: Stream + std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlattenSwitch")
.field("stream", &self.outer)
.field("inner", &self.inner)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::future;
use futures::{FutureExt, StreamExt, stream};
use parking_lot::Mutex;
use tokio_test::{assert_pending, assert_ready_eq};
use super::*;
pin_project! {
struct MockStream<S: Stream> {
#[pin]
inner: S,
polled: Arc<Mutex<bool>>
}
}
impl<S: Stream> Stream for MockStream<S> {
type Item = S::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.project();
let result = this.inner.poll_next(cx);
*this.polled.lock() = true;
result
}
}
#[tokio::test]
async fn test_flatten_switch() {
use futures::{SinkExt, StreamExt, channel::mpsc};
use tokio::sync::broadcast::{self, error::SendError};
use tokio_stream::wrappers::BroadcastStream;
let waker = futures::task::noop_waker_ref();
let mut cx = std::task::Context::from_waker(waker);
let (tx_inner1, rx_inner1) = broadcast::channel(32);
let (tx_inner2, rx_inner2) = broadcast::channel(32);
let (tx_inner3, rx_inner3) = broadcast::channel(32);
let (mut tx, rx) = mpsc::unbounded();
let outer_polled = Arc::new(Mutex::new(false));
let take_outer_polled = || -> bool {
let mut guard = outer_polled.lock();
std::mem::replace(&mut guard, false)
};
let assert_outer_polled = || assert!(take_outer_polled());
let assert_outer_not_polled = || assert!(!take_outer_polled());
let outer_stream = MockStream {
inner: rx,
polled: Arc::clone(&outer_polled),
};
let mut switch_stream = FlattenSwitch::new(outer_stream);
assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
assert_outer_polled();
tx.send(
BroadcastStream::new(rx_inner1)
.map(|r: Result<_, _>| r.unwrap())
.boxed(),
)
.await
.unwrap();
assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
assert_outer_polled();
tx_inner1.send(10).unwrap();
assert_eq!(
switch_stream.poll_next_unpin(&mut cx),
Poll::Ready(Some(10))
);
assert_outer_not_polled(); assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
assert_outer_not_polled();
tx_inner1.send(20).unwrap();
assert_eq!(
switch_stream.poll_next_unpin(&mut cx),
Poll::Ready(Some(20))
);
assert_outer_not_polled();
tx.send(
BroadcastStream::new(rx_inner2)
.map(|r: Result<_, _>| r.unwrap())
.boxed(),
)
.await
.unwrap();
assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
assert_outer_polled();
matches!(tx_inner1.send(30), Err(SendError(_)));
assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
assert_outer_not_polled();
drop(tx_inner2);
assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Pending);
assert_outer_not_polled();
tx.send(
BroadcastStream::new(rx_inner3)
.map(|r: Result<_, _>| r.unwrap())
.boxed(),
)
.await
.unwrap();
tx_inner3.send(100).unwrap();
assert_eq!(
switch_stream.poll_next_unpin(&mut cx),
Poll::Ready(Some(100))
);
assert_outer_polled();
tx_inner3.send(110).unwrap();
assert_eq!(
switch_stream.poll_next_unpin(&mut cx),
Poll::Ready(Some(110))
);
assert_outer_not_polled();
drop(tx);
assert_eq!(switch_stream.poll_next_unpin(&mut cx), Poll::Ready(None));
assert_outer_polled();
}
#[tokio::test]
async fn test_inner_not_polled_twice_after_termination() {
let inner_polled = Arc::new(Mutex::new(false));
let take_inner_polled = || -> bool {
let mut guard = inner_polled.lock();
std::mem::replace(&mut guard, false)
};
let assert_inner_polled = || assert!(take_inner_polled());
let assert_inner_not_polled = || assert!(!take_inner_polled());
let first_inner = MockStream {
inner: stream::once(future::ready(1)),
polled: Arc::clone(&inner_polled),
};
let outer_stream =
stream::once(future::ready(first_inner)).chain(future::pending().into_stream());
let mut stream = FlattenSwitch::new(outer_stream);
let waker = futures::task::noop_waker_ref();
let mut cx = std::task::Context::from_waker(waker);
assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
assert_inner_polled();
assert_pending!(stream.poll_next_unpin(&mut cx));
assert_inner_polled();
assert_pending!(stream.poll_next_unpin(&mut cx));
assert_inner_not_polled();
}
}