1use std::{
12 collections::BTreeMap,
13 error::Error,
14 fmt::Display,
15 future::Future,
16 io,
17 marker::PhantomData,
18 pin::Pin,
19 task::{Context, Poll, Waker},
20 time::{Duration, Instant},
21};
22
23use futures_core::Stream;
24use pin_project_lite::pin_project;
25
26use crate::{Id, REACTOR};
27
28pub(crate) struct TimerQueue {
29 current_id: Id,
30 timers: BTreeMap<(Instant, Id), Waker>,
34}
35
36impl TimerQueue {
37 pub(crate) const fn new() -> Self {
38 Self {
39 current_id: const { Id::new(1) },
40 timers: BTreeMap::new(),
41 }
42 }
43
44 pub(crate) fn register(&mut self, expiry: Instant, mut waker: Waker) -> Id {
48 loop {
49 let id = self.current_id;
50 self.current_id = id.overflowing_incr();
51 waker = match self.timers.insert((expiry, id), waker) {
52 None => break id,
53 Some(prev_waker) => self.timers.insert((expiry, id), prev_waker).unwrap(),
56 }
57 }
58 }
59
60 pub(crate) fn modify(&mut self, id: Id, expiry: Instant, waker: &Waker) {
62 if let Some(wk) = self.timers.get_mut(&(expiry, id)) {
63 wk.clone_from(waker)
64 } else {
65 log::error!(
66 "{:?} Modifying non-existent timer ID = {}",
67 std::thread::current().id(),
68 id.0
69 );
70 }
71 }
72
73 pub(crate) fn cancel(&mut self, id: Id, expiry: Instant) {
75 self.timers.remove(&(expiry, id));
77 }
78
79 pub(crate) fn next_timeout(&mut self) -> Option<Duration> {
80 let now = Instant::now();
81 self.timers
82 .first_key_value()
83 .map(|((expiry, _), _)| expiry.saturating_duration_since(now))
84 }
85
86 pub(crate) fn clear_expired(&mut self) {
87 let now = Instant::now();
88 while let Some(entry) = self.timers.first_entry() {
90 let expiry = entry.key().0;
91 if expiry <= now {
92 entry.remove().wake();
93 } else {
94 break;
95 }
96 }
97 }
98
99 #[cfg(test)]
100 pub(crate) fn is_empty(&self) -> bool {
101 self.timers.is_empty()
102 }
103}
104
105#[derive(Debug)]
120#[must_use = "Futures do nothing unless polled"]
121pub struct Timer {
122 expiry: Instant,
123 timer_id: Option<Id>,
124 _phantom: PhantomData<*const ()>,
126}
127
128unsafe impl Sync for Timer {}
131
132impl Timer {
133 pub fn at(expiry: Instant) -> Self {
135 Timer {
136 expiry,
137 timer_id: None,
138 _phantom: PhantomData,
139 }
140 }
141
142 pub fn delay(delay: Duration) -> Self {
144 Self::at(Instant::now() + delay)
145 }
146
147 fn register(&mut self, cx: &mut Context<'_>) {
148 REACTOR.with(|r| match self.timer_id {
149 None => {
150 self.timer_id = Some(r.register_timer(self.expiry, cx.waker().clone()));
151 }
152 Some(id) => r.modify_timer(id, self.expiry, cx.waker()),
153 });
154 }
155}
156
157impl Future for Timer {
158 type Output = Instant;
159
160 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
161 if self.expiry <= Instant::now() {
162 if let Some(id) = self.timer_id.take() {
164 REACTOR.with(|r| r.cancel_timer(id, self.expiry));
165 }
166 return Poll::Ready(self.expiry);
167 }
168
169 self.register(cx);
170 Poll::Pending
171 }
172}
173
174impl Drop for Timer {
175 fn drop(&mut self) {
176 if let Some(id) = self.timer_id.take() {
177 REACTOR.with(|r| r.cancel_timer(id, self.expiry));
178 }
179 }
180}
181
182pub fn sleep(duration: Duration) -> Timer {
184 Timer::delay(duration)
185}
186
187#[must_use = "Streams do nothing unless polled"]
209pub struct Periodic {
210 timer: Timer,
211 period: Duration,
212}
213
214impl Periodic {
215 #[allow(clippy::self_named_constructors)]
217 pub fn periodic(period: Duration) -> Self {
218 Self {
219 timer: Timer::delay(period),
220 period,
221 }
222 }
223
224 pub fn periodic_at(start: Instant, period: Duration) -> Self {
226 Self {
227 timer: Timer::at(start),
228 period,
229 }
230 }
231
232 pub fn set_period(&mut self, period: Duration) {
234 self.period = period;
235 }
236}
237
238impl Stream for Periodic {
239 type Item = Instant;
240
241 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
242 if let Poll::Ready(expiry) = Pin::new(&mut self.timer).poll(cx) {
243 let next = expiry + self.period;
244 self.timer.expiry = next;
245 Poll::Ready(Some(expiry))
246 } else {
247 Poll::Pending
248 }
249 }
250}
251
252#[derive(Debug)]
254pub struct TimedOut(());
255impl Display for TimedOut {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 f.write_str("Future timed out")
258 }
259}
260impl Error for TimedOut {}
261
262impl From<TimedOut> for io::Error {
263 fn from(_: TimedOut) -> Self {
264 io::Error::from(io::ErrorKind::TimedOut)
265 }
266}
267
268pin_project! {
269 #[derive(Debug)]
271 #[must_use = "Futures do nothing unless polled"]
272 pub struct Timeout<F> {
273 #[pin]
274 timer: Timer,
275 #[pin]
276 fut: F,
277 }
278}
279
280impl<F: Future> Future for Timeout<F> {
281 type Output = Result<F::Output, TimedOut>;
282
283 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
284 if let Poll::Ready(out) = self.as_mut().project().fut.poll(cx) {
285 return Poll::Ready(Ok(out));
286 }
287 if self.as_mut().project().timer.poll(cx).is_ready() {
288 return Poll::Ready(Err(TimedOut(())));
289 }
290 Poll::Pending
291 }
292}
293
294pub fn timeout<F: Future>(fut: F, timeout: Duration) -> Timeout<F> {
317 Timeout {
318 timer: Timer::delay(timeout),
319 fut,
320 }
321}
322
323pub fn timeout_at<F: Future>(fut: F, expiry: Instant) -> Timeout<F> {
327 Timeout {
328 timer: Timer::at(expiry),
329 fut,
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use std::{
336 pin::{pin, Pin},
337 sync::Arc,
338 };
339
340 use crate::test::MockWaker;
341
342 use super::*;
343
344 #[test]
345 fn next_timeout() {
346 let wakers: Vec<_> = (0..3).map(|_| Arc::new(MockWaker::default())).collect();
347 let mut tq = TimerQueue::new();
348 assert!(tq.next_timeout().is_none());
349
350 tq.register(Instant::now(), wakers[0].clone().into());
352 tq.register(
353 Instant::now() - Duration::from_secs(1),
354 wakers[1].clone().into(),
355 );
356 tq.register(
357 Instant::now() + Duration::from_millis(50),
358 wakers[2].clone().into(),
359 );
360 assert_eq!(tq.next_timeout().unwrap(), Duration::ZERO);
361
362 tq.clear_expired();
363 assert!(tq.next_timeout().unwrap() > Duration::from_millis(40));
364 assert!(wakers[0].get());
365 assert!(wakers[1].get());
366 assert!(!wakers[2].get());
367
368 std::thread::sleep(Duration::from_millis(50));
370 tq.clear_expired();
371 assert!(tq.next_timeout().is_none());
372 assert!(wakers[2].get());
373
374 assert!(tq.timers.is_empty());
375 }
376
377 #[test]
378 fn modify() {
379 let wakers: Vec<_> = (0..2).map(|_| Arc::new(MockWaker::default())).collect();
380 let mut tq = TimerQueue::new();
381
382 let expiry = Instant::now() + Duration::from_millis(10);
383 let id = tq.register(expiry, wakers[0].clone().into());
384 tq.clear_expired();
385 assert!(tq.next_timeout().is_some());
386
387 tq.modify(id, expiry, &wakers[1].clone().into());
389 std::thread::sleep(Duration::from_millis(10));
390 tq.clear_expired();
391 assert!(tq.next_timeout().is_none());
392 assert!(!wakers[0].get());
393 assert!(wakers[1].get());
394
395 assert!(tq.timers.is_empty());
396 }
397
398 #[test]
399 fn cancel() {
400 let waker = Arc::new(MockWaker::default());
401 let mut tq = TimerQueue::new();
402
403 let expiry = Instant::now() + Duration::from_secs(10);
404 let id = tq.register(expiry, waker.clone().into());
405 tq.clear_expired();
406 assert!(tq.next_timeout().is_some());
407
408 tq.cancel(id, expiry);
410 tq.clear_expired();
411 assert!(tq.next_timeout().is_none());
412 assert!(!waker.get());
413
414 assert!(tq.timers.is_empty());
415 }
416
417 #[test]
418 fn timer_expired() {
419 let waker = Arc::new(MockWaker::default());
420 let mut timer = Timer::at(Instant::now());
421
422 assert!(Pin::new(&mut timer)
423 .poll(&mut Context::from_waker(&waker.into()))
424 .is_ready());
425 assert!(timer.timer_id.is_none());
426
427 assert!(REACTOR.with(|r| r.is_empty()));
428 }
429
430 #[test]
431 fn timer() {
432 let waker = Arc::new(MockWaker::default());
433 let mut timer = pin!(Timer::delay(Duration::from_millis(10)));
434
435 assert!(timer
436 .as_mut()
437 .poll(&mut Context::from_waker(&waker.clone().into()))
438 .is_pending());
439 assert!(timer.timer_id.is_some());
440 assert!(!REACTOR.with(|r| r.is_empty()));
441
442 std::thread::sleep(Duration::from_millis(10));
443 assert!(timer
444 .as_mut()
445 .poll(&mut Context::from_waker(&waker.into()))
446 .is_ready());
447 assert!(timer.timer_id.is_none());
448 assert!(REACTOR.with(|r| r.is_empty()));
449 }
450
451 #[test]
452 fn periodic() {
453 let waker = Arc::new(MockWaker::default());
454 let mut periodic = pin!(Periodic::periodic(Duration::from_millis(5)));
455
456 assert!(periodic
457 .as_mut()
458 .poll_next(&mut Context::from_waker(&waker.clone().into()))
459 .is_pending());
460 assert!(!REACTOR.with(|r| r.is_empty()));
461
462 std::thread::sleep(Duration::from_millis(5));
463 assert!(periodic
464 .as_mut()
465 .poll_next(&mut Context::from_waker(&waker.clone().into()))
466 .is_ready());
467 assert!(REACTOR.with(|r| r.is_empty()));
468
469 std::thread::sleep(Duration::from_millis(5));
470 assert!(periodic
471 .as_mut()
472 .poll_next(&mut Context::from_waker(&waker.clone().into()))
473 .is_ready());
474 assert!(REACTOR.with(|r| r.is_empty()));
475 }
476
477 #[test]
478 fn timeouts() {
479 let waker = Arc::new(MockWaker::default()).into();
480
481 let res1 = Pin::new(&mut timeout(
482 Timer::at(Instant::now()),
483 Duration::from_secs(10),
484 ))
485 .poll(&mut Context::from_waker(&waker));
486 assert!(matches!(res1, Poll::Ready(Ok(_))));
487
488 let res2 = Pin::new(&mut timeout_at(
489 Timer::delay(Duration::from_secs(10)),
490 Instant::now(),
491 ))
492 .poll(&mut Context::from_waker(&waker));
493 assert!(matches!(res2, Poll::Ready(Err(_))));
494 }
495}