easyfix_session/io/
time.rs1use std::{
2 fmt,
3 future::Future,
4 pin::Pin,
5 sync::atomic::{AtomicBool, Ordering},
6 task::{Context, Poll, ready},
7 time::{Duration, Instant},
8};
9
10use futures_core::Stream;
11use pin_project::pin_project;
12use tokio::time::interval_at;
13use tokio_stream::{StreamExt, adapters::Fuse};
14
15static BUSYWAIT_TIMEOUTS: AtomicBool = AtomicBool::new(false);
16
17#[doc(hidden)]
18pub fn enable_busywait_timers(enable_busywait: bool) {
19 BUSYWAIT_TIMEOUTS.store(enable_busywait, Ordering::Relaxed);
20}
21
22pub async fn timeout<T>(
23 duration: Duration,
24 future: impl Future<Output = T>,
25) -> Result<T, TimeElapsed> {
26 if BUSYWAIT_TIMEOUTS.load(Ordering::Relaxed) {
27 BusywaitTimeout::new(future, duration).await
28 } else {
29 tokio::time::timeout(duration, future)
30 .await
31 .map_err(|_| TimeElapsed(()))
32 }
33}
34
35#[pin_project(project = TimeoutStreamProj)]
36pub enum TimeoutStream<S> {
37 Busywait(#[pin] BusywaitTimeoutStream<S>),
38 Tokio(#[pin] tokio_stream::adapters::TimeoutRepeating<S>),
39}
40
41impl<S: Stream> Stream for TimeoutStream<S> {
42 type Item = Result<S::Item, TimeElapsed>;
43
44 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
45 match self.project() {
46 TimeoutStreamProj::Busywait(stream) => stream.poll_next(cx),
47 TimeoutStreamProj::Tokio(stream) => {
48 let result = ready!(stream.poll_next(cx));
49 Poll::Ready(result.map(|r| r.map_err(|_| TimeElapsed(()))))
50 }
51 }
52 }
53}
54
55pub fn timeout_stream<S>(duration: Duration, stream: S) -> TimeoutStream<S>
56where
57 S: Stream,
58{
59 if BUSYWAIT_TIMEOUTS.load(Ordering::Relaxed) {
60 TimeoutStream::Busywait(BusywaitTimeoutStream::new(stream, duration))
61 } else {
62 let timeout_interval_start = tokio::time::Instant::now()
65 .checked_add(duration)
66 .expect("timeout value too long");
67 TimeoutStream::Tokio(
68 stream.timeout_repeating(interval_at(timeout_interval_start, duration)),
69 )
70 }
71}
72
73#[derive(Debug)]
74pub struct TimeElapsed(());
75
76impl fmt::Display for TimeElapsed {
77 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78 f.write_str("Time elapsed")
79 }
80}
81
82impl std::error::Error for TimeElapsed {}
83
84impl From<TimeElapsed> for std::io::Error {
85 fn from(_err: TimeElapsed) -> std::io::Error {
86 std::io::ErrorKind::TimedOut.into()
87 }
88}
89
90struct Sleep {
91 wake_time: Instant,
92}
93
94impl Sleep {
95 fn new(duration: Duration) -> Sleep {
96 Sleep {
97 wake_time: Instant::now()
98 .checked_add(duration)
99 .expect("sleep time too long"),
100 }
101 }
102
103 fn reset(&mut self, duration: Duration) {
104 self.wake_time = Instant::now()
105 .checked_add(duration)
106 .expect("sleep time too long");
107 }
108}
109
110impl Future for Sleep {
111 type Output = ();
112
113 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
114 if self.wake_time > Instant::now() {
115 cx.waker().wake_by_ref();
116 Poll::Pending
117 } else {
118 Poll::Ready(())
119 }
120 }
121}
122
123#[pin_project]
124struct BusywaitTimeout<T> {
125 #[pin]
126 value: T,
127 #[pin]
128 delay: Sleep,
129}
130
131impl<T> BusywaitTimeout<T> {
132 pub fn new(value: T, delay: Duration) -> BusywaitTimeout<T> {
133 BusywaitTimeout {
134 value,
135 delay: Sleep::new(delay),
136 }
137 }
138}
139
140impl<T> Future for BusywaitTimeout<T>
141where
142 T: Future,
143{
144 type Output = Result<T::Output, TimeElapsed>;
145
146 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
147 let this = self.project();
148
149 if let Poll::Ready(value) = this.value.poll(cx) {
150 Poll::Ready(Ok(value))
151 } else {
152 match this.delay.poll(cx) {
153 Poll::Ready(()) => Poll::Ready(Err(TimeElapsed(()))),
154 Poll::Pending => Poll::Pending,
155 }
156 }
157 }
158}
159
160#[pin_project]
161pub struct BusywaitTimeoutStream<S> {
162 #[pin]
163 stream: Fuse<S>,
164 #[pin]
165 deadline: Sleep,
166 duration: Duration,
167 poll_deadline: bool,
168}
169
170impl<S: Stream> BusywaitTimeoutStream<S> {
171 fn new(stream: S, duration: Duration) -> Self {
172 BusywaitTimeoutStream {
173 stream: stream.fuse(),
174 deadline: Sleep::new(duration),
175 duration,
176 poll_deadline: true,
177 }
178 }
179}
180
181impl<S: Stream> Stream for BusywaitTimeoutStream<S> {
182 type Item = Result<S::Item, TimeElapsed>;
183
184 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185 let mut this = self.project();
186
187 match this.stream.poll_next(cx) {
188 Poll::Ready(v) => {
189 if v.is_some() {
190 this.deadline.reset(*this.duration);
191 *this.poll_deadline = true;
192 }
193 Poll::Ready(v.map(Ok))
194 }
195 Poll::Pending => {
196 if *this.poll_deadline {
197 ready!(this.deadline.poll(cx));
198 *this.poll_deadline = false;
199 Poll::Ready(Some(Err(TimeElapsed(()))))
200 } else {
201 this.deadline.reset(*this.duration);
202 *this.poll_deadline = true;
203 Poll::Pending
204 }
205 }
206 }
207 }
208
209 fn size_hint(&self) -> (usize, Option<usize>) {
210 let (lower, upper) = self.stream.size_hint();
211
212 fn twice_plus_one(value: Option<usize>) -> Option<usize> {
217 value?.checked_mul(2)?.checked_add(1)
218 }
219
220 (lower, twice_plus_one(upper))
221 }
222}