arc_malachitebft_engine/util/
timers.rs1use core::fmt;
2use std::collections::hash_map::Entry;
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::ops::RangeFrom;
7use std::sync::Arc;
8use std::time::Duration;
9
10use tokio::task::JoinHandle;
11use tracing::trace;
12
13use super::output_port::{OutputPort, OutputPortSubscriber};
14
15#[derive(Debug)]
16struct Timer<Key> {
17 key: Key,
19
20 task: JoinHandle<()>,
22
23 generation: u64,
26}
27
28#[derive(Copy, Clone, Debug, PartialEq, Eq)]
29pub struct TimeoutElapsed<Key> {
30 key: Key,
31 generation: u64,
32}
33
34impl<Key> TimeoutElapsed<Key> {
35 pub fn display_key(&self) -> &dyn fmt::Display
36 where
37 Key: fmt::Display,
38 {
39 &self.key
40 }
41}
42
43pub struct TimerScheduler<Key>
44where
45 Key: Clone + Eq + Hash + Send + 'static,
46{
47 output_port: Arc<OutputPort<TimeoutElapsed<Key>>>,
48 timers: HashMap<Key, Timer<Key>>,
49 generations: RangeFrom<u64>,
50}
51
52impl<Key> TimerScheduler<Key>
53where
54 Key: Clone + Eq + Hash + Send + 'static,
55{
56 pub fn new(subscriber: OutputPortSubscriber<TimeoutElapsed<Key>>) -> Self {
57 let output_port = OutputPort::with_capacity(32);
58 subscriber.subscribe_to_port(&output_port);
59
60 Self {
61 output_port: Arc::new(output_port),
62 timers: HashMap::new(),
63 generations: 1..,
64 }
65 }
66
67 pub fn start_timer(&mut self, key: Key, timeout: Duration)
80 where
81 Key: Clone + Send + 'static,
82 {
83 self.cancel(&key);
84
85 let generation = self
86 .generations
87 .next()
88 .expect("generation counter overflowed");
89
90 let task = {
91 let key = key.clone();
92 let output_port = Arc::clone(&self.output_port);
93
94 tokio::spawn(async move {
95 tokio::time::sleep(timeout).await;
96 output_port.send(TimeoutElapsed { key, generation })
97 })
98 };
99
100 self.timers.insert(
101 key.clone(),
102 Timer {
103 key,
104 task,
105 generation,
106 },
107 );
108 }
109
110 pub fn is_timer_active(&self, key: &Key) -> bool {
112 self.timers.contains_key(key)
113 }
114
115 pub fn cancel(&mut self, key: &Key) {
129 if let Some(timer) = self.timers.remove(key) {
130 timer.task.abort();
131 }
132 }
133
134 pub fn cancel_all(&mut self) {
136 self.timers.drain().for_each(|(_, timer)| {
137 timer.task.abort();
138 });
139 }
140
141 pub fn intercept_timer_msg(&mut self, timer_msg: TimeoutElapsed<Key>) -> Option<Key>
147 where
148 Key: Debug,
149 {
150 match self.timers.entry(timer_msg.key) {
151 Entry::Vacant(entry) => {
153 let key = entry.key();
154 trace!("Received timer {key:?} that has been removed, discarding");
155 None
156 }
157
158 Entry::Occupied(entry) if timer_msg.generation != entry.get().generation => {
160 let (key, timer) = (entry.key(), entry.get());
161
162 trace!(
163 "Received timer {key:?} from old generation {}, expected generation {}, discarding",
164 timer_msg.generation,
165 timer.generation,
166 );
167
168 None
169 }
170
171 Entry::Occupied(entry) => {
173 let timer = entry.remove();
174 Some(timer.key)
175 }
176 }
177 }
178}
179
180impl<Key> Drop for TimerScheduler<Key>
181where
182 Key: Clone + Eq + Hash + Send + 'static,
183{
184 fn drop(&mut self) {
185 self.cancel_all();
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 use ractor::{Actor, ActorRef};
194 use std::time::Duration;
195 use tokio::time::sleep;
196
197 #[derive(Copy, Debug, Clone, PartialEq, Eq, Hash)]
198 struct TestKey(&'static str);
199
200 #[derive(Debug)]
201 struct TestMsg(TimeoutElapsed<TestKey>);
202
203 impl From<TimeoutElapsed<TestKey>> for TestMsg {
204 fn from(timer_msg: TimeoutElapsed<TestKey>) -> Self {
205 TestMsg(timer_msg)
206 }
207 }
208
209 struct TestActor;
210
211 #[async_trait::async_trait]
212 impl Actor for TestActor {
213 type State = ();
214 type Arguments = ();
215 type Msg = TestMsg;
216
217 async fn pre_start(
218 &self,
219 _myself: ActorRef<TestMsg>,
220 _args: (),
221 ) -> Result<(), ractor::ActorProcessingErr> {
222 Ok(())
223 }
224
225 async fn handle(
226 &self,
227 _myself: ActorRef<TestMsg>,
228 TestMsg(elapsed): TestMsg,
229 _state: &mut (),
230 ) -> Result<(), ractor::ActorProcessingErr> {
231 println!("Received timer message: {elapsed:?}");
232 Ok(())
233 }
234 }
235
236 async fn spawn() -> TimerScheduler<TestKey> {
237 let actor_ref = TestActor::spawn(None, TestActor, ()).await.unwrap().0;
238 TimerScheduler::new(Box::new(actor_ref))
240 }
241
242 #[tokio::test]
243 async fn test_start_timer() {
244 let mut scheduler = spawn().await;
245 let key = TestKey("timer1");
246
247 scheduler.start_timer(key, Duration::from_millis(100));
248 assert!(scheduler.is_timer_active(&key));
249
250 sleep(Duration::from_millis(150)).await;
251 let elapsed_key = scheduler.intercept_timer_msg(TimeoutElapsed { key, generation: 1 });
252 assert_eq!(elapsed_key, Some(key));
253
254 assert!(!scheduler.is_timer_active(&key));
255 }
256
257 #[tokio::test]
258 async fn test_cancel_timer() {
259 let mut scheduler = spawn().await;
260 let key = TestKey("timer1");
261
262 scheduler.start_timer(key, Duration::from_millis(100));
263 scheduler.cancel(&key);
264
265 assert!(!scheduler.is_timer_active(&key));
266 }
267
268 #[tokio::test]
269 async fn test_cancel_all_timers() {
270 let mut scheduler = spawn().await;
271
272 scheduler.start_timer(TestKey("timer1"), Duration::from_millis(100));
273 scheduler.start_timer(TestKey("timer2"), Duration::from_millis(200));
274
275 scheduler.cancel_all();
276
277 assert!(!scheduler.is_timer_active(&TestKey("timer1")));
278 assert!(!scheduler.is_timer_active(&TestKey("timer2")));
279 }
280
281 #[tokio::test]
282 async fn test_intercept_timer_msg_valid() {
283 let mut scheduler = spawn().await;
284 let key = TestKey("timer1");
285
286 scheduler.start_timer(key, Duration::from_millis(100));
287 sleep(Duration::from_millis(150)).await;
288
289 let timer_msg = TimeoutElapsed { key, generation: 1 };
290
291 let intercepted_msg = scheduler.intercept_timer_msg(timer_msg);
292
293 assert_eq!(intercepted_msg, Some(key));
294 }
295
296 #[tokio::test]
297 async fn test_intercept_timer_msg_invalid_generation() {
298 let mut scheduler = spawn().await;
299 let key = TestKey("timer1");
300
301 scheduler.start_timer(key, Duration::from_millis(100));
302 scheduler.start_timer(key, Duration::from_millis(200));
303
304 let timer_msg = TimeoutElapsed { key, generation: 1 };
305
306 let intercepted_msg = scheduler.intercept_timer_msg(timer_msg);
307
308 assert_eq!(intercepted_msg, None);
309 }
310
311 #[tokio::test]
312 async fn test_intercept_timer_msg_cancelled() {
313 let mut scheduler = spawn().await;
314 let key = TestKey("timer1");
315
316 scheduler.start_timer(key, Duration::from_millis(100));
317 scheduler.cancel(&key);
318
319 let timer_msg = TimeoutElapsed { key, generation: 1 };
320
321 let intercepted_msg = scheduler.intercept_timer_msg(timer_msg);
322
323 assert_eq!(intercepted_msg, None);
324 }
325}