1use std::{
2 panic::AssertUnwindSafe,
3 sync::{
4 Arc, LazyLock,
5 atomic::{AtomicUsize, Ordering},
6 mpsc::{Sender, channel},
7 },
8};
9
10use crate::error;
11
12#[derive(Clone)]
14pub struct ProgressBar {
15 handle: Arc<ProgressHandle>,
16}
17
18struct ProgressHandle {
19 value: AtomicUsize,
20 max: AtomicUsize,
21 renderer: Arc<dyn ProgressRenderer>,
22}
23
24impl ProgressBar {
25 pub fn custom<F>(&self, f: F)
35 where
36 F: FnOnce() + Send + 'static,
37 {
38 GLOBAL_SENDER
39 .send(ProgressEvent::Custom {
40 _handle: self.handle.clone(),
41 f: Box::new(f),
42 })
43 .expect("Progress handler thread died.");
44 }
45
46 pub fn flush(&self) {
52 let (tx, rx) = channel();
53 GLOBAL_SENDER
54 .send(ProgressEvent::Flush {
55 _handle: self.handle.clone(),
56 done: tx,
57 })
58 .expect("Progress handler thread died.");
59 let _ = rx.recv();
60 }
61
62 pub fn increment(&self) {
68 GLOBAL_SENDER
69 .send(ProgressEvent::Increment {
70 handle: self.handle.clone(),
71 })
72 .expect("Progress handler thread died.");
73 }
74
75 pub fn increment_and_run<P, Q>(&self, pre: Option<P>, post: Option<Q>)
84 where
85 P: FnOnce() + Send + 'static,
86 Q: FnOnce() + Send + 'static,
87 {
88 GLOBAL_SENDER
89 .send(ProgressEvent::IncrementAndRun {
90 handle: self.handle.clone(),
91 pre: pre.map(box_dyn),
92 post: post.map(box_dyn),
93 })
94 .expect("Progress handler thread died.");
95 }
96
97 pub fn set_max(&self, max: usize) {
103 GLOBAL_SENDER
104 .send(ProgressEvent::SetMax {
105 handle: self.handle.clone(),
106 max,
107 })
108 .expect("Progress handler thread died.");
109 }
110
111 pub fn set_max_and_run<P, Q>(&self, max: usize, pre: Option<P>, post: Option<Q>)
124 where
125 P: FnOnce() + Send + 'static,
126 Q: FnOnce() + Send + 'static,
127 {
128 GLOBAL_SENDER
129 .send(ProgressEvent::SetMaxAndRun {
130 handle: self.handle.clone(),
131 max,
132 pre: pre.map(box_dyn),
133 post: post.map(box_dyn),
134 })
135 .expect("Progress handler thread died.");
136 }
137
138 pub fn set_value(&self, value: usize) {
144 GLOBAL_SENDER
145 .send(ProgressEvent::SetValue {
146 handle: self.handle.clone(),
147 value,
148 })
149 .expect("Progress handler thread died.");
150 }
151
152 pub fn set_value_and_run<P, Q>(&self, value: usize, pre: Option<P>, post: Option<Q>)
165 where
166 P: FnOnce() + Send + 'static,
167 Q: FnOnce() + Send + 'static,
168 {
169 GLOBAL_SENDER
170 .send(ProgressEvent::SetValueAndRun {
171 handle: self.handle.clone(),
172 value,
173 pre: pre.map(box_dyn),
174 post: post.map(box_dyn),
175 })
176 .expect("Progress handler thread died.");
177 }
178
179 pub fn start(value: usize, max: usize, renderer: Arc<dyn ProgressRenderer>) -> Self {
183 let handle = Arc::new(ProgressHandle {
184 value: AtomicUsize::new(value),
185 max: AtomicUsize::new(max),
186 renderer: renderer.clone(),
187 });
188
189 renderer.on_start(value, max);
190
191 Self { handle }
192 }
193}
194
195pub trait ProgressRenderer: Send + Sync {
197 fn on_start(&self, value: usize, max: usize);
199
200 fn on_update(&self, value: usize, max: usize);
202
203 fn on_finish(&self);
209}
210
211impl Drop for ProgressHandle {
212 fn drop(&mut self) {
213 self.renderer.on_finish();
214 }
215}
216
217enum ProgressEvent {
222 Increment {
223 handle: Arc<ProgressHandle>,
224 },
225 IncrementAndRun {
226 handle: Arc<ProgressHandle>,
227 pre: Option<Box<dyn FnOnce() + Send>>,
228 post: Option<Box<dyn FnOnce() + Send>>,
229 },
230 SetValue {
231 handle: Arc<ProgressHandle>,
232 value: usize,
233 },
234 SetValueAndRun {
235 handle: Arc<ProgressHandle>,
236 value: usize,
237 pre: Option<Box<dyn FnOnce() + Send>>,
238 post: Option<Box<dyn FnOnce() + Send>>,
239 },
240 SetMax {
241 handle: Arc<ProgressHandle>,
242 max: usize,
243 },
244 SetMaxAndRun {
245 handle: Arc<ProgressHandle>,
246 max: usize,
247 pre: Option<Box<dyn FnOnce() + Send>>,
248 post: Option<Box<dyn FnOnce() + Send>>,
249 },
250 Custom {
251 _handle: Arc<ProgressHandle>,
252 f: Box<dyn FnOnce() + Send>,
253 },
254 Flush {
255 _handle: Arc<ProgressHandle>,
256 done: std::sync::mpsc::Sender<()>,
257 },
258}
259
260static GLOBAL_SENDER: LazyLock<Sender<ProgressEvent>> = LazyLock::new(|| {
261 let (tx, rx) = channel::<ProgressEvent>();
262
263 std::thread::spawn(move || {
264 while let Ok(event) = rx.recv() {
265 match event {
266 ProgressEvent::Increment { handle } => {
267 let value = handle.value.fetch_add(1, Ordering::SeqCst);
268 let max = handle.max.load(Ordering::SeqCst);
269 handle.renderer.on_update(value + 1, max);
270 }
271 ProgressEvent::IncrementAndRun { handle, pre, post } => {
272 run_optional_log_panic(pre);
273 let value = handle.value.fetch_add(1, Ordering::SeqCst);
274 run_optional_log_panic(post);
275 let max = handle.max.load(Ordering::SeqCst);
276 handle.renderer.on_update(value + 1, max);
277 }
278 ProgressEvent::SetValue { handle, value } => {
279 handle.value.store(value, Ordering::SeqCst);
280 let max = handle.max.load(Ordering::SeqCst);
281 handle.renderer.on_update(value, max);
282 }
283 ProgressEvent::SetValueAndRun {
284 handle,
285 value,
286 pre,
287 post,
288 } => {
289 handle.value.store(value, Ordering::SeqCst);
290 run_optional_log_panic(pre);
291 let max = handle.max.load(Ordering::SeqCst);
292 run_optional_log_panic(post);
293 handle.renderer.on_update(value, max);
294 }
295 ProgressEvent::SetMax { handle, max } => {
296 let value = handle.value.load(Ordering::SeqCst);
297 handle.max.store(max, Ordering::SeqCst);
298 handle.renderer.on_update(value, max);
299 }
300 ProgressEvent::SetMaxAndRun {
301 handle,
302 max,
303 pre,
304 post,
305 } => {
306 let value = handle.value.load(Ordering::SeqCst);
307 run_optional_log_panic(pre);
308 handle.max.store(max, Ordering::SeqCst);
309 run_optional_log_panic(post);
310 handle.renderer.on_update(value, max);
311 }
312 ProgressEvent::Custom { _handle, f } => run_or_log_panic(f),
313 ProgressEvent::Flush { _handle, done } => done.send(()).unwrap(),
314 }
315 }
316 });
317
318 tx
319});
320
321fn box_dyn<F: FnOnce() + Send + 'static>(f: F) -> Box<dyn FnOnce() + Send + 'static> {
322 Box::new(f)
323}
324
325#[inline(always)]
326fn run_or_log_panic<F: FnOnce() + Send + 'static>(f: F) {
327 if let Err(err) = std::panic::catch_unwind(AssertUnwindSafe(f)) {
328 error!("ProgressBar closure panicked: {err:?}")
329 }
330}
331
332#[inline(always)]
333fn run_optional_log_panic<F: FnOnce() + Send + 'static>(f: Option<F>) {
334 if let Some(f) = f {
335 run_or_log_panic(f);
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use std::sync::{
342 Arc, Mutex,
343 atomic::{AtomicBool, Ordering},
344 };
345
346 use crate::progress::{ProgressBar, ProgressRenderer};
347
348 struct TestRenderer {
349 log: Arc<Mutex<Vec<String>>>,
350 }
351
352 impl TestRenderer {
353 fn new() -> Self {
354 Self {
355 log: Arc::new(Mutex::new(Vec::new())),
356 }
357 }
358
359 fn log(&self) -> Vec<String> {
360 self.log.lock().unwrap().clone()
361 }
362
363 fn push_log(&self, log: String) {
364 self.log.lock().unwrap().push(log);
365 }
366 }
367
368 impl ProgressRenderer for TestRenderer {
369 fn on_start(&self, value: usize, max: usize) {
370 self.push_log(format!("Start: [{value} / {max}]"));
371 }
372
373 fn on_update(&self, value: usize, max: usize) {
374 self.push_log(format!("Update: [{value} / {max}]"));
375 }
376
377 fn on_finish(&self) {
378 self.push_log(format!("Finish"));
379 }
380 }
381
382 #[test]
383 fn progress_increment() {
384 let renderer = Arc::new(TestRenderer::new());
385 let bar = ProgressBar::start(0, 3, renderer.clone());
386
387 bar.increment();
388 bar.increment();
389 bar.flush();
390
391 let log = renderer.log();
392 assert_eq!(log.len(), 3);
393 assert_eq!(log[0], "Start: [0 / 3]");
394 assert_eq!(log[1], "Update: [1 / 3]");
395 assert_eq!(log[2], "Update: [2 / 3]");
396 }
397
398 #[test]
399 fn progress_finishes_on_drop() {
400 let renderer = Arc::new(TestRenderer::new());
401 let bar = ProgressBar::start(0, 0, renderer.clone());
402
403 drop(bar);
404
405 let log = renderer.log();
406 assert_eq!(log.len(), 2);
407 assert_eq!(log[0], "Start: [0 / 0]");
408 assert_eq!(log[1], "Finish")
409 }
410
411 #[test]
412 fn progress_closures_run_on_increment() {
413 let renderer = Arc::new(TestRenderer::new());
414 let bar = ProgressBar::start(0, 1, renderer.clone());
415
416 let pre_ran = Arc::new(AtomicBool::new(false));
417 let post_ran = Arc::new(AtomicBool::new(false));
418
419 let pre_flag = pre_ran.clone();
420 let post_flag = post_ran.clone();
421
422 bar.increment_and_run(
423 Some(move || pre_flag.store(true, Ordering::SeqCst)),
424 Some(move || post_flag.store(true, Ordering::SeqCst)),
425 );
426 bar.flush();
427
428 assert!(pre_ran.load(Ordering::SeqCst));
429 assert!(post_ran.load(Ordering::SeqCst));
430
431 let log = renderer.log();
432 assert_eq!(log.len(), 2);
433 assert_eq!(log[0], "Start: [0 / 1]");
434 assert_eq!(log[1], "Update: [1 / 1]");
435 }
436
437 #[test]
438 fn progress_closures_run_on_set_value() {
439 let renderer = Arc::new(TestRenderer::new());
440 let bar = ProgressBar::start(0, 5, renderer.clone());
441
442 let pre_ran = Arc::new(AtomicBool::new(false));
443 let post_ran = Arc::new(AtomicBool::new(false));
444
445 let pre_flag = pre_ran.clone();
446 let post_flag = post_ran.clone();
447
448 bar.set_value_and_run(
449 5,
450 Some(move || pre_flag.store(true, Ordering::SeqCst)),
451 Some(move || post_flag.store(true, Ordering::SeqCst)),
452 );
453 bar.flush();
454
455 assert!(pre_ran.load(Ordering::SeqCst));
456 assert!(post_ran.load(Ordering::SeqCst));
457
458 let log = renderer.log();
459 assert_eq!(log.len(), 2);
460 assert_eq!(log[0], "Start: [0 / 5]");
461 assert_eq!(log[1], "Update: [5 / 5]");
462 }
463
464 #[test]
465 fn progress_closures_run_on_set_max() {
466 let renderer = Arc::new(TestRenderer::new());
467 let bar = ProgressBar::start(0, 3, renderer.clone());
468
469 let pre_ran = Arc::new(AtomicBool::new(false));
470 let post_ran = Arc::new(AtomicBool::new(false));
471
472 let pre_flag = pre_ran.clone();
473 let post_flag = post_ran.clone();
474
475 bar.set_max_and_run(
476 5,
477 Some(move || pre_flag.store(true, Ordering::SeqCst)),
478 Some(move || post_flag.store(true, Ordering::SeqCst)),
479 );
480 bar.flush();
481
482 assert!(pre_ran.load(Ordering::SeqCst));
483 assert!(post_ran.load(Ordering::SeqCst));
484
485 let log = renderer.log();
486 assert_eq!(log.len(), 2);
487 assert_eq!(log[0], "Start: [0 / 3]");
488 assert_eq!(log[1], "Update: [0 / 5]");
489 }
490
491 #[test]
492 fn progress_thread_safe_on_closure_panic() {
493 let renderer = Arc::new(TestRenderer::new());
494 let bar = ProgressBar::start(0, 100, renderer.clone());
495
496 bar.increment_and_run(Some(|| panic!()), Some(|| ()));
497 bar.flush();
498 }
499
500 #[test]
501 fn progress_thread_stability() {
502 let renderer = Arc::new(TestRenderer::new());
503 let bar = ProgressBar::start(0, 100, renderer.clone());
504
505 let handles: Vec<_> = (0..5)
506 .map(|_| {
507 let bar = bar.clone();
508 std::thread::spawn(move || {
509 for _ in 0..20 {
510 bar.increment();
511 }
512 })
513 })
514 .collect();
515
516 for h in handles {
517 h.join().unwrap()
518 }
519
520 bar.flush();
521
522 let log = renderer.log();
523 assert_eq!(log.len(), 101);
524
525 assert_eq!(log[0], "Start: [0 / 100]");
526 for (i, entry) in log.iter().skip(1).enumerate() {
527 assert_eq!(entry, &format!("Update: [{i} / 100]", i = i + 1))
528 }
529 }
530}