dynamo_runtime/utils/
stream.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use futures::stream::{Stream, StreamExt};
17use std::{
18    future::Future,
19    pin::Pin,
20    task::{Context, Poll},
21};
22
23use tokio::time::{self, Duration, Instant, Sleep, sleep_until};
24
25pub struct DeadlineStream<S> {
26    stream: S,
27    sleep: Pin<Box<Sleep>>,
28}
29
30impl<S: Stream + Unpin> Stream for DeadlineStream<S> {
31    type Item = S::Item;
32
33    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34        // Check if we've passed the deadline
35        if Pin::new(&mut self.sleep).poll(cx).is_ready() {
36            // The deadline expired; end the stream now
37            return Poll::Ready(None);
38        }
39
40        // Otherwise, poll the underlying stream
41        let val = self.as_mut().stream.poll_next_unpin(cx);
42        // Log the poll result and return it
43        match &val {
44            Poll::Ready(Some(_)) => tracing::trace!("DeadlineStream: received item"),
45            Poll::Ready(None) => tracing::trace!("DeadlineStream: underlying stream ended"),
46            Poll::Pending => tracing::trace!("DeadlineStream: waiting for next item"),
47        }
48        val
49    }
50}
51
52pub fn until_deadline<S: Stream + Unpin>(stream: S, deadline: Instant) -> DeadlineStream<S> {
53    DeadlineStream {
54        stream,
55        // Set an async task that sleeps until deadline and wakes up to cancel the stream
56        sleep: Box::pin(sleep_until(deadline)),
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use futures::stream::{self, Stream, StreamExt};
63    use tokio::pin;
64
65    use super::*;
66
67    // Helper function to run the deadline stream test with given parameters
68    async fn run_deadline_test(sleep_times_ms: Vec<u64>, deadline_ms: u64) -> Vec<u64> {
69        let stream = stream::iter(sleep_times_ms);
70        let stream = stream.then(|x| {
71            let sleep = time::sleep(Duration::from_millis(x));
72            async move {
73                sleep.await;
74                x
75            }
76        });
77
78        let deadline = Instant::now() + Duration::from_millis(deadline_ms);
79        let mut result = Vec::new();
80
81        pin!(stream);
82        let mut stream = until_deadline(stream, deadline);
83
84        while let Some(x) = stream.next().await {
85            result.push(x);
86        }
87
88        result
89    }
90
91    #[tokio::test]
92    async fn test_deadline_exceeded() {
93        // The sum of the sleep times should exceed the deadline
94        let sleep_times_ms = vec![100, 100, 200, 50];
95        let deadline_ms = 300;
96
97        let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
98        // Since deadline is exceeded, only the items before deadline should be returned
99        assert_eq!(result, vec![100, 100]);
100    }
101
102    #[tokio::test]
103    async fn test_complete_before_deadline() {
104        // The sum of the sleep times should be less than the deadline
105        let sleep_times_ms = vec![100, 50, 50];
106        let deadline_ms = 300;
107
108        let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
109        // Since deadline is not exceeded, all items should be returned from stream
110        assert_eq!(result, vec![100, 50, 50]);
111    }
112}