use alloc::vec::Vec;
use core::cmp::Ordering;
use core::time::Duration;
use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel};
use std::time::Instant;
enum RaiseMsg<E> {
At(Instant, E),
Stop,
}
pub struct SchedulerHandle<E> {
tx: Sender<RaiseMsg<E>>,
}
impl<E> Clone for SchedulerHandle<E> {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
}
}
}
impl<E> SchedulerHandle<E> {
pub fn raise_at(&self, at: Instant, event: E) -> bool {
self.tx.send(RaiseMsg::At(at, event)).is_ok()
}
pub fn raise_in(&self, delay: Duration, event: E) -> bool {
self.raise_at(Instant::now() + delay, event)
}
pub fn raise_now(&self, event: E) -> bool {
self.raise_at(Instant::now(), event)
}
pub fn stop(&self) {
let _ = self.tx.send(RaiseMsg::Stop);
}
}
struct Entry<E> {
deadline: Instant,
seq: u64,
event: E,
}
impl<E> PartialEq for Entry<E> {
fn eq(&self, other: &Self) -> bool {
self.deadline == other.deadline && self.seq == other.seq
}
}
impl<E> Eq for Entry<E> {}
impl<E> Ord for Entry<E> {
fn cmp(&self, other: &Self) -> Ordering {
other
.deadline
.cmp(&self.deadline)
.then_with(|| other.seq.cmp(&self.seq))
}
}
impl<E> PartialOrd for Entry<E> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct Scheduler<E> {
rx: Receiver<RaiseMsg<E>>,
heap: alloc::collections::BinaryHeap<Entry<E>>,
seq: u64,
idle_floor: Duration,
}
impl<E> Scheduler<E> {
#[must_use]
pub fn new(idle_floor: Duration) -> (Self, SchedulerHandle<E>) {
let (tx, rx) = channel();
let sched = Self {
rx,
heap: alloc::collections::BinaryHeap::new(),
seq: 0,
idle_floor,
};
(sched, SchedulerHandle { tx })
}
fn push(&mut self, deadline: Instant, event: E) {
let seq = self.seq;
self.seq = self.seq.wrapping_add(1);
self.heap.push(Entry {
deadline,
seq,
event,
});
}
fn drain_channel(&mut self) -> bool {
let mut stop = false;
while let Ok(msg) = self.rx.try_recv() {
match msg {
RaiseMsg::At(at, ev) => self.push(at, ev),
RaiseMsg::Stop => stop = true,
}
}
stop
}
fn drain_due(&mut self, now: Instant) -> Vec<E> {
let mut due = Vec::new();
while self.heap.peek().is_some_and(|t| t.deadline <= now) {
if let Some(entry) = self.heap.pop() {
due.push(entry.event);
}
}
due
}
pub fn park_due_batch(&mut self) -> (Vec<E>, bool) {
let stop = self.drain_channel();
let due = self.drain_due(Instant::now());
if !due.is_empty() || stop {
return (due, stop);
}
let timeout = match self.heap.peek() {
Some(top) => top.deadline.saturating_duration_since(Instant::now()),
None => self.idle_floor,
};
match self.rx.recv_timeout(timeout) {
Ok(RaiseMsg::At(at, ev)) => self.push(at, ev),
Ok(RaiseMsg::Stop) => return (Vec::new(), true),
Err(RecvTimeoutError::Timeout) => {}
Err(RecvTimeoutError::Disconnected) => return (Vec::new(), true),
}
let _ = self.drain_channel();
(self.drain_due(Instant::now()), false)
}
pub fn run<F: FnMut(E)>(&mut self, mut dispatch: F) {
loop {
let stop = self.drain_channel();
let now = Instant::now();
for ev in self.drain_due(now) {
dispatch(ev);
}
if stop {
let now = Instant::now();
for ev in self.drain_due(now) {
dispatch(ev);
}
return;
}
let timeout = match self.heap.peek() {
Some(top) => top.deadline.saturating_duration_since(Instant::now()),
None => self.idle_floor,
};
match self.rx.recv_timeout(timeout) {
Ok(RaiseMsg::At(at, ev)) => self.push(at, ev),
Ok(RaiseMsg::Stop) => {
let now = Instant::now();
for ev in self.drain_due(now) {
dispatch(ev);
}
return;
}
Err(RecvTimeoutError::Timeout) => {} Err(RecvTimeoutError::Disconnected) => return,
}
}
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use std::thread;
#[derive(Debug, Clone, PartialEq, Eq)]
enum Ev {
A,
B,
C,
Tick(u32),
}
fn run_in_thread(mut sched: Scheduler<Ev>, log: Arc<Mutex<Vec<Ev>>>) -> thread::JoinHandle<()> {
thread::spawn(move || {
sched.run(|ev| log.lock().unwrap().push(ev));
})
}
#[test]
fn fires_in_deadline_order_not_insertion_order() {
let (mut sched, h) = Scheduler::<Ev>::new(Duration::from_secs(1));
let now = Instant::now();
sched_push_for_test(&mut sched, now + Duration::from_millis(60), Ev::C);
sched_push_for_test(&mut sched, now + Duration::from_millis(20), Ev::A);
sched_push_for_test(&mut sched, now + Duration::from_millis(40), Ev::B);
let log = Arc::new(Mutex::new(Vec::new()));
let jh = run_in_thread(sched, Arc::clone(&log));
thread::sleep(Duration::from_millis(150));
h.stop();
jh.join().unwrap();
assert_eq!(*log.lock().unwrap(), vec![Ev::A, Ev::B, Ev::C]);
}
#[test]
fn raise_during_park_wakes_and_fires_early() {
let (mut sched, h) = Scheduler::<Ev>::new(Duration::from_secs(1));
sched_push_for_test(&mut sched, Instant::now() + Duration::from_secs(30), Ev::C);
let log = Arc::new(Mutex::new(Vec::new()));
let jh = run_in_thread(sched, Arc::clone(&log));
thread::sleep(Duration::from_millis(20));
let t0 = Instant::now();
h.raise_in(Duration::from_millis(10), Ev::A);
loop {
if log.lock().unwrap().contains(&Ev::A) {
break;
}
assert!(
t0.elapsed() < Duration::from_secs(2),
"raise must wake the park"
);
thread::sleep(Duration::from_millis(2));
}
assert!(
t0.elapsed() < Duration::from_secs(1),
"fired far before the 30s entry"
);
h.stop();
jh.join().unwrap();
}
#[test]
fn equal_deadline_breaks_fifo_by_seq() {
let (mut sched, h) = Scheduler::<Ev>::new(Duration::from_secs(1));
let at = Instant::now() + Duration::from_millis(20);
sched_push_for_test(&mut sched, at, Ev::A);
sched_push_for_test(&mut sched, at, Ev::B);
sched_push_for_test(&mut sched, at, Ev::C);
let log = Arc::new(Mutex::new(Vec::new()));
let jh = run_in_thread(sched, Arc::clone(&log));
thread::sleep(Duration::from_millis(120));
h.stop();
jh.join().unwrap();
assert_eq!(*log.lock().unwrap(), vec![Ev::A, Ev::B, Ev::C]);
}
#[test]
fn periodic_rearm_from_dispatch() {
let (mut sched, h) = Scheduler::<Ev>::new(Duration::from_secs(1));
h.raise_now(Ev::Tick(0));
let log = Arc::new(Mutex::new(Vec::new()));
let h2 = h.clone();
let jh = thread::spawn(move || {
let mut n = 0u32;
sched.run(|ev| {
if let Ev::Tick(_) = ev {
n += 1;
if n < 5 {
h2.raise_in(Duration::from_millis(10), Ev::Tick(n));
}
}
log.lock().unwrap().push(ev);
});
});
thread::sleep(Duration::from_millis(200));
h.stop();
jh.join().unwrap();
}
#[test]
fn raise_storm_parallel_to_fires_no_loss() {
let (mut sched, h) = Scheduler::<u32>::new(Duration::from_millis(50));
let count = Arc::new(Mutex::new(0u64));
let c2 = Arc::clone(&count);
let jh = thread::spawn(move || {
sched.run(|_ev: u32| {
*c2.lock().unwrap() += 1;
});
});
const RAISERS: u32 = 8;
const PER: u32 = 500;
let mut handles = Vec::new();
for _ in 0..RAISERS {
let hc = h.clone();
handles.push(thread::spawn(move || {
for i in 0..PER {
hc.raise_in(Duration::from_millis((i % 10) as u64), i);
}
}));
}
for hh in handles {
hh.join().unwrap();
}
thread::sleep(Duration::from_millis(300));
h.stop();
jh.join().unwrap();
assert_eq!(*count.lock().unwrap(), u64::from(RAISERS) * u64::from(PER));
}
fn sched_push_for_test<E>(s: &mut Scheduler<E>, at: Instant, ev: E) {
s.push(at, ev);
}
}