Skip to main content

lunar_lib/
progress.rs

1use std::{
2    panic::AssertUnwindSafe,
3    sync::{
4        Arc, LazyLock,
5        atomic::{AtomicUsize, Ordering},
6        mpsc::{Sender, channel},
7    },
8};
9
10use crate::error;
11
12/// A thread safe, thread stable progress bar, abstracted for any frontend
13#[derive(Clone)]
14pub struct ProgressBar {
15    handle: Arc<ProgressHandle>,
16}
17
18struct ProgressHandle {
19    value: AtomicUsize,
20    max: AtomicUsize,
21    renderer: Arc<dyn ProgressRenderer>,
22}
23
24impl ProgressBar {
25    /// Runs a custom closure for this progress bar.
26    ///
27    /// # Panics
28    ///
29    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
30    ///
31    /// # Notes
32    ///
33    /// If the input closure panics, it will not kill the thread, it will instead log the error and return
34    pub fn custom<F>(&self, f: F)
35    where
36        F: FnOnce() + Send + 'static,
37    {
38        GLOBAL_SENDER
39            .send(ProgressEvent::Custom {
40                _handle: self.handle.clone(),
41                f: Box::new(f),
42            })
43            .expect("Progress handler thread died.");
44    }
45
46    /// Waits until the flush is processed, allowing you to wait until all previous events have been processed
47    ///
48    /// # Panics
49    ///
50    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
51    pub fn flush(&self) {
52        let (tx, rx) = channel();
53        GLOBAL_SENDER
54            .send(ProgressEvent::Flush {
55                _handle: self.handle.clone(),
56                done: tx,
57            })
58            .expect("Progress handler thread died.");
59        let _ = rx.recv();
60    }
61
62    /// Increments the value of a progress bar by 1
63    ///
64    /// # Panics
65    ///
66    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
67    pub fn increment(&self) {
68        GLOBAL_SENDER
69            .send(ProgressEvent::Increment {
70                handle: self.handle.clone(),
71            })
72            .expect("Progress handler thread died.");
73    }
74
75    /// Increments the value of the progress bar by 1 and runs a pre-write and post-write closure
76    ///
77    /// The `pre` closure, if some, will run before the value is incremented
78    /// The `post` closure, if some, will run after the value is incremented, and before [`ProgressRenderer::on_update()`] method is run
79    ///
80    /// # Panics
81    ///
82    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
83    pub fn increment_and_run<P, Q>(&self, pre: Option<P>, post: Option<Q>)
84    where
85        P: FnOnce() + Send + 'static,
86        Q: FnOnce() + Send + 'static,
87    {
88        GLOBAL_SENDER
89            .send(ProgressEvent::IncrementAndRun {
90                handle: self.handle.clone(),
91                pre: pre.map(box_dyn),
92                post: post.map(box_dyn),
93            })
94            .expect("Progress handler thread died.");
95    }
96
97    /// Sets the max-value of the progress bar
98    ///
99    /// # Panics
100    ///
101    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
102    pub fn set_max(&self, max: usize) {
103        GLOBAL_SENDER
104            .send(ProgressEvent::SetMax {
105                handle: self.handle.clone(),
106                max,
107            })
108            .expect("Progress handler thread died.");
109    }
110
111    /// Sets the max-value of the progress bar and runs a pre-write and post-write closure
112    ///
113    /// The `pre` closure, if some, will run before the value is incremented
114    /// The `post` closure, if some, will run after the value is incremented, and before [`ProgressRenderer::on_update()`] method is run
115    ///
116    /// # Panics
117    ///
118    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
119    ///
120    /// # Notes
121    ///
122    /// If the any of the input closures panic, it will not kill the thread, it will instead log the error and ignore it
123    pub fn set_max_and_run<P, Q>(&self, max: usize, pre: Option<P>, post: Option<Q>)
124    where
125        P: FnOnce() + Send + 'static,
126        Q: FnOnce() + Send + 'static,
127    {
128        GLOBAL_SENDER
129            .send(ProgressEvent::SetMaxAndRun {
130                handle: self.handle.clone(),
131                max,
132                pre: pre.map(box_dyn),
133                post: post.map(box_dyn),
134            })
135            .expect("Progress handler thread died.");
136    }
137
138    /// Sets the value of the progress bar
139    ///
140    /// # Panics
141    ///
142    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
143    pub fn set_value(&self, value: usize) {
144        GLOBAL_SENDER
145            .send(ProgressEvent::SetValue {
146                handle: self.handle.clone(),
147                value,
148            })
149            .expect("Progress handler thread died.");
150    }
151
152    /// Sets the value of the progress bar and runs a pre-write and post-write closure
153    ///
154    /// The `pre` closure, if some, will run before the value is incremented
155    /// The `post` closure, if some, will run after the value is incremented, and before [`ProgressRenderer::on_update()`] method is run
156    ///
157    /// # Panics
158    ///
159    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
160    ///
161    /// # Notes
162    ///
163    /// If the any of the input closures panic, it will not kill the thread, it will instead log the error and ignore it
164    pub fn set_value_and_run<P, Q>(&self, value: usize, pre: Option<P>, post: Option<Q>)
165    where
166        P: FnOnce() + Send + 'static,
167        Q: FnOnce() + Send + 'static,
168    {
169        GLOBAL_SENDER
170            .send(ProgressEvent::SetValueAndRun {
171                handle: self.handle.clone(),
172                value,
173                pre: pre.map(box_dyn),
174                post: post.map(box_dyn),
175            })
176            .expect("Progress handler thread died.");
177    }
178
179    /// Creates a new progress bar instance using the renderer
180    ///
181    /// This progress bar can safely be shared and updated between threads
182    pub fn start(value: usize, max: usize, renderer: Arc<dyn ProgressRenderer>) -> Self {
183        let handle = Arc::new(ProgressHandle {
184            value: AtomicUsize::new(value),
185            max: AtomicUsize::new(max),
186            renderer: renderer.clone(),
187        });
188
189        renderer.on_start(value, max);
190
191        Self { handle }
192    }
193}
194
195/// Defines how a progress bar should be rendered
196pub trait ProgressRenderer: Send + Sync {
197    /// Called when the [`ProgressBar`] which holds the instance of [`Self`] is spawned
198    fn on_start(&self, value: usize, max: usize);
199
200    /// Called when the [`ProgressBar`] which holds the instance of [`Self`] is updated
201    fn on_update(&self, value: usize, max: usize);
202
203    /// Called when the [`ProgressBar`] which holds the instance of [`Self`] is dropped
204    ///
205    /// # Notes
206    ///
207    /// [`Self::on_update()`] is NOT automatically called before finish, if you would like to update before the renderer is dropped, you can do so when defining this function
208    fn on_finish(&self);
209}
210
211impl Drop for ProgressHandle {
212    fn drop(&mut self) {
213        self.renderer.on_finish();
214    }
215}
216
217// ==============================================================
218// Thread stuff
219// ==============================================================
220
221enum ProgressEvent {
222    Increment {
223        handle: Arc<ProgressHandle>,
224    },
225    IncrementAndRun {
226        handle: Arc<ProgressHandle>,
227        pre: Option<Box<dyn FnOnce() + Send>>,
228        post: Option<Box<dyn FnOnce() + Send>>,
229    },
230    SetValue {
231        handle: Arc<ProgressHandle>,
232        value: usize,
233    },
234    SetValueAndRun {
235        handle: Arc<ProgressHandle>,
236        value: usize,
237        pre: Option<Box<dyn FnOnce() + Send>>,
238        post: Option<Box<dyn FnOnce() + Send>>,
239    },
240    SetMax {
241        handle: Arc<ProgressHandle>,
242        max: usize,
243    },
244    SetMaxAndRun {
245        handle: Arc<ProgressHandle>,
246        max: usize,
247        pre: Option<Box<dyn FnOnce() + Send>>,
248        post: Option<Box<dyn FnOnce() + Send>>,
249    },
250    Custom {
251        _handle: Arc<ProgressHandle>,
252        f: Box<dyn FnOnce() + Send>,
253    },
254    Flush {
255        _handle: Arc<ProgressHandle>,
256        done: std::sync::mpsc::Sender<()>,
257    },
258}
259
260static GLOBAL_SENDER: LazyLock<Sender<ProgressEvent>> = LazyLock::new(|| {
261    let (tx, rx) = channel::<ProgressEvent>();
262
263    std::thread::spawn(move || {
264        while let Ok(event) = rx.recv() {
265            match event {
266                ProgressEvent::Increment { handle } => {
267                    let value = handle.value.fetch_add(1, Ordering::SeqCst);
268                    let max = handle.max.load(Ordering::SeqCst);
269                    handle.renderer.on_update(value + 1, max);
270                }
271                ProgressEvent::IncrementAndRun { handle, pre, post } => {
272                    run_optional_log_panic(pre);
273                    let value = handle.value.fetch_add(1, Ordering::SeqCst);
274                    run_optional_log_panic(post);
275                    let max = handle.max.load(Ordering::SeqCst);
276                    handle.renderer.on_update(value + 1, max);
277                }
278                ProgressEvent::SetValue { handle, value } => {
279                    handle.value.store(value, Ordering::SeqCst);
280                    let max = handle.max.load(Ordering::SeqCst);
281                    handle.renderer.on_update(value, max);
282                }
283                ProgressEvent::SetValueAndRun {
284                    handle,
285                    value,
286                    pre,
287                    post,
288                } => {
289                    handle.value.store(value, Ordering::SeqCst);
290                    run_optional_log_panic(pre);
291                    let max = handle.max.load(Ordering::SeqCst);
292                    run_optional_log_panic(post);
293                    handle.renderer.on_update(value, max);
294                }
295                ProgressEvent::SetMax { handle, max } => {
296                    let value = handle.value.load(Ordering::SeqCst);
297                    handle.max.store(max, Ordering::SeqCst);
298                    handle.renderer.on_update(value, max);
299                }
300                ProgressEvent::SetMaxAndRun {
301                    handle,
302                    max,
303                    pre,
304                    post,
305                } => {
306                    let value = handle.value.load(Ordering::SeqCst);
307                    run_optional_log_panic(pre);
308                    handle.max.store(max, Ordering::SeqCst);
309                    run_optional_log_panic(post);
310                    handle.renderer.on_update(value, max);
311                }
312                ProgressEvent::Custom { _handle, f } => run_or_log_panic(f),
313                ProgressEvent::Flush { _handle, done } => done.send(()).unwrap(),
314            }
315        }
316    });
317
318    tx
319});
320
321fn box_dyn<F: FnOnce() + Send + 'static>(f: F) -> Box<dyn FnOnce() + Send + 'static> {
322    Box::new(f)
323}
324
325#[inline(always)]
326fn run_or_log_panic<F: FnOnce() + Send + 'static>(f: F) {
327    if let Err(err) = std::panic::catch_unwind(AssertUnwindSafe(f)) {
328        error!("ProgressBar closure panicked: {err:?}")
329    }
330}
331
332#[inline(always)]
333fn run_optional_log_panic<F: FnOnce() + Send + 'static>(f: Option<F>) {
334    if let Some(f) = f {
335        run_or_log_panic(f);
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use std::sync::{
342        Arc, Mutex,
343        atomic::{AtomicBool, Ordering},
344    };
345
346    use crate::progress::{ProgressBar, ProgressRenderer};
347
348    struct TestRenderer {
349        log: Arc<Mutex<Vec<String>>>,
350    }
351
352    impl TestRenderer {
353        fn new() -> Self {
354            Self {
355                log: Arc::new(Mutex::new(Vec::new())),
356            }
357        }
358
359        fn log(&self) -> Vec<String> {
360            self.log.lock().unwrap().clone()
361        }
362
363        fn push_log(&self, log: String) {
364            self.log.lock().unwrap().push(log);
365        }
366    }
367
368    impl ProgressRenderer for TestRenderer {
369        fn on_start(&self, value: usize, max: usize) {
370            self.push_log(format!("Start: [{value} / {max}]"));
371        }
372
373        fn on_update(&self, value: usize, max: usize) {
374            self.push_log(format!("Update: [{value} / {max}]"));
375        }
376
377        fn on_finish(&self) {
378            self.push_log(format!("Finish"));
379        }
380    }
381
382    #[test]
383    fn progress_increment() {
384        let renderer = Arc::new(TestRenderer::new());
385        let bar = ProgressBar::start(0, 3, renderer.clone());
386
387        bar.increment();
388        bar.increment();
389        bar.flush();
390
391        let log = renderer.log();
392        assert_eq!(log.len(), 3);
393        assert_eq!(log[0], "Start: [0 / 3]");
394        assert_eq!(log[1], "Update: [1 / 3]");
395        assert_eq!(log[2], "Update: [2 / 3]");
396    }
397
398    #[test]
399    fn progress_finishes_on_drop() {
400        let renderer = Arc::new(TestRenderer::new());
401        let bar = ProgressBar::start(0, 0, renderer.clone());
402
403        drop(bar);
404
405        let log = renderer.log();
406        assert_eq!(log.len(), 2);
407        assert_eq!(log[0], "Start: [0 / 0]");
408        assert_eq!(log[1], "Finish")
409    }
410
411    #[test]
412    fn progress_closures_run_on_increment() {
413        let renderer = Arc::new(TestRenderer::new());
414        let bar = ProgressBar::start(0, 1, renderer.clone());
415
416        let pre_ran = Arc::new(AtomicBool::new(false));
417        let post_ran = Arc::new(AtomicBool::new(false));
418
419        let pre_flag = pre_ran.clone();
420        let post_flag = post_ran.clone();
421
422        bar.increment_and_run(
423            Some(move || pre_flag.store(true, Ordering::SeqCst)),
424            Some(move || post_flag.store(true, Ordering::SeqCst)),
425        );
426        bar.flush();
427
428        assert!(pre_ran.load(Ordering::SeqCst));
429        assert!(post_ran.load(Ordering::SeqCst));
430
431        let log = renderer.log();
432        assert_eq!(log.len(), 2);
433        assert_eq!(log[0], "Start: [0 / 1]");
434        assert_eq!(log[1], "Update: [1 / 1]");
435    }
436
437    #[test]
438    fn progress_closures_run_on_set_value() {
439        let renderer = Arc::new(TestRenderer::new());
440        let bar = ProgressBar::start(0, 5, renderer.clone());
441
442        let pre_ran = Arc::new(AtomicBool::new(false));
443        let post_ran = Arc::new(AtomicBool::new(false));
444
445        let pre_flag = pre_ran.clone();
446        let post_flag = post_ran.clone();
447
448        bar.set_value_and_run(
449            5,
450            Some(move || pre_flag.store(true, Ordering::SeqCst)),
451            Some(move || post_flag.store(true, Ordering::SeqCst)),
452        );
453        bar.flush();
454
455        assert!(pre_ran.load(Ordering::SeqCst));
456        assert!(post_ran.load(Ordering::SeqCst));
457
458        let log = renderer.log();
459        assert_eq!(log.len(), 2);
460        assert_eq!(log[0], "Start: [0 / 5]");
461        assert_eq!(log[1], "Update: [5 / 5]");
462    }
463
464    #[test]
465    fn progress_closures_run_on_set_max() {
466        let renderer = Arc::new(TestRenderer::new());
467        let bar = ProgressBar::start(0, 3, renderer.clone());
468
469        let pre_ran = Arc::new(AtomicBool::new(false));
470        let post_ran = Arc::new(AtomicBool::new(false));
471
472        let pre_flag = pre_ran.clone();
473        let post_flag = post_ran.clone();
474
475        bar.set_max_and_run(
476            5,
477            Some(move || pre_flag.store(true, Ordering::SeqCst)),
478            Some(move || post_flag.store(true, Ordering::SeqCst)),
479        );
480        bar.flush();
481
482        assert!(pre_ran.load(Ordering::SeqCst));
483        assert!(post_ran.load(Ordering::SeqCst));
484
485        let log = renderer.log();
486        assert_eq!(log.len(), 2);
487        assert_eq!(log[0], "Start: [0 / 3]");
488        assert_eq!(log[1], "Update: [0 / 5]");
489    }
490
491    #[test]
492    fn progress_thread_safe_on_closure_panic() {
493        let renderer = Arc::new(TestRenderer::new());
494        let bar = ProgressBar::start(0, 100, renderer.clone());
495
496        bar.increment_and_run(Some(|| panic!()), Some(|| ()));
497        bar.flush();
498    }
499
500    #[test]
501    fn progress_thread_stability() {
502        let renderer = Arc::new(TestRenderer::new());
503        let bar = ProgressBar::start(0, 100, renderer.clone());
504
505        let handles: Vec<_> = (0..5)
506            .map(|_| {
507                let bar = bar.clone();
508                std::thread::spawn(move || {
509                    for _ in 0..20 {
510                        bar.increment();
511                    }
512                })
513            })
514            .collect();
515
516        for h in handles {
517            h.join().unwrap()
518        }
519
520        bar.flush();
521
522        let log = renderer.log();
523        assert_eq!(log.len(), 101);
524
525        assert_eq!(log[0], "Start: [0 / 100]");
526        for (i, entry) in log.iter().skip(1).enumerate() {
527            assert_eq!(entry, &format!("Update: [{i} / 100]", i = i + 1))
528        }
529    }
530}