1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use super::{
13 inner::ClockInner,
14 manual::{ManualClock, next_sleep_id},
15};
16
17#[pin_project(PinnedDrop)]
21pub struct ClockSleep {
22 #[pin]
23 inner: ClockSleepInner,
24}
25
26#[pin_project(project = ClockSleepInnerProj)]
27enum ClockSleepInner {
28 Realtime {
29 #[pin]
30 sleep: Sleep,
31 },
32 Manual {
33 wake_at_ms: i64,
34 sleep_id: u64,
35 clock: Arc<ManualClock>,
36 registered: bool,
37 },
38}
39
40impl ClockSleep {
41 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
42 let inner = match clock_inner {
43 ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
44 sleep: rt.sleep(duration),
45 },
46 ClockInner::Manual(manual) => {
47 let wake_at_ms = manual.now_ms() + duration.as_millis() as i64;
48
49 ClockSleepInner::Manual {
50 wake_at_ms,
51 sleep_id: next_sleep_id(),
52 clock: Arc::clone(manual),
53 registered: false,
54 }
55 }
56 };
57
58 Self { inner }
59 }
60}
61
62impl Future for ClockSleep {
63 type Output = ();
64
65 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
66 let this = self.project();
67
68 match this.inner.project() {
69 ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
70
71 ClockSleepInnerProj::Manual {
72 wake_at_ms,
73 sleep_id,
74 clock,
75 registered,
76 } => {
77 if clock.now_ms() >= *wake_at_ms {
79 return Poll::Ready(());
80 }
81
82 if !*registered {
84 clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
85 *registered = true;
86 }
87
88 Poll::Pending
89 }
90 }
91 }
92}
93
94#[pinned_drop]
95impl PinnedDrop for ClockSleep {
96 fn drop(self: Pin<&mut Self>) {
97 if let ClockSleepInner::Manual {
99 sleep_id,
100 clock,
101 registered: true,
102 ..
103 } = &self.inner
104 {
105 clock.cancel_wake(*sleep_id);
106 }
107 }
108}
109
110#[pin_project]
114pub struct ClockTimeout<F> {
115 #[pin]
116 future: F,
117 #[pin]
118 sleep: ClockSleep,
119 completed: bool,
120}
121
122impl<F> ClockTimeout<F> {
123 pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
124 Self {
125 future,
126 sleep: ClockSleep::new(clock_inner, duration),
127 completed: false,
128 }
129 }
130}
131
132impl<F: Future> Future for ClockTimeout<F> {
133 type Output = Result<F::Output, Elapsed>;
134
135 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 let this = self.project();
137
138 if *this.completed {
139 panic!("ClockTimeout polled after completion");
140 }
141
142 if let Poll::Ready(output) = this.future.poll(cx) {
144 *this.completed = true;
145 return Poll::Ready(Ok(output));
146 }
147
148 if let Poll::Ready(()) = this.sleep.poll(cx) {
150 *this.completed = true;
151 return Poll::Ready(Err(Elapsed));
152 }
153
154 Poll::Pending
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub struct Elapsed;
161
162impl std::fmt::Display for Elapsed {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 write!(f, "deadline has elapsed")
165 }
166}
167
168impl std::error::Error for Elapsed {}