use futures::stream::{Stream, StreamExt};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::time::{self, Duration, Instant, Sleep, sleep_until};
pub struct DeadlineStream<S> {
stream: S,
sleep: Pin<Box<Sleep>>,
}
impl<S: Stream + Unpin> Stream for DeadlineStream<S> {
type Item = S::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if Pin::new(&mut self.sleep).poll(cx).is_ready() {
return Poll::Ready(None);
}
let val = self.as_mut().stream.poll_next_unpin(cx);
match &val {
Poll::Ready(Some(_)) => tracing::trace!("DeadlineStream: received item"),
Poll::Ready(None) => tracing::trace!("DeadlineStream: underlying stream ended"),
Poll::Pending => tracing::trace!("DeadlineStream: waiting for next item"),
}
val
}
}
pub fn until_deadline<S: Stream + Unpin>(stream: S, deadline: Instant) -> DeadlineStream<S> {
DeadlineStream {
stream,
sleep: Box::pin(sleep_until(deadline)),
}
}
#[cfg(test)]
mod tests {
use futures::stream::{self, Stream, StreamExt};
use tokio::pin;
use super::*;
async fn run_deadline_test(sleep_times_ms: Vec<u64>, deadline_ms: u64) -> Vec<u64> {
let stream = stream::iter(sleep_times_ms);
let stream = stream.then(|x| {
let sleep = time::sleep(Duration::from_millis(x));
async move {
sleep.await;
x
}
});
let deadline = Instant::now() + Duration::from_millis(deadline_ms);
let mut result = Vec::new();
pin!(stream);
let mut stream = until_deadline(stream, deadline);
while let Some(x) = stream.next().await {
result.push(x);
}
result
}
#[tokio::test]
async fn test_deadline_exceeded() {
let sleep_times_ms = vec![100, 100, 200, 50];
let deadline_ms = 300;
let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
assert_eq!(result, vec![100, 100]);
}
#[tokio::test]
async fn test_complete_before_deadline() {
let sleep_times_ms = vec![100, 50, 50];
let deadline_ms = 300;
let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
assert_eq!(result, vec![100, 50, 50]);
}
}