easyfix_session/io/
time.rs

1use std::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    sync::atomic::{AtomicBool, Ordering},
6    task::{Context, Poll, ready},
7    time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use pin_project::pin_project;
12use tokio::time::interval_at;
13use tokio_stream::{StreamExt, adapters::Fuse};
14
15static BUSYWAIT_TIMEOUTS: AtomicBool = AtomicBool::new(false);
16
17#[doc(hidden)]
18pub fn enable_busywait_timers(enable_busywait: bool) {
19    BUSYWAIT_TIMEOUTS.store(enable_busywait, Ordering::Relaxed);
20}
21
22pub async fn timeout<T>(
23    duration: Duration,
24    future: impl Future<Output = T>,
25) -> Result<T, TimeElapsed> {
26    if BUSYWAIT_TIMEOUTS.load(Ordering::Relaxed) {
27        BusywaitTimeout::new(future, duration).await
28    } else {
29        tokio::time::timeout(duration, future)
30            .await
31            .map_err(|_| TimeElapsed(()))
32    }
33}
34
35#[pin_project(project = TimeoutStreamProj)]
36pub enum TimeoutStream<S> {
37    Busywait(#[pin] BusywaitTimeoutStream<S>),
38    Tokio(#[pin] tokio_stream::adapters::TimeoutRepeating<S>),
39}
40
41impl<S: Stream> Stream for TimeoutStream<S> {
42    type Item = Result<S::Item, TimeElapsed>;
43
44    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
45        match self.project() {
46            TimeoutStreamProj::Busywait(stream) => stream.poll_next(cx),
47            TimeoutStreamProj::Tokio(stream) => {
48                let result = ready!(stream.poll_next(cx));
49                Poll::Ready(result.map(|r| r.map_err(|_| TimeElapsed(()))))
50            }
51        }
52    }
53}
54
55pub fn timeout_stream<S>(duration: Duration, stream: S) -> TimeoutStream<S>
56where
57    S: Stream,
58{
59    if BUSYWAIT_TIMEOUTS.load(Ordering::Relaxed) {
60        TimeoutStream::Busywait(BusywaitTimeoutStream::new(stream, duration))
61    } else {
62        // skip first tick that would otherwise get timeout to trigger immediately
63        // during first poll operation
64        let timeout_interval_start = tokio::time::Instant::now()
65            .checked_add(duration)
66            .expect("timeout value too long");
67        TimeoutStream::Tokio(
68            stream.timeout_repeating(interval_at(timeout_interval_start, duration)),
69        )
70    }
71}
72
73#[derive(Debug)]
74pub struct TimeElapsed(());
75
76impl fmt::Display for TimeElapsed {
77    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78        f.write_str("Time elapsed")
79    }
80}
81
82impl std::error::Error for TimeElapsed {}
83
84impl From<TimeElapsed> for std::io::Error {
85    fn from(_err: TimeElapsed) -> std::io::Error {
86        std::io::ErrorKind::TimedOut.into()
87    }
88}
89
90struct Sleep {
91    wake_time: Instant,
92}
93
94impl Sleep {
95    fn new(duration: Duration) -> Sleep {
96        Sleep {
97            wake_time: Instant::now()
98                .checked_add(duration)
99                .expect("sleep time too long"),
100        }
101    }
102
103    fn reset(&mut self, duration: Duration) {
104        self.wake_time = Instant::now()
105            .checked_add(duration)
106            .expect("sleep time too long");
107    }
108}
109
110impl Future for Sleep {
111    type Output = ();
112
113    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
114        if self.wake_time > Instant::now() {
115            cx.waker().wake_by_ref();
116            Poll::Pending
117        } else {
118            Poll::Ready(())
119        }
120    }
121}
122
123#[pin_project]
124struct BusywaitTimeout<T> {
125    #[pin]
126    value: T,
127    #[pin]
128    delay: Sleep,
129}
130
131impl<T> BusywaitTimeout<T> {
132    pub fn new(value: T, delay: Duration) -> BusywaitTimeout<T> {
133        BusywaitTimeout {
134            value,
135            delay: Sleep::new(delay),
136        }
137    }
138}
139
140impl<T> Future for BusywaitTimeout<T>
141where
142    T: Future,
143{
144    type Output = Result<T::Output, TimeElapsed>;
145
146    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
147        let this = self.project();
148
149        if let Poll::Ready(value) = this.value.poll(cx) {
150            Poll::Ready(Ok(value))
151        } else {
152            match this.delay.poll(cx) {
153                Poll::Ready(()) => Poll::Ready(Err(TimeElapsed(()))),
154                Poll::Pending => Poll::Pending,
155            }
156        }
157    }
158}
159
160#[pin_project]
161pub struct BusywaitTimeoutStream<S> {
162    #[pin]
163    stream: Fuse<S>,
164    #[pin]
165    deadline: Sleep,
166    duration: Duration,
167    poll_deadline: bool,
168}
169
170impl<S: Stream> BusywaitTimeoutStream<S> {
171    fn new(stream: S, duration: Duration) -> Self {
172        BusywaitTimeoutStream {
173            stream: stream.fuse(),
174            deadline: Sleep::new(duration),
175            duration,
176            poll_deadline: true,
177        }
178    }
179}
180
181impl<S: Stream> Stream for BusywaitTimeoutStream<S> {
182    type Item = Result<S::Item, TimeElapsed>;
183
184    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185        let mut this = self.project();
186
187        match this.stream.poll_next(cx) {
188            Poll::Ready(v) => {
189                if v.is_some() {
190                    this.deadline.reset(*this.duration);
191                    *this.poll_deadline = true;
192                }
193                Poll::Ready(v.map(Ok))
194            }
195            Poll::Pending => {
196                if *this.poll_deadline {
197                    ready!(this.deadline.poll(cx));
198                    *this.poll_deadline = false;
199                    Poll::Ready(Some(Err(TimeElapsed(()))))
200                } else {
201                    this.deadline.reset(*this.duration);
202                    *this.poll_deadline = true;
203                    Poll::Pending
204                }
205            }
206        }
207    }
208
209    fn size_hint(&self) -> (usize, Option<usize>) {
210        let (lower, upper) = self.stream.size_hint();
211
212        // The timeout stream may insert an error before and after each message
213        // from the underlying stream, but no more than one error between each
214        // message. Hence the upper bound is computed as 2x+1.
215
216        fn twice_plus_one(value: Option<usize>) -> Option<usize> {
217            value?.checked_mul(2)?.checked_add(1)
218        }
219
220        (lower, twice_plus_one(upper))
221    }
222}