use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::task::AtomicWaker;
use num_enum::{FromPrimitive, IntoPrimitive};
use tokio_stream::Stream;
pub struct ControlledStream<S: Stream> {
inner: S,
state: Arc<SharedStreamState>,
}
pub struct SharedStreamState {
waker: AtomicWaker,
state: AtomicU8,
}
#[derive(Clone, Debug, PartialEq, Eq, Copy, IntoPrimitive, FromPrimitive)]
#[repr(u8)]
pub enum StreamState {
Run,
Pause,
#[num_enum(default)]
Stop,
}
impl SharedStreamState {
pub fn set(&self, state: StreamState) {
self.state.store(state as u8, Ordering::Relaxed);
self.waker.wake();
}
}
impl<S: Stream> ControlledStream<S> {
pub fn new(inner: S) -> Self {
Self::with_initial_state(inner, StreamState::Run)
}
pub fn with_initial_state(inner: S, state: StreamState) -> Self {
Self {
inner,
state: Arc::new(SharedStreamState {
waker: AtomicWaker::new(),
state: AtomicU8::new(state as u8),
}),
}
}
pub fn state(&self) -> Arc<SharedStreamState> {
self.state.clone()
}
}
impl<S: Stream> Stream for ControlledStream<S> {
type Item = S::Item;
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll<Option<Self::Item>> {
let (inner, state) = unsafe {
let this = self.get_unchecked_mut();
(Pin::new_unchecked(&mut this.inner), Pin::new_unchecked(&mut this.state))
};
let cur_state = state.state.load(Ordering::Relaxed).into();
match cur_state {
StreamState::Run => {
match Stream::poll_next(inner, cx) {
Poll::Ready(item) => Poll::Ready(item),
Poll::Pending => {
state.waker.register(cx.waker());
Poll::Pending
}
}
}
StreamState::Pause => {
state.waker.register(cx.waker());
Poll::Pending
}
StreamState::Stop => {
Poll::Ready(None)
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
use std::{
pin::pin,
sync::atomic::Ordering,
task::Poll,
time::{Duration, Instant},
};
use tokio_stream::StreamExt;
use crate::pipeline::util::stream::{ControlledStream, StreamState};
#[tokio::test]
async fn empty_controlled_stream() {
let values: Vec<&'static str> = vec![];
let stream = tokio_stream::iter(values.clone());
let mut stream = ControlledStream::new(stream);
assert_eq!(stream.next().await, None);
}
#[tokio::test]
async fn single_controlled_stream() {
let values: Vec<&'static str> = vec!["the one"];
let stream = tokio_stream::iter(values.clone());
let mut stream = ControlledStream::new(stream);
assert_eq!(stream.next().await, Some("the one"));
assert_eq!(stream.next().await, None);
}
#[tokio::test]
async fn multi_controlled_stream() {
let values: Vec<&'static str> = vec!["a", "b", "c"];
let stream = tokio_stream::iter(values.clone());
let mut stream = ControlledStream::new(stream);
assert_eq!(stream.next().await, Some("a"));
assert_eq!(stream.next().await, Some("b"));
assert_eq!(stream.next().await, Some("c"));
assert_eq!(stream.next().await, None);
}
#[tokio::test]
async fn collect_controlled_stream() {
{
let values = vec!["0", "1", "2", "3", "4", "5"];
let stream = tokio_stream::iter(values.clone());
let stream = ControlledStream::new(stream);
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected, values);
}
{
let values = vec!["abc"];
let stream = tokio_stream::iter(values.clone());
let stream = ControlledStream::new(stream);
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected, values);
}
{
let values: Vec<&'static str> = vec![];
let stream = tokio_stream::iter(values.clone());
let stream = ControlledStream::new(stream);
let collected: Vec<_> = stream.collect().await;
assert_eq!(collected, values);
}
}
#[tokio::test]
async fn pause_empty_controlled_stream() {
let mut stream = ControlledStream::new(tokio_stream::iter(Vec::<u8>::new()));
let stream_state = stream.state();
stream_state.set(StreamState::Pause);
let next = pin!(stream.next());
let polled = futures::poll!(next);
assert_eq!(Poll::Pending, polled);
stream_state.set(StreamState::Run);
let next = pin!(stream.next());
let polled = futures::poll!(next);
assert_eq!(Poll::Ready(None), polled);
}
#[tokio::test]
async fn pause_empty_controlled_stream_from_other_thread() {
let mut stream = ControlledStream::new(tokio_stream::iter(Vec::<u8>::new()));
let stream_state = stream.state();
stream_state.set(StreamState::Pause);
let next = pin!(stream.next());
let polled = futures::poll!(next);
assert_eq!(Poll::Pending, polled);
let sc = stream_state.clone();
let thread = std::thread::spawn(move || {
sc.set(StreamState::Run);
});
thread.join().unwrap();
let next = pin!(stream.next());
let polled = futures::poll!(next);
assert_eq!(Poll::Ready(None), polled);
}
#[tokio::test]
async fn pause_empty_controlled_stream_from_other_thread2() {
let mut stream = ControlledStream::new(tokio_stream::iter(Vec::<u8>::new()));
let stream_state = stream.state();
stream_state.set(StreamState::Pause);
let next = pin!(stream.next());
let polled = futures::poll!(next);
assert_eq!(Poll::Pending, polled);
let sc = stream_state.clone();
let thread = std::thread::spawn(move || {
sc.set(StreamState::Run);
});
let fut = stream.next();
let next = fut.await;
assert_eq!(None, next);
thread.join().unwrap();
}
#[tokio::test]
async fn pause_empty_controlled_stream_from_other_thread3() {
let mut stream = ControlledStream::with_initial_state(tokio_stream::iter(Vec::<u8>::new()), StreamState::Pause);
let stream_state = stream.state();
assert_eq!(stream_state.state.load(Ordering::Relaxed), StreamState::Pause as u8);
let next = pin!(stream.next());
let polled = futures::poll!(next);
assert_eq!(Poll::Pending, polled);
let sc = stream_state.clone();
let thread = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(100));
println!("&sc: {:p}", &sc);
println!("&sc.atomic_state: {:p}", &sc.state);
sc.set(StreamState::Run);
});
let fut = stream.next();
let next = fut.await;
assert_eq!(None, next);
thread.join().unwrap();
}
#[tokio::test]
async fn pause_nonempty_controlled_stream() {
let values = vec![1, 2, 3, 4, 5];
let mut stream = Box::pin(ControlledStream::with_initial_state(
tokio_stream::iter(values.clone()),
StreamState::Pause,
));
let stream_state = stream.state();
println!("stream with state pause");
let polled = futures::poll!(pin!(stream.next()));
assert_eq!(Poll::Pending, polled);
println!("is pending");
let t0 = Instant::now();
let state_clone = stream_state.clone();
let thread = std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(100));
println!("state.set(Run)");
state_clone.set(StreamState::Run);
std::thread::sleep(Duration::from_millis(100));
println!("state.set(Pause)");
state_clone.set(StreamState::Pause);
std::thread::sleep(Duration::from_millis(100));
println!("state.set(Run)");
state_clone.set(StreamState::Run);
std::thread::sleep(Duration::from_millis(100));
println!("state.set(Stop)");
state_clone.set(StreamState::Stop);
});
println!("next().await...");
let value = stream.next().await;
let t1 = Instant::now();
println!("=> {value:?} after {:?}", t1.duration_since(t0));
assert!(t1.duration_since(t0) > Duration::from_millis(100));
assert_eq!(Some(1), value);
println!("next().await...");
let t0 = Instant::now();
let value = stream.next().await;
let t1 = Instant::now();
println!("=> {value:?} after {:?}", t1.duration_since(t0));
assert!(t1.duration_since(t0) < Duration::from_millis(100));
assert_eq!(Some(2), value);
std::thread::sleep(Duration::from_millis(105));
println!("next().await...");
let t0 = Instant::now();
let last_value = stream.next().await;
let t1 = Instant::now();
println!("=> {last_value:?} after {:?}", t1.duration_since(t0));
assert!(t1.duration_since(t0) > Duration::from_millis(90));
assert_eq!(Some(3), last_value);
println!("next().await...");
let t0 = Instant::now();
let value = stream.next().await;
let t1 = Instant::now();
println!("=> {value:?} after {:?}", t1.duration_since(t0));
assert!(t1.duration_since(t0) < Duration::from_millis(100));
assert_eq!(Some(4), value);
std::thread::sleep(Duration::from_millis(105));
println!("next().await...");
let t0 = Instant::now();
let last_value = stream.next().await;
let t1 = Instant::now();
println!("=> {last_value:?} after {:?}", t1.duration_since(t0));
assert!(t1.duration_since(t0) < Duration::from_millis(90));
assert_eq!(None, last_value);
thread.join().unwrap();
}
}