ic_cron/
task_scheduler.rs

1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3
4use ic_cdk::export::candid::{CandidType, Deserialize, Result as CandidResult};
5
6use crate::types::{
7    Iterations, ScheduledTask, SchedulingOptions, TaskExecutionQueue, TaskId, TaskTimestamp,
8};
9
10#[derive(Default, CandidType, Deserialize, Clone)]
11pub struct TaskScheduler {
12    pub tasks: HashMap<TaskId, ScheduledTask>,
13    pub task_id_counter: TaskId,
14
15    pub queue: TaskExecutionQueue,
16}
17
18impl TaskScheduler {
19    pub fn enqueue<TaskPayload: CandidType>(
20        &mut self,
21        payload: TaskPayload,
22        scheduling_interval: SchedulingOptions,
23        timestamp: u64,
24    ) -> CandidResult<TaskId> {
25        let id = self.generate_task_id();
26        let task = ScheduledTask::new(id, payload, timestamp, None, scheduling_interval)?;
27
28        match task.scheduling_options.iterations {
29            Iterations::Exact(times) => {
30                if times > 0 {
31                    self.queue.push(TaskTimestamp {
32                        task_id: id,
33                        timestamp: timestamp + task.scheduling_options.delay_nano,
34                    })
35                }
36            }
37            Iterations::Infinite => self.queue.push(TaskTimestamp {
38                task_id: id,
39                timestamp: timestamp + task.scheduling_options.delay_nano,
40            }),
41        };
42
43        self.tasks.insert(id, task);
44
45        Ok(id)
46    }
47
48    pub fn iterate(&mut self, timestamp: u64) -> Vec<ScheduledTask> {
49        let mut tasks = vec![];
50
51        for task_id in self
52            .queue
53            .pop_ready(timestamp)
54            .into_iter()
55            .map(|it| it.task_id)
56        {
57            let mut should_remove = false;
58
59            match self.tasks.entry(task_id) {
60                Entry::Occupied(mut entry) => {
61                    let task = entry.get_mut();
62
63                    match task.scheduling_options.iterations {
64                        Iterations::Infinite => {
65                            let new_rescheduled_at = if task.delay_passed {
66                                if let Some(rescheduled_at) = task.rescheduled_at {
67                                    rescheduled_at + task.scheduling_options.interval_nano
68                                } else {
69                                    task.scheduled_at + task.scheduling_options.interval_nano
70                                }
71                            } else {
72                                task.delay_passed = true;
73
74                                if let Some(rescheduled_at) = task.rescheduled_at {
75                                    rescheduled_at + task.scheduling_options.delay_nano
76                                } else {
77                                    task.scheduled_at + task.scheduling_options.delay_nano
78                                }
79                            };
80
81                            task.rescheduled_at = Some(new_rescheduled_at);
82
83                            self.queue.push(TaskTimestamp {
84                                task_id,
85                                timestamp: new_rescheduled_at
86                                    + task.scheduling_options.interval_nano,
87                            });
88                        }
89                        Iterations::Exact(times_left) => {
90                            if times_left > 1 {
91                                let new_rescheduled_at = if task.delay_passed {
92                                    if let Some(rescheduled_at) = task.rescheduled_at {
93                                        rescheduled_at + task.scheduling_options.interval_nano
94                                    } else {
95                                        task.scheduled_at + task.scheduling_options.interval_nano
96                                    }
97                                } else {
98                                    task.delay_passed = true;
99
100                                    if let Some(rescheduled_at) = task.rescheduled_at {
101                                        rescheduled_at + task.scheduling_options.delay_nano
102                                    } else {
103                                        task.scheduled_at + task.scheduling_options.delay_nano
104                                    }
105                                };
106
107                                task.rescheduled_at = Some(new_rescheduled_at);
108
109                                self.queue.push(TaskTimestamp {
110                                    task_id,
111                                    timestamp: new_rescheduled_at
112                                        + task.scheduling_options.interval_nano,
113                                });
114
115                                task.scheduling_options.iterations =
116                                    Iterations::Exact(times_left - 1);
117                            } else {
118                                should_remove = true;
119                            }
120                        }
121                    };
122
123                    tasks.push(task.clone());
124                }
125                Entry::Vacant(_) => {}
126            }
127
128            if should_remove {
129                self.tasks.remove(&task_id);
130            }
131        }
132
133        tasks
134    }
135
136    pub fn dequeue(&mut self, task_id: TaskId) -> Option<ScheduledTask> {
137        self.tasks.remove(&task_id)
138    }
139
140    pub fn is_empty(&self) -> bool {
141        self.queue.is_empty()
142    }
143
144    pub fn get_task(&self, task_id: &TaskId) -> Option<&ScheduledTask> {
145        self.tasks.get(task_id)
146    }
147
148    pub fn get_task_mut(&mut self, task_id: &TaskId) -> Option<&mut ScheduledTask> {
149        self.tasks.get_mut(task_id)
150    }
151
152    pub fn get_task_by_id_cloned(&self, task_id: &TaskId) -> Option<ScheduledTask> {
153        self.get_task(task_id).cloned()
154    }
155
156    pub fn get_tasks_cloned(&self) -> Vec<ScheduledTask> {
157        self.tasks.values().cloned().collect()
158    }
159
160    fn generate_task_id(&mut self) -> TaskId {
161        let res = self.task_id_counter;
162        self.task_id_counter += 1;
163
164        res
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use ic_cdk::export::candid::{decode_one, encode_one};
171    use ic_cdk::export::candid::{CandidType, Deserialize};
172
173    use crate::task_scheduler::TaskScheduler;
174    use crate::types::{Iterations, SchedulingOptions};
175
176    #[derive(CandidType, Deserialize)]
177    pub struct TestPayload {
178        pub a: bool,
179    }
180
181    #[test]
182    fn main_flow_works_fine() {
183        let mut scheduler = TaskScheduler::default();
184
185        let task_id_1 = scheduler
186            .enqueue(
187                TestPayload { a: true },
188                SchedulingOptions {
189                    delay_nano: 10,
190                    interval_nano: 10,
191                    iterations: Iterations::Exact(1),
192                },
193                0,
194            )
195            .ok()
196            .unwrap();
197
198        let task_id_2 = scheduler
199            .enqueue(
200                TestPayload { a: true },
201                SchedulingOptions {
202                    delay_nano: 10,
203                    interval_nano: 10,
204                    iterations: Iterations::Infinite,
205                },
206                0,
207            )
208            .ok()
209            .unwrap();
210
211        let task_id_3 = scheduler
212            .enqueue(
213                TestPayload { a: false },
214                SchedulingOptions {
215                    delay_nano: 20,
216                    interval_nano: 20,
217                    iterations: Iterations::Exact(2),
218                },
219                0,
220            )
221            .ok()
222            .unwrap();
223
224        assert!(!scheduler.is_empty(), "Scheduler is not empty");
225
226        let tasks_emp = scheduler.iterate(5);
227        assert!(
228            tasks_emp.is_empty(),
229            "There should not be any tasks at timestamp 5"
230        );
231
232        let tasks_1_2 = scheduler.iterate(10);
233        assert_eq!(
234            tasks_1_2.len(),
235            2,
236            "At timestamp 10 there should be 2 tasks"
237        );
238        assert!(
239            tasks_1_2.iter().any(|t| t.id == task_id_1),
240            "Should contain task 1"
241        );
242        assert!(
243            tasks_1_2.iter().any(|t| t.id == task_id_2),
244            "Should contain task 2"
245        );
246
247        let tasks_emp = scheduler.iterate(15);
248        assert!(
249            tasks_emp.is_empty(),
250            "There should not be any tasks at timestamp 15"
251        );
252
253        let tasks_2_3 = scheduler.iterate(20);
254        assert_eq!(
255            tasks_2_3.len(),
256            2,
257            "At timestamp 20 there should be 2 tasks"
258        );
259        assert!(
260            tasks_2_3.iter().any(|t| t.id == task_id_2),
261            "Should contain task 2"
262        );
263        assert!(
264            tasks_2_3.iter().any(|t| t.id == task_id_3),
265            "Should contain task 3"
266        );
267
268        let tasks_2 = scheduler.iterate(30);
269        assert_eq!(
270            tasks_2.len(),
271            1,
272            "There should be a single task at timestamp 30"
273        );
274        assert_eq!(tasks_2[0].id, task_id_2, "Should contain task 2");
275
276        let tasks_2_3 = scheduler.iterate(42);
277        assert_eq!(
278            tasks_2_3.len(),
279            2,
280            "At timestamp 40 there should be 2 tasks"
281        );
282        assert!(
283            tasks_2_3.iter().any(|t| t.id == task_id_2),
284            "Should contain task 2"
285        );
286        assert!(
287            tasks_2_3.iter().any(|t| t.id == task_id_3),
288            "Should contain task 3"
289        );
290
291        let tasks_2 = scheduler.iterate(55);
292        assert_eq!(
293            tasks_2.len(),
294            1,
295            "There should be a single task at timestamp 60"
296        );
297        assert_eq!(tasks_2[0].id, task_id_2, "Should contain task 2");
298
299        let tasks_2 = scheduler.iterate(60);
300        assert_eq!(
301            tasks_2.len(),
302            1,
303            "There should be a single task at timestamp 60"
304        );
305        assert_eq!(tasks_2[0].id, task_id_2, "Should contain task 2");
306
307        scheduler.dequeue(task_id_2).unwrap();
308
309        scheduler
310            .enqueue(
311                TestPayload { a: true },
312                SchedulingOptions {
313                    delay_nano: 10,
314                    interval_nano: 10,
315                    iterations: Iterations::Exact(1),
316                },
317                0,
318            )
319            .ok()
320            .unwrap();
321    }
322
323    #[test]
324    fn delay_works_fine() {
325        let mut scheduler = TaskScheduler::default();
326
327        let task_id_1 = scheduler
328            .enqueue(
329                TestPayload { a: true },
330                SchedulingOptions {
331                    delay_nano: 10,
332                    interval_nano: 20,
333                    iterations: Iterations::Infinite,
334                },
335                0,
336            )
337            .ok()
338            .unwrap();
339
340        let tasks = scheduler.iterate(5);
341
342        assert!(
343            tasks.is_empty(),
344            "There shouldn't be any task at this timestamp (5)"
345        );
346
347        let tasks = scheduler.iterate(10);
348        assert_eq!(
349            tasks.len(),
350            1,
351            "There should be a task that was triggered by a delay at this timestamp (10)"
352        );
353
354        let tasks = scheduler.iterate(20);
355        assert!(
356            tasks.is_empty(),
357            "There shouldn't be any task at this timestamp (20)"
358        );
359
360        let tasks = scheduler.iterate(30);
361        assert_eq!(
362            tasks.len(),
363            1,
364            "There should be a task that was triggered by an interval at this timestamp (30)"
365        );
366
367        let tasks = scheduler.iterate(50);
368        assert_eq!(
369            tasks.len(),
370            1,
371            "There should be a task that was triggered by an interval at this timestamp (50)"
372        );
373    }
374
375    #[test]
376    fn ser_de_works_fine() {
377        let mut scheduler = TaskScheduler::default();
378
379        scheduler
380            .enqueue(
381                TestPayload { a: true },
382                SchedulingOptions {
383                    delay_nano: 10,
384                    interval_nano: 20,
385                    iterations: Iterations::Infinite,
386                },
387                0,
388            )
389            .ok()
390            .unwrap();
391
392        let bytes = encode_one(scheduler).expect("Should be able to encode task scheduler");
393        let mut scheduler: TaskScheduler =
394            decode_one(&bytes).expect("Should be able to decode task scheduler");
395
396        let tasks = scheduler.iterate(10);
397
398        assert_eq!(
399            tasks.len(),
400            1,
401            "There should be a task that was triggered by a delay at this timestamp (10)"
402        );
403    }
404}