dynamo_runtime/utils/
stream.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use futures::stream::{Stream, StreamExt};
5use std::{
6    future::Future,
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use tokio::time::{self, Duration, Instant, Sleep, sleep_until};
12
13pub struct DeadlineStream<S> {
14    stream: S,
15    sleep: Pin<Box<Sleep>>,
16}
17
18impl<S: Stream + Unpin> Stream for DeadlineStream<S> {
19    type Item = S::Item;
20
21    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
22        // Check if we've passed the deadline
23        if Pin::new(&mut self.sleep).poll(cx).is_ready() {
24            // The deadline expired; end the stream now
25            return Poll::Ready(None);
26        }
27
28        // Otherwise, poll the underlying stream
29        let val = self.as_mut().stream.poll_next_unpin(cx);
30        // Log the poll result and return it
31        match &val {
32            Poll::Ready(Some(_)) => tracing::trace!("DeadlineStream: received item"),
33            Poll::Ready(None) => tracing::trace!("DeadlineStream: underlying stream ended"),
34            Poll::Pending => tracing::trace!("DeadlineStream: waiting for next item"),
35        }
36        val
37    }
38}
39
40pub fn until_deadline<S: Stream + Unpin>(stream: S, deadline: Instant) -> DeadlineStream<S> {
41    DeadlineStream {
42        stream,
43        // Set an async task that sleeps until deadline and wakes up to cancel the stream
44        sleep: Box::pin(sleep_until(deadline)),
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use futures::stream::{self, Stream, StreamExt};
51    use tokio::pin;
52
53    use super::*;
54
55    // Helper function to run the deadline stream test with given parameters
56    async fn run_deadline_test(sleep_times_ms: Vec<u64>, deadline_ms: u64) -> Vec<u64> {
57        let stream = stream::iter(sleep_times_ms);
58        let stream = stream.then(|x| {
59            let sleep = time::sleep(Duration::from_millis(x));
60            async move {
61                sleep.await;
62                x
63            }
64        });
65
66        let deadline = Instant::now() + Duration::from_millis(deadline_ms);
67        let mut result = Vec::new();
68
69        pin!(stream);
70        let mut stream = until_deadline(stream, deadline);
71
72        while let Some(x) = stream.next().await {
73            result.push(x);
74        }
75
76        result
77    }
78
79    #[tokio::test]
80    async fn test_deadline_exceeded() {
81        // The sum of the sleep times should exceed the deadline
82        let sleep_times_ms = vec![100, 100, 200, 50];
83        let deadline_ms = 300;
84
85        let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
86        // Since deadline is exceeded, only the items before deadline should be returned
87        assert_eq!(result, vec![100, 100]);
88    }
89
90    #[tokio::test]
91    async fn test_complete_before_deadline() {
92        // The sum of the sleep times should be less than the deadline
93        let sleep_times_ms = vec![100, 50, 50];
94        let deadline_ms = 300;
95
96        let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
97        // Since deadline is not exceeded, all items should be returned from stream
98        assert_eq!(result, vec![100, 50, 50]);
99    }
100}