Skip to main content

widgetkit_runtime/
scheduler.rs

1use crate::internal::Dispatcher;
2use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, unbounded};
3use std::{
4    cmp::Ordering,
5    collections::{BinaryHeap, HashMap, HashSet},
6    thread::{self, JoinHandle},
7    time::Instant,
8};
9use widgetkit_core::{Duration, TimerId};
10
11pub struct Scheduler<'a, M> {
12    state: &'a mut SchedulerState<M>,
13    dispatcher: Dispatcher<M>,
14}
15
16impl<'a, M> Scheduler<'a, M>
17where
18    M: Send + 'static,
19{
20    pub(crate) fn new(state: &'a mut SchedulerState<M>, dispatcher: Dispatcher<M>) -> Self {
21        Self { state, dispatcher }
22    }
23
24    pub fn after(&mut self, duration: Duration, message: M) -> TimerId {
25        self.state.after(duration, message, self.dispatcher.clone())
26    }
27
28    pub fn every(&mut self, duration: Duration, message: M) -> TimerId
29    where
30        M: Clone,
31    {
32        self.state.every(duration, message, self.dispatcher.clone())
33    }
34
35    pub fn cancel(&mut self, timer_id: TimerId) -> bool {
36        self.state.cancel(timer_id)
37    }
38
39    pub fn clear(&mut self) {
40        self.state.clear();
41    }
42}
43
44pub(crate) struct SchedulerState<M> {
45    command_tx: Option<Sender<SchedulerCommand<M>>>,
46    active_timers: HashSet<TimerId>,
47    worker: Option<JoinHandle<()>>,
48}
49
50impl<M> SchedulerState<M>
51where
52    M: Send + 'static,
53{
54    pub(crate) fn new(dispatcher: Dispatcher<M>) -> Self {
55        let (command_tx, command_rx) = unbounded();
56        let worker = thread::spawn(move || scheduler_worker(dispatcher, command_rx));
57        Self {
58            command_tx: Some(command_tx),
59            active_timers: HashSet::new(),
60            worker: Some(worker),
61        }
62    }
63
64    fn after(&mut self, duration: Duration, message: M, _dispatcher: Dispatcher<M>) -> TimerId {
65        let timer_id = TimerId::new();
66        self.active_timers.insert(timer_id);
67        self.send_command(SchedulerCommand::Schedule {
68            timer_id,
69            deadline: Instant::now() + duration,
70            interval: None,
71            delivery: TimerDelivery::Once(Some(message)),
72        });
73        timer_id
74    }
75
76    fn every(&mut self, duration: Duration, message: M, _dispatcher: Dispatcher<M>) -> TimerId
77    where
78        M: Clone,
79    {
80        let timer_id = TimerId::new();
81        self.active_timers.insert(timer_id);
82        let factory: Box<dyn Fn() -> M + Send> = Box::new(move || message.clone());
83        self.send_command(SchedulerCommand::Schedule {
84            timer_id,
85            deadline: Instant::now() + duration,
86            interval: Some(duration),
87            delivery: TimerDelivery::Repeat(factory),
88        });
89        timer_id
90    }
91
92    fn cancel(&mut self, timer_id: TimerId) -> bool {
93        let existed = self.active_timers.remove(&timer_id);
94        if existed {
95            self.send_command(SchedulerCommand::Cancel { timer_id });
96        }
97        existed
98    }
99
100    pub(crate) fn reap(&mut self, timer_id: TimerId) {
101        self.active_timers.remove(&timer_id);
102    }
103
104    pub(crate) fn clear(&mut self) {
105        if self.active_timers.is_empty() {
106            return;
107        }
108        self.active_timers.clear();
109        self.send_command(SchedulerCommand::Clear);
110    }
111
112    pub(crate) fn shutdown(&mut self) {
113        self.active_timers.clear();
114        if let Some(command_tx) = self.command_tx.take() {
115            let _ = command_tx.send(SchedulerCommand::Shutdown);
116        }
117        if let Some(worker) = self.worker.take() {
118            let _ = worker.join();
119        }
120    }
121
122    #[cfg(test)]
123    pub(crate) fn active_count(&self) -> usize {
124        self.active_timers.len()
125    }
126
127    fn send_command(&self, command: SchedulerCommand<M>) {
128        if let Some(command_tx) = self.command_tx.as_ref() {
129            let _ = command_tx.send(command);
130        }
131    }
132}
133
134impl<M> Drop for SchedulerState<M> {
135    fn drop(&mut self) {
136        self.active_timers.clear();
137        if let Some(command_tx) = self.command_tx.take() {
138            let _ = command_tx.send(SchedulerCommand::Shutdown);
139        }
140        if let Some(worker) = self.worker.take() {
141            let _ = worker.join();
142        }
143    }
144}
145
146enum SchedulerCommand<M> {
147    Schedule {
148        timer_id: TimerId,
149        deadline: Instant,
150        interval: Option<Duration>,
151        delivery: TimerDelivery<M>,
152    },
153    Cancel {
154        timer_id: TimerId,
155    },
156    Clear,
157    Shutdown,
158}
159
160enum TimerDelivery<M> {
161    Once(Option<M>),
162    Repeat(Box<dyn Fn() -> M + Send>),
163}
164
165struct TimerEntry<M> {
166    deadline: Instant,
167    interval: Option<Duration>,
168    delivery: TimerDelivery<M>,
169}
170
171#[derive(Clone, Copy, Debug, Eq, PartialEq)]
172struct DeadlineKey {
173    deadline: Instant,
174    timer_id: TimerId,
175}
176
177impl Ord for DeadlineKey {
178    fn cmp(&self, other: &Self) -> Ordering {
179        other
180            .deadline
181            .cmp(&self.deadline)
182            .then_with(|| other.timer_id.into_raw().cmp(&self.timer_id.into_raw()))
183    }
184}
185
186impl PartialOrd for DeadlineKey {
187    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
188        Some(self.cmp(other))
189    }
190}
191
192fn scheduler_worker<M>(dispatcher: Dispatcher<M>, command_rx: Receiver<SchedulerCommand<M>>)
193where
194    M: Send + 'static,
195{
196    let mut entries: HashMap<TimerId, TimerEntry<M>> = HashMap::new();
197    let mut deadlines = BinaryHeap::new();
198
199    loop {
200        dispatch_due(&dispatcher, &mut entries, &mut deadlines);
201
202        let Some(timeout) = next_timeout(&entries, &mut deadlines) else {
203            match command_rx.recv() {
204                Ok(command) => {
205                    if !apply_command(command, &mut entries, &mut deadlines) {
206                        break;
207                    }
208                }
209                Err(_) => break,
210            }
211            continue;
212        };
213
214        match command_rx.recv_timeout(timeout) {
215            Ok(command) => {
216                if !apply_command(command, &mut entries, &mut deadlines) {
217                    break;
218                }
219            }
220            Err(RecvTimeoutError::Timeout) => continue,
221            Err(RecvTimeoutError::Disconnected) => break,
222        }
223    }
224
225    entries.clear();
226    deadlines.clear();
227}
228
229fn apply_command<M>(
230    command: SchedulerCommand<M>,
231    entries: &mut HashMap<TimerId, TimerEntry<M>>,
232    deadlines: &mut BinaryHeap<DeadlineKey>,
233) -> bool {
234    match command {
235        SchedulerCommand::Schedule {
236            timer_id,
237            deadline,
238            interval,
239            delivery,
240        } => {
241            entries.insert(
242                timer_id,
243                TimerEntry {
244                    deadline,
245                    interval,
246                    delivery,
247                },
248            );
249            deadlines.push(DeadlineKey { deadline, timer_id });
250            true
251        }
252        SchedulerCommand::Cancel { timer_id } => {
253            entries.remove(&timer_id);
254            true
255        }
256        SchedulerCommand::Clear => {
257            entries.clear();
258            deadlines.clear();
259            true
260        }
261        SchedulerCommand::Shutdown => false,
262    }
263}
264
265fn dispatch_due<M>(
266    dispatcher: &Dispatcher<M>,
267    entries: &mut HashMap<TimerId, TimerEntry<M>>,
268    deadlines: &mut BinaryHeap<DeadlineKey>,
269) where
270    M: Send + 'static,
271{
272    let now = Instant::now();
273    loop {
274        prune_stale(entries, deadlines);
275        let Some(next) = deadlines.peek().copied() else {
276            break;
277        };
278        if next.deadline > now {
279            break;
280        }
281        let _ = deadlines.pop();
282
283        let Some(entry) = entries.get_mut(&next.timer_id) else {
284            continue;
285        };
286        if entry.deadline != next.deadline {
287            continue;
288        }
289
290        match &mut entry.delivery {
291            TimerDelivery::Once(message) => {
292                if let Some(message) = message.take() {
293                    let _ = dispatcher.post_message(message);
294                }
295                entries.remove(&next.timer_id);
296                dispatcher.finish_timer(next.timer_id);
297            }
298            TimerDelivery::Repeat(factory) => {
299                let _ = dispatcher.post_message(factory());
300                let interval = entry.interval.expect("repeat timers must carry an interval");
301                entry.deadline = advance_deadline(entry.deadline, interval, now);
302                deadlines.push(DeadlineKey {
303                    deadline: entry.deadline,
304                    timer_id: next.timer_id,
305                });
306            }
307        }
308    }
309}
310
311fn next_timeout<M>(
312    entries: &HashMap<TimerId, TimerEntry<M>>,
313    deadlines: &mut BinaryHeap<DeadlineKey>,
314) -> Option<Duration> {
315    prune_stale(entries, deadlines);
316    deadlines
317        .peek()
318        .map(|next| next.deadline.saturating_duration_since(Instant::now()))
319}
320
321fn prune_stale<M>(entries: &HashMap<TimerId, TimerEntry<M>>, deadlines: &mut BinaryHeap<DeadlineKey>) {
322    while let Some(next) = deadlines.peek() {
323        let Some(entry) = entries.get(&next.timer_id) else {
324            let _ = deadlines.pop();
325            continue;
326        };
327        if entry.deadline != next.deadline {
328            let _ = deadlines.pop();
329            continue;
330        }
331        break;
332    }
333}
334
335fn advance_deadline(previous_deadline: Instant, interval: Duration, now: Instant) -> Instant {
336    let mut next_deadline = previous_deadline + interval;
337    while next_deadline <= now {
338        next_deadline += interval;
339    }
340    next_deadline
341}