treg/
registry.rs

1use std::{
2    collections::BTreeMap,
3    sync::{
4        atomic::{AtomicUsize, Ordering},
5        Arc,
6    },
7};
8
9use anyhow::{anyhow, Result};
10use serde::Serialize;
11use tokio::{sync::RwLock, task::JoinHandle};
12
13pub struct RegistryItem {
14    pub handle: JoinHandle<Result<()>>,
15    pub expected_len: usize,
16    pub progress: Arc<AtomicUsize>,
17    pub kind: String,
18}
19
20#[derive(Debug, Clone, Serialize)]
21pub struct TaskDescriptor {
22    pub len: usize,
23    pub progress: usize,
24    pub kind: String,
25    pub id: u64,
26}
27type RegistryCore = BTreeMap<u64, RegistryItem>;
28
29struct Registry(RegistryCore);
30
31impl Registry {
32    pub(super) fn new() -> Self {
33        Registry(BTreeMap::new())
34    }
35
36    pub(super) fn get_free_id(&self) -> u64 {
37        let mut values = self.0.iter().map(|(id, _)| *id).collect::<Vec<u64>>();
38
39        values.sort_by(|a, b| b.cmp(a));
40        if !values.is_empty() {
41            values[0] + 1
42        } else {
43            1
44        }
45    }
46
47    pub(super) fn push(&mut self, item: RegistryItem) -> u64 {
48        let id = self.get_free_id();
49        self.0.insert(id, item);
50        id
51    }
52
53    pub(super) fn ask(&self, id: u64) -> Option<(usize, usize, String)> {
54        self.0.get(&id).map(|task| {
55            (
56                task.progress.load(Ordering::SeqCst),
57                task.expected_len,
58                task.kind.to_string(),
59            )
60        })
61    }
62
63    pub(super) fn ask_percent(&self, id: u64) -> usize {
64        self.0
65            .get(&id)
66            .map(|task| task.progress.load(Ordering::SeqCst) * 100usize / task.expected_len)
67            .unwrap_or(0usize)
68    }
69
70    pub(super) fn ask_percent_float(&self, id: u64) -> f64 {
71        self.0
72            .get(&id)
73            .map(|task| {
74                let float_size = task.expected_len as f64;
75                let float_progress = task.progress.load(Ordering::SeqCst) as f64;
76
77                float_progress * 100f64 / float_size
78            })
79            .unwrap_or(0f64)
80    }
81
82    pub(super) fn list_all(&self) -> Vec<TaskDescriptor> {
83        self.0
84            .iter()
85            .map(|(id, item)| TaskDescriptor {
86                id: *id,
87                len: item.expected_len,
88                progress: item.progress.load(Ordering::SeqCst),
89                kind: item.kind.to_string(),
90            })
91            .collect()
92    }
93
94    pub(super) fn update(&self, id: u64, progress: usize) {
95        if let Some(task) = self.0.get(&id) {
96            task.progress.store(progress, Ordering::SeqCst);
97        }
98    }
99
100    pub(super) fn cancel(&mut self, id: u64) {
101        if let Some(task) = self.0.get(&id) {
102            task.handle.abort();
103            self.0.remove(&id);
104        }
105    }
106
107    pub(super) fn remove(&mut self, id: u64) {
108        self.0.remove(&id);
109    }
110
111    pub(super) fn count(&self) -> usize {
112        self.0.len()
113    }
114}
115
116static mut REGISTRY: Option<RwLock<Registry>> = None;
117
118pub(crate) async fn push(item: RegistryItem) -> Result<u64> {
119    if unsafe { REGISTRY.is_none() } {
120        unsafe {
121            REGISTRY = Some(RwLock::new(Registry::new()));
122        }
123    }
124
125    if let Some(lock) = unsafe { &REGISTRY } {
126        let mut reg = lock.write().await;
127
128        Ok(reg.push(item))
129    } else {
130        Err(anyhow!("Uninitialized Registry"))
131    }
132}
133
134pub(crate) async fn ask(id: u64) -> Result<Option<(usize, usize, String)>> {
135    if let Some(lock) = unsafe { &REGISTRY } {
136        let reg = lock.read().await;
137
138        Ok(reg.ask(id))
139    } else {
140        Err(anyhow!("Uninitialized Registry"))
141    }
142}
143
144pub(crate) async fn ask_percent(id: u64) -> Result<usize> {
145    if let Some(lock) = unsafe { &REGISTRY } {
146        let req = lock.read().await;
147
148        Ok(req.ask_percent(id))
149    } else {
150        Err(anyhow!("Uninitialized Registry"))
151    }
152}
153
154pub(crate) async fn ask_percent_float(id: u64) -> Result<f64> {
155    if let Some(lock) = unsafe { &REGISTRY } {
156        let req = lock.read().await;
157
158        Ok(req.ask_percent_float(id))
159    } else {
160        Err(anyhow!("Uninitialized Registry"))
161    }
162}
163
164pub(crate) async fn list_all() -> Vec<TaskDescriptor> {
165    if let Some(lock) = unsafe { &REGISTRY } {
166        let reg = lock.read().await;
167
168        reg.list_all()
169    } else {
170        vec![]
171    }
172}
173
174pub(crate) async fn cancel(id: u64) -> Result<()> {
175    if let Some(lock) = unsafe { &REGISTRY } {
176        let mut reg = lock.write().await;
177        reg.cancel(id);
178        Ok(())
179    } else {
180        Err(anyhow!("Uninitialized Registry"))
181    }
182}
183
184pub(crate) async fn update(id: u64, progress: usize) -> Result<()> {
185    if let Some(lock) = unsafe { &REGISTRY } {
186        let req = lock.read().await;
187        req.update(id, progress);
188        Ok(())
189    } else {
190        Err(anyhow!("Uninitialized Registry"))
191    }
192}
193
194pub(crate) async fn remove(id: u64) -> Result<()> {
195    if let Some(lock) = unsafe { &REGISTRY } {
196        let mut reg = lock.write().await;
197        reg.remove(id);
198        Ok(())
199    } else {
200        Err(anyhow!("Uninitialized Registry"))
201    }
202}
203
204pub(crate) async fn get_free_id() -> u64 {
205    if let Some(lock) = unsafe { &REGISTRY } {
206        let req = lock.write().await;
207        req.get_free_id()
208    } else {
209        1
210    }
211}
212
213pub(crate) async fn count() -> usize {
214    if let Some(lock) = unsafe { &REGISTRY } {
215        let reg = lock.read().await;
216        reg.count()
217    } else {
218        0
219    }
220}
221
222#[cfg(test)]
223mod registry_tests {
224    use std::sync::{atomic::AtomicUsize, Arc};
225
226    use tokio::time::Instant;
227
228    use crate::registry::RegistryItem;
229
230    use super::Registry;
231
232    #[tokio::test]
233    async fn get_free_id_test() {
234        let mut reg = Registry::new();
235
236        assert_eq!(reg.get_free_id(), 1);
237        reg.push(RegistryItem {
238            handle: tokio::spawn(async move {
239                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
240                Ok(())
241            }),
242            expected_len: 1usize,
243            progress: Arc::new(AtomicUsize::new(0)),
244            kind: "TEST".to_string(),
245        });
246        assert_eq!(reg.get_free_id(), 2);
247    }
248
249    #[tokio::test]
250    async fn push_test() {
251        let mut reg = Registry::new();
252        reg.push(RegistryItem {
253            handle: tokio::spawn(async move {
254                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
255                Ok(())
256            }),
257            expected_len: 1,
258            progress: Arc::new(AtomicUsize::new(0)),
259            kind: "TEST".to_string(),
260        });
261        reg.push(RegistryItem {
262            handle: tokio::spawn(async move {
263                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
264                Ok(())
265            }),
266            expected_len: 1,
267            progress: Arc::new(AtomicUsize::new(0)),
268            kind: "TEST".to_string(),
269        });
270
271        assert_eq!(reg.count(), 2);
272    }
273
274    #[tokio::test]
275    async fn ask_test() {
276        let mut reg = Registry::new();
277        assert!(reg.ask(1).is_none());
278        let id = reg.push(RegistryItem {
279            handle: tokio::spawn(async move {
280                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
281                Ok(())
282            }),
283            expected_len: 1,
284            progress: Arc::new(AtomicUsize::new(0)),
285            kind: "TEST".to_string(),
286        });
287        let a = reg.ask(id);
288
289        assert!(a.is_some());
290        if let Some(ask) = a {
291            assert_eq!(ask.0, 0);
292            assert_eq!(ask.1, 1);
293            assert_eq!(ask.2, "TEST".to_string());
294        } else {
295            assert!(false);
296        }
297    }
298
299    #[tokio::test]
300    async fn ask_percent_test() {
301        let mut reg = Registry::new();
302        let id = reg.push(RegistryItem {
303            handle: tokio::spawn(async move {
304                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
305                Ok(())
306            }),
307            expected_len: 2,
308            progress: Arc::new(AtomicUsize::new(1)),
309            kind: "TEST".to_string(),
310        });
311        let a = reg.ask_percent(id);
312        assert_eq!(a, 50usize);
313    }
314
315    #[tokio::test]
316    async fn ask_percent_float_test() {
317        let mut reg = Registry::new();
318        let id = reg.push(RegistryItem {
319            handle: tokio::spawn(async move {
320                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
321                Ok(())
322            }),
323            expected_len: 2,
324            progress: Arc::new(AtomicUsize::new(1)),
325            kind: "TEST".to_string(),
326        });
327        let a = reg.ask_percent_float(id);
328        assert_eq!(a, 50f64);
329    }
330
331    #[tokio::test]
332    async fn list_all_test() {
333        let mut reg = Registry::new();
334        reg.push(RegistryItem {
335            handle: tokio::spawn(async move {
336                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
337                Ok(())
338            }),
339            expected_len: 1,
340            progress: Arc::new(AtomicUsize::new(0)),
341            kind: "TEST".to_string(),
342        });
343        reg.push(RegistryItem {
344            handle: tokio::spawn(async move {
345                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
346                Ok(())
347            }),
348            expected_len: 2,
349            progress: Arc::new(AtomicUsize::new(1)),
350            kind: "TEST2".to_string(),
351        });
352        let list = reg.list_all();
353        assert_eq!(list.len(), 2usize);
354        assert_eq!(list[0].id, 1);
355        assert_eq!(list[0].len, 1);
356        assert_eq!(list[0].progress, 0);
357        assert_eq!(list[0].kind, "TEST".to_string());
358        assert_eq!(list[1].id, 2);
359        assert_eq!(list[1].len, 2);
360        assert_eq!(list[1].progress, 1);
361        assert_eq!(list[1].kind, "TEST2".to_string());
362    }
363
364    #[tokio::test]
365    async fn update_test() {
366        let mut reg = Registry::new();
367        let id = reg.push(RegistryItem {
368            handle: tokio::spawn(async move {
369                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
370                Ok(())
371            }),
372            expected_len: 1,
373            progress: Arc::new(AtomicUsize::new(0)),
374            kind: "TEST".to_string(),
375        });
376        reg.update(id, 1);
377        let a = reg.ask(id);
378
379        assert!(a.is_some());
380        if let Some(ask) = a {
381            assert_eq!(ask.0, 1);
382        } else {
383            assert!(false)
384        }
385    }
386
387    #[tokio::test]
388    async fn cancel_test() {
389        let mut reg = Registry::new();
390        let id = reg.push(RegistryItem {
391            handle: tokio::spawn(async move {
392                let now = Instant::now();
393                loop {
394                    tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
395                    if now.elapsed().ge(&tokio::time::Duration::from_secs(1)) {
396                        break;
397                    }
398                }
399                Ok(())
400            }),
401            expected_len: 1,
402            progress: Arc::new(AtomicUsize::new(0)),
403            kind: "TEST".to_string(),
404        });
405        reg.cancel(id);
406        assert_eq!(reg.count(), 0);
407    }
408
409    #[tokio::test]
410    async fn remove_test() {
411        let mut reg = Registry::new();
412        let id = reg.push(RegistryItem {
413            handle: tokio::spawn(async move {
414                let now = Instant::now();
415                loop {
416                    tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
417                    if now.elapsed().ge(&tokio::time::Duration::from_secs(1)) {
418                        break;
419                    }
420                }
421                Ok(())
422            }),
423            expected_len: 1,
424            progress: Arc::new(AtomicUsize::new(0)),
425            kind: "TEST".to_string(),
426        });
427        assert_eq!(reg.count(), 1);
428        reg.remove(id);
429        assert_eq!(reg.count(), 0);
430    }
431}