1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5 future::Future,
6 pin::Pin,
7 sync::{
8 Arc,
9 atomic::{AtomicU64, Ordering},
10 },
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use super::{inner::ClockInner, manual::ManualClock};
16
17static NEXT_SLEEP_ID: AtomicU64 = AtomicU64::new(0);
19
20fn next_sleep_id() -> u64 {
22 NEXT_SLEEP_ID.fetch_add(1, Ordering::Relaxed)
23}
24
25#[pin_project(PinnedDrop)]
29pub struct ClockSleep {
30 #[pin]
31 inner: ClockSleepInner,
32}
33
34#[pin_project(project = ClockSleepInnerProj)]
35enum ClockSleepInner {
36 Realtime {
37 #[pin]
38 sleep: Sleep,
39 },
40 Manual {
41 wake_at_ms: i64,
42 sleep_id: u64,
43 clock: Arc<ManualClock>,
44 registered: bool,
45 },
46}
47
48impl ClockSleep {
49 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
50 let inner = match clock_inner {
51 ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
52 sleep: rt.sleep(duration),
53 },
54 ClockInner::Manual(manual) => {
55 let wake_at_ms = manual.now_ms() + duration.as_millis() as i64;
56
57 ClockSleepInner::Manual {
58 wake_at_ms,
59 sleep_id: next_sleep_id(),
60 clock: Arc::clone(manual),
61 registered: false,
62 }
63 }
64 };
65
66 Self { inner }
67 }
68}
69
70impl Future for ClockSleep {
71 type Output = ();
72
73 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
74 let this = self.project();
75
76 match this.inner.project() {
77 ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
78
79 ClockSleepInnerProj::Manual {
80 wake_at_ms,
81 sleep_id,
82 clock,
83 registered,
84 } => {
85 if clock.now_ms() >= *wake_at_ms {
87 return Poll::Ready(());
88 }
89
90 if !*registered {
92 clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
93 *registered = true;
94 }
95
96 Poll::Pending
97 }
98 }
99 }
100}
101
102#[pinned_drop]
103impl PinnedDrop for ClockSleep {
104 fn drop(self: Pin<&mut Self>) {
105 if let ClockSleepInner::Manual {
107 sleep_id,
108 clock,
109 registered: true,
110 ..
111 } = &self.inner
112 {
113 clock.cancel_wake(*sleep_id);
114 }
115 }
116}
117
118#[pin_project]
122pub struct ClockTimeout<F> {
123 #[pin]
124 future: F,
125 #[pin]
126 sleep: ClockSleep,
127 completed: bool,
128}
129
130impl<F> ClockTimeout<F> {
131 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
132 Self {
133 future,
134 sleep: ClockSleep::new(clock_inner, duration),
135 completed: false,
136 }
137 }
138}
139
140impl<F: Future> Future for ClockTimeout<F> {
141 type Output = Result<F::Output, Elapsed>;
142
143 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144 let this = self.project();
145
146 if *this.completed {
147 panic!("ClockTimeout polled after completion");
148 }
149
150 if let Poll::Ready(output) = this.future.poll(cx) {
152 *this.completed = true;
153 return Poll::Ready(Ok(output));
154 }
155
156 if let Poll::Ready(()) = this.sleep.poll(cx) {
158 *this.completed = true;
159 return Poll::Ready(Err(Elapsed));
160 }
161
162 Poll::Pending
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub struct Elapsed;
169
170impl std::fmt::Display for Elapsed {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 write!(f, "deadline has elapsed")
173 }
174}
175
176impl std::error::Error for Elapsed {}