1use std::{
15 cmp::Ordering as ComparisonOrdering,
16 collections::BinaryHeap,
17 sync::{
18 atomic::{AtomicBool, Ordering as AtomicOrdering},
19 Arc, Weak,
20 },
21 time::{Duration, Instant},
22};
23
24use async_trait::async_trait;
25use flume::{bounded, Receiver, RecvError, Sender};
26use tokio::{runtime::Handle, select, sync::Mutex, task, time};
27use zenoh_core::zconfigurable;
28
29zconfigurable! {
30 static ref TIMER_EVENTS_CHANNEL_SIZE: usize = 1;
31}
32
33#[async_trait]
34pub trait Timed {
35 async fn run(&mut self);
36}
37
38type TimedFuture = Arc<dyn Timed + Send + Sync>;
39
40#[derive(Clone)]
41pub struct TimedHandle(Weak<AtomicBool>);
42
43impl TimedHandle {
44 pub fn defuse(self) {
45 if let Some(arc) = self.0.upgrade() {
46 arc.store(false, AtomicOrdering::Release);
47 }
48 }
49}
50
51#[derive(Clone)]
52pub struct TimedEvent {
53 when: Instant,
54 period: Option<Duration>,
55 future: TimedFuture,
56 fused: Arc<AtomicBool>,
57}
58
59impl TimedEvent {
60 pub fn once(when: Instant, event: impl Timed + Send + Sync + 'static) -> TimedEvent {
61 TimedEvent {
62 when,
63 period: None,
64 future: Arc::new(event),
65 fused: Arc::new(AtomicBool::new(true)),
66 }
67 }
68
69 pub fn periodic(interval: Duration, event: impl Timed + Send + Sync + 'static) -> TimedEvent {
70 TimedEvent {
71 when: Instant::now() + interval,
72 period: Some(interval),
73 future: Arc::new(event),
74 fused: Arc::new(AtomicBool::new(true)),
75 }
76 }
77
78 pub fn is_fused(&self) -> bool {
79 self.fused.load(AtomicOrdering::Acquire)
80 }
81
82 pub fn get_handle(&self) -> TimedHandle {
83 TimedHandle(Arc::downgrade(&self.fused))
84 }
85}
86
87impl Eq for TimedEvent {}
88
89impl Ord for TimedEvent {
90 fn cmp(&self, other: &Self) -> ComparisonOrdering {
91 other.when.cmp(&self.when)
97 }
98}
99
100impl PartialOrd for TimedEvent {
101 fn partial_cmp(&self, other: &Self) -> Option<ComparisonOrdering> {
102 Some(self.cmp(other))
103 }
104}
105
106impl PartialEq for TimedEvent {
107 fn eq(&self, other: &Self) -> bool {
108 self.when == other.when
109 }
110}
111
112async fn timer_task(
113 events: Arc<Mutex<BinaryHeap<TimedEvent>>>,
114 new_event: Receiver<(bool, TimedEvent)>,
115) -> Result<(), RecvError> {
116 let e = "Timer has been dropped. Unable to run timed events.";
118
119 let mut events = events.lock().await;
121
122 loop {
123 let new = new_event.recv_async();
125
126 match events.peek() {
127 Some(next) => {
128 let wait = async {
130 let next = next.clone();
131 let now = Instant::now();
132 if next.when > now {
133 time::sleep(next.when - now).await;
134 }
135 Ok((false, next))
136 };
137
138 let result = select! {
139 result = wait => { result },
140 result = new => { result },
141 };
142
143 match result {
144 Ok((is_new, mut ev)) => {
145 if is_new {
146 events.push(ev);
148 continue;
149 }
150
151 let _ = events.pop();
153
154 if ev.is_fused() {
156 Arc::get_mut(&mut ev.future).unwrap().run().await;
159
160 if let Some(interval) = ev.period {
162 ev.when = Instant::now() + interval;
163 events.push(ev);
164 }
165 }
166 }
167 Err(_) => {
168 tracing::trace!("{}", e);
170 return Ok(());
171 }
172 }
173 }
174 None => match new.await {
175 Ok((_, ev)) => {
176 events.push(ev);
177 continue;
178 }
179 Err(_) => {
180 tracing::trace!("{}", e);
182 return Ok(());
183 }
184 },
185 }
186 }
187}
188
189#[derive(Clone)]
190pub struct Timer {
191 events: Arc<Mutex<BinaryHeap<TimedEvent>>>,
192 sl_sender: Option<Sender<()>>,
193 ev_sender: Option<Sender<(bool, TimedEvent)>>,
194}
195
196impl Timer {
197 pub fn new(spawn_blocking: bool) -> Timer {
198 let (ev_sender, ev_receiver) = bounded::<(bool, TimedEvent)>(*TIMER_EVENTS_CHANNEL_SIZE);
200 let (sl_sender, sl_receiver) = bounded::<()>(1);
201
202 let timer = Timer {
204 events: Arc::new(Mutex::new(BinaryHeap::new())),
205 sl_sender: Some(sl_sender),
206 ev_sender: Some(ev_sender),
207 };
208
209 let c_e = timer.events.clone();
211 let fut = async move {
212 select! {
213 _ = sl_receiver.recv_async() => {},
214 _ = timer_task(c_e, ev_receiver) => {},
215 };
216 tracing::trace!("A - Timer task no longer running...");
217 };
218 if spawn_blocking {
219 task::spawn_blocking(|| Handle::current().block_on(fut));
220 } else {
221 task::spawn(fut);
222 }
223
224 timer
226 }
227
228 pub fn start(&mut self, spawn_blocking: bool) {
229 if self.sl_sender.is_none() {
230 let (ev_sender, ev_receiver) =
232 bounded::<(bool, TimedEvent)>(*TIMER_EVENTS_CHANNEL_SIZE);
233 let (sl_sender, sl_receiver) = bounded::<()>(1);
234
235 self.sl_sender = Some(sl_sender);
237 self.ev_sender = Some(ev_sender);
238
239 let c_e = self.events.clone();
241 let fut = async move {
242 select! {
243 _ = sl_receiver.recv_async() => {},
244 _ = timer_task(c_e, ev_receiver) => {},
245 };
246 tracing::trace!("A - Timer task no longer running...");
247 };
248 if spawn_blocking {
249 task::spawn_blocking(|| Handle::current().block_on(fut));
250 } else {
251 task::spawn(fut);
252 }
253 }
254 }
255
256 #[inline]
257 pub async fn start_async(&mut self, spawn_blocking: bool) {
258 self.start(spawn_blocking)
259 }
260
261 pub fn stop(&mut self) {
262 if let Some(sl_sender) = &self.sl_sender {
263 let _ = sl_sender.send(());
265
266 tracing::trace!("Stopping timer...");
267 self.sl_sender = None;
269 self.ev_sender = None;
270 }
271 }
272
273 pub async fn stop_async(&mut self) {
274 if let Some(sl_sender) = &self.sl_sender {
275 let _ = sl_sender.send_async(()).await;
277
278 tracing::trace!("Stopping timer...");
279 self.sl_sender = None;
281 self.ev_sender = None;
282 }
283 }
284
285 pub fn add(&self, event: TimedEvent) {
286 if let Some(ev_sender) = &self.ev_sender {
287 let _ = ev_sender.send((true, event));
288 }
289 }
290
291 pub async fn add_async(&self, event: TimedEvent) {
292 if let Some(ev_sender) = &self.ev_sender {
293 let _ = ev_sender.send_async((true, event)).await;
294 }
295 }
296}
297
298impl Default for Timer {
299 fn default() -> Self {
300 Self::new(false)
301 }
302}
303
304mod tests {
305 #[test]
306 fn timer() {
307 use std::{
308 sync::{
309 atomic::{AtomicUsize, Ordering},
310 Arc,
311 },
312 time::{Duration, Instant},
313 };
314
315 use async_trait::async_trait;
316 use tokio::{runtime::Runtime, time};
317
318 use super::{Timed, TimedEvent, Timer};
319
320 #[derive(Clone)]
321 struct MyEvent {
322 counter: Arc<AtomicUsize>,
323 }
324
325 #[async_trait]
326 impl Timed for MyEvent {
327 async fn run(&mut self) {
328 self.counter.fetch_add(1, Ordering::SeqCst);
329 }
330 }
331
332 async fn run() {
333 let mut timer = Timer::new(false);
335
336 let counter = Arc::new(AtomicUsize::new(0));
338
339 let myev = MyEvent {
341 counter: counter.clone(),
342 };
343
344 let interval = Duration::from_secs(1);
346
347 println!("Timer [1]: Once event and run");
349 let now = Instant::now();
351 let event = TimedEvent::once(now + (2 * interval), myev.clone());
352
353 timer.add_async(event).await;
355
356 time::sleep(3 * interval).await;
358
359 let value = counter.swap(0, Ordering::SeqCst);
361 assert_eq!(value, 1);
362
363 println!("Timer [2]: Once event and defuse");
365 let now = Instant::now();
367 let event = TimedEvent::once(now + (2 * interval), myev.clone());
368 let handle = event.get_handle();
369
370 timer.add_async(event).await;
372 handle.defuse();
374
375 time::sleep(3 * interval).await;
377
378 let value = counter.swap(0, Ordering::SeqCst);
380 assert_eq!(value, 0);
381
382 println!("Timer [3]: Periodic event run and defuse");
384 let amount: usize = 3;
386
387 let to_elapse = (2 * amount as u32) * interval;
389
390 let event = TimedEvent::periodic(2 * interval, myev.clone());
392 let handle = event.get_handle();
393
394 timer.add_async(event).await;
396
397 time::sleep(to_elapse + interval).await;
399
400 let value = counter.swap(0, Ordering::SeqCst);
402 assert_eq!(value, amount);
403
404 handle.clone().defuse();
406 handle.defuse();
407
408 time::sleep(to_elapse).await;
410
411 let value = counter.swap(0, Ordering::SeqCst);
413 assert_eq!(value, 0);
414
415 println!("Timer [4]: Periodic event and stop/start timer");
417 let event = TimedEvent::periodic(2 * interval, myev);
419
420 timer.add_async(event).await;
422
423 time::sleep(to_elapse + interval).await;
425
426 let value = counter.swap(0, Ordering::SeqCst);
428 assert_eq!(value, amount);
429
430 timer.stop_async().await;
432
433 time::sleep(to_elapse).await;
435
436 let value = counter.swap(0, Ordering::SeqCst);
438 assert_eq!(value, 0);
439
440 timer.start_async(false).await;
442
443 time::sleep(to_elapse).await;
445
446 let value = counter.swap(0, Ordering::SeqCst);
448 assert_eq!(value, amount);
449 }
450
451 let rt = Runtime::new().unwrap();
452 rt.block_on(run());
453 }
454}