1use std::{
2 panic::AssertUnwindSafe,
3 sync::{
4 Arc,
5 atomic::{AtomicUsize, Ordering},
6 mpsc::{Sender, channel},
7 },
8};
9
10use crate::error;
11
12mod cli_progress_bar;
13pub use cli_progress_bar::*;
14
15mod null_progress_renderer;
16pub use null_progress_renderer::*;
17
18#[derive(Clone)]
20pub struct ProgressBar {
21 handle: Arc<ProgressHandle>,
22}
23
24struct ProgressHandle {
25 value: AtomicUsize,
26 max: AtomicUsize,
27 renderer: Arc<dyn ProgressRenderer>,
28 sender: Sender<ProgressEvent>,
29}
30
31impl Drop for ProgressHandle {
32 fn drop(&mut self) {
33 if !self.renderer.__is_null() {
34 let _ = self.sender.send(ProgressEvent::Exit);
35 self.renderer.on_finish();
36 }
37 }
38}
39
40impl ProgressBar {
41 pub fn custom<F>(&self, f: F)
51 where
52 F: FnOnce() + Send + 'static,
53 {
54 if self.handle.renderer.__is_null() {
55 return;
56 }
57
58 self.handle
59 .sender
60 .send(ProgressEvent::Custom {
61 f: Box::new(f),
62 _handle: self.handle.clone(),
63 })
64 .expect("Progress handler thread died.");
65 }
66
67 pub fn flush(&self) {
73 if self.handle.renderer.__is_null() {
74 return;
75 }
76
77 let (tx, rx) = channel();
78 self.handle
79 .sender
80 .send(ProgressEvent::Flush {
81 done: tx,
82 _handle: self.handle.clone(),
83 })
84 .expect("Progress handler thread died.");
85 let _ = rx.recv();
86 }
87
88 pub fn get_max(&self) -> usize {
94 self.handle.max.load(Ordering::SeqCst)
95 }
96
97 pub fn get_value(&self) -> usize {
103 self.handle.value.load(Ordering::SeqCst)
104 }
105
106 pub fn increment(&self) {
112 if self.handle.renderer.__is_null() {
113 return;
114 }
115
116 self.handle
117 .sender
118 .send(ProgressEvent::Increment {
119 handle: self.handle.clone(),
120 })
121 .expect("Progress handler thread died.");
122 }
123
124 pub fn increment_and_run<P, Q>(&self, pre: Option<P>, post: Option<Q>)
133 where
134 P: FnOnce() + Send + 'static,
135 Q: FnOnce() + Send + 'static,
136 {
137 if self.handle.renderer.__is_null() {
138 return;
139 }
140
141 self.handle
142 .sender
143 .send(ProgressEvent::IncrementAndRun {
144 handle: self.handle.clone(),
145 pre: pre.map(box_dyn),
146 post: post.map(box_dyn),
147 })
148 .expect("Progress handler thread died.");
149 }
150
151 pub fn new(value: usize, max: usize, renderer: Arc<dyn ProgressRenderer>) -> Self {
155 let (tx, rx) = channel();
156
157 let handle = Arc::new(ProgressHandle {
158 value: AtomicUsize::new(value),
159 max: AtomicUsize::new(max),
160 renderer: renderer.clone(),
161 sender: tx,
162 });
163
164 if !renderer.__is_null() {
165 spawn_progress_thread(rx);
166 renderer.on_start(value, max);
167 }
168
169 Self { handle }
170 }
171
172 pub fn set_max(&self, max: usize) {
178 if self.handle.renderer.__is_null() {
179 return;
180 }
181
182 self.handle
183 .sender
184 .send(ProgressEvent::SetMax {
185 handle: self.handle.clone(),
186 max,
187 })
188 .expect("Progress handler thread died.");
189 }
190
191 pub fn set_max_and_run<P, Q>(&self, max: usize, pre: Option<P>, post: Option<Q>)
204 where
205 P: FnOnce() + Send + 'static,
206 Q: FnOnce() + Send + 'static,
207 {
208 if self.handle.renderer.__is_null() {
209 return;
210 }
211
212 self.handle
213 .sender
214 .send(ProgressEvent::SetMaxAndRun {
215 handle: self.handle.clone(),
216 max,
217 pre: pre.map(box_dyn),
218 post: post.map(box_dyn),
219 })
220 .expect("Progress handler thread died.");
221 }
222
223 pub fn set_value(&self, value: usize) {
229 if self.handle.renderer.__is_null() {
230 return;
231 }
232
233 self.handle
234 .sender
235 .send(ProgressEvent::SetValue {
236 handle: self.handle.clone(),
237 value,
238 })
239 .expect("Progress handler thread died.");
240 }
241
242 pub fn set_value_and_run<P, Q>(&self, value: usize, pre: Option<P>, post: Option<Q>)
255 where
256 P: FnOnce() + Send + 'static,
257 Q: FnOnce() + Send + 'static,
258 {
259 if self.handle.renderer.__is_null() {
260 return;
261 }
262
263 self.handle
264 .sender
265 .send(ProgressEvent::SetValueAndRun {
266 handle: self.handle.clone(),
267 value,
268 pre: pre.map(box_dyn),
269 post: post.map(box_dyn),
270 })
271 .expect("Progress handler thread died.");
272 }
273
274 pub fn update(&self) {
280 if self.handle.renderer.__is_null() {
281 return;
282 }
283
284 self.handle
285 .sender
286 .send(ProgressEvent::Update {
287 handle: self.handle.clone(),
288 })
289 .expect("Progress handler thread died.")
290 }
291
292 pub fn is_alive(&self) -> bool {
298 if self.handle.renderer.__is_null() {
299 return false;
300 }
301
302 self.handle.sender.send(ProgressEvent::Check).is_ok()
303 }
304
305 pub fn set_label(&self, label: &str) {
307 self.handle.renderer.set_label(label);
308 }
309
310 pub fn notify(&self, message: impl ToString) {
316 if self.handle.renderer.__is_null() {
317 return;
318 }
319
320 self.handle
321 .sender
322 .send(ProgressEvent::Notify {
323 handle: self.handle.clone(),
324 msg: message.to_string(),
325 })
326 .expect("Progress handler thread died.")
327 }
328}
329
330pub trait ProgressRenderer: Send + Sync {
332 fn on_start(&self, value: usize, max: usize);
334
335 fn on_update(&self, value: usize, max: usize);
337
338 fn on_finish(&self);
344
345 fn on_notify(&self, msg: String);
349
350 fn set_label(&self, msg: &str);
354
355 #[inline(always)]
357 #[doc(hidden)]
358 fn __is_null(&self) -> bool {
359 false
360 }
361}
362
363enum ProgressEvent {
368 Check,
369 Exit,
370 Update {
371 handle: Arc<ProgressHandle>,
372 },
373 Increment {
374 handle: Arc<ProgressHandle>,
375 },
376 IncrementAndRun {
377 handle: Arc<ProgressHandle>,
378 pre: Option<Box<dyn FnOnce() + Send>>,
379 post: Option<Box<dyn FnOnce() + Send>>,
380 },
381 SetValue {
382 handle: Arc<ProgressHandle>,
383 value: usize,
384 },
385 SetValueAndRun {
386 handle: Arc<ProgressHandle>,
387 value: usize,
388 pre: Option<Box<dyn FnOnce() + Send>>,
389 post: Option<Box<dyn FnOnce() + Send>>,
390 },
391 SetMax {
392 handle: Arc<ProgressHandle>,
393 max: usize,
394 },
395 SetMaxAndRun {
396 handle: Arc<ProgressHandle>,
397 max: usize,
398 pre: Option<Box<dyn FnOnce() + Send>>,
399 post: Option<Box<dyn FnOnce() + Send>>,
400 },
401 Custom {
402 _handle: Arc<ProgressHandle>,
403 f: Box<dyn FnOnce() + Send>,
404 },
405 Flush {
406 _handle: Arc<ProgressHandle>,
407 done: std::sync::mpsc::Sender<()>,
408 },
409 Notify {
410 handle: Arc<ProgressHandle>,
411 msg: String,
412 },
413}
414
415fn spawn_progress_thread(rx: std::sync::mpsc::Receiver<ProgressEvent>) {
416 std::thread::spawn(move || {
417 while let Ok(event) = rx.recv() {
418 match event {
419 ProgressEvent::Check => {}
420 ProgressEvent::Exit => {
421 break;
422 }
423 ProgressEvent::Increment { handle } => {
424 let value = handle.value.fetch_add(1, Ordering::SeqCst);
425 let max = handle.max.load(Ordering::SeqCst);
426 handle.renderer.on_update(value + 1, max);
427 }
428 ProgressEvent::IncrementAndRun { handle, pre, post } => {
429 run_optional_log_panic(pre);
430 let value = handle.value.fetch_add(1, Ordering::SeqCst);
431 run_optional_log_panic(post);
432 let max = handle.max.load(Ordering::SeqCst);
433 handle.renderer.on_update(value + 1, max);
434 }
435 ProgressEvent::SetValue { handle, value } => {
436 handle.value.store(value, Ordering::SeqCst);
437 let max = handle.max.load(Ordering::SeqCst);
438 handle.renderer.on_update(value, max);
439 }
440 ProgressEvent::SetValueAndRun {
441 handle,
442 value,
443 pre,
444 post,
445 } => {
446 handle.value.store(value, Ordering::SeqCst);
447 run_optional_log_panic(pre);
448 let max = handle.max.load(Ordering::SeqCst);
449 run_optional_log_panic(post);
450 handle.renderer.on_update(value, max);
451 }
452 ProgressEvent::SetMax { max, handle } => {
453 let value = handle.value.load(Ordering::SeqCst);
454 handle.max.store(max, Ordering::SeqCst);
455 handle.renderer.on_update(value, max);
456 }
457 ProgressEvent::SetMaxAndRun {
458 handle,
459 max,
460 pre,
461 post,
462 } => {
463 let value = handle.value.load(Ordering::SeqCst);
464 run_optional_log_panic(pre);
465 handle.max.store(max, Ordering::SeqCst);
466 run_optional_log_panic(post);
467 handle.renderer.on_update(value, max);
468 }
469 ProgressEvent::Custom { f, _handle } => run_or_log_panic(f),
470 ProgressEvent::Flush { done, _handle } => done.send(()).unwrap(),
471 ProgressEvent::Update { handle } => {
472 let value = handle.value.load(Ordering::SeqCst);
473 let max = handle.max.load(Ordering::SeqCst);
474 handle.renderer.on_update(value, max);
475 }
476 ProgressEvent::Notify { handle, msg } => {
477 let value = handle.value.load(Ordering::SeqCst);
478 let max = handle.max.load(Ordering::SeqCst);
479 handle.renderer.on_notify(msg);
480 handle.renderer.on_update(value, max);
481 }
482 }
483 }
484 });
485}
486
487fn box_dyn<F: FnOnce() + Send + 'static>(f: F) -> Box<dyn FnOnce() + Send + 'static> {
488 Box::new(f)
489}
490
491#[inline(always)]
492fn run_or_log_panic<F: FnOnce() + Send + 'static>(f: F) {
493 if let Err(err) = std::panic::catch_unwind(AssertUnwindSafe(f)) {
494 error!("ProgressBar closure panicked: {err:?}")
495 }
496}
497
498#[inline(always)]
499fn run_optional_log_panic<F: FnOnce() + Send + 'static>(f: Option<F>) {
500 if let Some(f) = f {
501 run_or_log_panic(f);
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use std::sync::{
508 Arc, Mutex,
509 atomic::{AtomicBool, Ordering},
510 };
511
512 use crate::progress::{NullProgressRenderer, ProgressBar, ProgressRenderer};
513
514 struct TestRenderer {
515 log: Arc<Mutex<Vec<String>>>,
516 }
517
518 impl TestRenderer {
519 fn new() -> Self {
520 Self {
521 log: Arc::new(Mutex::new(Vec::new())),
522 }
523 }
524
525 fn log(&self) -> Vec<String> {
526 self.log.lock().unwrap().clone()
527 }
528
529 fn push_log(&self, log: String) {
530 self.log.lock().unwrap().push(log);
531 }
532 }
533
534 impl ProgressRenderer for TestRenderer {
535 fn on_start(&self, value: usize, max: usize) {
536 self.push_log(format!("Start: [{value} / {max}]"));
537 }
538
539 fn on_update(&self, value: usize, max: usize) {
540 self.push_log(format!("Update: [{value} / {max}]"));
541 }
542
543 fn on_finish(&self) {
544 self.push_log(format!("Finish"));
545 }
546
547 fn on_notify(&self, msg: String) {
548 self.push_log(msg.to_owned());
549 }
550
551 fn set_label(&self, _msg: &str) {}
552 }
553
554 #[test]
555 fn progress_null_renderer_doesnt_panic() {
556 let renderer = Arc::new(NullProgressRenderer);
557 let bar = ProgressBar::new(0, 3, renderer.clone());
558
559 bar.increment();
560 bar.update();
561 bar.set_value(0);
562 bar.set_max(0);
563 bar.is_alive();
564 bar.notify("");
565 bar.flush();
566 }
567
568 #[test]
569 fn progress_increment() {
570 let renderer = Arc::new(TestRenderer::new());
571 let bar = ProgressBar::new(0, 3, renderer.clone());
572
573 bar.increment();
574 bar.increment();
575 bar.flush();
576
577 let log = renderer.log();
578 assert_eq!(log.len(), 3);
579 assert_eq!(log[0], "Start: [0 / 3]");
580 assert_eq!(log[1], "Update: [1 / 3]");
581 assert_eq!(log[2], "Update: [2 / 3]");
582 }
583
584 #[test]
585 fn progress_notify() {
586 let renderer = Arc::new(TestRenderer::new());
587 let bar = ProgressBar::new(0, 3, renderer.clone());
588
589 bar.increment();
590 bar.flush();
591 bar.notify("notification");
592 bar.flush();
593
594 let log = renderer.log();
595
596 log.iter().for_each(|s| println!("{s}"));
597
598 assert_eq!(log.len(), 4);
599 assert_eq!(log[0], "Start: [0 / 3]");
600 assert_eq!(log[1], "Update: [1 / 3]");
601 assert_eq!(log[2], "notification");
602 assert_eq!(log[3], "Update: [1 / 3]");
603 }
604
605 #[test]
606 fn progress_finishes_on_drop() {
607 let renderer = Arc::new(TestRenderer::new());
608 let bar = ProgressBar::new(0, 0, renderer.clone());
609
610 drop(bar);
611
612 let log = renderer.log();
613 assert_eq!(log.len(), 2);
614 assert_eq!(log[0], "Start: [0 / 0]");
615 assert_eq!(log[1], "Finish")
616 }
617
618 #[test]
619 fn progress_closures_run_on_increment() {
620 let renderer = Arc::new(TestRenderer::new());
621 let bar = ProgressBar::new(0, 1, renderer.clone());
622
623 let pre_ran = Arc::new(AtomicBool::new(false));
624 let post_ran = Arc::new(AtomicBool::new(false));
625
626 let pre_flag = pre_ran.clone();
627 let post_flag = post_ran.clone();
628
629 bar.increment_and_run(
630 Some(move || pre_flag.store(true, Ordering::SeqCst)),
631 Some(move || post_flag.store(true, Ordering::SeqCst)),
632 );
633 bar.flush();
634
635 assert!(pre_ran.load(Ordering::SeqCst));
636 assert!(post_ran.load(Ordering::SeqCst));
637
638 let log = renderer.log();
639 assert_eq!(log.len(), 2);
640 assert_eq!(log[0], "Start: [0 / 1]");
641 assert_eq!(log[1], "Update: [1 / 1]");
642 }
643
644 #[test]
645 fn progress_closures_run_on_set_value() {
646 let renderer = Arc::new(TestRenderer::new());
647 let bar = ProgressBar::new(0, 5, renderer.clone());
648
649 let pre_ran = Arc::new(AtomicBool::new(false));
650 let post_ran = Arc::new(AtomicBool::new(false));
651
652 let pre_flag = pre_ran.clone();
653 let post_flag = post_ran.clone();
654
655 bar.set_value_and_run(
656 5,
657 Some(move || pre_flag.store(true, Ordering::SeqCst)),
658 Some(move || post_flag.store(true, Ordering::SeqCst)),
659 );
660 bar.flush();
661
662 assert!(pre_ran.load(Ordering::SeqCst));
663 assert!(post_ran.load(Ordering::SeqCst));
664
665 let log = renderer.log();
666 assert_eq!(log.len(), 2);
667 assert_eq!(log[0], "Start: [0 / 5]");
668 assert_eq!(log[1], "Update: [5 / 5]");
669 }
670
671 #[test]
672 fn progress_closures_run_on_set_max() {
673 let renderer = Arc::new(TestRenderer::new());
674 let bar = ProgressBar::new(0, 3, renderer.clone());
675
676 let pre_ran = Arc::new(AtomicBool::new(false));
677 let post_ran = Arc::new(AtomicBool::new(false));
678
679 let pre_flag = pre_ran.clone();
680 let post_flag = post_ran.clone();
681
682 bar.set_max_and_run(
683 5,
684 Some(move || pre_flag.store(true, Ordering::SeqCst)),
685 Some(move || post_flag.store(true, Ordering::SeqCst)),
686 );
687 bar.flush();
688
689 assert!(pre_ran.load(Ordering::SeqCst));
690 assert!(post_ran.load(Ordering::SeqCst));
691
692 let log = renderer.log();
693 assert_eq!(log.len(), 2);
694 assert_eq!(log[0], "Start: [0 / 3]");
695 assert_eq!(log[1], "Update: [0 / 5]");
696 }
697
698 #[test]
699 fn progress_thread_safe_on_closure_panic() {
700 let renderer = Arc::new(TestRenderer::new());
701 let bar = ProgressBar::new(0, 100, renderer.clone());
702
703 bar.increment_and_run(Some(|| panic!()), Some(|| ()));
704 bar.flush();
705 }
706
707 #[test]
708 fn progress_thread_stability() {
709 let renderer = Arc::new(TestRenderer::new());
710 let bar = ProgressBar::new(0, 100, renderer.clone());
711
712 let handles: Vec<_> = (0..5)
713 .map(|_| {
714 let bar = bar.clone();
715 std::thread::spawn(move || {
716 for _ in 0..20 {
717 bar.increment();
718 }
719 })
720 })
721 .collect();
722
723 for h in handles {
724 h.join().unwrap()
725 }
726
727 bar.flush();
728
729 let log = renderer.log();
730 assert_eq!(log.len(), 101);
731
732 assert_eq!(log[0], "Start: [0 / 100]");
733 for (i, entry) in log.iter().skip(1).enumerate() {
734 assert_eq!(entry, &format!("Update: [{i} / 100]", i = i + 1))
735 }
736 }
737}