Skip to main content

es_entity/clock/
sleep.rs

1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5    future::Future,
6    pin::Pin,
7    sync::{
8        Arc,
9        atomic::{AtomicU64, Ordering},
10    },
11    task::{Context, Poll},
12    time::Duration,
13};
14
15use super::{inner::ClockInner, manual::ManualClock};
16
17/// Counter for unique sleep IDs.
18static NEXT_SLEEP_ID: AtomicU64 = AtomicU64::new(0);
19
20/// Generate a unique sleep ID.
21fn next_sleep_id() -> u64 {
22    NEXT_SLEEP_ID.fetch_add(1, Ordering::Relaxed)
23}
24
25/// A future that completes after a duration has elapsed on the clock.
26///
27/// Created by [`ClockHandle::sleep`](crate::ClockHandle::sleep).
28#[pin_project(PinnedDrop)]
29pub struct ClockSleep {
30    #[pin]
31    inner: ClockSleepInner,
32}
33
34#[pin_project(project = ClockSleepInnerProj)]
35enum ClockSleepInner {
36    Realtime {
37        #[pin]
38        sleep: Sleep,
39    },
40    Manual {
41        wake_at_ms: i64,
42        sleep_id: u64,
43        clock: Arc<ManualClock>,
44        registered: bool,
45    },
46}
47
48impl ClockSleep {
49    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
50        let inner = match clock_inner {
51            ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
52                sleep: rt.sleep(duration),
53            },
54            ClockInner::Manual(manual) => {
55                let wake_at_ms = manual.now_ms() + duration.as_millis() as i64;
56
57                ClockSleepInner::Manual {
58                    wake_at_ms,
59                    sleep_id: next_sleep_id(),
60                    clock: Arc::clone(manual),
61                    registered: false,
62                }
63            }
64        };
65
66        Self { inner }
67    }
68}
69
70impl Future for ClockSleep {
71    type Output = ();
72
73    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
74        let this = self.project();
75
76        match this.inner.project() {
77            ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
78
79            ClockSleepInnerProj::Manual {
80                wake_at_ms,
81                sleep_id,
82                clock,
83                registered,
84            } => {
85                // Check if we've reached wake time
86                if clock.now_ms() >= *wake_at_ms {
87                    return Poll::Ready(());
88                }
89
90                // Register for wake notification if not already done
91                if !*registered {
92                    clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
93                    *registered = true;
94                }
95
96                Poll::Pending
97            }
98        }
99    }
100}
101
102#[pinned_drop]
103impl PinnedDrop for ClockSleep {
104    fn drop(self: Pin<&mut Self>) {
105        // Clean up pending wake registration if cancelled
106        if let ClockSleepInner::Manual {
107            sleep_id,
108            clock,
109            registered: true,
110            ..
111        } = &self.inner
112        {
113            clock.cancel_wake(*sleep_id);
114        }
115    }
116}
117
118/// A future that completes with a timeout after a duration has elapsed on the clock.
119///
120/// Created by [`ClockHandle::timeout`](crate::ClockHandle::timeout).
121#[pin_project]
122pub struct ClockTimeout<F> {
123    #[pin]
124    future: F,
125    #[pin]
126    sleep: ClockSleep,
127    completed: bool,
128}
129
130impl<F> ClockTimeout<F> {
131    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
132        Self {
133            future,
134            sleep: ClockSleep::new(clock_inner, duration),
135            completed: false,
136        }
137    }
138}
139
140impl<F: Future> Future for ClockTimeout<F> {
141    type Output = Result<F::Output, Elapsed>;
142
143    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144        let this = self.project();
145
146        if *this.completed {
147            panic!("ClockTimeout polled after completion");
148        }
149
150        // Check the future first
151        if let Poll::Ready(output) = this.future.poll(cx) {
152            *this.completed = true;
153            return Poll::Ready(Ok(output));
154        }
155
156        // Check if timeout elapsed
157        if let Poll::Ready(()) = this.sleep.poll(cx) {
158            *this.completed = true;
159            return Poll::Ready(Err(Elapsed));
160        }
161
162        Poll::Pending
163    }
164}
165
166/// Error returned when a timeout expires.
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub struct Elapsed;
169
170impl std::fmt::Display for Elapsed {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        write!(f, "deadline has elapsed")
173    }
174}
175
176impl std::error::Error for Elapsed {}