1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3
4use ic_cdk::export::candid::{CandidType, Deserialize, Result as CandidResult};
5
6use crate::types::{
7 Iterations, ScheduledTask, SchedulingOptions, TaskExecutionQueue, TaskId, TaskTimestamp,
8};
9
10#[derive(Default, CandidType, Deserialize, Clone)]
11pub struct TaskScheduler {
12 pub tasks: HashMap<TaskId, ScheduledTask>,
13 pub task_id_counter: TaskId,
14
15 pub queue: TaskExecutionQueue,
16}
17
18impl TaskScheduler {
19 pub fn enqueue<TaskPayload: CandidType>(
20 &mut self,
21 payload: TaskPayload,
22 scheduling_interval: SchedulingOptions,
23 timestamp: u64,
24 ) -> CandidResult<TaskId> {
25 let id = self.generate_task_id();
26 let task = ScheduledTask::new(id, payload, timestamp, None, scheduling_interval)?;
27
28 match task.scheduling_options.iterations {
29 Iterations::Exact(times) => {
30 if times > 0 {
31 self.queue.push(TaskTimestamp {
32 task_id: id,
33 timestamp: timestamp + task.scheduling_options.delay_nano,
34 })
35 }
36 }
37 Iterations::Infinite => self.queue.push(TaskTimestamp {
38 task_id: id,
39 timestamp: timestamp + task.scheduling_options.delay_nano,
40 }),
41 };
42
43 self.tasks.insert(id, task);
44
45 Ok(id)
46 }
47
48 pub fn iterate(&mut self, timestamp: u64) -> Vec<ScheduledTask> {
49 let mut tasks = vec![];
50
51 for task_id in self
52 .queue
53 .pop_ready(timestamp)
54 .into_iter()
55 .map(|it| it.task_id)
56 {
57 let mut should_remove = false;
58
59 match self.tasks.entry(task_id) {
60 Entry::Occupied(mut entry) => {
61 let task = entry.get_mut();
62
63 match task.scheduling_options.iterations {
64 Iterations::Infinite => {
65 let new_rescheduled_at = if task.delay_passed {
66 if let Some(rescheduled_at) = task.rescheduled_at {
67 rescheduled_at + task.scheduling_options.interval_nano
68 } else {
69 task.scheduled_at + task.scheduling_options.interval_nano
70 }
71 } else {
72 task.delay_passed = true;
73
74 if let Some(rescheduled_at) = task.rescheduled_at {
75 rescheduled_at + task.scheduling_options.delay_nano
76 } else {
77 task.scheduled_at + task.scheduling_options.delay_nano
78 }
79 };
80
81 task.rescheduled_at = Some(new_rescheduled_at);
82
83 self.queue.push(TaskTimestamp {
84 task_id,
85 timestamp: new_rescheduled_at
86 + task.scheduling_options.interval_nano,
87 });
88 }
89 Iterations::Exact(times_left) => {
90 if times_left > 1 {
91 let new_rescheduled_at = if task.delay_passed {
92 if let Some(rescheduled_at) = task.rescheduled_at {
93 rescheduled_at + task.scheduling_options.interval_nano
94 } else {
95 task.scheduled_at + task.scheduling_options.interval_nano
96 }
97 } else {
98 task.delay_passed = true;
99
100 if let Some(rescheduled_at) = task.rescheduled_at {
101 rescheduled_at + task.scheduling_options.delay_nano
102 } else {
103 task.scheduled_at + task.scheduling_options.delay_nano
104 }
105 };
106
107 task.rescheduled_at = Some(new_rescheduled_at);
108
109 self.queue.push(TaskTimestamp {
110 task_id,
111 timestamp: new_rescheduled_at
112 + task.scheduling_options.interval_nano,
113 });
114
115 task.scheduling_options.iterations =
116 Iterations::Exact(times_left - 1);
117 } else {
118 should_remove = true;
119 }
120 }
121 };
122
123 tasks.push(task.clone());
124 }
125 Entry::Vacant(_) => {}
126 }
127
128 if should_remove {
129 self.tasks.remove(&task_id);
130 }
131 }
132
133 tasks
134 }
135
136 pub fn dequeue(&mut self, task_id: TaskId) -> Option<ScheduledTask> {
137 self.tasks.remove(&task_id)
138 }
139
140 pub fn is_empty(&self) -> bool {
141 self.queue.is_empty()
142 }
143
144 pub fn get_task(&self, task_id: &TaskId) -> Option<&ScheduledTask> {
145 self.tasks.get(task_id)
146 }
147
148 pub fn get_task_mut(&mut self, task_id: &TaskId) -> Option<&mut ScheduledTask> {
149 self.tasks.get_mut(task_id)
150 }
151
152 pub fn get_task_by_id_cloned(&self, task_id: &TaskId) -> Option<ScheduledTask> {
153 self.get_task(task_id).cloned()
154 }
155
156 pub fn get_tasks_cloned(&self) -> Vec<ScheduledTask> {
157 self.tasks.values().cloned().collect()
158 }
159
160 fn generate_task_id(&mut self) -> TaskId {
161 let res = self.task_id_counter;
162 self.task_id_counter += 1;
163
164 res
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use ic_cdk::export::candid::{decode_one, encode_one};
171 use ic_cdk::export::candid::{CandidType, Deserialize};
172
173 use crate::task_scheduler::TaskScheduler;
174 use crate::types::{Iterations, SchedulingOptions};
175
176 #[derive(CandidType, Deserialize)]
177 pub struct TestPayload {
178 pub a: bool,
179 }
180
181 #[test]
182 fn main_flow_works_fine() {
183 let mut scheduler = TaskScheduler::default();
184
185 let task_id_1 = scheduler
186 .enqueue(
187 TestPayload { a: true },
188 SchedulingOptions {
189 delay_nano: 10,
190 interval_nano: 10,
191 iterations: Iterations::Exact(1),
192 },
193 0,
194 )
195 .ok()
196 .unwrap();
197
198 let task_id_2 = scheduler
199 .enqueue(
200 TestPayload { a: true },
201 SchedulingOptions {
202 delay_nano: 10,
203 interval_nano: 10,
204 iterations: Iterations::Infinite,
205 },
206 0,
207 )
208 .ok()
209 .unwrap();
210
211 let task_id_3 = scheduler
212 .enqueue(
213 TestPayload { a: false },
214 SchedulingOptions {
215 delay_nano: 20,
216 interval_nano: 20,
217 iterations: Iterations::Exact(2),
218 },
219 0,
220 )
221 .ok()
222 .unwrap();
223
224 assert!(!scheduler.is_empty(), "Scheduler is not empty");
225
226 let tasks_emp = scheduler.iterate(5);
227 assert!(
228 tasks_emp.is_empty(),
229 "There should not be any tasks at timestamp 5"
230 );
231
232 let tasks_1_2 = scheduler.iterate(10);
233 assert_eq!(
234 tasks_1_2.len(),
235 2,
236 "At timestamp 10 there should be 2 tasks"
237 );
238 assert!(
239 tasks_1_2.iter().any(|t| t.id == task_id_1),
240 "Should contain task 1"
241 );
242 assert!(
243 tasks_1_2.iter().any(|t| t.id == task_id_2),
244 "Should contain task 2"
245 );
246
247 let tasks_emp = scheduler.iterate(15);
248 assert!(
249 tasks_emp.is_empty(),
250 "There should not be any tasks at timestamp 15"
251 );
252
253 let tasks_2_3 = scheduler.iterate(20);
254 assert_eq!(
255 tasks_2_3.len(),
256 2,
257 "At timestamp 20 there should be 2 tasks"
258 );
259 assert!(
260 tasks_2_3.iter().any(|t| t.id == task_id_2),
261 "Should contain task 2"
262 );
263 assert!(
264 tasks_2_3.iter().any(|t| t.id == task_id_3),
265 "Should contain task 3"
266 );
267
268 let tasks_2 = scheduler.iterate(30);
269 assert_eq!(
270 tasks_2.len(),
271 1,
272 "There should be a single task at timestamp 30"
273 );
274 assert_eq!(tasks_2[0].id, task_id_2, "Should contain task 2");
275
276 let tasks_2_3 = scheduler.iterate(42);
277 assert_eq!(
278 tasks_2_3.len(),
279 2,
280 "At timestamp 40 there should be 2 tasks"
281 );
282 assert!(
283 tasks_2_3.iter().any(|t| t.id == task_id_2),
284 "Should contain task 2"
285 );
286 assert!(
287 tasks_2_3.iter().any(|t| t.id == task_id_3),
288 "Should contain task 3"
289 );
290
291 let tasks_2 = scheduler.iterate(55);
292 assert_eq!(
293 tasks_2.len(),
294 1,
295 "There should be a single task at timestamp 60"
296 );
297 assert_eq!(tasks_2[0].id, task_id_2, "Should contain task 2");
298
299 let tasks_2 = scheduler.iterate(60);
300 assert_eq!(
301 tasks_2.len(),
302 1,
303 "There should be a single task at timestamp 60"
304 );
305 assert_eq!(tasks_2[0].id, task_id_2, "Should contain task 2");
306
307 scheduler.dequeue(task_id_2).unwrap();
308
309 scheduler
310 .enqueue(
311 TestPayload { a: true },
312 SchedulingOptions {
313 delay_nano: 10,
314 interval_nano: 10,
315 iterations: Iterations::Exact(1),
316 },
317 0,
318 )
319 .ok()
320 .unwrap();
321 }
322
323 #[test]
324 fn delay_works_fine() {
325 let mut scheduler = TaskScheduler::default();
326
327 let task_id_1 = scheduler
328 .enqueue(
329 TestPayload { a: true },
330 SchedulingOptions {
331 delay_nano: 10,
332 interval_nano: 20,
333 iterations: Iterations::Infinite,
334 },
335 0,
336 )
337 .ok()
338 .unwrap();
339
340 let tasks = scheduler.iterate(5);
341
342 assert!(
343 tasks.is_empty(),
344 "There shouldn't be any task at this timestamp (5)"
345 );
346
347 let tasks = scheduler.iterate(10);
348 assert_eq!(
349 tasks.len(),
350 1,
351 "There should be a task that was triggered by a delay at this timestamp (10)"
352 );
353
354 let tasks = scheduler.iterate(20);
355 assert!(
356 tasks.is_empty(),
357 "There shouldn't be any task at this timestamp (20)"
358 );
359
360 let tasks = scheduler.iterate(30);
361 assert_eq!(
362 tasks.len(),
363 1,
364 "There should be a task that was triggered by an interval at this timestamp (30)"
365 );
366
367 let tasks = scheduler.iterate(50);
368 assert_eq!(
369 tasks.len(),
370 1,
371 "There should be a task that was triggered by an interval at this timestamp (50)"
372 );
373 }
374
375 #[test]
376 fn ser_de_works_fine() {
377 let mut scheduler = TaskScheduler::default();
378
379 scheduler
380 .enqueue(
381 TestPayload { a: true },
382 SchedulingOptions {
383 delay_nano: 10,
384 interval_nano: 20,
385 iterations: Iterations::Infinite,
386 },
387 0,
388 )
389 .ok()
390 .unwrap();
391
392 let bytes = encode_one(scheduler).expect("Should be able to encode task scheduler");
393 let mut scheduler: TaskScheduler =
394 decode_one(&bytes).expect("Should be able to decode task scheduler");
395
396 let tasks = scheduler.iterate(10);
397
398 assert_eq!(
399 tasks.len(),
400 1,
401 "There should be a task that was triggered by a delay at this timestamp (10)"
402 );
403 }
404}