1use candid::{CandidType, Deserialize};
7use ic_cdk_timers::{clear_timer, set_timer, set_timer_interval, TimerId};
8use serde::Serialize;
9use std::cell::RefCell;
10use std::collections::BTreeMap;
11use std::sync::Arc;
12
13thread_local! {
14 static TIMER_REGISTRY: RefCell<TimerRegistry> = RefCell::new(TimerRegistry::new());
16}
17
18const MAX_TIMERS: usize = 100;
20
21#[derive(Default)]
23struct TimerRegistry {
24 timers: BTreeMap<TimerId, TimerInfo>,
25 next_id: u64,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, CandidType)]
30pub struct TimerInfo {
31 pub id: u64,
32 pub name: String,
33 pub timer_type: TimerType,
34 pub created_at: u64,
35 pub interval_secs: Option<u64>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, CandidType)]
40pub enum TimerType {
41 Once,
42 Periodic,
43}
44
45impl TimerRegistry {
46 fn new() -> Self {
47 Self {
48 timers: BTreeMap::new(),
49 next_id: 1,
50 }
51 }
52
53 fn add_timer(&mut self, timer_id: TimerId, info: TimerInfo) -> Result<(), TimerError> {
54 if self.timers.len() >= MAX_TIMERS {
55 return Err(TimerError::TooManyTimers {
56 max: MAX_TIMERS,
57 current: self.timers.len(),
58 });
59 }
60 self.timers.insert(timer_id, info);
61 self.next_id += 1;
62 Ok(())
63 }
64
65 fn remove_timer(&mut self, timer_id: TimerId) -> Option<TimerInfo> {
66 self.timers.remove(&timer_id)
67 }
68
69 fn list_timers(&self) -> Vec<TimerInfo> {
70 self.timers.values().cloned().collect()
71 }
72
73 fn clear_all(&mut self) {
74 for timer_id in self.timers.keys().copied().collect::<Vec<_>>() {
75 clear_timer(timer_id);
76 }
77 self.timers.clear();
78 }
79}
80
81#[derive(Debug, thiserror::Error)]
83pub enum TimerError {
84 #[error("Too many timers: {current}/{max}")]
85 TooManyTimers { max: usize, current: usize },
86
87 #[error("Timer not found: {0}")]
88 TimerNotFound(String),
89
90 #[error("Invalid interval: {0} seconds")]
91 InvalidInterval(u64),
92}
93
94pub fn schedule_once<F>(delay_secs: u64, name: &str, task: F) -> Result<TimerId, TimerError>
106where
107 F: FnOnce() + 'static,
108{
109 if delay_secs == 0 {
110 return Err(TimerError::InvalidInterval(delay_secs));
111 }
112
113 let timer_id = set_timer(std::time::Duration::from_secs(delay_secs), move || {
114 task();
116
117 });
120
121 let info = TimerInfo {
123 id: TIMER_REGISTRY.with(|r| r.borrow().next_id),
124 name: name.to_string(),
125 timer_type: TimerType::Once,
126 created_at: ic_cdk::api::time(),
127 interval_secs: Some(delay_secs),
128 };
129
130 TIMER_REGISTRY.with(|r| r.borrow_mut().add_timer(timer_id, info))?;
131
132 Ok(timer_id)
133}
134
135pub fn schedule_periodic<F>(interval_secs: u64, name: &str, task: F) -> Result<TimerId, TimerError>
147where
148 F: Fn() + 'static,
149{
150 if interval_secs == 0 {
151 return Err(TimerError::InvalidInterval(interval_secs));
152 }
153
154 let task = Arc::new(task);
156
157 let timer_id = set_timer_interval(std::time::Duration::from_secs(interval_secs), move || {
158 task();
159 });
160
161 let info = TimerInfo {
163 id: TIMER_REGISTRY.with(|r| r.borrow().next_id),
164 name: name.to_string(),
165 timer_type: TimerType::Periodic,
166 created_at: ic_cdk::api::time(),
167 interval_secs: Some(interval_secs),
168 };
169
170 TIMER_REGISTRY.with(|r| r.borrow_mut().add_timer(timer_id, info))?;
171
172 Ok(timer_id)
173}
174
175pub fn cancel_timer(timer_id: TimerId) -> Result<(), TimerError> {
186 TIMER_REGISTRY.with(|r| {
187 if r.borrow_mut().remove_timer(timer_id).is_some() {
188 clear_timer(timer_id);
189 Ok(())
190 } else {
191 Err(TimerError::TimerNotFound(format!("{:?}", timer_id)))
192 }
193 })
194}
195
196pub fn list_active_timers() -> Vec<TimerInfo> {
208 TIMER_REGISTRY.with(|r| r.borrow().list_timers())
209}
210
211pub fn cancel_all_timers() {
215 TIMER_REGISTRY.with(|r| r.borrow_mut().clear_all());
216}
217
218pub fn active_timer_count() -> usize {
220 TIMER_REGISTRY.with(|r| r.borrow().timers.len())
221}
222
223pub fn schedule_with_backoff<F>(
243 initial_delay_secs: u64,
244 max_retries: u32,
245 backoff_multiplier: f64,
246 name: &str,
247 task: F,
248) -> Result<TimerId, TimerError>
249where
250 F: FnMut(u32) -> bool + 'static,
251{
252 struct BackoffState<G: FnMut(u32) -> bool> {
255 task: G,
256 attempt: u32,
257 max_retries: u32,
258 current_delay: u64,
259 backoff_multiplier: f64,
260 name: String,
261 }
262
263 let state = Arc::new(RefCell::new(BackoffState {
264 task,
265 attempt: 0,
266 max_retries,
267 current_delay: initial_delay_secs,
268 backoff_multiplier,
269 name: name.to_string(),
270 }));
271
272 let state_clone = state.clone();
273
274 schedule_once(initial_delay_secs, name, move || {
275 let mut state = state_clone.borrow_mut();
276 state.attempt += 1;
277
278 let attempt = state.attempt;
280 let should_stop = (state.task)(attempt);
281
282 if !should_stop && state.attempt < state.max_retries {
283 state.current_delay = (state.current_delay as f64 * state.backoff_multiplier) as u64;
285
286 ic_cdk::print(format!(
289 "Would reschedule {} with delay {} seconds",
290 state.name, state.current_delay
291 ));
292 }
293 })
294}
295
296#[macro_export]
298macro_rules! timer_once {
299 ($delay:expr, $name:expr, $body:expr) => {
300 $crate::timers::schedule_once($delay, $name, || $body)
301 };
302}
303
304#[macro_export]
306macro_rules! timer_periodic {
307 ($interval:expr, $name:expr, $body:expr) => {
308 $crate::timers::schedule_periodic($interval, $name, || $body)
309 };
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_timer_info_creation() {
318 let info = TimerInfo {
319 id: 1,
320 name: "test_timer".to_string(),
321 timer_type: TimerType::Once,
322 created_at: 0,
323 interval_secs: Some(60),
324 };
325
326 assert_eq!(info.id, 1);
327 assert_eq!(info.name, "test_timer");
328 matches!(info.timer_type, TimerType::Once);
329 assert_eq!(info.interval_secs, Some(60));
330 }
331
332 #[test]
333 fn test_timer_registry() {
334 let registry = TimerRegistry::new();
335 assert_eq!(registry.timers.len(), 0);
336 assert_eq!(registry.next_id, 1);
337
338 let _info = TimerInfo {
340 id: 1,
341 name: "test".to_string(),
342 timer_type: TimerType::Periodic,
343 created_at: 0,
344 interval_secs: Some(30),
345 };
346
347 }
350
351 #[test]
352 fn test_max_timers_limit() {
353 let registry = TimerRegistry::new();
354
355 for i in 0..MAX_TIMERS {
357 let _info = TimerInfo {
358 id: i as u64,
359 name: format!("timer_{}", i),
360 timer_type: TimerType::Once,
361 created_at: 0,
362 interval_secs: Some(60),
363 };
364 }
367
368 assert!(registry.timers.len() <= MAX_TIMERS);
369 }
370}