1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use super::{
13 artificial::{ArtificialClock, next_sleep_id},
14 inner::ClockInner,
15};
16
17#[pin_project(PinnedDrop)]
21pub struct ClockSleep {
22 #[pin]
23 inner: ClockSleepInner,
24}
25
26#[pin_project(project = ClockSleepInnerProj)]
27enum ClockSleepInner {
28 Realtime {
29 #[pin]
30 sleep: Sleep,
31 },
32 ArtificialAuto {
33 #[pin]
34 sleep: Sleep,
35 wake_at_ms: i64,
36 clock: Arc<ArtificialClock>,
37 },
38 ArtificialManual {
39 wake_at_ms: i64,
40 sleep_id: u64,
41 clock: Arc<ArtificialClock>,
42 registered: bool,
43 #[pin]
45 realtime_fallback: Option<Sleep>,
46 },
47}
48
49impl ClockSleep {
50 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
51 let inner = match clock_inner {
52 ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
53 sleep: rt.sleep(duration),
54 },
55 ClockInner::Artificial(artificial) => {
56 let wake_at_ms = artificial.now_ms() + duration.as_millis() as i64;
57
58 if artificial.is_manual() {
59 ClockSleepInner::ArtificialManual {
60 wake_at_ms,
61 sleep_id: next_sleep_id(),
62 clock: Arc::clone(artificial),
63 registered: false,
64 realtime_fallback: None,
65 }
66 } else {
67 let real_duration = artificial.real_duration(duration);
69 ClockSleepInner::ArtificialAuto {
70 sleep: tokio::time::sleep(real_duration),
71 wake_at_ms,
72 clock: Arc::clone(artificial),
73 }
74 }
75 }
76 };
77
78 Self { inner }
79 }
80}
81
82impl Future for ClockSleep {
83 type Output = ();
84
85 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
86 let this = self.project();
87
88 match this.inner.project() {
89 ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
90
91 ClockSleepInnerProj::ArtificialAuto {
92 sleep,
93 wake_at_ms,
94 clock,
95 } => {
96 if clock.now_ms() >= *wake_at_ms {
98 return Poll::Ready(());
99 }
100 sleep.poll(cx)
102 }
103
104 ClockSleepInnerProj::ArtificialManual {
105 wake_at_ms,
106 sleep_id,
107 clock,
108 registered,
109 mut realtime_fallback,
110 } => {
111 if clock.now_ms() >= *wake_at_ms {
113 return Poll::Ready(());
114 }
115
116 if clock.is_realtime() {
118 if realtime_fallback.is_none() {
120 let remaining_ms = (*wake_at_ms - clock.now_ms()).max(0) as u64;
121 realtime_fallback.set(Some(tokio::time::sleep(Duration::from_millis(
122 remaining_ms,
123 ))));
124 }
125
126 if let Some(sleep) = realtime_fallback.as_pin_mut() {
128 return sleep.poll(cx);
129 }
130 }
131
132 if !*registered {
134 clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
135 *registered = true;
136 }
137
138 Poll::Pending
139 }
140 }
141 }
142}
143
144#[pinned_drop]
145impl PinnedDrop for ClockSleep {
146 fn drop(self: Pin<&mut Self>) {
147 if let ClockSleepInner::ArtificialManual {
149 sleep_id,
150 clock,
151 registered: true,
152 ..
153 } = &self.inner
154 {
155 clock.cancel_wake(*sleep_id);
156 }
157 }
158}
159
160#[pin_project]
164pub struct ClockTimeout<F> {
165 #[pin]
166 future: F,
167 #[pin]
168 sleep: ClockSleep,
169 completed: bool,
170}
171
172impl<F> ClockTimeout<F> {
173 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
174 Self {
175 future,
176 sleep: ClockSleep::new(clock_inner, duration),
177 completed: false,
178 }
179 }
180}
181
182impl<F: Future> Future for ClockTimeout<F> {
183 type Output = Result<F::Output, Elapsed>;
184
185 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186 let this = self.project();
187
188 if *this.completed {
189 panic!("ClockTimeout polled after completion");
190 }
191
192 if let Poll::Ready(output) = this.future.poll(cx) {
194 *this.completed = true;
195 return Poll::Ready(Ok(output));
196 }
197
198 if let Poll::Ready(()) = this.sleep.poll(cx) {
200 *this.completed = true;
201 return Poll::Ready(Err(Elapsed));
202 }
203
204 Poll::Pending
205 }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
210pub struct Elapsed;
211
212impl std::fmt::Display for Elapsed {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 write!(f, "deadline has elapsed")
215 }
216}
217
218impl std::error::Error for Elapsed {}