1use std::cmp::{max, min, Ordering};
2use std::collections::BinaryHeap;
3
4use ic_cdk::export::candid::types::{Serializer, Type};
5use ic_cdk::export::candid::{
6 decode_one, encode_one, CandidType, Deserialize, Result as CandidResult,
7};
8
9pub type TaskId = u64;
10
11#[derive(Clone, CandidType, Deserialize)]
12pub struct Task {
13 pub data: Vec<u8>,
14}
15
16#[derive(Clone, Copy, CandidType, Deserialize)]
17pub enum Iterations {
18 Infinite,
19 Exact(u64),
20}
21
22#[derive(Clone, Copy, CandidType, Deserialize)]
23pub struct SchedulingOptions {
24 pub delay_nano: u64,
25 pub interval_nano: u64,
26 pub iterations: Iterations,
27}
28
29#[derive(Clone, CandidType, Deserialize)]
30pub struct ScheduledTask {
31 pub id: TaskId,
32 pub payload: Task,
33 pub scheduled_at: u64,
34 pub rescheduled_at: Option<u64>,
35 pub scheduling_options: SchedulingOptions,
36 pub delay_passed: bool,
37}
38
39impl ScheduledTask {
40 pub fn new<TaskPayload: CandidType>(
41 id: TaskId,
42 payload: TaskPayload,
43 scheduled_at: u64,
44 rescheduled_at: Option<u64>,
45 scheduling_interval: SchedulingOptions,
46 ) -> CandidResult<Self> {
47 let task = Task {
48 data: encode_one(payload).unwrap(),
49 };
50
51 Ok(Self {
52 id,
53 payload: task,
54 scheduled_at,
55 rescheduled_at,
56 scheduling_options: scheduling_interval,
57 delay_passed: false,
58 })
59 }
60
61 pub fn get_payload<'a, T>(&'a self) -> CandidResult<T>
62 where
63 T: Deserialize<'a> + CandidType,
64 {
65 decode_one(&self.payload.data)
66 }
67
68 pub fn set_payload<T: CandidType>(&mut self, payload: T) {
69 self.payload.data = encode_one(payload).unwrap()
70 }
71}
72
73#[derive(CandidType, Deserialize, Clone, Copy)]
74pub struct TaskTimestamp {
75 pub task_id: TaskId,
76 pub timestamp: u64,
77}
78
79impl PartialEq for TaskTimestamp {
80 fn eq(&self, other: &Self) -> bool {
81 self.timestamp.eq(&other.timestamp) && self.task_id.eq(&other.task_id)
82 }
83}
84
85impl Eq for TaskTimestamp {}
86
87impl PartialOrd for TaskTimestamp {
88 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
89 self.timestamp
90 .partial_cmp(&other.timestamp)
91 .map(|it| it.reverse())
92 }
93
94 fn lt(&self, other: &Self) -> bool {
95 self.timestamp.gt(&other.timestamp)
96 }
97
98 fn le(&self, other: &Self) -> bool {
99 self.timestamp.ge(&other.timestamp)
100 }
101
102 fn gt(&self, other: &Self) -> bool {
103 self.timestamp.lt(&other.timestamp)
104 }
105
106 fn ge(&self, other: &Self) -> bool {
107 self.timestamp.le(&other.timestamp)
108 }
109}
110
111impl Ord for TaskTimestamp {
112 fn cmp(&self, other: &Self) -> Ordering {
113 self.timestamp.cmp(&other.timestamp).reverse()
114 }
115
116 fn max(self, other: Self) -> Self
117 where
118 Self: Sized,
119 {
120 max(self, other)
121 }
122
123 fn min(self, other: Self) -> Self
124 where
125 Self: Sized,
126 {
127 min(self, other)
128 }
129
130 fn clamp(self, min: Self, max: Self) -> Self
131 where
132 Self: Sized,
133 {
134 if self.timestamp < max.timestamp {
135 max
136 } else if self.timestamp > min.timestamp {
137 min
138 } else {
139 self
140 }
141 }
142}
143
144#[derive(Default, Deserialize, Clone)]
145pub struct TaskExecutionQueue(BinaryHeap<TaskTimestamp>);
146
147impl TaskExecutionQueue {
148 #[inline(always)]
149 pub fn push(&mut self, task: TaskTimestamp) {
150 self.0.push(task);
151 }
152
153 pub fn pop_ready(&mut self, timestamp: u64) -> Vec<TaskTimestamp> {
154 let mut cur = self.0.peek();
155 if cur.is_none() {
156 return Vec::new();
157 }
158
159 let mut result = vec![];
160
161 while cur.unwrap().timestamp <= timestamp {
162 result.push(self.0.pop().unwrap());
163
164 cur = self.0.peek();
165 if cur.is_none() {
166 break;
167 }
168 }
169
170 result
171 }
172
173 #[inline(always)]
174 pub fn is_empty(&self) -> bool {
175 self.0.is_empty()
176 }
177
178 #[inline(always)]
179 pub fn len(&self) -> usize {
180 self.0.len()
181 }
182}
183
184impl CandidType for TaskExecutionQueue {
185 fn _ty() -> Type {
186 Type::Vec(Box::new(TaskTimestamp::_ty()))
187 }
188
189 fn ty() -> Type {
190 Self::_ty()
191 }
192
193 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
194 where
195 S: Serializer,
196 {
197 self.clone().0.into_sorted_vec().idl_serialize(serializer)
198 }
199}