Skip to main content

arc_malachitebft_engine/util/
timers.rs

1use core::fmt;
2use std::collections::hash_map::Entry;
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::RangeFrom;
7use std::sync::Arc;
8use std::time::Duration;
9
10use tokio::task::JoinHandle;
11use tracing::trace;
12
13use super::output_port::{OutputPort, OutputPortSubscriber};
14
15#[derive(Debug)]
16struct Timer<Key> {
17    /// Message to give to the actor when the timer expires
18    key: Key,
19
20    // Task that will notify the actor that the timer has elapsed
21    task: JoinHandle<()>,
22
23    /// Generation counter to the timer to check if we received a timeout
24    /// message from an old timer that was enqueued in mailbox before canceled
25    generation: u64,
26}
27
28#[derive(Copy, Clone, Debug, PartialEq, Eq)]
29pub struct TimeoutElapsed<Key> {
30    key: Key,
31    generation: u64,
32}
33
34impl<Key> TimeoutElapsed<Key> {
35    pub fn display_key(&self) -> &dyn fmt::Display
36    where
37        Key: fmt::Display,
38    {
39        &self.key
40    }
41}
42
43pub struct TimerScheduler<Key>
44where
45    Key: Clone + Eq + Hash + Send + 'static,
46{
47    output_port: Arc<OutputPort<TimeoutElapsed<Key>>>,
48    timers: HashMap<Key, Timer<Key>>,
49    generations: RangeFrom<u64>,
50}
51
52impl<Key> TimerScheduler<Key>
53where
54    Key: Clone + Eq + Hash + Send + 'static,
55{
56    pub fn new(subscriber: OutputPortSubscriber<TimeoutElapsed<Key>>) -> Self {
57        let output_port = OutputPort::with_capacity(32);
58        subscriber.subscribe_to_port(&output_port);
59
60        Self {
61            output_port: Arc::new(output_port),
62            timers: HashMap::new(),
63            generations: 1..,
64        }
65    }
66
67    /// Start a timer that will send `msg` once to the actor after the given `timeout`.
68    ///
69    /// Each timer has a key and if a new timer with same key is started
70    /// the previous is cancelled.
71    ///
72    /// # Warning
73    /// It is NOT guaranteed that a message from the previous timer is not received,
74    /// as it could already be enqueued in the mailbox when the new timer was started.
75    ///
76    /// When the actor receives a timeout message for timer from the scheduler, it should
77    /// check if the timer is still active by calling [`TimerScheduler::intercept_timer_msg`]
78    /// and ignore the message otherwise.
79    pub fn start_timer(&mut self, key: Key, timeout: Duration)
80    where
81        Key: Clone + Send + 'static,
82    {
83        self.cancel(&key);
84
85        let generation = self
86            .generations
87            .next()
88            .expect("generation counter overflowed");
89
90        let task = {
91            let key = key.clone();
92            let output_port = Arc::clone(&self.output_port);
93
94            tokio::spawn(async move {
95                tokio::time::sleep(timeout).await;
96                output_port.send(TimeoutElapsed { key, generation })
97            })
98        };
99
100        self.timers.insert(
101            key.clone(),
102            Timer {
103                key,
104                task,
105                generation,
106            },
107        );
108    }
109
110    /// Check if a timer with a given `key` is active, ie. it hasn't been canceled nor has it elapsed yet.
111    pub fn is_timer_active(&self, key: &Key) -> bool {
112        self.timers.contains_key(key)
113    }
114
115    /// Cancel a timer with a given `key`.
116    ///
117    /// If canceling a timer that was already canceled, or key never was used to start a timer
118    /// this operation will do nothing.
119    ///
120    /// # Warning
121    /// It is NOT guaranteed that a message from a canceled timer, including its previous incarnation
122    /// for the same key, will not be received by the actor, as the message might already
123    /// be enqueued in the mailbox when cancel is called.
124    ///
125    /// When the actor receives a timeout message for timer from the scheduler, it should
126    /// check if the timer is still active by calling [`TimerScheduler::intercept_timer_msg`]
127    /// and ignore the message otherwise.
128    pub fn cancel(&mut self, key: &Key) {
129        if let Some(timer) = self.timers.remove(key) {
130            timer.task.abort();
131        }
132    }
133
134    /// Cancel all timers.
135    pub fn cancel_all(&mut self) {
136        self.timers.drain().for_each(|(_, timer)| {
137            timer.task.abort();
138        });
139    }
140
141    /// Intercepts a timer message and checks the state of the timer associated with the provided `timer_msg`:
142    ///
143    /// 1. If the timer message was from canceled timer that was already enqueued in mailbox, returns `None`.
144    /// 2. If the timer message was from an old timer that was enqueued in mailbox before being canceled, returns `None`.
145    /// 3. Otherwise it is a valid timer message, returns the associated `Key` wrapped in `Some`.
146    pub fn intercept_timer_msg(&mut self, timer_msg: TimeoutElapsed<Key>) -> Option<Key>
147    where
148        Key: Debug,
149    {
150        match self.timers.entry(timer_msg.key) {
151            // The timer message was from canceled timer that was already enqueued in mailbox
152            Entry::Vacant(entry) => {
153                let key = entry.key();
154                trace!("Received timer {key:?} that has been removed, discarding");
155                None
156            }
157
158            // The timer message was from an old timer that was enqueued in mailbox before being canceled
159            Entry::Occupied(entry) if timer_msg.generation != entry.get().generation => {
160                let (key, timer) = (entry.key(), entry.get());
161
162                trace!(
163                    "Received timer {key:?} from old generation {}, expected generation {}, discarding",
164                    timer_msg.generation,
165                    timer.generation,
166                );
167
168                None
169            }
170
171            // Valid timer message
172            Entry::Occupied(entry) => {
173                let timer = entry.remove();
174                Some(timer.key)
175            }
176        }
177    }
178}
179
180impl<Key> Drop for TimerScheduler<Key>
181where
182    Key: Clone + Eq + Hash + Send + 'static,
183{
184    fn drop(&mut self) {
185        self.cancel_all();
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    use ractor::{Actor, ActorRef};
194    use std::time::Duration;
195    use tokio::time::sleep;
196
197    #[derive(Copy, Debug, Clone, PartialEq, Eq, Hash)]
198    struct TestKey(&'static str);
199
200    #[derive(Debug)]
201    struct TestMsg(TimeoutElapsed<TestKey>);
202
203    impl From<TimeoutElapsed<TestKey>> for TestMsg {
204        fn from(timer_msg: TimeoutElapsed<TestKey>) -> Self {
205            TestMsg(timer_msg)
206        }
207    }
208
209    struct TestActor;
210
211    #[async_trait::async_trait]
212    impl Actor for TestActor {
213        type State = ();
214        type Arguments = ();
215        type Msg = TestMsg;
216
217        async fn pre_start(
218            &self,
219            _myself: ActorRef<TestMsg>,
220            _args: (),
221        ) -> Result<(), ractor::ActorProcessingErr> {
222            Ok(())
223        }
224
225        async fn handle(
226            &self,
227            _myself: ActorRef<TestMsg>,
228            TestMsg(elapsed): TestMsg,
229            _state: &mut (),
230        ) -> Result<(), ractor::ActorProcessingErr> {
231            println!("Received timer message: {elapsed:?}");
232            Ok(())
233        }
234    }
235
236    async fn spawn() -> TimerScheduler<TestKey> {
237        let actor_ref = TestActor::spawn(None, TestActor, ()).await.unwrap().0;
238        // let subscriber: OutputPortSubscriber<TimeoutElapsed<TestKey>> = ;
239        TimerScheduler::new(Box::new(actor_ref))
240    }
241
242    #[tokio::test]
243    async fn test_start_timer() {
244        let mut scheduler = spawn().await;
245        let key = TestKey("timer1");
246
247        scheduler.start_timer(key, Duration::from_millis(100));
248        assert!(scheduler.is_timer_active(&key));
249
250        sleep(Duration::from_millis(150)).await;
251        let elapsed_key = scheduler.intercept_timer_msg(TimeoutElapsed { key, generation: 1 });
252        assert_eq!(elapsed_key, Some(key));
253
254        assert!(!scheduler.is_timer_active(&key));
255    }
256
257    #[tokio::test]
258    async fn test_cancel_timer() {
259        let mut scheduler = spawn().await;
260        let key = TestKey("timer1");
261
262        scheduler.start_timer(key, Duration::from_millis(100));
263        scheduler.cancel(&key);
264
265        assert!(!scheduler.is_timer_active(&key));
266    }
267
268    #[tokio::test]
269    async fn test_cancel_all_timers() {
270        let mut scheduler = spawn().await;
271
272        scheduler.start_timer(TestKey("timer1"), Duration::from_millis(100));
273        scheduler.start_timer(TestKey("timer2"), Duration::from_millis(200));
274
275        scheduler.cancel_all();
276
277        assert!(!scheduler.is_timer_active(&TestKey("timer1")));
278        assert!(!scheduler.is_timer_active(&TestKey("timer2")));
279    }
280
281    #[tokio::test]
282    async fn test_intercept_timer_msg_valid() {
283        let mut scheduler = spawn().await;
284        let key = TestKey("timer1");
285
286        scheduler.start_timer(key, Duration::from_millis(100));
287        sleep(Duration::from_millis(150)).await;
288
289        let timer_msg = TimeoutElapsed { key, generation: 1 };
290
291        let intercepted_msg = scheduler.intercept_timer_msg(timer_msg);
292
293        assert_eq!(intercepted_msg, Some(key));
294    }
295
296    #[tokio::test]
297    async fn test_intercept_timer_msg_invalid_generation() {
298        let mut scheduler = spawn().await;
299        let key = TestKey("timer1");
300
301        scheduler.start_timer(key, Duration::from_millis(100));
302        scheduler.start_timer(key, Duration::from_millis(200));
303
304        let timer_msg = TimeoutElapsed { key, generation: 1 };
305
306        let intercepted_msg = scheduler.intercept_timer_msg(timer_msg);
307
308        assert_eq!(intercepted_msg, None);
309    }
310
311    #[tokio::test]
312    async fn test_intercept_timer_msg_cancelled() {
313        let mut scheduler = spawn().await;
314        let key = TestKey("timer1");
315
316        scheduler.start_timer(key, Duration::from_millis(100));
317        scheduler.cancel(&key);
318
319        let timer_msg = TimeoutElapsed { key, generation: 1 };
320
321        let intercepted_msg = scheduler.intercept_timer_msg(timer_msg);
322
323        assert_eq!(intercepted_msg, None);
324    }
325}