1use std::{
2 collections::BTreeMap,
3 sync::{
4 atomic::{AtomicUsize, Ordering},
5 Arc,
6 },
7};
8
9use anyhow::{anyhow, Result};
10use serde::Serialize;
11use tokio::{sync::RwLock, task::JoinHandle};
12
13pub struct RegistryItem {
14 pub handle: JoinHandle<Result<()>>,
15 pub expected_len: usize,
16 pub progress: Arc<AtomicUsize>,
17 pub kind: String,
18}
19
20#[derive(Debug, Clone, Serialize)]
21pub struct TaskDescriptor {
22 pub len: usize,
23 pub progress: usize,
24 pub kind: String,
25 pub id: u64,
26}
27type RegistryCore = BTreeMap<u64, RegistryItem>;
28
29struct Registry(RegistryCore);
30
31impl Registry {
32 pub(super) fn new() -> Self {
33 Registry(BTreeMap::new())
34 }
35
36 pub(super) fn get_free_id(&self) -> u64 {
37 let mut values = self.0.iter().map(|(id, _)| *id).collect::<Vec<u64>>();
38
39 values.sort_by(|a, b| b.cmp(a));
40 if !values.is_empty() {
41 values[0] + 1
42 } else {
43 1
44 }
45 }
46
47 pub(super) fn push(&mut self, item: RegistryItem) -> u64 {
48 let id = self.get_free_id();
49 self.0.insert(id, item);
50 id
51 }
52
53 pub(super) fn ask(&self, id: u64) -> Option<(usize, usize, String)> {
54 self.0.get(&id).map(|task| {
55 (
56 task.progress.load(Ordering::SeqCst),
57 task.expected_len,
58 task.kind.to_string(),
59 )
60 })
61 }
62
63 pub(super) fn ask_percent(&self, id: u64) -> usize {
64 self.0
65 .get(&id)
66 .map(|task| task.progress.load(Ordering::SeqCst) * 100usize / task.expected_len)
67 .unwrap_or(0usize)
68 }
69
70 pub(super) fn ask_percent_float(&self, id: u64) -> f64 {
71 self.0
72 .get(&id)
73 .map(|task| {
74 let float_size = task.expected_len as f64;
75 let float_progress = task.progress.load(Ordering::SeqCst) as f64;
76
77 float_progress * 100f64 / float_size
78 })
79 .unwrap_or(0f64)
80 }
81
82 pub(super) fn list_all(&self) -> Vec<TaskDescriptor> {
83 self.0
84 .iter()
85 .map(|(id, item)| TaskDescriptor {
86 id: *id,
87 len: item.expected_len,
88 progress: item.progress.load(Ordering::SeqCst),
89 kind: item.kind.to_string(),
90 })
91 .collect()
92 }
93
94 pub(super) fn update(&self, id: u64, progress: usize) {
95 if let Some(task) = self.0.get(&id) {
96 task.progress.store(progress, Ordering::SeqCst);
97 }
98 }
99
100 pub(super) fn cancel(&mut self, id: u64) {
101 if let Some(task) = self.0.get(&id) {
102 task.handle.abort();
103 self.0.remove(&id);
104 }
105 }
106
107 pub(super) fn remove(&mut self, id: u64) {
108 self.0.remove(&id);
109 }
110
111 pub(super) fn count(&self) -> usize {
112 self.0.len()
113 }
114}
115
116static mut REGISTRY: Option<RwLock<Registry>> = None;
117
118pub(crate) async fn push(item: RegistryItem) -> Result<u64> {
119 if unsafe { REGISTRY.is_none() } {
120 unsafe {
121 REGISTRY = Some(RwLock::new(Registry::new()));
122 }
123 }
124
125 if let Some(lock) = unsafe { ®ISTRY } {
126 let mut reg = lock.write().await;
127
128 Ok(reg.push(item))
129 } else {
130 Err(anyhow!("Uninitialized Registry"))
131 }
132}
133
134pub(crate) async fn ask(id: u64) -> Result<Option<(usize, usize, String)>> {
135 if let Some(lock) = unsafe { ®ISTRY } {
136 let reg = lock.read().await;
137
138 Ok(reg.ask(id))
139 } else {
140 Err(anyhow!("Uninitialized Registry"))
141 }
142}
143
144pub(crate) async fn ask_percent(id: u64) -> Result<usize> {
145 if let Some(lock) = unsafe { ®ISTRY } {
146 let req = lock.read().await;
147
148 Ok(req.ask_percent(id))
149 } else {
150 Err(anyhow!("Uninitialized Registry"))
151 }
152}
153
154pub(crate) async fn ask_percent_float(id: u64) -> Result<f64> {
155 if let Some(lock) = unsafe { ®ISTRY } {
156 let req = lock.read().await;
157
158 Ok(req.ask_percent_float(id))
159 } else {
160 Err(anyhow!("Uninitialized Registry"))
161 }
162}
163
164pub(crate) async fn list_all() -> Vec<TaskDescriptor> {
165 if let Some(lock) = unsafe { ®ISTRY } {
166 let reg = lock.read().await;
167
168 reg.list_all()
169 } else {
170 vec![]
171 }
172}
173
174pub(crate) async fn cancel(id: u64) -> Result<()> {
175 if let Some(lock) = unsafe { ®ISTRY } {
176 let mut reg = lock.write().await;
177 reg.cancel(id);
178 Ok(())
179 } else {
180 Err(anyhow!("Uninitialized Registry"))
181 }
182}
183
184pub(crate) async fn update(id: u64, progress: usize) -> Result<()> {
185 if let Some(lock) = unsafe { ®ISTRY } {
186 let req = lock.read().await;
187 req.update(id, progress);
188 Ok(())
189 } else {
190 Err(anyhow!("Uninitialized Registry"))
191 }
192}
193
194pub(crate) async fn remove(id: u64) -> Result<()> {
195 if let Some(lock) = unsafe { ®ISTRY } {
196 let mut reg = lock.write().await;
197 reg.remove(id);
198 Ok(())
199 } else {
200 Err(anyhow!("Uninitialized Registry"))
201 }
202}
203
204pub(crate) async fn get_free_id() -> u64 {
205 if let Some(lock) = unsafe { ®ISTRY } {
206 let req = lock.write().await;
207 req.get_free_id()
208 } else {
209 1
210 }
211}
212
213pub(crate) async fn count() -> usize {
214 if let Some(lock) = unsafe { ®ISTRY } {
215 let reg = lock.read().await;
216 reg.count()
217 } else {
218 0
219 }
220}
221
222#[cfg(test)]
223mod registry_tests {
224 use std::sync::{atomic::AtomicUsize, Arc};
225
226 use tokio::time::Instant;
227
228 use crate::registry::RegistryItem;
229
230 use super::Registry;
231
232 #[tokio::test]
233 async fn get_free_id_test() {
234 let mut reg = Registry::new();
235
236 assert_eq!(reg.get_free_id(), 1);
237 reg.push(RegistryItem {
238 handle: tokio::spawn(async move {
239 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
240 Ok(())
241 }),
242 expected_len: 1usize,
243 progress: Arc::new(AtomicUsize::new(0)),
244 kind: "TEST".to_string(),
245 });
246 assert_eq!(reg.get_free_id(), 2);
247 }
248
249 #[tokio::test]
250 async fn push_test() {
251 let mut reg = Registry::new();
252 reg.push(RegistryItem {
253 handle: tokio::spawn(async move {
254 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
255 Ok(())
256 }),
257 expected_len: 1,
258 progress: Arc::new(AtomicUsize::new(0)),
259 kind: "TEST".to_string(),
260 });
261 reg.push(RegistryItem {
262 handle: tokio::spawn(async move {
263 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
264 Ok(())
265 }),
266 expected_len: 1,
267 progress: Arc::new(AtomicUsize::new(0)),
268 kind: "TEST".to_string(),
269 });
270
271 assert_eq!(reg.count(), 2);
272 }
273
274 #[tokio::test]
275 async fn ask_test() {
276 let mut reg = Registry::new();
277 assert!(reg.ask(1).is_none());
278 let id = reg.push(RegistryItem {
279 handle: tokio::spawn(async move {
280 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
281 Ok(())
282 }),
283 expected_len: 1,
284 progress: Arc::new(AtomicUsize::new(0)),
285 kind: "TEST".to_string(),
286 });
287 let a = reg.ask(id);
288
289 assert!(a.is_some());
290 if let Some(ask) = a {
291 assert_eq!(ask.0, 0);
292 assert_eq!(ask.1, 1);
293 assert_eq!(ask.2, "TEST".to_string());
294 } else {
295 assert!(false);
296 }
297 }
298
299 #[tokio::test]
300 async fn ask_percent_test() {
301 let mut reg = Registry::new();
302 let id = reg.push(RegistryItem {
303 handle: tokio::spawn(async move {
304 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
305 Ok(())
306 }),
307 expected_len: 2,
308 progress: Arc::new(AtomicUsize::new(1)),
309 kind: "TEST".to_string(),
310 });
311 let a = reg.ask_percent(id);
312 assert_eq!(a, 50usize);
313 }
314
315 #[tokio::test]
316 async fn ask_percent_float_test() {
317 let mut reg = Registry::new();
318 let id = reg.push(RegistryItem {
319 handle: tokio::spawn(async move {
320 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
321 Ok(())
322 }),
323 expected_len: 2,
324 progress: Arc::new(AtomicUsize::new(1)),
325 kind: "TEST".to_string(),
326 });
327 let a = reg.ask_percent_float(id);
328 assert_eq!(a, 50f64);
329 }
330
331 #[tokio::test]
332 async fn list_all_test() {
333 let mut reg = Registry::new();
334 reg.push(RegistryItem {
335 handle: tokio::spawn(async move {
336 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
337 Ok(())
338 }),
339 expected_len: 1,
340 progress: Arc::new(AtomicUsize::new(0)),
341 kind: "TEST".to_string(),
342 });
343 reg.push(RegistryItem {
344 handle: tokio::spawn(async move {
345 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
346 Ok(())
347 }),
348 expected_len: 2,
349 progress: Arc::new(AtomicUsize::new(1)),
350 kind: "TEST2".to_string(),
351 });
352 let list = reg.list_all();
353 assert_eq!(list.len(), 2usize);
354 assert_eq!(list[0].id, 1);
355 assert_eq!(list[0].len, 1);
356 assert_eq!(list[0].progress, 0);
357 assert_eq!(list[0].kind, "TEST".to_string());
358 assert_eq!(list[1].id, 2);
359 assert_eq!(list[1].len, 2);
360 assert_eq!(list[1].progress, 1);
361 assert_eq!(list[1].kind, "TEST2".to_string());
362 }
363
364 #[tokio::test]
365 async fn update_test() {
366 let mut reg = Registry::new();
367 let id = reg.push(RegistryItem {
368 handle: tokio::spawn(async move {
369 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
370 Ok(())
371 }),
372 expected_len: 1,
373 progress: Arc::new(AtomicUsize::new(0)),
374 kind: "TEST".to_string(),
375 });
376 reg.update(id, 1);
377 let a = reg.ask(id);
378
379 assert!(a.is_some());
380 if let Some(ask) = a {
381 assert_eq!(ask.0, 1);
382 } else {
383 assert!(false)
384 }
385 }
386
387 #[tokio::test]
388 async fn cancel_test() {
389 let mut reg = Registry::new();
390 let id = reg.push(RegistryItem {
391 handle: tokio::spawn(async move {
392 let now = Instant::now();
393 loop {
394 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
395 if now.elapsed().ge(&tokio::time::Duration::from_secs(1)) {
396 break;
397 }
398 }
399 Ok(())
400 }),
401 expected_len: 1,
402 progress: Arc::new(AtomicUsize::new(0)),
403 kind: "TEST".to_string(),
404 });
405 reg.cancel(id);
406 assert_eq!(reg.count(), 0);
407 }
408
409 #[tokio::test]
410 async fn remove_test() {
411 let mut reg = Registry::new();
412 let id = reg.push(RegistryItem {
413 handle: tokio::spawn(async move {
414 let now = Instant::now();
415 loop {
416 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
417 if now.elapsed().ge(&tokio::time::Duration::from_secs(1)) {
418 break;
419 }
420 }
421 Ok(())
422 }),
423 expected_len: 1,
424 progress: Arc::new(AtomicUsize::new(0)),
425 kind: "TEST".to_string(),
426 });
427 assert_eq!(reg.count(), 1);
428 reg.remove(id);
429 assert_eq!(reg.count(), 0);
430 }
431}