1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5 future::Future,
6 pin::Pin,
7 sync::{
8 Arc,
9 atomic::{AtomicU64, Ordering},
10 },
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use super::{inner::ClockInner, manual::ManualClock};
16
17static NEXT_SLEEP_ID: AtomicU64 = AtomicU64::new(0);
19
20fn next_sleep_id() -> u64 {
22 NEXT_SLEEP_ID.fetch_add(1, Ordering::Relaxed)
23}
24
25#[pin_project(PinnedDrop)]
29pub struct ClockSleep {
30 #[pin]
31 inner: ClockSleepInner,
32}
33
34#[pin_project(project = ClockSleepInnerProj)]
35enum ClockSleepInner {
36 Realtime {
37 #[pin]
38 sleep: Sleep,
39 },
40 Manual {
41 wake_at_ms: i64,
42 sleep_id: u64,
43 clock: Arc<ManualClock>,
44 registered: bool,
45 coalesceable: bool,
47 },
48}
49
50impl ClockSleep {
51 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
52 Self::new_inner(clock_inner, duration, false)
53 }
54
55 pub(crate) fn new_coalesceable(clock_inner: &ClockInner, duration: Duration) -> Self {
56 Self::new_inner(clock_inner, duration, true)
57 }
58
59 fn new_inner(clock_inner: &ClockInner, duration: Duration, coalesceable: bool) -> Self {
60 let inner = match clock_inner {
61 ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
62 sleep: rt.sleep(duration),
63 },
64 ClockInner::Manual(manual) => {
65 let wake_at_ms = manual.now_ms() + duration.as_millis() as i64;
66
67 ClockSleepInner::Manual {
68 wake_at_ms,
69 sleep_id: next_sleep_id(),
70 clock: Arc::clone(manual),
71 registered: false,
72 coalesceable,
73 }
74 }
75 };
76
77 Self { inner }
78 }
79}
80
81impl Future for ClockSleep {
82 type Output = ();
83
84 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
85 let this = self.project();
86
87 match this.inner.project() {
88 ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
89
90 ClockSleepInnerProj::Manual {
91 wake_at_ms,
92 sleep_id,
93 clock,
94 registered,
95 coalesceable,
96 } => {
97 if clock.now_ms() >= *wake_at_ms {
99 return Poll::Ready(());
100 }
101
102 if !*registered {
104 if *coalesceable {
105 clock.register_coalesce_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
106 } else {
107 clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
108 }
109 *registered = true;
110 }
111
112 Poll::Pending
113 }
114 }
115 }
116}
117
118#[pinned_drop]
119impl PinnedDrop for ClockSleep {
120 fn drop(self: Pin<&mut Self>) {
121 if let ClockSleepInner::Manual {
123 sleep_id,
124 clock,
125 registered: true,
126 ..
127 } = &self.inner
128 {
129 clock.cancel_wake(*sleep_id);
130 }
131 }
132}
133
134#[pin_project]
138pub struct ClockTimeout<F> {
139 #[pin]
140 future: F,
141 #[pin]
142 sleep: ClockSleep,
143 completed: bool,
144}
145
146impl<F> ClockTimeout<F> {
147 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
148 Self {
149 future,
150 sleep: ClockSleep::new(clock_inner, duration),
151 completed: false,
152 }
153 }
154}
155
156impl<F: Future> Future for ClockTimeout<F> {
157 type Output = Result<F::Output, Elapsed>;
158
159 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160 let this = self.project();
161
162 if *this.completed {
163 panic!("ClockTimeout polled after completion");
164 }
165
166 if let Poll::Ready(output) = this.future.poll(cx) {
168 *this.completed = true;
169 return Poll::Ready(Ok(output));
170 }
171
172 if let Poll::Ready(()) = this.sleep.poll(cx) {
174 *this.completed = true;
175 return Poll::Ready(Err(Elapsed));
176 }
177
178 Poll::Pending
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub struct Elapsed;
185
186impl std::fmt::Display for Elapsed {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 write!(f, "deadline has elapsed")
189 }
190}
191
192impl std::error::Error for Elapsed {}