reifydb_runtime/actor/timers/
scheduler.rs1#![allow(clippy::disallowed_methods)]
5
6use std::{
7 cmp::Ordering as CmpOrdering,
8 collections::BinaryHeap,
9 sync::{
10 Arc,
11 atomic::{AtomicBool, Ordering},
12 },
13 thread::{self, JoinHandle},
14 time::{Duration, Instant},
15};
16
17use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, bounded};
18use rayon::ThreadPool;
19
20use super::{TimerHandle, next_timer_id};
21
22struct TimerEntry {
23 id: u64,
25 deadline: Instant,
27 kind: TimerKind,
29 cancelled: Arc<AtomicBool>,
31}
32
33enum TimerKind {
34 Once {
36 callback: Box<dyn FnOnce() + Send>,
37 },
38 Repeat {
40 callback: Arc<dyn Fn() -> bool + Send + Sync>,
41 interval: Duration,
42 },
43}
44
45impl Eq for TimerEntry {}
46
47impl PartialEq for TimerEntry {
48 fn eq(&self, other: &Self) -> bool {
49 self.deadline == other.deadline && self.id == other.id
50 }
51}
52
53impl Ord for TimerEntry {
54 fn cmp(&self, other: &Self) -> CmpOrdering {
56 other.deadline.cmp(&self.deadline).then_with(|| other.id.cmp(&self.id))
58 }
59}
60
61impl PartialOrd for TimerEntry {
62 fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
63 Some(self.cmp(other))
64 }
65}
66
67enum SchedulerCommand {
69 ScheduleOnce {
71 id: u64,
72 delay: Duration,
73 callback: Box<dyn FnOnce() + Send>,
74 cancelled: Arc<AtomicBool>,
75 },
76 ScheduleRepeat {
78 id: u64,
79 interval: Duration,
80 callback: Arc<dyn Fn() -> bool + Send + Sync>,
81 cancelled: Arc<AtomicBool>,
82 },
83 Shutdown,
85}
86
87pub struct SchedulerHandle {
92 command_tx: Sender<SchedulerCommand>,
93 join_handle: Option<JoinHandle<()>>,
94}
95
96impl SchedulerHandle {
97 pub fn new(pool: Arc<ThreadPool>) -> Self {
101 let (command_tx, command_rx) = bounded(256);
102
103 let join_handle = thread::Builder::new()
104 .name("timer-scheduler".to_string())
105 .spawn(move || {
106 scheduler_loop(command_rx, pool);
107 })
108 .expect("failed to spawn timer scheduler thread");
109
110 Self {
111 command_tx,
112 join_handle: Some(join_handle),
113 }
114 }
115
116 pub fn schedule_once<F>(&self, delay: Duration, callback: F) -> TimerHandle
120 where
121 F: FnOnce() + Send + 'static,
122 {
123 let id = next_timer_id();
124 let handle = TimerHandle::new(id);
125 let cancelled = handle.cancelled_flag();
126
127 let _ = self.command_tx.send(SchedulerCommand::ScheduleOnce {
128 id,
129 delay,
130 callback: Box::new(callback),
131 cancelled,
132 });
133
134 handle
135 }
136
137 pub fn schedule_repeat<F>(&self, interval: Duration, callback: F) -> TimerHandle
142 where
143 F: Fn() -> bool + Send + Sync + 'static,
144 {
145 let id = next_timer_id();
146 let handle = TimerHandle::new(id);
147 let cancelled = handle.cancelled_flag();
148
149 let _ = self.command_tx.send(SchedulerCommand::ScheduleRepeat {
150 id,
151 interval,
152 callback: Arc::new(callback),
153 cancelled,
154 });
155
156 handle
157 }
158
159 pub fn shared(&self) -> Self {
160 Self {
161 command_tx: self.command_tx.clone(),
162 join_handle: None,
163 }
164 }
165
166 pub fn shutdown(&mut self) {
168 if let Some(handle) = self.join_handle.take() {
169 let _ = self.command_tx.send(SchedulerCommand::Shutdown);
170 let _ = handle.join();
171 }
172 }
173}
174
175impl Drop for SchedulerHandle {
176 fn drop(&mut self) {
177 if let Some(handle) = self.join_handle.take() {
178 let _ = self.command_tx.send(SchedulerCommand::Shutdown);
179 let _ = handle.join();
180 }
181 }
182}
183
184fn scheduler_loop(command_rx: Receiver<SchedulerCommand>, pool: Arc<ThreadPool>) {
186 let mut heap: BinaryHeap<TimerEntry> = BinaryHeap::new();
187
188 loop {
189 let timeout = heap.peek().map(|entry| {
191 let now = Instant::now();
192 if entry.deadline <= now {
193 Duration::ZERO
194 } else {
195 entry.deadline.duration_since(now)
196 }
197 });
198
199 let command = match timeout {
201 Some(Duration::ZERO) => {
202 command_rx.try_recv().ok()
204 }
205 Some(dur) => {
206 match command_rx.recv_timeout(dur) {
208 Ok(cmd) => Some(cmd),
209 Err(RecvTimeoutError::Timeout) => None,
210 Err(RecvTimeoutError::Disconnected) => {
211 return;
213 }
214 }
215 }
216 None => {
217 match command_rx.recv() {
219 Ok(cmd) => Some(cmd),
220 Err(_) => return, }
222 }
223 };
224
225 if let Some(cmd) = command {
227 match cmd {
228 SchedulerCommand::ScheduleOnce {
229 id,
230 delay,
231 callback,
232 cancelled,
233 } => {
234 let deadline = if delay.is_zero() {
235 if !cancelled.load(Ordering::SeqCst) {
237 pool.spawn(callback);
238 }
239 continue;
240 } else {
241 Instant::now() + delay
242 };
243
244 heap.push(TimerEntry {
245 id,
246 deadline,
247 kind: TimerKind::Once {
248 callback,
249 },
250 cancelled,
251 });
252 }
253 SchedulerCommand::ScheduleRepeat {
254 id,
255 interval,
256 callback,
257 cancelled,
258 } => {
259 let deadline = Instant::now() + interval;
260
261 heap.push(TimerEntry {
262 id,
263 deadline,
264 kind: TimerKind::Repeat {
265 callback,
266 interval,
267 },
268 cancelled,
269 });
270 }
271 SchedulerCommand::Shutdown => {
272 return;
273 }
274 }
275 }
276
277 let now = Instant::now();
279 while let Some(entry) = heap.peek() {
280 if entry.deadline > now {
281 break;
282 }
283
284 let entry = heap.pop().unwrap();
285
286 if entry.cancelled.load(Ordering::SeqCst) {
288 continue;
289 }
290
291 match entry.kind {
292 TimerKind::Once {
293 callback,
294 } => {
295 pool.spawn(callback);
296 }
297 TimerKind::Repeat {
298 callback,
299 interval,
300 } => {
301 let cancelled = entry.cancelled.clone();
302 let callback_clone = callback.clone();
303
304 pool.spawn(move || {
305 if !cancelled.load(Ordering::SeqCst) {
306 let continue_timer = callback_clone();
307 if !continue_timer {
308 cancelled.store(true, Ordering::SeqCst);
309 }
310 }
311 });
312
313 if !entry.cancelled.load(Ordering::SeqCst) {
315 heap.push(TimerEntry {
316 id: entry.id,
317 deadline: now + interval,
318 kind: TimerKind::Repeat {
319 callback,
320 interval,
321 },
322 cancelled: entry.cancelled,
323 });
324 }
325 }
326 }
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use std::sync::{atomic::AtomicUsize, mpsc};
334
335 use parking_lot::Mutex;
336 use rayon::ThreadPoolBuilder;
337
338 fn test_pool() -> Arc<ThreadPool> {
339 Arc::new(ThreadPoolBuilder::new().num_threads(1).build().unwrap())
340 }
341
342 use super::*;
343
344 #[test]
345 fn test_schedule_once() {
346 let mut scheduler = SchedulerHandle::new(test_pool());
347
348 let (tx, rx) = mpsc::channel();
349 scheduler.schedule_once(Duration::from_millis(10), move || {
350 tx.send(()).unwrap();
351 });
352
353 rx.recv_timeout(Duration::from_secs(1)).unwrap();
354 scheduler.shutdown();
355 }
356
357 #[test]
358 fn test_schedule_once_zero_delay() {
359 let mut scheduler = SchedulerHandle::new(test_pool());
360
361 let (tx, rx) = mpsc::channel();
362 scheduler.schedule_once(Duration::ZERO, move || {
363 tx.send(()).unwrap();
364 });
365
366 rx.recv_timeout(Duration::from_secs(1)).unwrap();
367 scheduler.shutdown();
368 }
369
370 #[test]
371 fn test_schedule_repeat() {
372 let mut scheduler = SchedulerHandle::new(test_pool());
373
374 let counter = Arc::new(AtomicUsize::new(0));
375 let counter_clone = counter.clone();
376
377 let handle = scheduler.schedule_repeat(Duration::from_millis(10), move || {
378 counter_clone.fetch_add(1, Ordering::SeqCst);
379 true });
381
382 thread::sleep(Duration::from_millis(50));
384 handle.cancel();
385
386 let count = counter.load(Ordering::SeqCst);
387 assert!(count >= 3, "Expected at least 3 iterations, got {}", count);
388
389 scheduler.shutdown();
390 }
391
392 #[test]
393 fn test_schedule_repeat_stops_on_false() {
394 let mut scheduler = SchedulerHandle::new(test_pool());
395
396 let counter = Arc::new(AtomicUsize::new(0));
397 let counter_clone = counter.clone();
398
399 scheduler.schedule_repeat(Duration::from_millis(10), move || {
400 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
401 count < 3 });
403
404 thread::sleep(Duration::from_millis(100));
406
407 let count = counter.load(Ordering::SeqCst);
409 assert!(count <= 4, "Expected at most 4 iterations, got {}", count);
410
411 scheduler.shutdown();
412 }
413
414 #[test]
415 fn test_cancel_before_fire() {
416 let mut scheduler = SchedulerHandle::new(test_pool());
417
418 let (tx, rx) = mpsc::channel();
419 let handle = scheduler.schedule_once(Duration::from_millis(50), move || {
420 tx.send(()).unwrap();
421 });
422
423 handle.cancel();
425
426 assert!(rx.recv_timeout(Duration::from_millis(100)).is_err());
428
429 scheduler.shutdown();
430 }
431
432 #[test]
433 fn test_multiple_timers() {
434 let mut scheduler = SchedulerHandle::new(test_pool());
435
436 let results = Arc::new(Mutex::new(Vec::new()));
437
438 for i in 0..5 {
439 let results_clone = results.clone();
440 let delay = Duration::from_millis((5 - i) * 10); scheduler.schedule_once(delay, move || {
442 results_clone.lock().push(i);
443 });
444 }
445
446 thread::sleep(Duration::from_millis(100));
447
448 let results = results.lock();
449 assert_eq!(*results, vec![4, 3, 2, 1, 0]);
451
452 scheduler.shutdown();
453 }
454}