Skip to main content

reifydb_runtime/actor/timers/
scheduler.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4#![allow(clippy::disallowed_methods)]
5
6use std::{
7	cmp::Ordering as CmpOrdering,
8	collections::BinaryHeap,
9	sync::{
10		Arc,
11		atomic::{AtomicBool, Ordering},
12	},
13	thread::{self, JoinHandle},
14	time::{Duration, Instant},
15};
16
17use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, bounded};
18use rayon::ThreadPool;
19
20use super::{TimerHandle, next_timer_id};
21
22struct TimerEntry {
23	/// Unique timer ID.
24	id: u64,
25	/// When the timer should fire.
26	deadline: Instant,
27	/// The kind of timer
28	kind: TimerKind,
29	/// Shared flag to check if cancelled.
30	cancelled: Arc<AtomicBool>,
31}
32
33enum TimerKind {
34	/// Fire once and remove.
35	Once {
36		callback: Box<dyn FnOnce() + Send>,
37	},
38	/// Fire repeatedly until cancelled or callback returns false.
39	Repeat {
40		callback: Arc<dyn Fn() -> bool + Send + Sync>,
41		interval: Duration,
42	},
43}
44
45impl Eq for TimerEntry {}
46
47impl PartialEq for TimerEntry {
48	fn eq(&self, other: &Self) -> bool {
49		self.deadline == other.deadline && self.id == other.id
50	}
51}
52
53impl Ord for TimerEntry {
54	// BinaryHeap is a max-heap, so we reverse the ordering to get a min-heap by deadline.
55	fn cmp(&self, other: &Self) -> CmpOrdering {
56		// Reverse ordering for min-heap behavior
57		other.deadline.cmp(&self.deadline).then_with(|| other.id.cmp(&self.id))
58	}
59}
60
61impl PartialOrd for TimerEntry {
62	fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
63		Some(self.cmp(other))
64	}
65}
66
67/// Commands sent to the scheduler coordinator thread.
68enum SchedulerCommand {
69	/// Schedule a one-shot timer.
70	ScheduleOnce {
71		id: u64,
72		delay: Duration,
73		callback: Box<dyn FnOnce() + Send>,
74		cancelled: Arc<AtomicBool>,
75	},
76	/// Schedule a repeating timer.
77	ScheduleRepeat {
78		id: u64,
79		interval: Duration,
80		callback: Arc<dyn Fn() -> bool + Send + Sync>,
81		cancelled: Arc<AtomicBool>,
82	},
83	/// Shutdown the scheduler.
84	Shutdown,
85}
86
87/// Handle to the timer scheduler.
88///
89/// Used to schedule timers and shutdown the scheduler.
90/// Cloning the handle creates another reference to the same scheduler.
91pub struct SchedulerHandle {
92	command_tx: Sender<SchedulerCommand>,
93	join_handle: Option<JoinHandle<()>>,
94}
95
96impl SchedulerHandle {
97	/// Create and start a new scheduler.
98	///
99	/// Timer callbacks are dispatched to the given rayon thread pool.
100	pub fn new(pool: Arc<ThreadPool>) -> Self {
101		let (command_tx, command_rx) = bounded(256);
102
103		let join_handle = thread::Builder::new()
104			.name("timer-scheduler".to_string())
105			.spawn(move || {
106				scheduler_loop(command_rx, pool);
107			})
108			.expect("failed to spawn timer scheduler thread");
109
110		Self {
111			command_tx,
112			join_handle: Some(join_handle),
113		}
114	}
115
116	/// Schedule a callback to fire once after a delay.
117	///
118	/// Returns a handle that can be used to cancel the timer.
119	pub fn schedule_once<F>(&self, delay: Duration, callback: F) -> TimerHandle
120	where
121		F: FnOnce() + Send + 'static,
122	{
123		let id = next_timer_id();
124		let handle = TimerHandle::new(id);
125		let cancelled = handle.cancelled_flag();
126
127		let _ = self.command_tx.send(SchedulerCommand::ScheduleOnce {
128			id,
129			delay,
130			callback: Box::new(callback),
131			cancelled,
132		});
133
134		handle
135	}
136
137	/// Schedule a callback to fire repeatedly at an interval.
138	///
139	/// The callback returns `true` to continue or `false` to stop.
140	/// Returns a handle that can be used to cancel the timer.
141	pub fn schedule_repeat<F>(&self, interval: Duration, callback: F) -> TimerHandle
142	where
143		F: Fn() -> bool + Send + Sync + 'static,
144	{
145		let id = next_timer_id();
146		let handle = TimerHandle::new(id);
147		let cancelled = handle.cancelled_flag();
148
149		let _ = self.command_tx.send(SchedulerCommand::ScheduleRepeat {
150			id,
151			interval,
152			callback: Arc::new(callback),
153			cancelled,
154		});
155
156		handle
157	}
158
159	pub fn shared(&self) -> Self {
160		Self {
161			command_tx: self.command_tx.clone(),
162			join_handle: None,
163		}
164	}
165
166	/// Shutdown the scheduler and wait for it to complete.
167	pub fn shutdown(&mut self) {
168		if let Some(handle) = self.join_handle.take() {
169			let _ = self.command_tx.send(SchedulerCommand::Shutdown);
170			let _ = handle.join();
171		}
172	}
173}
174
175impl Drop for SchedulerHandle {
176	fn drop(&mut self) {
177		if let Some(handle) = self.join_handle.take() {
178			let _ = self.command_tx.send(SchedulerCommand::Shutdown);
179			let _ = handle.join();
180		}
181	}
182}
183
184/// The main scheduler loop running on the coordinator thread.
185fn scheduler_loop(command_rx: Receiver<SchedulerCommand>, pool: Arc<ThreadPool>) {
186	let mut heap: BinaryHeap<TimerEntry> = BinaryHeap::new();
187
188	loop {
189		// Calculate timeout until next timer
190		let timeout = heap.peek().map(|entry| {
191			let now = Instant::now();
192			if entry.deadline <= now {
193				Duration::ZERO
194			} else {
195				entry.deadline.duration_since(now)
196			}
197		});
198
199		// Wait for command or timeout
200		let command = match timeout {
201			Some(Duration::ZERO) => {
202				// Timer(s) ready to fire - check for commands without blocking
203				command_rx.try_recv().ok()
204			}
205			Some(dur) => {
206				// Wait until next timer or command
207				match command_rx.recv_timeout(dur) {
208					Ok(cmd) => Some(cmd),
209					Err(RecvTimeoutError::Timeout) => None,
210					Err(RecvTimeoutError::Disconnected) => {
211						// Channel closed, exit
212						return;
213					}
214				}
215			}
216			None => {
217				// No timers - block until command
218				match command_rx.recv() {
219					Ok(cmd) => Some(cmd),
220					Err(_) => return, // Channel closed
221				}
222			}
223		};
224
225		// Process command if received
226		if let Some(cmd) = command {
227			match cmd {
228				SchedulerCommand::ScheduleOnce {
229					id,
230					delay,
231					callback,
232					cancelled,
233				} => {
234					let deadline = if delay.is_zero() {
235						// Zero delay - fire immediately
236						if !cancelled.load(Ordering::SeqCst) {
237							pool.spawn(callback);
238						}
239						continue;
240					} else {
241						Instant::now() + delay
242					};
243
244					heap.push(TimerEntry {
245						id,
246						deadline,
247						kind: TimerKind::Once {
248							callback,
249						},
250						cancelled,
251					});
252				}
253				SchedulerCommand::ScheduleRepeat {
254					id,
255					interval,
256					callback,
257					cancelled,
258				} => {
259					let deadline = Instant::now() + interval;
260
261					heap.push(TimerEntry {
262						id,
263						deadline,
264						kind: TimerKind::Repeat {
265							callback,
266							interval,
267						},
268						cancelled,
269					});
270				}
271				SchedulerCommand::Shutdown => {
272					return;
273				}
274			}
275		}
276
277		// Fire all due timers
278		let now = Instant::now();
279		while let Some(entry) = heap.peek() {
280			if entry.deadline > now {
281				break;
282			}
283
284			let entry = heap.pop().unwrap();
285
286			// Check if cancelled
287			if entry.cancelled.load(Ordering::SeqCst) {
288				continue;
289			}
290
291			match entry.kind {
292				TimerKind::Once {
293					callback,
294				} => {
295					pool.spawn(callback);
296				}
297				TimerKind::Repeat {
298					callback,
299					interval,
300				} => {
301					let cancelled = entry.cancelled.clone();
302					let callback_clone = callback.clone();
303
304					pool.spawn(move || {
305						if !cancelled.load(Ordering::SeqCst) {
306							let continue_timer = callback_clone();
307							if !continue_timer {
308								cancelled.store(true, Ordering::SeqCst);
309							}
310						}
311					});
312
313					// Re-schedule if not cancelled
314					if !entry.cancelled.load(Ordering::SeqCst) {
315						heap.push(TimerEntry {
316							id: entry.id,
317							deadline: now + interval,
318							kind: TimerKind::Repeat {
319								callback,
320								interval,
321							},
322							cancelled: entry.cancelled,
323						});
324					}
325				}
326			}
327		}
328	}
329}
330
331#[cfg(test)]
332mod tests {
333	use std::sync::{atomic::AtomicUsize, mpsc};
334
335	use parking_lot::Mutex;
336	use rayon::ThreadPoolBuilder;
337
338	fn test_pool() -> Arc<ThreadPool> {
339		Arc::new(ThreadPoolBuilder::new().num_threads(1).build().unwrap())
340	}
341
342	use super::*;
343
344	#[test]
345	fn test_schedule_once() {
346		let mut scheduler = SchedulerHandle::new(test_pool());
347
348		let (tx, rx) = mpsc::channel();
349		scheduler.schedule_once(Duration::from_millis(10), move || {
350			tx.send(()).unwrap();
351		});
352
353		rx.recv_timeout(Duration::from_secs(1)).unwrap();
354		scheduler.shutdown();
355	}
356
357	#[test]
358	fn test_schedule_once_zero_delay() {
359		let mut scheduler = SchedulerHandle::new(test_pool());
360
361		let (tx, rx) = mpsc::channel();
362		scheduler.schedule_once(Duration::ZERO, move || {
363			tx.send(()).unwrap();
364		});
365
366		rx.recv_timeout(Duration::from_secs(1)).unwrap();
367		scheduler.shutdown();
368	}
369
370	#[test]
371	fn test_schedule_repeat() {
372		let mut scheduler = SchedulerHandle::new(test_pool());
373
374		let counter = Arc::new(AtomicUsize::new(0));
375		let counter_clone = counter.clone();
376
377		let handle = scheduler.schedule_repeat(Duration::from_millis(10), move || {
378			counter_clone.fetch_add(1, Ordering::SeqCst);
379			true // Continue
380		});
381
382		// Wait for several iterations
383		thread::sleep(Duration::from_millis(50));
384		handle.cancel();
385
386		let count = counter.load(Ordering::SeqCst);
387		assert!(count >= 3, "Expected at least 3 iterations, got {}", count);
388
389		scheduler.shutdown();
390	}
391
392	#[test]
393	fn test_schedule_repeat_stops_on_false() {
394		let mut scheduler = SchedulerHandle::new(test_pool());
395
396		let counter = Arc::new(AtomicUsize::new(0));
397		let counter_clone = counter.clone();
398
399		scheduler.schedule_repeat(Duration::from_millis(10), move || {
400			let count = counter_clone.fetch_add(1, Ordering::SeqCst);
401			count < 3 // Stop after 3 iterations
402		});
403
404		// Wait enough time for many iterations
405		thread::sleep(Duration::from_millis(100));
406
407		// Should have stopped at 3
408		let count = counter.load(Ordering::SeqCst);
409		assert!(count <= 4, "Expected at most 4 iterations, got {}", count);
410
411		scheduler.shutdown();
412	}
413
414	#[test]
415	fn test_cancel_before_fire() {
416		let mut scheduler = SchedulerHandle::new(test_pool());
417
418		let (tx, rx) = mpsc::channel();
419		let handle = scheduler.schedule_once(Duration::from_millis(50), move || {
420			tx.send(()).unwrap();
421		});
422
423		// Cancel immediately
424		handle.cancel();
425
426		// Should not receive anything
427		assert!(rx.recv_timeout(Duration::from_millis(100)).is_err());
428
429		scheduler.shutdown();
430	}
431
432	#[test]
433	fn test_multiple_timers() {
434		let mut scheduler = SchedulerHandle::new(test_pool());
435
436		let results = Arc::new(Mutex::new(Vec::new()));
437
438		for i in 0..5 {
439			let results_clone = results.clone();
440			let delay = Duration::from_millis((5 - i) * 10); // Reverse order
441			scheduler.schedule_once(delay, move || {
442				results_clone.lock().push(i);
443			});
444		}
445
446		thread::sleep(Duration::from_millis(100));
447
448		let results = results.lock();
449		// Timers should fire in deadline order (4, 3, 2, 1, 0)
450		assert_eq!(*results, vec![4, 3, 2, 1, 0]);
451
452		scheduler.shutdown();
453	}
454}