Skip to main content

timer_lib/
registry.rs

1use std::collections::HashMap;
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9
10use crate::errors::TimerError;
11use crate::timer::{RecurringSchedule, Timer, TimerCallback, TimerOutcome, TimerState};
12
13/// A registry for tracking timers by identifier.
14#[derive(Clone, Default)]
15pub struct TimerRegistry {
16    timers: Arc<RwLock<HashMap<u64, Timer>>>,
17    next_id: Arc<AtomicU64>,
18}
19
20impl TimerRegistry {
21    /// Creates a new timer registry.
22    pub fn new() -> Self {
23        Self {
24            timers: Arc::new(RwLock::new(HashMap::new())),
25            next_id: Arc::new(AtomicU64::new(0)),
26        }
27    }
28
29    /// Inserts an existing timer and returns its identifier.
30    pub async fn insert(&self, timer: Timer) -> u64 {
31        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
32        self.timers.write().await.insert(id, timer);
33        id
34    }
35
36    /// Starts and registers a one-time timer.
37    pub async fn start_once<F>(
38        &self,
39        delay: Duration,
40        callback: F,
41    ) -> Result<(u64, Timer), TimerError>
42    where
43        F: TimerCallback + 'static,
44    {
45        let timer = Timer::new();
46        let _ = timer.start_once(delay, callback).await?;
47        let id = self.insert(timer.clone()).await;
48        Ok((id, timer))
49    }
50
51    /// Starts and registers a recurring timer.
52    pub async fn start_recurring<F>(
53        &self,
54        schedule: RecurringSchedule,
55        callback: F,
56    ) -> Result<(u64, Timer), TimerError>
57    where
58        F: TimerCallback + 'static,
59    {
60        let timer = Timer::new();
61        let _ = timer.start_recurring(schedule, callback).await?;
62        let id = self.insert(timer.clone()).await;
63        Ok((id, timer))
64    }
65
66    /// Removes a timer from the registry and returns it.
67    pub async fn remove(&self, id: u64) -> Option<Timer> {
68        self.timers.write().await.remove(&id)
69    }
70
71    /// Returns true when the registry tracks the given timer identifier.
72    pub async fn contains(&self, id: u64) -> bool {
73        self.timers.read().await.contains_key(&id)
74    }
75
76    /// Stops a timer by identifier when it exists.
77    pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
78        let timer = self.get(id).await;
79        match timer {
80            Some(timer) => timer.stop().await.map(Some),
81            None => Ok(None),
82        }
83    }
84
85    /// Cancels a timer by identifier when it exists.
86    pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
87        let timer = self.get(id).await;
88        match timer {
89            Some(timer) => timer.cancel().await.map(Some),
90            None => Ok(None),
91        }
92    }
93
94    /// Pauses a timer by identifier when it exists.
95    pub async fn pause(&self, id: u64) -> Result<bool, TimerError> {
96        let timer = self.get(id).await;
97        match timer {
98            Some(timer) => {
99                timer.pause().await?;
100                Ok(true)
101            }
102            None => Ok(false),
103        }
104    }
105
106    /// Resumes a timer by identifier when it exists.
107    pub async fn resume(&self, id: u64) -> Result<bool, TimerError> {
108        let timer = self.get(id).await;
109        match timer {
110            Some(timer) => {
111                timer.resume().await?;
112                Ok(true)
113            }
114            None => Ok(false),
115        }
116    }
117
118    /// Stops all timers currently tracked by the registry.
119    pub async fn stop_all(&self) {
120        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
121        for timer in timers {
122            let _ = timer.stop().await;
123        }
124    }
125
126    /// Pauses all running timers currently tracked by the registry.
127    pub async fn pause_all(&self) {
128        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
129        for timer in timers {
130            let _ = timer.pause().await;
131        }
132    }
133
134    /// Waits for all tracked timers that have a joinable outcome.
135    pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
136        let timers: Vec<(u64, Timer)> = self
137            .timers
138            .read()
139            .await
140            .iter()
141            .map(|(id, timer)| (*id, timer.clone()))
142            .collect();
143
144        let mut outcomes = Vec::with_capacity(timers.len());
145        for (id, timer) in timers {
146            if let Ok(outcome) = timer.join().await {
147                outcomes.push((id, outcome));
148            }
149        }
150
151        outcomes
152    }
153
154    /// Cancels all timers currently tracked by the registry.
155    pub async fn cancel_all(&self) {
156        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
157        for timer in timers {
158            let _ = timer.cancel().await;
159        }
160    }
161
162    /// Resumes all paused timers currently tracked by the registry.
163    pub async fn resume_all(&self) {
164        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
165        for timer in timers {
166            let _ = timer.resume().await;
167        }
168    }
169
170    /// Lists all active timers.
171    pub async fn active_ids(&self) -> Vec<u64> {
172        let timers: Vec<(u64, Timer)> = self
173            .timers
174            .read()
175            .await
176            .iter()
177            .map(|(id, timer)| (*id, timer.clone()))
178            .collect();
179
180        let mut active = Vec::new();
181        for (id, timer) in timers {
182            if timer.get_state().await != TimerState::Stopped {
183                active.push(id);
184            }
185        }
186        active
187    }
188
189    /// Retrieves a timer by ID.
190    pub async fn get(&self, id: u64) -> Option<Timer> {
191        self.timers.read().await.get(&id).cloned()
192    }
193
194    /// Returns the number of tracked timers.
195    pub async fn len(&self) -> usize {
196        self.timers.read().await.len()
197    }
198
199    /// Returns true when the registry is empty.
200    pub async fn is_empty(&self) -> bool {
201        self.len().await == 0
202    }
203
204    /// Removes all tracked timers and returns the number removed.
205    pub async fn clear(&self) -> usize {
206        let mut timers = self.timers.write().await;
207        let removed = timers.len();
208        timers.clear();
209        removed
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::timer::TimerFinishReason;
217    use tokio::task::yield_now;
218    use tokio::time::advance;
219
220    async fn settle() {
221        for _ in 0..5 {
222            yield_now().await;
223        }
224    }
225
226    #[tokio::test(flavor = "current_thread", start_paused = true)]
227    async fn registry_start_helpers_are_easy_to_use() {
228        let registry = TimerRegistry::new();
229        let (once_id, once_timer) = registry
230            .start_once(Duration::from_secs(1), || async { Ok(()) })
231            .await
232            .unwrap();
233        let (recurring_id, recurring_timer) = registry
234            .start_recurring(RecurringSchedule::new(Duration::from_secs(2)), || async {
235                Ok(())
236            })
237            .await
238            .unwrap();
239
240        assert_ne!(once_id, recurring_id);
241        assert_eq!(registry.len().await, 2);
242        assert!(registry.get(once_id).await.is_some());
243
244        advance(Duration::from_secs(1)).await;
245        settle().await;
246        assert_eq!(
247            once_timer.join().await.unwrap().reason,
248            crate::timer::TimerFinishReason::Completed
249        );
250
251        let active = registry.active_ids().await;
252        assert!(active.contains(&recurring_id));
253
254        let _ = recurring_timer.cancel().await.unwrap();
255    }
256
257    #[tokio::test(flavor = "current_thread", start_paused = true)]
258    async fn registry_supports_direct_timer_controls() {
259        let registry = TimerRegistry::new();
260        let (timer_id, _timer) = registry
261            .start_once(Duration::from_secs(5), || async { Ok(()) })
262            .await
263            .unwrap();
264
265        assert!(registry.contains(timer_id).await);
266        let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
267        assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
268        assert_eq!(registry.clear().await, 1);
269        assert!(registry.is_empty().await);
270    }
271
272    #[tokio::test(flavor = "current_thread", start_paused = true)]
273    async fn registry_can_pause_and_resume_tracked_timers() {
274        let registry = TimerRegistry::new();
275        let (timer_id, timer) = registry
276            .start_recurring(
277                RecurringSchedule::new(Duration::from_secs(2)).with_expiration_count(1),
278                || async { Ok(()) },
279            )
280            .await
281            .unwrap();
282        settle().await;
283
284        assert!(registry.pause(timer_id).await.unwrap());
285        assert_eq!(timer.get_state().await, TimerState::Paused);
286
287        advance(Duration::from_secs(5)).await;
288        settle().await;
289        assert_eq!(timer.get_statistics().await.execution_count, 0);
290
291        assert!(registry.resume(timer_id).await.unwrap());
292        advance(Duration::from_secs(2)).await;
293        settle().await;
294        assert_eq!(
295            timer.join().await.unwrap().reason,
296            TimerFinishReason::Completed
297        );
298    }
299}