Skip to main content

elfo_core/time/
interval.rs

1use std::{
2    any::Any,
3    future::Future,
4    pin::Pin,
5    task::{self, Poll},
6};
7
8use pin_project::pin_project;
9use sealed::sealed;
10use tokio::time::{Duration, Instant, Sleep};
11
12use crate::{
13    envelope::{Envelope, MessageKind},
14    message::Message,
15    source::{SourceArc, SourceStream, UnattachedSource},
16    time::far_future,
17    tracing::TraceId,
18    Addr,
19};
20
21/// A source that emits messages periodically.
22/// Clones the message on every tick.
23///
24/// # Tracing
25///
26/// Every message starts a new trace, thus a new trace id is generated and
27/// assigned to the current scope.
28///
29/// # Example
30///
31/// ```
32/// # use std::time::Duration;
33/// # use elfo_core as elfo;
34/// # struct Config { period: Duration }
35/// # async fn exec(mut ctx: elfo::Context<Config>) {
36/// # use elfo::{message, msg};
37/// use elfo::{time::Interval, messages::ConfigUpdated};
38///
39/// #[message]
40/// struct MyTick;
41///
42/// let interval = ctx.attach(Interval::new(MyTick));
43/// interval.start(ctx.config().period);
44///
45/// while let Some(envelope) = ctx.recv().await {
46///     msg!(match envelope {
47///         ConfigUpdated => {
48///             interval.set_period(ctx.config().period);
49///         },
50///         MyTick => {
51///             tracing::info!("tick!");
52///         },
53///     });
54/// }
55/// # }
56/// ```
57pub struct Interval<M> {
58    source: SourceArc<IntervalSource<M>>,
59}
60
61#[sealed]
62impl<M: Message> crate::source::SourceHandle for Interval<M> {
63    fn is_terminated(&self) -> bool {
64        self.source.is_terminated()
65    }
66
67    fn terminate_by_ref(&self) -> bool {
68        self.source.terminate_by_ref()
69    }
70}
71
72const NEVER: Duration = Duration::ZERO;
73
74#[pin_project]
75struct IntervalSource<M> {
76    message: M,
77    period: Duration,
78    is_delayed: bool,
79    #[pin]
80    sleep: Sleep,
81}
82
83impl<M: Message> Interval<M> {
84    /// Creates an unattached instance of [`Interval`].
85    pub fn new(message: M) -> UnattachedSource<Self> {
86        let source = IntervalSource {
87            message,
88            period: NEVER,
89            is_delayed: false,
90            sleep: tokio::time::sleep_until(far_future()),
91        };
92
93        let source = SourceArc::new(source, false);
94        UnattachedSource::new(source, |source| Self { source })
95    }
96
97    /// Replaces a stored message with the provided one.
98    pub fn set_message(&self, message: M) {
99        let mut guard = ward!(self.source.lock());
100        *guard.stream().project().message = message;
101    }
102
103    // TODO: pub fn set_missed_tick_policy
104
105    /// Configures the period of ticks. Intended to be called on
106    /// `ConfigUpdated`.
107    ///
108    /// Does nothing if the timer is not started or the period hasn't been
109    /// changed.
110    ///
111    /// Unlike rescheduling (`start_*` methods), it only adjusts the current
112    /// period and doesn't change the time origin. For instance, if we have
113    /// a configured interval with period = 5s and try to call one of these
114    /// methods, the difference looks something like this:
115    ///
116    /// ```text
117    /// set_period(10s): | 5s | 5s | 5s |  # 10s  |   10s   |
118    /// start(10s):      | 5s | 5s | 5s |  #   10s   |   10s   |
119    ///                                    #
120    ///                               called here
121    /// ```
122    ///
123    /// # Panics
124    ///
125    /// If `period` is zero.
126    #[track_caller]
127    pub fn set_period(&self, period: Duration) {
128        assert_ne!(period, NEVER, "period must be non-zero");
129
130        let mut guard = ward!(self.source.lock());
131        let source = guard.stream().project();
132
133        // Do nothing if not started or the period hasn't been changed.
134        if *source.period == NEVER || period == *source.period {
135            return;
136        }
137
138        // Reschedule if inside the period.
139        if !*source.is_delayed {
140            let new_deadline = source.sleep.deadline() - *source.period + period;
141            source.sleep.reset(new_deadline);
142            *source.period = period;
143            guard.wake();
144        } else {
145            *source.period = period;
146        }
147    }
148
149    /// Schedules the timer to start emitting ticks every `period`.
150    /// The first tick will be emitted also after `period`.
151    ///
152    /// Reschedules the timer if it's already started.
153    ///
154    /// # Panics
155    ///
156    /// If `period` is zero.
157    #[track_caller]
158    pub fn start(&self, period: Duration) {
159        assert_ne!(period, NEVER, "period must be non-zero");
160        self.schedule(None, period);
161    }
162
163    /// Schedules the timer to start emitting ticks every `period`.
164    /// The first tick will be emitted after `delay`.
165    ///
166    /// Reschedules the timer if it's already started.
167    ///
168    /// # Panics
169    ///
170    /// If `period` is zero.
171    #[track_caller]
172    pub fn start_after(&self, delay: Duration, period: Duration) {
173        self.start_at(Instant::now() + delay, period);
174    }
175
176    /// Schedules the timer to start emitting ticks every `period`.
177    /// The first tick will be emitted at `when`.
178    ///
179    /// Reschedules the timer if it's already started.
180    ///
181    /// # Panics
182    ///
183    /// If `period` is zero.
184    ///
185    /// # Stability
186    ///
187    /// This method is unstable, because it accepts [`tokio::time::Instant`],
188    /// which will be replaced in the future to support other runtimes.
189    #[instability::unstable]
190    #[track_caller]
191    pub fn start_at(&self, when: Instant, period: Duration) {
192        assert_ne!(period, NEVER, "period must be non-zero");
193        self.schedule(Some(when), period);
194    }
195
196    /// Stops any ticks. To resume ticks use one of `start_*` methods.
197    ///
198    /// Note: it doesn't terminates the source. It means the source is present
199    /// in the source map until [`SourceHandle::terminate()`] is called.
200    ///
201    /// [`SourceHandle::terminate()`]: crate::SourceHandle::terminate()
202    pub fn stop(&self) {
203        self.schedule(Some(far_future()), NEVER);
204    }
205
206    fn schedule(&self, when: Option<Instant>, period: Duration) {
207        let mut guard = ward!(self.source.lock());
208        let source = guard.stream().project();
209
210        *source.is_delayed = when.is_some();
211        *source.period = period;
212
213        let new_deadline = when.unwrap_or_else(|| Instant::now() + period);
214        source.sleep.reset(new_deadline);
215        guard.wake();
216    }
217}
218
219impl<M: Message> SourceStream for IntervalSource<M> {
220    fn as_any_mut(self: Pin<&mut Self>) -> Pin<&mut dyn Any> {
221        // SAFETY: we only cast here, it cannot move data.
222        unsafe { self.map_unchecked_mut(|s| s) }
223    }
224
225    fn poll_recv(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>> {
226        let mut this = self.project();
227
228        // Do nothing if stopped or not configured.
229        if *this.period == NEVER {
230            return Poll::Pending;
231        }
232
233        // Wait for a tick from implementation.
234        if !this.sleep.as_mut().poll(cx).is_ready() {
235            return Poll::Pending;
236        }
237
238        // After first tick, the interval isn't delayed even if it was.
239        *this.is_delayed = false;
240
241        // Reset the underlying timer.
242        // It would be nice to use `reset_without_reregister` here, but it's private.
243        // TODO: consider moving to `tokio::time::Interval`, which uses it internally.
244        let new_deadline = this.sleep.deadline() + *this.period;
245        this.sleep.reset(new_deadline);
246
247        // Emit the message.
248        let message = this.message.clone();
249        let kind = MessageKind::regular(Addr::NULL);
250        let trace_id = TraceId::generate();
251        let envelope = Envelope::with_trace_id(message, kind, trace_id);
252
253        Poll::Ready(Some(envelope))
254    }
255}