use futures::{Stream, StreamExt, stream::FusedStream};
use pin_project_lite::pin_project;
use std::{
pin::Pin,
task::{self, Poll},
};
pin_project! {
#[must_use = "streams do nothing unless polled"]
pub struct Sample<T, S>
where
T: Stream,
{
#[pin]
inner: Option<T>,
#[pin]
sampler: S,
value: Option<T::Item>
}
}
impl<T, S> Sample<T, S>
where
T: Stream,
{
pub(super) fn new(stream: T, sampler: S) -> Self
where
S: Stream,
{
Self {
inner: Some(stream),
sampler,
value: None,
}
}
}
impl<T: Stream, S: Stream> Stream for Sample<T, S> {
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let Some(mut inner) = this.inner.as_mut().as_pin_mut() else {
return Poll::Ready(None);
};
while let Poll::Ready(ready) = inner.poll_next_unpin(cx) {
match ready {
Some(value) => {
*this.value = Some(value);
}
None => {
this.inner.set(None);
break;
}
}
}
while let Poll::Ready(ready) = this.sampler.poll_next_unpin(cx) {
match ready {
Some(_) => match this.value.take() {
Some(value) => {
return Poll::Ready(Some(value)); }
None => {
continue;
}
},
None => {
this.inner.set(None);
return match this.value.take() {
Some(value) => Poll::Ready(Some(value)), None => Poll::Ready(None), };
}
}
}
if this.inner.is_none() {
return Poll::Ready(None);
}
Poll::Pending }
}
impl<T, S> FusedStream for Sample<T, S>
where
T: Stream,
S: Stream,
{
fn is_terminated(&self) -> bool {
self.inner.is_none()
}
}
impl<T, S> std::fmt::Debug for Sample<T, S>
where
T: Stream + std::fmt::Debug,
T::Item: std::fmt::Debug,
S: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sample")
.field("inner", &self.inner)
.field("sampler", &self.sampler)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::future;
use futures::{SinkExt, StreamExt, stream};
use tokio_test::{assert_pending, assert_ready_eq};
#[cfg(feature = "test-util")]
use tokio_stream::wrappers::IntervalStream;
#[cfg(feature = "test-util")]
use std::time::Duration;
use super::*;
#[tokio::test]
async fn test_sample() {
let waker = futures::task::noop_waker_ref();
let mut cx = std::task::Context::from_waker(waker);
let (mut tx, rx) = futures::channel::mpsc::unbounded();
let (mut tx_sampler, rx_sampler) = futures::channel::mpsc::unbounded();
let mut stream = Sample::new(rx, rx_sampler);
assert_pending!(stream.poll_next_unpin(&mut cx));
tx.send(1).await.unwrap();
assert_pending!(stream.poll_next_unpin(&mut cx));
tx_sampler.send(()).await.unwrap();
assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
assert_pending!(stream.poll_next_unpin(&mut cx));
tx_sampler.send(()).await.unwrap();
assert_pending!(stream.poll_next_unpin(&mut cx));
tx.send(2).await.unwrap();
tx.send(3).await.unwrap();
assert_pending!(stream.poll_next_unpin(&mut cx));
tx_sampler.send(()).await.unwrap();
assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(3)); assert_pending!(stream.poll_next_unpin(&mut cx));
tx.send(4).await.unwrap();
drop(tx_sampler); assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(4)); assert_ready_eq!(stream.poll_next_unpin(&mut cx), None); }
#[tokio::test]
async fn test_sample_underlying_terminates() {
let waker = futures::task::noop_waker_ref();
let mut cx = std::task::Context::from_waker(waker);
let (mut tx, rx) = futures::channel::mpsc::unbounded();
let (mut tx_sampler, rx_sampler) = futures::channel::mpsc::unbounded();
let mut stream = Sample::new(rx, rx_sampler);
assert_pending!(stream.poll_next_unpin(&mut cx));
tx.send(1).await.unwrap();
assert_pending!(stream.poll_next_unpin(&mut cx));
tx_sampler.send(()).await.unwrap();
assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
assert_pending!(stream.poll_next_unpin(&mut cx));
drop(tx); assert_ready_eq!(stream.poll_next_unpin(&mut cx), None); }
#[tokio::test]
async fn test_sample_underlying_terminates_but_sample_yields() {
let waker = futures::task::noop_waker_ref();
let mut cx = std::task::Context::from_waker(waker);
let (mut tx_sampler, rx_sampler) = futures::channel::mpsc::unbounded();
let mut stream = Sample::new(stream::once(future::ready(1)), rx_sampler);
tx_sampler.send(()).await.unwrap();
assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1)); assert_ready_eq!(stream.poll_next_unpin(&mut cx), None); }
#[cfg(feature = "test-util")]
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn test_sample_with_interval() {
use crate::StreamTools;
use crate::test_util::delay_items;
let sampler = IntervalStream::new(tokio::time::interval(Duration::from_millis(1500)));
let delays = vec![
0, 1000, 2000, 2999, 3000, 3001, 4000, 7000, 8000, 8999, 9500, 10500, 15001, 15500,
]
.into_iter()
.map(|delay_ms| (Duration::from_millis(delay_ms), delay_ms));
let stream = delay_items(delays);
let results = Sample::new(stream, sampler)
.record_delay()
.collect::<Vec<_>>()
.await;
let expected_results = vec![
(Duration::ZERO, 0),
(Duration::from_millis(1500), 1000),
(Duration::from_millis(3000), 3000),
(Duration::from_millis(4500), 4000),
(Duration::from_millis(7500), 7000),
(Duration::from_millis(9000), 8999),
(Duration::from_millis(10500), 10500),
];
assert_eq!(expected_results, results);
}
}