es_entity/clock/
sleep.rs

1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use super::{
13    artificial::{ArtificialClock, next_sleep_id},
14    inner::ClockInner,
15};
16
17/// A future that completes after a duration has elapsed on the clock.
18///
19/// Created by [`ClockHandle::sleep`](crate::ClockHandle::sleep).
20#[pin_project(PinnedDrop)]
21pub struct ClockSleep {
22    #[pin]
23    inner: ClockSleepInner,
24}
25
26#[pin_project(project = ClockSleepInnerProj)]
27enum ClockSleepInner {
28    Realtime {
29        #[pin]
30        sleep: Sleep,
31    },
32    ArtificialAuto {
33        #[pin]
34        sleep: Sleep,
35        wake_at_ms: i64,
36        clock: Arc<ArtificialClock>,
37    },
38    ArtificialManual {
39        wake_at_ms: i64,
40        sleep_id: u64,
41        clock: Arc<ArtificialClock>,
42        registered: bool,
43        /// Fallback timer used when clock transitions to realtime mode.
44        #[pin]
45        realtime_fallback: Option<Sleep>,
46    },
47}
48
49impl ClockSleep {
50    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
51        let inner = match clock_inner {
52            ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
53                sleep: rt.sleep(duration),
54            },
55            ClockInner::Artificial(artificial) => {
56                let wake_at_ms = artificial.now_ms() + duration.as_millis() as i64;
57
58                if artificial.is_manual() {
59                    ClockSleepInner::ArtificialManual {
60                        wake_at_ms,
61                        sleep_id: next_sleep_id(),
62                        clock: Arc::clone(artificial),
63                        registered: false,
64                        realtime_fallback: None,
65                    }
66                } else {
67                    // Auto-advance mode uses real tokio sleep with scaled duration
68                    let real_duration = artificial.real_duration(duration);
69                    ClockSleepInner::ArtificialAuto {
70                        sleep: tokio::time::sleep(real_duration),
71                        wake_at_ms,
72                        clock: Arc::clone(artificial),
73                    }
74                }
75            }
76        };
77
78        Self { inner }
79    }
80}
81
82impl Future for ClockSleep {
83    type Output = ();
84
85    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
86        let this = self.project();
87
88        match this.inner.project() {
89            ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
90
91            ClockSleepInnerProj::ArtificialAuto {
92                sleep,
93                wake_at_ms,
94                clock,
95            } => {
96                // Check if artificial time has reached wake time
97                if clock.now_ms() >= *wake_at_ms {
98                    return Poll::Ready(());
99                }
100                // Otherwise wait for real timer
101                sleep.poll(cx)
102            }
103
104            ClockSleepInnerProj::ArtificialManual {
105                wake_at_ms,
106                sleep_id,
107                clock,
108                registered,
109                mut realtime_fallback,
110            } => {
111                // Check if we've reached wake time
112                if clock.now_ms() >= *wake_at_ms {
113                    return Poll::Ready(());
114                }
115
116                // If clock has transitioned to realtime, use a real timer for remaining time
117                if clock.is_realtime() {
118                    // Create fallback timer if not already created
119                    if realtime_fallback.is_none() {
120                        let remaining_ms = (*wake_at_ms - clock.now_ms()).max(0) as u64;
121                        realtime_fallback.set(Some(tokio::time::sleep(Duration::from_millis(
122                            remaining_ms,
123                        ))));
124                    }
125
126                    // Poll the fallback timer
127                    if let Some(sleep) = realtime_fallback.as_pin_mut() {
128                        return sleep.poll(cx);
129                    }
130                }
131
132                // Register for wake notification if not already done
133                if !*registered {
134                    clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
135                    *registered = true;
136                }
137
138                Poll::Pending
139            }
140        }
141    }
142}
143
144#[pinned_drop]
145impl PinnedDrop for ClockSleep {
146    fn drop(self: Pin<&mut Self>) {
147        // Clean up pending wake registration if cancelled
148        if let ClockSleepInner::ArtificialManual {
149            sleep_id,
150            clock,
151            registered: true,
152            ..
153        } = &self.inner
154        {
155            clock.cancel_wake(*sleep_id);
156        }
157    }
158}
159
160/// A future that completes with a timeout after a duration has elapsed on the clock.
161///
162/// Created by [`ClockHandle::timeout`](crate::ClockHandle::timeout).
163#[pin_project]
164pub struct ClockTimeout<F> {
165    #[pin]
166    future: F,
167    #[pin]
168    sleep: ClockSleep,
169    completed: bool,
170}
171
172impl<F> ClockTimeout<F> {
173    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
174        Self {
175            future,
176            sleep: ClockSleep::new(clock_inner, duration),
177            completed: false,
178        }
179    }
180}
181
182impl<F: Future> Future for ClockTimeout<F> {
183    type Output = Result<F::Output, Elapsed>;
184
185    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186        let this = self.project();
187
188        if *this.completed {
189            panic!("ClockTimeout polled after completion");
190        }
191
192        // Check the future first
193        if let Poll::Ready(output) = this.future.poll(cx) {
194            *this.completed = true;
195            return Poll::Ready(Ok(output));
196        }
197
198        // Check if timeout elapsed
199        if let Poll::Ready(()) = this.sleep.poll(cx) {
200            *this.completed = true;
201            return Poll::Ready(Err(Elapsed));
202        }
203
204        Poll::Pending
205    }
206}
207
208/// Error returned when a timeout expires.
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
210pub struct Elapsed;
211
212impl std::fmt::Display for Elapsed {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        write!(f, "deadline has elapsed")
215    }
216}
217
218impl std::error::Error for Elapsed {}