easyfix_session/io/
time.rs

1use std::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    sync::atomic::{AtomicBool, Ordering},
6    task::{Context, Poll, ready},
7    time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use pin_project::pin_project;
12use tokio::time::interval_at;
13use tokio_stream::{StreamExt, adapters::Fuse};
14
15static BUSYWAIT_TIMEOUTS: AtomicBool = AtomicBool::new(false);
16
17#[doc(hidden)]
18pub fn enable_busywait_timers(enable_busywait: bool) {
19    BUSYWAIT_TIMEOUTS.store(enable_busywait, Ordering::Relaxed);
20}
21
22fn far_future() -> Instant {
23    // Roughly 30 years from now.
24    // API does not provide a way to obtain max `Instant`
25    // or convert specific date in the future to instant.
26    // 1000 years overflows on macOS, 100 years overflows on FreeBSD.
27    // See tokio sources (src/tokio/time/instant.rs)
28    Instant::now() + Duration::from_secs(86400 * 365 * 30)
29}
30
31pub async fn timeout<T>(
32    duration: Duration,
33    future: impl Future<Output = T>,
34) -> Result<T, TimeElapsed> {
35    if BUSYWAIT_TIMEOUTS.load(Ordering::Relaxed) {
36        BusywaitTimeout::new(future, duration).await
37    } else {
38        tokio::time::timeout(duration, future)
39            .await
40            .map_err(|_| TimeElapsed(()))
41    }
42}
43
44pub async fn timeout_at<T>(
45    deadline: Instant,
46    future: impl Future<Output = T>,
47) -> Result<T, TimeElapsed> {
48    if BUSYWAIT_TIMEOUTS.load(Ordering::Relaxed) {
49        BusywaitTimeout::with_deadline(future, deadline).await
50    } else {
51        tokio::time::timeout_at(deadline.into(), future)
52            .await
53            .map_err(|_| TimeElapsed(()))
54    }
55}
56
57#[pin_project(project = TimeoutStreamProj)]
58pub enum TimeoutStream<S> {
59    Busywait(#[pin] BusywaitTimeoutStream<S>),
60    Tokio(#[pin] tokio_stream::adapters::TimeoutRepeating<S>),
61}
62
63impl<S: Stream> Stream for TimeoutStream<S> {
64    type Item = Result<S::Item, TimeElapsed>;
65
66    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67        match self.project() {
68            TimeoutStreamProj::Busywait(stream) => stream.poll_next(cx),
69            TimeoutStreamProj::Tokio(stream) => {
70                let result = ready!(stream.poll_next(cx));
71                Poll::Ready(result.map(|r| r.map_err(|_| TimeElapsed(()))))
72            }
73        }
74    }
75}
76
77pub fn timeout_stream<S>(duration: Duration, stream: S) -> TimeoutStream<S>
78where
79    S: Stream,
80{
81    if BUSYWAIT_TIMEOUTS.load(Ordering::Relaxed) {
82        TimeoutStream::Busywait(BusywaitTimeoutStream::new(stream, duration))
83    } else {
84        // skip first tick that would otherwise get timeout to trigger immediately
85        // during first poll operation
86        let timeout_interval_start = tokio::time::Instant::now()
87            .checked_add(duration)
88            .unwrap_or_else(|| far_future().into());
89        TimeoutStream::Tokio(
90            stream.timeout_repeating(interval_at(timeout_interval_start, duration)),
91        )
92    }
93}
94
95#[derive(Debug)]
96pub struct TimeElapsed(());
97
98impl fmt::Display for TimeElapsed {
99    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100        f.write_str("Time elapsed")
101    }
102}
103
104impl std::error::Error for TimeElapsed {}
105
106impl From<TimeElapsed> for std::io::Error {
107    fn from(_err: TimeElapsed) -> std::io::Error {
108        std::io::ErrorKind::TimedOut.into()
109    }
110}
111
112struct Sleep {
113    wake_time: Instant,
114}
115
116impl Sleep {
117    fn new(duration: Duration) -> Sleep {
118        Sleep {
119            wake_time: Instant::now()
120                .checked_add(duration)
121                .unwrap_or_else(far_future),
122        }
123    }
124
125    fn with_wake_time(wake_time: Instant) -> Sleep {
126        Sleep { wake_time }
127    }
128
129    fn reset(&mut self, duration: Duration) {
130        self.wake_time = Instant::now()
131            .checked_add(duration)
132            .unwrap_or_else(far_future)
133    }
134}
135
136impl Future for Sleep {
137    type Output = ();
138
139    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
140        if self.wake_time > Instant::now() {
141            cx.waker().wake_by_ref();
142            Poll::Pending
143        } else {
144            Poll::Ready(())
145        }
146    }
147}
148
149#[pin_project]
150struct BusywaitTimeout<T> {
151    #[pin]
152    value: T,
153    #[pin]
154    delay: Sleep,
155}
156
157impl<T> BusywaitTimeout<T> {
158    pub fn new(value: T, delay: Duration) -> BusywaitTimeout<T> {
159        BusywaitTimeout {
160            value,
161            delay: Sleep::new(delay),
162        }
163    }
164
165    pub fn with_deadline(value: T, deadline: Instant) -> BusywaitTimeout<T> {
166        BusywaitTimeout {
167            value,
168            delay: Sleep::with_wake_time(deadline),
169        }
170    }
171}
172
173impl<T> Future for BusywaitTimeout<T>
174where
175    T: Future,
176{
177    type Output = Result<T::Output, TimeElapsed>;
178
179    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180        let this = self.project();
181
182        if let Poll::Ready(value) = this.value.poll(cx) {
183            Poll::Ready(Ok(value))
184        } else {
185            match this.delay.poll(cx) {
186                Poll::Ready(()) => Poll::Ready(Err(TimeElapsed(()))),
187                Poll::Pending => Poll::Pending,
188            }
189        }
190    }
191}
192
193#[pin_project]
194pub struct BusywaitTimeoutStream<S> {
195    #[pin]
196    stream: Fuse<S>,
197    #[pin]
198    deadline: Sleep,
199    duration: Duration,
200    poll_deadline: bool,
201}
202
203impl<S: Stream> BusywaitTimeoutStream<S> {
204    fn new(stream: S, duration: Duration) -> Self {
205        BusywaitTimeoutStream {
206            stream: stream.fuse(),
207            deadline: Sleep::new(duration),
208            duration,
209            poll_deadline: true,
210        }
211    }
212}
213
214impl<S: Stream> Stream for BusywaitTimeoutStream<S> {
215    type Item = Result<S::Item, TimeElapsed>;
216
217    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
218        let mut this = self.project();
219
220        match this.stream.poll_next(cx) {
221            Poll::Ready(v) => {
222                if v.is_some() {
223                    this.deadline.reset(*this.duration);
224                    *this.poll_deadline = true;
225                }
226                Poll::Ready(v.map(Ok))
227            }
228            Poll::Pending => {
229                if *this.poll_deadline {
230                    ready!(this.deadline.poll(cx));
231                    *this.poll_deadline = false;
232                    Poll::Ready(Some(Err(TimeElapsed(()))))
233                } else {
234                    this.deadline.reset(*this.duration);
235                    *this.poll_deadline = true;
236                    Poll::Pending
237                }
238            }
239        }
240    }
241
242    fn size_hint(&self) -> (usize, Option<usize>) {
243        let (lower, upper) = self.stream.size_hint();
244
245        // The timeout stream may insert an error before and after each message
246        // from the underlying stream, but no more than one error between each
247        // message. Hence the upper bound is computed as 2x+1.
248
249        fn twice_plus_one(value: Option<usize>) -> Option<usize> {
250            value?.checked_mul(2)?.checked_add(1)
251        }
252
253        (lower, twice_plus_one(upper))
254    }
255}