ogre_stream_ext/
timeout_ext.rs

1//! Adds `.close_stream_on_item_timeout(max_duration_for_next_element)` to `Stream`s,
2//! closing the Stream immediately after yielding the timeout error if no items arrive on time.
3//!
4//! NOTE: although the crate `futures_time` has a `.timeout()` method for Streams -- which appears to
5//!       give what we want -- it does not: it indeed yields a timeout error after a duration has
6//!       elapsed **after** the last element has been yielded, but it doesn't close the Stream,
7//!       still allowing further elements to be yielded.
8
9use std::fmt::{Debug, Display, Formatter};
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::atomic::AtomicBool;
13use std::sync::atomic::Ordering::Relaxed;
14use std::task::{Context, Poll};
15use std::time::{Duration, Instant};
16use async_io::Timer;
17use futures::Stream;
18
19#[derive(Debug)]
20pub struct ItemTimeoutErr {
21    pub previous_instant: Instant,
22}
23impl Display for ItemTimeoutErr {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        <Self as Debug>::fmt(self, f)
26    }
27}
28impl std::error::Error for ItemTimeoutErr {}
29
30/// The extension trait that adds `.close_stream_on_item_timeout(max_duration_for_next_element)` to all `Stream`s.
31///
32/// To use it, do:
33///    use crate::stream_timeout::StreamExtCloseOnItemTimeout;
34///    use futures::StreamExt; // if you also want `.map()`, `.filter()`, etc.
35///
36/// Then:
37///    my_stream
38///       .map(|x| …)
39///       .close_stream_on_item_timeout(Duration::from_secs(...));
40///      .filter_map(|timeout_result| match timeout_result {
41///        Ok(original_element) => Some(original_element),  // unwrap and yield the original element downstream
42///        Err(_timeout_err) => {
43///          // do whatever you need when the timeout is detected and the stream is about to be closed
44///          None  // do not yield the timeout error downstream
45///        }
46///      })
47pub trait StreamExtCloseOnItemTimeout: Stream + Sized {
48
49    fn close_stream_on_item_timeout(
50        self,
51        timeout: Duration,
52    ) -> StreamWithItemTimeout<Self> {
53        StreamWithItemTimeout::new(self, timeout)
54    }
55}
56
57impl<S: Stream> StreamExtCloseOnItemTimeout for S {}
58
59
60/// A Stream wrapper that will yield a timeout error if an item is not generated
61/// within a maximum specified time, closing the Stream in the next `.poll()`
62pub struct StreamWithItemTimeout<UpstreamType>
63where
64    UpstreamType: Stream,
65{
66    upstream: UpstreamType,
67    timeout: Duration,
68    timer: async_io::Timer,
69    /// Flag: did the timeout happen? Close on the next `.poll()`
70    timedout: AtomicBool,
71}
72
73impl<UpstreamType> StreamWithItemTimeout<UpstreamType>
74where
75    UpstreamType: Stream,
76{
77    pub fn new(upstream: UpstreamType, timeout: Duration) -> Self {
78        StreamWithItemTimeout {
79            upstream,
80            timeout,
81            timer: Timer::after(timeout),
82            timedout: AtomicBool::new(false),
83        }
84    }
85}
86
87impl<UpstreamType, ItemType> Stream for StreamWithItemTimeout<UpstreamType>
88where
89    UpstreamType: Stream<Item = ItemType> + Unpin,
90{
91    type Item = Result<ItemType, ItemTimeoutErr>;
92
93    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
94        if self.timedout.load(Relaxed) {
95            // A timeout happened earlier. Close the stream regardless if there are more elements
96            return Poll::Ready(None)
97        }
98        let timeout = self.timeout;
99        match Pin::new(&mut self.upstream).poll_next(cx) {
100            Poll::Ready(Some(item)) => {
101                // reset the timer & return the element
102                _ = std::mem::replace(&mut self.timer, Timer::after(timeout));
103                Poll::Ready(Some(Ok(item)))
104            },
105
106            Poll::Ready(None) => {
107                // stream ended spontaneously without any timeout
108                Poll::Ready(None)
109            }
110
111            Poll::Pending => {
112                // no element available in the stream -- check for timeout
113                match Pin::new(&mut self.timer).poll(cx) {
114                    Poll::Pending => Poll::Pending,     // didn't time out yet
115                    Poll::Ready(instant) => {
116                        // the timer fired -- we have a timeout: yield the error and schedule for stream termination
117                        self.timedout.store(true, Relaxed);
118                        Poll::Ready(Some(Err(ItemTimeoutErr { previous_instant: instant })))
119                    },
120                }
121            },
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use std::time::Duration;
130    use futures::{SinkExt, StreamExt};
131
132    /// Asserts our basic timeout functionality works as expected:
133    ///   1. If an element takes more than the timeout duration to be yielded, the timeout error will be yielded instead
134    ///   2. After the timeout error is yielded, the stream will be closed immediately, regardless if there are more elements available
135    #[tokio::test]
136    async fn basic_timeout_requirements() {
137        let (mut tx, rx) = futures::channel::mpsc::channel(0);
138        let mut out_stream = rx
139            .boxed()
140            .close_stream_on_item_timeout(Duration::from_millis(100));  // set the timeout for yielded items
141
142        // task to publish items -- item of value '10' (and beyond) should not be received
143        _ = tokio::spawn(async move {
144            for i in 0..15 {
145                tokio::time::sleep(Duration::from_millis(((i as f64)*10.1) as u64)).await;
146                tx.send(i).await.expect("Error sending an element");
147            }
148        });
149
150        // consume the 10 OK results (items from 0 to 9)
151        for expected_item in 0..=9 {
152            let observed_item = out_stream.next().await
153                .unwrap_or_else(|| panic!("Stream ended prematurely at #{expected_item}"))
154                .unwrap_or_else(|err| panic!("Timeout happened prematurely at #{expected_item}: {err}"));
155            assert_eq!(observed_item, expected_item, "Received item is wrong");
156        }
157
158        // consume the timeout result (item of value 10 and beyond takes more than 1 second to be yielded)
159        let observed_timeout_result = out_stream.next().await
160            .expect("Stream ended prematurely -- without yielding the Timeout error");
161        assert!(observed_timeout_result.is_err(), "item of value '10' was yielded without timing out. Yielded result: {observed_timeout_result:?}");
162
163        // assert the stream ended -- even when there were other elements available
164        assert!(out_stream.next().await.is_none(), "Stream did not end after a timeout was detected");
165    }
166
167    /// Asserts the timeout still happens if the first elements takes so much to be yielded
168    #[tokio::test]
169    async fn timeout_before_first_element() {
170
171        const TIMEOUT: Duration = Duration::from_millis(100);
172
173        let (_tx, rx) = futures::channel::mpsc::channel::<()>(0);
174        let mut out_stream = rx
175            .boxed()
176            .close_stream_on_item_timeout(TIMEOUT);  // set the timeout for yielded items
177
178        // no elements will be published on the stream
179
180        // consume -- a single timeout item will be yielded after `TIMEOUT` has elapsed
181        let stopwatcher = Instant::now();
182        let observed_result = out_stream.next().await
183            .expect("Stream ended prematurely -- without yielding the Timeout error");
184        assert!(observed_result.is_err(), "an item was yielded without timing out. Yielded result: {observed_result:?}");
185        let elapsed_time = stopwatcher.elapsed();
186        assert!((TIMEOUT.as_secs_f64() - elapsed_time.as_secs_f64()).abs() < 1e-3, "The Timeout error did not happen at the right time");
187
188        // assert the stream ended
189        assert!(out_stream.next().await.is_none(), "Stream did not end after a timeout was detected");
190    }
191
192    /// Asserts we are still able to use the Stream without ever timing out
193    #[tokio::test]
194    async fn regular_stream_usage() {
195        let (mut tx, rx) = futures::channel::mpsc::channel(0);
196        let mut out_stream = rx
197            .boxed()
198            .close_stream_on_item_timeout(Duration::from_millis(100));  // set the timeout for yielded items
199
200        // task to publish items -- items will get closer to the timeout, but never reach it
201        _ = tokio::spawn(async move {
202            for i in 0..15 {
203                tokio::time::sleep(Duration::from_millis((((i % 10) as f64)*10.1) as u64)).await;
204                tx.send(i).await.expect("Error sending an element");
205            }
206            // `tx` was moved to this task. Once it ends, the Stream will be closed.
207            // So the following is not really necessary
208            tx.close_channel();
209        });
210
211        // consume the 15 OK results
212        for expected_item in 0..15 {
213            let observed_item = out_stream.next().await
214                .unwrap_or_else(|| panic!("Stream ended prematurely at #{expected_item}"))
215                .unwrap_or_else(|err| panic!("Timeout happened prematurely at #{expected_item}: {err}"));
216            assert_eq!(observed_item, expected_item, "Received item is wrong");
217        }
218
219        // sanity check: assert the stream really ended
220        assert!(out_stream.next().await.is_none(), "Sanity check failed: Stream did not end");
221    }
222
223}