1use crate::internal::Dispatcher;
2use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, unbounded};
3use std::{
4 cmp::Ordering,
5 collections::{BinaryHeap, HashMap, HashSet},
6 thread::{self, JoinHandle},
7 time::Instant,
8};
9use widgetkit_core::{Duration, TimerId};
10
11pub struct Scheduler<'a, M> {
12 state: &'a mut SchedulerState<M>,
13 dispatcher: Dispatcher<M>,
14}
15
16impl<'a, M> Scheduler<'a, M>
17where
18 M: Send + 'static,
19{
20 pub(crate) fn new(state: &'a mut SchedulerState<M>, dispatcher: Dispatcher<M>) -> Self {
21 Self { state, dispatcher }
22 }
23
24 pub fn after(&mut self, duration: Duration, message: M) -> TimerId {
25 self.state.after(duration, message, self.dispatcher.clone())
26 }
27
28 pub fn every(&mut self, duration: Duration, message: M) -> TimerId
29 where
30 M: Clone,
31 {
32 self.state.every(duration, message, self.dispatcher.clone())
33 }
34
35 pub fn cancel(&mut self, timer_id: TimerId) -> bool {
36 self.state.cancel(timer_id)
37 }
38
39 pub fn clear(&mut self) {
40 self.state.clear();
41 }
42}
43
44pub(crate) struct SchedulerState<M> {
45 command_tx: Option<Sender<SchedulerCommand<M>>>,
46 active_timers: HashSet<TimerId>,
47 worker: Option<JoinHandle<()>>,
48}
49
50impl<M> SchedulerState<M>
51where
52 M: Send + 'static,
53{
54 pub(crate) fn new(dispatcher: Dispatcher<M>) -> Self {
55 let (command_tx, command_rx) = unbounded();
56 let worker = thread::spawn(move || scheduler_worker(dispatcher, command_rx));
57 Self {
58 command_tx: Some(command_tx),
59 active_timers: HashSet::new(),
60 worker: Some(worker),
61 }
62 }
63
64 fn after(&mut self, duration: Duration, message: M, _dispatcher: Dispatcher<M>) -> TimerId {
65 let timer_id = TimerId::new();
66 self.active_timers.insert(timer_id);
67 self.send_command(SchedulerCommand::Schedule {
68 timer_id,
69 deadline: Instant::now() + duration,
70 interval: None,
71 delivery: TimerDelivery::Once(Some(message)),
72 });
73 timer_id
74 }
75
76 fn every(&mut self, duration: Duration, message: M, _dispatcher: Dispatcher<M>) -> TimerId
77 where
78 M: Clone,
79 {
80 let timer_id = TimerId::new();
81 self.active_timers.insert(timer_id);
82 let factory: Box<dyn Fn() -> M + Send> = Box::new(move || message.clone());
83 self.send_command(SchedulerCommand::Schedule {
84 timer_id,
85 deadline: Instant::now() + duration,
86 interval: Some(duration),
87 delivery: TimerDelivery::Repeat(factory),
88 });
89 timer_id
90 }
91
92 fn cancel(&mut self, timer_id: TimerId) -> bool {
93 let existed = self.active_timers.remove(&timer_id);
94 if existed {
95 self.send_command(SchedulerCommand::Cancel { timer_id });
96 }
97 existed
98 }
99
100 pub(crate) fn reap(&mut self, timer_id: TimerId) {
101 self.active_timers.remove(&timer_id);
102 }
103
104 pub(crate) fn clear(&mut self) {
105 if self.active_timers.is_empty() {
106 return;
107 }
108 self.active_timers.clear();
109 self.send_command(SchedulerCommand::Clear);
110 }
111
112 pub(crate) fn shutdown(&mut self) {
113 self.active_timers.clear();
114 if let Some(command_tx) = self.command_tx.take() {
115 let _ = command_tx.send(SchedulerCommand::Shutdown);
116 }
117 if let Some(worker) = self.worker.take() {
118 let _ = worker.join();
119 }
120 }
121
122 #[cfg(test)]
123 pub(crate) fn active_count(&self) -> usize {
124 self.active_timers.len()
125 }
126
127 fn send_command(&self, command: SchedulerCommand<M>) {
128 if let Some(command_tx) = self.command_tx.as_ref() {
129 let _ = command_tx.send(command);
130 }
131 }
132}
133
134impl<M> Drop for SchedulerState<M> {
135 fn drop(&mut self) {
136 self.active_timers.clear();
137 if let Some(command_tx) = self.command_tx.take() {
138 let _ = command_tx.send(SchedulerCommand::Shutdown);
139 }
140 if let Some(worker) = self.worker.take() {
141 let _ = worker.join();
142 }
143 }
144}
145
146enum SchedulerCommand<M> {
147 Schedule {
148 timer_id: TimerId,
149 deadline: Instant,
150 interval: Option<Duration>,
151 delivery: TimerDelivery<M>,
152 },
153 Cancel {
154 timer_id: TimerId,
155 },
156 Clear,
157 Shutdown,
158}
159
160enum TimerDelivery<M> {
161 Once(Option<M>),
162 Repeat(Box<dyn Fn() -> M + Send>),
163}
164
165struct TimerEntry<M> {
166 deadline: Instant,
167 interval: Option<Duration>,
168 delivery: TimerDelivery<M>,
169}
170
171#[derive(Clone, Copy, Debug, Eq, PartialEq)]
172struct DeadlineKey {
173 deadline: Instant,
174 timer_id: TimerId,
175}
176
177impl Ord for DeadlineKey {
178 fn cmp(&self, other: &Self) -> Ordering {
179 other
180 .deadline
181 .cmp(&self.deadline)
182 .then_with(|| other.timer_id.into_raw().cmp(&self.timer_id.into_raw()))
183 }
184}
185
186impl PartialOrd for DeadlineKey {
187 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
188 Some(self.cmp(other))
189 }
190}
191
192fn scheduler_worker<M>(dispatcher: Dispatcher<M>, command_rx: Receiver<SchedulerCommand<M>>)
193where
194 M: Send + 'static,
195{
196 let mut entries: HashMap<TimerId, TimerEntry<M>> = HashMap::new();
197 let mut deadlines = BinaryHeap::new();
198
199 loop {
200 dispatch_due(&dispatcher, &mut entries, &mut deadlines);
201
202 let Some(timeout) = next_timeout(&entries, &mut deadlines) else {
203 match command_rx.recv() {
204 Ok(command) => {
205 if !apply_command(command, &mut entries, &mut deadlines) {
206 break;
207 }
208 }
209 Err(_) => break,
210 }
211 continue;
212 };
213
214 match command_rx.recv_timeout(timeout) {
215 Ok(command) => {
216 if !apply_command(command, &mut entries, &mut deadlines) {
217 break;
218 }
219 }
220 Err(RecvTimeoutError::Timeout) => continue,
221 Err(RecvTimeoutError::Disconnected) => break,
222 }
223 }
224
225 entries.clear();
226 deadlines.clear();
227}
228
229fn apply_command<M>(
230 command: SchedulerCommand<M>,
231 entries: &mut HashMap<TimerId, TimerEntry<M>>,
232 deadlines: &mut BinaryHeap<DeadlineKey>,
233) -> bool {
234 match command {
235 SchedulerCommand::Schedule {
236 timer_id,
237 deadline,
238 interval,
239 delivery,
240 } => {
241 entries.insert(
242 timer_id,
243 TimerEntry {
244 deadline,
245 interval,
246 delivery,
247 },
248 );
249 deadlines.push(DeadlineKey { deadline, timer_id });
250 true
251 }
252 SchedulerCommand::Cancel { timer_id } => {
253 entries.remove(&timer_id);
254 true
255 }
256 SchedulerCommand::Clear => {
257 entries.clear();
258 deadlines.clear();
259 true
260 }
261 SchedulerCommand::Shutdown => false,
262 }
263}
264
265fn dispatch_due<M>(
266 dispatcher: &Dispatcher<M>,
267 entries: &mut HashMap<TimerId, TimerEntry<M>>,
268 deadlines: &mut BinaryHeap<DeadlineKey>,
269) where
270 M: Send + 'static,
271{
272 let now = Instant::now();
273 loop {
274 prune_stale(entries, deadlines);
275 let Some(next) = deadlines.peek().copied() else {
276 break;
277 };
278 if next.deadline > now {
279 break;
280 }
281 let _ = deadlines.pop();
282
283 let Some(entry) = entries.get_mut(&next.timer_id) else {
284 continue;
285 };
286 if entry.deadline != next.deadline {
287 continue;
288 }
289
290 match &mut entry.delivery {
291 TimerDelivery::Once(message) => {
292 if let Some(message) = message.take() {
293 let _ = dispatcher.post_message(message);
294 }
295 entries.remove(&next.timer_id);
296 dispatcher.finish_timer(next.timer_id);
297 }
298 TimerDelivery::Repeat(factory) => {
299 let _ = dispatcher.post_message(factory());
300 let interval = entry.interval.expect("repeat timers must carry an interval");
301 entry.deadline = advance_deadline(entry.deadline, interval, now);
302 deadlines.push(DeadlineKey {
303 deadline: entry.deadline,
304 timer_id: next.timer_id,
305 });
306 }
307 }
308 }
309}
310
311fn next_timeout<M>(
312 entries: &HashMap<TimerId, TimerEntry<M>>,
313 deadlines: &mut BinaryHeap<DeadlineKey>,
314) -> Option<Duration> {
315 prune_stale(entries, deadlines);
316 deadlines
317 .peek()
318 .map(|next| next.deadline.saturating_duration_since(Instant::now()))
319}
320
321fn prune_stale<M>(entries: &HashMap<TimerId, TimerEntry<M>>, deadlines: &mut BinaryHeap<DeadlineKey>) {
322 while let Some(next) = deadlines.peek() {
323 let Some(entry) = entries.get(&next.timer_id) else {
324 let _ = deadlines.pop();
325 continue;
326 };
327 if entry.deadline != next.deadline {
328 let _ = deadlines.pop();
329 continue;
330 }
331 break;
332 }
333}
334
335fn advance_deadline(previous_deadline: Instant, interval: Duration, now: Instant) -> Instant {
336 let mut next_deadline = previous_deadline + interval;
337 while next_deadline <= now {
338 next_deadline += interval;
339 }
340 next_deadline
341}