Skip to main content

es_entity/clock/
sleep.rs

1use pin_project::{pin_project, pinned_drop};
2use tokio::time::Sleep;
3
4use std::{
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use super::{
13    inner::ClockInner,
14    manual::{ManualClock, next_sleep_id},
15};
16
17/// A future that completes after a duration has elapsed on the clock.
18///
19/// Created by [`ClockHandle::sleep`](crate::ClockHandle::sleep).
20#[pin_project(PinnedDrop)]
21pub struct ClockSleep {
22    #[pin]
23    inner: ClockSleepInner,
24}
25
26#[pin_project(project = ClockSleepInnerProj)]
27enum ClockSleepInner {
28    Realtime {
29        #[pin]
30        sleep: Sleep,
31    },
32    Manual {
33        wake_at_ms: i64,
34        sleep_id: u64,
35        clock: Arc<ManualClock>,
36        registered: bool,
37    },
38}
39
40impl ClockSleep {
41    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration) -> Self {
42        let inner = match clock_inner {
43            ClockInner::Realtime(rt) => ClockSleepInner::Realtime {
44                sleep: rt.sleep(duration),
45            },
46            ClockInner::Manual(manual) => {
47                let wake_at_ms = manual.now_ms() + duration.as_millis() as i64;
48
49                ClockSleepInner::Manual {
50                    wake_at_ms,
51                    sleep_id: next_sleep_id(),
52                    clock: Arc::clone(manual),
53                    registered: false,
54                }
55            }
56        };
57
58        Self { inner }
59    }
60}
61
62impl Future for ClockSleep {
63    type Output = ();
64
65    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
66        let this = self.project();
67
68        match this.inner.project() {
69            ClockSleepInnerProj::Realtime { sleep } => sleep.poll(cx),
70
71            ClockSleepInnerProj::Manual {
72                wake_at_ms,
73                sleep_id,
74                clock,
75                registered,
76            } => {
77                // Check if we've reached wake time
78                if clock.now_ms() >= *wake_at_ms {
79                    return Poll::Ready(());
80                }
81
82                // Register for wake notification if not already done
83                if !*registered {
84                    clock.register_wake(*wake_at_ms, *sleep_id, cx.waker().clone());
85                    *registered = true;
86                }
87
88                Poll::Pending
89            }
90        }
91    }
92}
93
94#[pinned_drop]
95impl PinnedDrop for ClockSleep {
96    fn drop(self: Pin<&mut Self>) {
97        // Clean up pending wake registration if cancelled
98        if let ClockSleepInner::Manual {
99            sleep_id,
100            clock,
101            registered: true,
102            ..
103        } = &self.inner
104        {
105            clock.cancel_wake(*sleep_id);
106        }
107    }
108}
109
110/// A future that completes with a timeout after a duration has elapsed on the clock.
111///
112/// Created by [`ClockHandle::timeout`](crate::ClockHandle::timeout).
113#[pin_project]
114pub struct ClockTimeout<F> {
115    #[pin]
116    future: F,
117    #[pin]
118    sleep: ClockSleep,
119    completed: bool,
120}
121
122impl<F> ClockTimeout<F> {
123    pub(crate) fn new(clock_inner: &ClockInner, duration: Duration, future: F) -> Self {
124        Self {
125            future,
126            sleep: ClockSleep::new(clock_inner, duration),
127            completed: false,
128        }
129    }
130}
131
132impl<F: Future> Future for ClockTimeout<F> {
133    type Output = Result<F::Output, Elapsed>;
134
135    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136        let this = self.project();
137
138        if *this.completed {
139            panic!("ClockTimeout polled after completion");
140        }
141
142        // Check the future first
143        if let Poll::Ready(output) = this.future.poll(cx) {
144            *this.completed = true;
145            return Poll::Ready(Ok(output));
146        }
147
148        // Check if timeout elapsed
149        if let Poll::Ready(()) = this.sleep.poll(cx) {
150            *this.completed = true;
151            return Poll::Ready(Err(Elapsed));
152        }
153
154        Poll::Pending
155    }
156}
157
158/// Error returned when a timeout expires.
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub struct Elapsed;
161
162impl std::fmt::Display for Elapsed {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        write!(f, "deadline has elapsed")
165    }
166}
167
168impl std::error::Error for Elapsed {}