Skip to main content

es_entity/clock/
sleep.rs

1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5    future::Future,
6    pin::Pin,
7    sync::{
8        Arc,
9        atomic::{AtomicU64, Ordering},
10    },
11    task::{Context, Poll},
12    time::Duration,
13};
14
15use super::{inner::ClockInner, manual::ManualClock};
16
17/// Counter for unique sleep IDs.
18static NEXT_SLEEP_ID: AtomicU64 = AtomicU64::new(0);
19
20/// Generate a unique sleep ID.
21fn next_sleep_id() -> u64 {
22    NEXT_SLEEP_ID.fetch_add(1, Ordering::Relaxed)
23}
24
25/// A future that completes after a duration has elapsed on the clock.
26///
27/// Created by [`ClockHandle::sleep`](crate::ClockHandle::sleep).
28#[pin_project(PinnedDrop)]
29pub struct ClockSleep {
30    #[pin]
31    inner: ClockSleepInner,
32}
33
34#[pin_project(project = ClockSleepInnerProj)]
35enum ClockSleepInner {
36    Realtime {
37        #[pin]
38        sleep: Sleep,
39    },
40    Manual {
41        wake_at_ms: i64,
42        sleep_id: u64,
43        clock: Arc<ManualClock>,
44        registered: bool,
45        /// If true, this wake is registered in coalesce_wakes instead of pending_wakes.
46        coalesceable: bool,
47    },
48}
49
50impl ClockSleep {
51    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
52        Self::new_inner(clock_inner, duration, false)
53    }
54
55    pub(crate) fn new_coalesceable(clock_inner: &ClockInner, duration: Duration) -> Self {
56        Self::new_inner(clock_inner, duration, true)
57    }
58
59    fn new_inner(clock_inner: &ClockInner, duration: Duration, coalesceable: bool) -> Self {
60        let inner = match clock_inner {
61            ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
62                sleep: rt.sleep(duration),
63            },
64            ClockInner::Manual(manual) => {
65                let wake_at_ms = manual.now_ms() + duration.as_millis() as i64;
66
67                ClockSleepInner::Manual {
68                    wake_at_ms,
69                    sleep_id: next_sleep_id(),
70                    clock: Arc::clone(manual),
71                    registered: false,
72                    coalesceable,
73                }
74            }
75        };
76
77        Self { inner }
78    }
79}
80
81impl Future for ClockSleep {
82    type Output = ();
83
84    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
85        let this = self.project();
86
87        match this.inner.project() {
88            ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
89
90            ClockSleepInnerProj::Manual {
91                wake_at_ms,
92                sleep_id,
93                clock,
94                registered,
95                coalesceable,
96            } => {
97                // Check if we've reached wake time
98                if clock.now_ms() >= *wake_at_ms {
99                    return Poll::Ready(());
100                }
101
102                // Register for wake notification if not already done
103                if !*registered {
104                    if *coalesceable {
105                        clock.register_coalesce_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
106                    } else {
107                        clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
108                    }
109                    *registered = true;
110                }
111
112                Poll::Pending
113            }
114        }
115    }
116}
117
118#[pinned_drop]
119impl PinnedDrop for ClockSleep {
120    fn drop(self: Pin<&mut Self>) {
121        // Clean up pending wake registration if cancelled
122        if let ClockSleepInner::Manual {
123            sleep_id,
124            clock,
125            registered: true,
126            ..
127        } = &self.inner
128        {
129            clock.cancel_wake(*sleep_id);
130        }
131    }
132}
133
134/// A future that completes with a timeout after a duration has elapsed on the clock.
135///
136/// Created by [`ClockHandle::timeout`](crate::ClockHandle::timeout).
137#[pin_project]
138pub struct ClockTimeout<F> {
139    #[pin]
140    future: F,
141    #[pin]
142    sleep: ClockSleep,
143    completed: bool,
144}
145
146impl<F> ClockTimeout<F> {
147    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
148        Self {
149            future,
150            sleep: ClockSleep::new(clock_inner, duration),
151            completed: false,
152        }
153    }
154}
155
156impl<F: Future> Future for ClockTimeout<F> {
157    type Output = Result<F::Output, Elapsed>;
158
159    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160        let this = self.project();
161
162        if *this.completed {
163            panic!("ClockTimeout polled after completion");
164        }
165
166        // Check the future first
167        if let Poll::Ready(output) = this.future.poll(cx) {
168            *this.completed = true;
169            return Poll::Ready(Ok(output));
170        }
171
172        // Check if timeout elapsed
173        if let Poll::Ready(()) = this.sleep.poll(cx) {
174            *this.completed = true;
175            return Poll::Ready(Err(Elapsed));
176        }
177
178        Poll::Pending
179    }
180}
181
182/// Error returned when a timeout expires.
183#[derive(Debug, Clone, Copy, PartialEq, Eq)]
184pub struct Elapsed;
185
186impl std::fmt::Display for Elapsed {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        write!(f, "deadline has elapsed")
189    }
190}
191
192impl std::error::Error for Elapsed {}