ogre_stream_ext/
timeout_ext.rs1use std::fmt::{Debug, Display, Formatter};
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::atomic::AtomicBool;
13use std::sync::atomic::Ordering::Relaxed;
14use std::task::{Context, Poll};
15use std::time::{Duration, Instant};
16use async_io::Timer;
17use futures::Stream;
18
19#[derive(Debug)]
20pub struct ItemTimeoutErr {
21 pub previous_instant: Instant,
22}
23impl Display for ItemTimeoutErr {
24 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25 <Self as Debug>::fmt(self, f)
26 }
27}
28impl std::error::Error for ItemTimeoutErr {}
29
30pub trait StreamExtCloseOnItemTimeout: Stream + Sized {
48
49 fn close_stream_on_item_timeout(
50 self,
51 timeout: Duration,
52 ) -> StreamWithItemTimeout<Self> {
53 StreamWithItemTimeout::new(self, timeout)
54 }
55}
56
57impl<S: Stream> StreamExtCloseOnItemTimeout for S {}
58
59
60pub struct StreamWithItemTimeout<UpstreamType>
63where
64 UpstreamType: Stream,
65{
66 upstream: UpstreamType,
67 timeout: Duration,
68 timer: async_io::Timer,
69 timedout: AtomicBool,
71}
72
73impl<UpstreamType> StreamWithItemTimeout<UpstreamType>
74where
75 UpstreamType: Stream,
76{
77 pub fn new(upstream: UpstreamType, timeout: Duration) -> Self {
78 StreamWithItemTimeout {
79 upstream,
80 timeout,
81 timer: Timer::after(timeout),
82 timedout: AtomicBool::new(false),
83 }
84 }
85}
86
87impl<UpstreamType, ItemType> Stream for StreamWithItemTimeout<UpstreamType>
88where
89 UpstreamType: Stream<Item = ItemType> + Unpin,
90{
91 type Item = Result<ItemType, ItemTimeoutErr>;
92
93 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
94 if self.timedout.load(Relaxed) {
95 return Poll::Ready(None)
97 }
98 let timeout = self.timeout;
99 match Pin::new(&mut self.upstream).poll_next(cx) {
100 Poll::Ready(Some(item)) => {
101 _ = std::mem::replace(&mut self.timer, Timer::after(timeout));
103 Poll::Ready(Some(Ok(item)))
104 },
105
106 Poll::Ready(None) => {
107 Poll::Ready(None)
109 }
110
111 Poll::Pending => {
112 match Pin::new(&mut self.timer).poll(cx) {
114 Poll::Pending => Poll::Pending, Poll::Ready(instant) => {
116 self.timedout.store(true, Relaxed);
118 Poll::Ready(Some(Err(ItemTimeoutErr { previous_instant: instant })))
119 },
120 }
121 },
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::time::Duration;
130 use futures::{SinkExt, StreamExt};
131
132 #[tokio::test]
136 async fn basic_timeout_requirements() {
137 let (mut tx, rx) = futures::channel::mpsc::channel(0);
138 let mut out_stream = rx
139 .boxed()
140 .close_stream_on_item_timeout(Duration::from_millis(100)); _ = tokio::spawn(async move {
144 for i in 0..15 {
145 tokio::time::sleep(Duration::from_millis(((i as f64)*10.1) as u64)).await;
146 tx.send(i).await.expect("Error sending an element");
147 }
148 });
149
150 for expected_item in 0..=9 {
152 let observed_item = out_stream.next().await
153 .unwrap_or_else(|| panic!("Stream ended prematurely at #{expected_item}"))
154 .unwrap_or_else(|err| panic!("Timeout happened prematurely at #{expected_item}: {err}"));
155 assert_eq!(observed_item, expected_item, "Received item is wrong");
156 }
157
158 let observed_timeout_result = out_stream.next().await
160 .expect("Stream ended prematurely -- without yielding the Timeout error");
161 assert!(observed_timeout_result.is_err(), "item of value '10' was yielded without timing out. Yielded result: {observed_timeout_result:?}");
162
163 assert!(out_stream.next().await.is_none(), "Stream did not end after a timeout was detected");
165 }
166
167 #[tokio::test]
169 async fn timeout_before_first_element() {
170
171 const TIMEOUT: Duration = Duration::from_millis(100);
172
173 let (_tx, rx) = futures::channel::mpsc::channel::<()>(0);
174 let mut out_stream = rx
175 .boxed()
176 .close_stream_on_item_timeout(TIMEOUT); let stopwatcher = Instant::now();
182 let observed_result = out_stream.next().await
183 .expect("Stream ended prematurely -- without yielding the Timeout error");
184 assert!(observed_result.is_err(), "an item was yielded without timing out. Yielded result: {observed_result:?}");
185 let elapsed_time = stopwatcher.elapsed();
186 assert!((TIMEOUT.as_secs_f64() - elapsed_time.as_secs_f64()).abs() < 1e-3, "The Timeout error did not happen at the right time");
187
188 assert!(out_stream.next().await.is_none(), "Stream did not end after a timeout was detected");
190 }
191
192 #[tokio::test]
194 async fn regular_stream_usage() {
195 let (mut tx, rx) = futures::channel::mpsc::channel(0);
196 let mut out_stream = rx
197 .boxed()
198 .close_stream_on_item_timeout(Duration::from_millis(100)); _ = tokio::spawn(async move {
202 for i in 0..15 {
203 tokio::time::sleep(Duration::from_millis((((i % 10) as f64)*10.1) as u64)).await;
204 tx.send(i).await.expect("Error sending an element");
205 }
206 tx.close_channel();
209 });
210
211 for expected_item in 0..15 {
213 let observed_item = out_stream.next().await
214 .unwrap_or_else(|| panic!("Stream ended prematurely at #{expected_item}"))
215 .unwrap_or_else(|err| panic!("Timeout happened prematurely at #{expected_item}: {err}"));
216 assert_eq!(observed_item, expected_item, "Received item is wrong");
217 }
218
219 assert!(out_stream.next().await.is_none(), "Sanity check failed: Stream did not end");
221 }
222
223}