Skip to main content

lunar_lib/
progress.rs

1use std::{
2    panic::AssertUnwindSafe,
3    sync::{
4        Arc,
5        atomic::{AtomicUsize, Ordering},
6        mpsc::{Sender, channel},
7    },
8};
9
10use crate::error;
11
12mod cli_progress_bar;
13pub use cli_progress_bar::*;
14
15mod null_progress_renderer;
16pub use null_progress_renderer::*;
17
18/// A thread safe, thread stable progress bar, abstracted for any frontend
19#[derive(Clone)]
20pub struct ProgressBar {
21    handle: Arc<ProgressHandle>,
22}
23
24struct ProgressHandle {
25    value: AtomicUsize,
26    max: AtomicUsize,
27    renderer: Arc<dyn ProgressRenderer>,
28    sender: Sender<ProgressEvent>,
29}
30
31impl Drop for ProgressHandle {
32    fn drop(&mut self) {
33        if !self.renderer.__is_null() {
34            let _ = self.sender.send(ProgressEvent::Exit);
35            self.renderer.on_finish();
36        }
37    }
38}
39
40impl ProgressBar {
41    /// Runs a custom closure for this progress bar.
42    ///
43    /// # Panics
44    ///
45    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
46    ///
47    /// # Notes
48    ///
49    /// If the input closure panics, it will not kill the thread, it will instead log the error and return
50    pub fn custom<F>(&self, f: F)
51    where
52        F: FnOnce() + Send + 'static,
53    {
54        if self.handle.renderer.__is_null() {
55            return;
56        }
57
58        self.handle
59            .sender
60            .send(ProgressEvent::Custom {
61                f: Box::new(f),
62                _handle: self.handle.clone(),
63            })
64            .expect("Progress handler thread died.");
65    }
66
67    /// Waits until the flush is processed, allowing you to wait until all previous events have been processed
68    ///
69    /// # Panics
70    ///
71    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
72    pub fn flush(&self) {
73        if self.handle.renderer.__is_null() {
74            return;
75        }
76
77        let (tx, rx) = channel();
78        self.handle
79            .sender
80            .send(ProgressEvent::Flush {
81                done: tx,
82                _handle: self.handle.clone(),
83            })
84            .expect("Progress handler thread died.");
85        let _ = rx.recv();
86    }
87
88    /// Gets the max value of the internal `ProgressHandle`
89    ///
90    /// # Notes
91    ///
92    /// While this will return the correct max value at time of call, In multi-threaded contexts, this max may have already changed by the time it has been used
93    pub fn get_max(&self) -> usize {
94        self.handle.max.load(Ordering::SeqCst)
95    }
96
97    /// Gets the value of the internal `ProgressHandle`
98    ///
99    /// # Notes
100    ///
101    /// While this will return the correct value at time of call, In multi-threaded contexts, this value may have already changed by the time it has been used
102    pub fn get_value(&self) -> usize {
103        self.handle.value.load(Ordering::SeqCst)
104    }
105
106    /// Increments the value of a progress bar by 1
107    ///
108    /// # Panics
109    ///
110    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
111    pub fn increment(&self) {
112        if self.handle.renderer.__is_null() {
113            return;
114        }
115
116        self.handle
117            .sender
118            .send(ProgressEvent::Increment {
119                handle: self.handle.clone(),
120            })
121            .expect("Progress handler thread died.");
122    }
123
124    /// Increments the value of the progress bar by 1 and runs a pre-write and post-write closure
125    ///
126    /// The `pre` closure, if some, will run before the value is incremented
127    /// The `post` closure, if some, will run after the value is incremented, and before [`ProgressRenderer::on_update()`] method is run
128    ///
129    /// # Panics
130    ///
131    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
132    pub fn increment_and_run<P, Q>(&self, pre: Option<P>, post: Option<Q>)
133    where
134        P: FnOnce() + Send + 'static,
135        Q: FnOnce() + Send + 'static,
136    {
137        if self.handle.renderer.__is_null() {
138            return;
139        }
140
141        self.handle
142            .sender
143            .send(ProgressEvent::IncrementAndRun {
144                handle: self.handle.clone(),
145                pre: pre.map(box_dyn),
146                post: post.map(box_dyn),
147            })
148            .expect("Progress handler thread died.");
149    }
150
151    /// Creates a new progress bar instance using the renderer
152    ///
153    /// This progress bar can safely be shared and updated between threads
154    pub fn new(value: usize, max: usize, renderer: Arc<dyn ProgressRenderer>) -> Self {
155        let (tx, rx) = channel();
156
157        let handle = Arc::new(ProgressHandle {
158            value: AtomicUsize::new(value),
159            max: AtomicUsize::new(max),
160            renderer: renderer.clone(),
161            sender: tx,
162        });
163
164        if !renderer.__is_null() {
165            spawn_progress_thread(rx);
166            renderer.on_start(value, max);
167        }
168
169        Self { handle }
170    }
171
172    /// Sets the max-value of the progress bar
173    ///
174    /// # Panics
175    ///
176    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
177    pub fn set_max(&self, max: usize) {
178        if self.handle.renderer.__is_null() {
179            return;
180        }
181
182        self.handle
183            .sender
184            .send(ProgressEvent::SetMax {
185                handle: self.handle.clone(),
186                max,
187            })
188            .expect("Progress handler thread died.");
189    }
190
191    /// Sets the max-value of the progress bar and runs a pre-write and post-write closure
192    ///
193    /// The `pre` closure, if some, will run before the value is incremented
194    /// The `post` closure, if some, will run after the value is incremented, and before [`ProgressRenderer::on_update()`] method is run
195    ///
196    /// # Panics
197    ///
198    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
199    ///
200    /// # Notes
201    ///
202    /// If the any of the input closures panic, it will not kill the thread, it will instead log the error and ignore it
203    pub fn set_max_and_run<P, Q>(&self, max: usize, pre: Option<P>, post: Option<Q>)
204    where
205        P: FnOnce() + Send + 'static,
206        Q: FnOnce() + Send + 'static,
207    {
208        if self.handle.renderer.__is_null() {
209            return;
210        }
211
212        self.handle
213            .sender
214            .send(ProgressEvent::SetMaxAndRun {
215                handle: self.handle.clone(),
216                max,
217                pre: pre.map(box_dyn),
218                post: post.map(box_dyn),
219            })
220            .expect("Progress handler thread died.");
221    }
222
223    /// Sets the value of the progress bar
224    ///
225    /// # Panics
226    ///
227    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
228    pub fn set_value(&self, value: usize) {
229        if self.handle.renderer.__is_null() {
230            return;
231        }
232
233        self.handle
234            .sender
235            .send(ProgressEvent::SetValue {
236                handle: self.handle.clone(),
237                value,
238            })
239            .expect("Progress handler thread died.");
240    }
241
242    /// Sets the value of the progress bar and runs a pre-write and post-write closure
243    ///
244    /// The `pre` closure, if some, will run before the value is incremented
245    /// The `post` closure, if some, will run after the value is incremented, and before [`ProgressRenderer::on_update()`] method is run
246    ///
247    /// # Panics
248    ///
249    /// Panics if the internal `GLOBAL_SENDER` thread cannot be communicated with. This usually signals critical failure somewhere
250    ///
251    /// # Notes
252    ///
253    /// If the any of the input closures panic, it will not kill the thread, it will instead log the error and ignore it
254    pub fn set_value_and_run<P, Q>(&self, value: usize, pre: Option<P>, post: Option<Q>)
255    where
256        P: FnOnce() + Send + 'static,
257        Q: FnOnce() + Send + 'static,
258    {
259        if self.handle.renderer.__is_null() {
260            return;
261        }
262
263        self.handle
264            .sender
265            .send(ProgressEvent::SetValueAndRun {
266                handle: self.handle.clone(),
267                value,
268                pre: pre.map(box_dyn),
269                post: post.map(box_dyn),
270            })
271            .expect("Progress handler thread died.");
272    }
273
274    /// Forces the [`ProgressRenderer`] to upate with the new values
275    ///
276    /// # Panics
277    ///
278    /// Panics if the progress bars handling thread cannot be communicated with.
279    pub fn update(&self) {
280        if self.handle.renderer.__is_null() {
281            return;
282        }
283
284        self.handle
285            .sender
286            .send(ProgressEvent::Update {
287                handle: self.handle.clone(),
288            })
289            .expect("Progress handler thread died.")
290    }
291
292    /// Checks if the the handling thread is alive by sending a check to it and seeing if it returns [`Ok`]
293    ///
294    /// # Notes
295    ///
296    /// This function will correctly check if the handling thread is alive at the time of calling, but can not guarantee the thread will stay alive after calling or that it receives the check
297    pub fn is_alive(&self) -> bool {
298        if self.handle.renderer.__is_null() {
299            return false;
300        }
301
302        self.handle.sender.send(ProgressEvent::Check).is_ok()
303    }
304
305    /// Calls the [`ProgressRenderer::set_label()`] function for the held [`ProgressRenderer`]
306    pub fn set_label(&self, label: &str) {
307        self.handle.renderer.set_label(label);
308    }
309
310    /// Tells the handling thread to call the [`ProgressRenderer::on_notify()`] function for the held [`ProgressRenderer`]
311    ///
312    /// # Panics
313    ///
314    /// Panics if the progress bars handling thread cannot be communicated with.
315    pub fn notify(&self, message: impl ToString) {
316        if self.handle.renderer.__is_null() {
317            return;
318        }
319
320        self.handle
321            .sender
322            .send(ProgressEvent::Notify {
323                handle: self.handle.clone(),
324                msg: message.to_string(),
325            })
326            .expect("Progress handler thread died.")
327    }
328}
329
330/// Defines how a progress bar should be rendered
331pub trait ProgressRenderer: Send + Sync {
332    /// Called when the [`ProgressBar`] which holds the instance of [`Self`] is spawned
333    fn on_start(&self, value: usize, max: usize);
334
335    /// Called when the [`ProgressBar`] which holds the instance of [`Self`] is updated
336    fn on_update(&self, value: usize, max: usize);
337
338    /// Called when the [`ProgressBar`] which holds the instance of [`Self`] is dropped
339    ///
340    /// # Notes
341    ///
342    /// [`Self::on_update()`] is NOT automatically called before finish, if you would like to update before the renderer is dropped, you can do so when defining this function
343    fn on_finish(&self);
344
345    /// Generic notification. Can be sent from anything with access to the [`ProgressBar`] that holds the instance of [`Self`]
346    ///
347    /// This function is not called internally by anything, its meant for other libraries to call when they want to notify your renderer of some message. If you do not care, you can ignore this function
348    fn on_notify(&self, msg: String);
349
350    /// Sets the label of [`Self`]. Can be sent from anything with access to the [`ProgressBar`] that holds the instance of [`Self`]
351    ///
352    /// This function is not called internally by anything, its meant for other libraries to call when they want to label your progress bar. If you do not care, you can ignore this function
353    fn set_label(&self, msg: &str);
354
355    /// If this function returns true, this renderer will do nothing
356    #[inline(always)]
357    #[doc(hidden)]
358    fn __is_null(&self) -> bool {
359        false
360    }
361}
362
363// ==============================================================
364// Thread stuff
365// ==============================================================
366
367enum ProgressEvent {
368    Check,
369    Exit,
370    Update {
371        handle: Arc<ProgressHandle>,
372    },
373    Increment {
374        handle: Arc<ProgressHandle>,
375    },
376    IncrementAndRun {
377        handle: Arc<ProgressHandle>,
378        pre: Option<Box<dyn FnOnce() + Send>>,
379        post: Option<Box<dyn FnOnce() + Send>>,
380    },
381    SetValue {
382        handle: Arc<ProgressHandle>,
383        value: usize,
384    },
385    SetValueAndRun {
386        handle: Arc<ProgressHandle>,
387        value: usize,
388        pre: Option<Box<dyn FnOnce() + Send>>,
389        post: Option<Box<dyn FnOnce() + Send>>,
390    },
391    SetMax {
392        handle: Arc<ProgressHandle>,
393        max: usize,
394    },
395    SetMaxAndRun {
396        handle: Arc<ProgressHandle>,
397        max: usize,
398        pre: Option<Box<dyn FnOnce() + Send>>,
399        post: Option<Box<dyn FnOnce() + Send>>,
400    },
401    Custom {
402        _handle: Arc<ProgressHandle>,
403        f: Box<dyn FnOnce() + Send>,
404    },
405    Flush {
406        _handle: Arc<ProgressHandle>,
407        done: std::sync::mpsc::Sender<()>,
408    },
409    Notify {
410        handle: Arc<ProgressHandle>,
411        msg: String,
412    },
413}
414
415fn spawn_progress_thread(rx: std::sync::mpsc::Receiver<ProgressEvent>) {
416    std::thread::spawn(move || {
417        while let Ok(event) = rx.recv() {
418            match event {
419                ProgressEvent::Check => {}
420                ProgressEvent::Exit => {
421                    break;
422                }
423                ProgressEvent::Increment { handle } => {
424                    let value = handle.value.fetch_add(1, Ordering::SeqCst);
425                    let max = handle.max.load(Ordering::SeqCst);
426                    handle.renderer.on_update(value + 1, max);
427                }
428                ProgressEvent::IncrementAndRun { handle, pre, post } => {
429                    run_optional_log_panic(pre);
430                    let value = handle.value.fetch_add(1, Ordering::SeqCst);
431                    run_optional_log_panic(post);
432                    let max = handle.max.load(Ordering::SeqCst);
433                    handle.renderer.on_update(value + 1, max);
434                }
435                ProgressEvent::SetValue { handle, value } => {
436                    handle.value.store(value, Ordering::SeqCst);
437                    let max = handle.max.load(Ordering::SeqCst);
438                    handle.renderer.on_update(value, max);
439                }
440                ProgressEvent::SetValueAndRun {
441                    handle,
442                    value,
443                    pre,
444                    post,
445                } => {
446                    handle.value.store(value, Ordering::SeqCst);
447                    run_optional_log_panic(pre);
448                    let max = handle.max.load(Ordering::SeqCst);
449                    run_optional_log_panic(post);
450                    handle.renderer.on_update(value, max);
451                }
452                ProgressEvent::SetMax { max, handle } => {
453                    let value = handle.value.load(Ordering::SeqCst);
454                    handle.max.store(max, Ordering::SeqCst);
455                    handle.renderer.on_update(value, max);
456                }
457                ProgressEvent::SetMaxAndRun {
458                    handle,
459                    max,
460                    pre,
461                    post,
462                } => {
463                    let value = handle.value.load(Ordering::SeqCst);
464                    run_optional_log_panic(pre);
465                    handle.max.store(max, Ordering::SeqCst);
466                    run_optional_log_panic(post);
467                    handle.renderer.on_update(value, max);
468                }
469                ProgressEvent::Custom { f, _handle } => run_or_log_panic(f),
470                ProgressEvent::Flush { done, _handle } => done.send(()).unwrap(),
471                ProgressEvent::Update { handle } => {
472                    let value = handle.value.load(Ordering::SeqCst);
473                    let max = handle.max.load(Ordering::SeqCst);
474                    handle.renderer.on_update(value, max);
475                }
476                ProgressEvent::Notify { handle, msg } => {
477                    let value = handle.value.load(Ordering::SeqCst);
478                    let max = handle.max.load(Ordering::SeqCst);
479                    handle.renderer.on_notify(msg);
480                    handle.renderer.on_update(value, max);
481                }
482            }
483        }
484    });
485}
486
487fn box_dyn<F: FnOnce() + Send + 'static>(f: F) -> Box<dyn FnOnce() + Send + 'static> {
488    Box::new(f)
489}
490
491#[inline(always)]
492fn run_or_log_panic<F: FnOnce() + Send + 'static>(f: F) {
493    if let Err(err) = std::panic::catch_unwind(AssertUnwindSafe(f)) {
494        error!("ProgressBar closure panicked: {err:?}")
495    }
496}
497
498#[inline(always)]
499fn run_optional_log_panic<F: FnOnce() + Send + 'static>(f: Option<F>) {
500    if let Some(f) = f {
501        run_or_log_panic(f);
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use std::sync::{
508        Arc, Mutex,
509        atomic::{AtomicBool, Ordering},
510    };
511
512    use crate::progress::{NullProgressRenderer, ProgressBar, ProgressRenderer};
513
514    struct TestRenderer {
515        log: Arc<Mutex<Vec<String>>>,
516    }
517
518    impl TestRenderer {
519        fn new() -> Self {
520            Self {
521                log: Arc::new(Mutex::new(Vec::new())),
522            }
523        }
524
525        fn log(&self) -> Vec<String> {
526            self.log.lock().unwrap().clone()
527        }
528
529        fn push_log(&self, log: String) {
530            self.log.lock().unwrap().push(log);
531        }
532    }
533
534    impl ProgressRenderer for TestRenderer {
535        fn on_start(&self, value: usize, max: usize) {
536            self.push_log(format!("Start: [{value} / {max}]"));
537        }
538
539        fn on_update(&self, value: usize, max: usize) {
540            self.push_log(format!("Update: [{value} / {max}]"));
541        }
542
543        fn on_finish(&self) {
544            self.push_log(format!("Finish"));
545        }
546
547        fn on_notify(&self, msg: String) {
548            self.push_log(msg.to_owned());
549        }
550
551        fn set_label(&self, _msg: &str) {}
552    }
553
554    #[test]
555    fn progress_null_renderer_doesnt_panic() {
556        let renderer = Arc::new(NullProgressRenderer);
557        let bar = ProgressBar::new(0, 3, renderer.clone());
558
559        bar.increment();
560        bar.update();
561        bar.set_value(0);
562        bar.set_max(0);
563        bar.is_alive();
564        bar.notify("");
565        bar.flush();
566    }
567
568    #[test]
569    fn progress_increment() {
570        let renderer = Arc::new(TestRenderer::new());
571        let bar = ProgressBar::new(0, 3, renderer.clone());
572
573        bar.increment();
574        bar.increment();
575        bar.flush();
576
577        let log = renderer.log();
578        assert_eq!(log.len(), 3);
579        assert_eq!(log[0], "Start: [0 / 3]");
580        assert_eq!(log[1], "Update: [1 / 3]");
581        assert_eq!(log[2], "Update: [2 / 3]");
582    }
583
584    #[test]
585    fn progress_notify() {
586        let renderer = Arc::new(TestRenderer::new());
587        let bar = ProgressBar::new(0, 3, renderer.clone());
588
589        bar.increment();
590        bar.flush();
591        bar.notify("notification");
592        bar.flush();
593
594        let log = renderer.log();
595
596        log.iter().for_each(|s| println!("{s}"));
597
598        assert_eq!(log.len(), 4);
599        assert_eq!(log[0], "Start: [0 / 3]");
600        assert_eq!(log[1], "Update: [1 / 3]");
601        assert_eq!(log[2], "notification");
602        assert_eq!(log[3], "Update: [1 / 3]");
603    }
604
605    #[test]
606    fn progress_finishes_on_drop() {
607        let renderer = Arc::new(TestRenderer::new());
608        let bar = ProgressBar::new(0, 0, renderer.clone());
609
610        drop(bar);
611
612        let log = renderer.log();
613        assert_eq!(log.len(), 2);
614        assert_eq!(log[0], "Start: [0 / 0]");
615        assert_eq!(log[1], "Finish")
616    }
617
618    #[test]
619    fn progress_closures_run_on_increment() {
620        let renderer = Arc::new(TestRenderer::new());
621        let bar = ProgressBar::new(0, 1, renderer.clone());
622
623        let pre_ran = Arc::new(AtomicBool::new(false));
624        let post_ran = Arc::new(AtomicBool::new(false));
625
626        let pre_flag = pre_ran.clone();
627        let post_flag = post_ran.clone();
628
629        bar.increment_and_run(
630            Some(move || pre_flag.store(true, Ordering::SeqCst)),
631            Some(move || post_flag.store(true, Ordering::SeqCst)),
632        );
633        bar.flush();
634
635        assert!(pre_ran.load(Ordering::SeqCst));
636        assert!(post_ran.load(Ordering::SeqCst));
637
638        let log = renderer.log();
639        assert_eq!(log.len(), 2);
640        assert_eq!(log[0], "Start: [0 / 1]");
641        assert_eq!(log[1], "Update: [1 / 1]");
642    }
643
644    #[test]
645    fn progress_closures_run_on_set_value() {
646        let renderer = Arc::new(TestRenderer::new());
647        let bar = ProgressBar::new(0, 5, renderer.clone());
648
649        let pre_ran = Arc::new(AtomicBool::new(false));
650        let post_ran = Arc::new(AtomicBool::new(false));
651
652        let pre_flag = pre_ran.clone();
653        let post_flag = post_ran.clone();
654
655        bar.set_value_and_run(
656            5,
657            Some(move || pre_flag.store(true, Ordering::SeqCst)),
658            Some(move || post_flag.store(true, Ordering::SeqCst)),
659        );
660        bar.flush();
661
662        assert!(pre_ran.load(Ordering::SeqCst));
663        assert!(post_ran.load(Ordering::SeqCst));
664
665        let log = renderer.log();
666        assert_eq!(log.len(), 2);
667        assert_eq!(log[0], "Start: [0 / 5]");
668        assert_eq!(log[1], "Update: [5 / 5]");
669    }
670
671    #[test]
672    fn progress_closures_run_on_set_max() {
673        let renderer = Arc::new(TestRenderer::new());
674        let bar = ProgressBar::new(0, 3, renderer.clone());
675
676        let pre_ran = Arc::new(AtomicBool::new(false));
677        let post_ran = Arc::new(AtomicBool::new(false));
678
679        let pre_flag = pre_ran.clone();
680        let post_flag = post_ran.clone();
681
682        bar.set_max_and_run(
683            5,
684            Some(move || pre_flag.store(true, Ordering::SeqCst)),
685            Some(move || post_flag.store(true, Ordering::SeqCst)),
686        );
687        bar.flush();
688
689        assert!(pre_ran.load(Ordering::SeqCst));
690        assert!(post_ran.load(Ordering::SeqCst));
691
692        let log = renderer.log();
693        assert_eq!(log.len(), 2);
694        assert_eq!(log[0], "Start: [0 / 3]");
695        assert_eq!(log[1], "Update: [0 / 5]");
696    }
697
698    #[test]
699    fn progress_thread_safe_on_closure_panic() {
700        let renderer = Arc::new(TestRenderer::new());
701        let bar = ProgressBar::new(0, 100, renderer.clone());
702
703        bar.increment_and_run(Some(|| panic!()), Some(|| ()));
704        bar.flush();
705    }
706
707    #[test]
708    fn progress_thread_stability() {
709        let renderer = Arc::new(TestRenderer::new());
710        let bar = ProgressBar::new(0, 100, renderer.clone());
711
712        let handles: Vec<_> = (0..5)
713            .map(|_| {
714                let bar = bar.clone();
715                std::thread::spawn(move || {
716                    for _ in 0..20 {
717                        bar.increment();
718                    }
719                })
720            })
721            .collect();
722
723        for h in handles {
724            h.join().unwrap()
725        }
726
727        bar.flush();
728
729        let log = renderer.log();
730        assert_eq!(log.len(), 101);
731
732        assert_eq!(log[0], "Start: [0 / 100]");
733        for (i, entry) in log.iter().skip(1).enumerate() {
734            assert_eq!(entry, &format!("Update: [{i} / 100]", i = i + 1))
735        }
736    }
737}