use indicatif::{
MultiProgress, ProgressBar, ProgressBarIter, ProgressDrawTarget, ProgressIterator,
ProgressStyle,
};
use mistralrs_quant::get_immediate_isq;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::*;
use std::iter::Iterator;
use std::sync::atomic::{AtomicUsize, Ordering};
use tqdm::Iter;
static PROGRESS_SUPPRESS_COUNT: AtomicUsize = AtomicUsize::new(0);
pub struct ProgressScopeGuard {
suppressed: bool,
}
impl ProgressScopeGuard {
pub fn new(silent: bool) -> Self {
if silent {
PROGRESS_SUPPRESS_COUNT.fetch_add(1, Ordering::SeqCst);
}
Self { suppressed: silent }
}
}
impl Drop for ProgressScopeGuard {
fn drop(&mut self) {
if self.suppressed {
PROGRESS_SUPPRESS_COUNT.fetch_sub(1, Ordering::SeqCst);
}
}
}
#[inline]
pub fn progress_suppressed() -> bool {
PROGRESS_SUPPRESS_COUNT.load(Ordering::SeqCst) > 0
}
#[inline]
pub fn configure_progress_bar(bar: &ProgressBar) {
if progress_suppressed() {
bar.set_draw_target(ProgressDrawTarget::hidden());
}
}
pub fn new_multi_progress() -> MultiProgress {
let multi = MultiProgress::new();
if progress_suppressed() {
multi.set_draw_target(ProgressDrawTarget::hidden());
}
multi
}
pub trait IterWithProgress<'a, T>: Iterator<Item = T> + 'a {
fn with_progress(self, is_silent: bool) -> Box<dyn Iterator<Item = T> + 'a>
where
Self: Sized,
{
if is_silent {
Box::new(self)
} else {
Box::new(self.tqdm())
}
}
}
impl<'a, T: Iterator + 'a> IterWithProgress<'a, T::Item> for T {}
pub struct NiceProgressBar<'a, T: ExactSizeIterator, const COLOR: char = 'b'>(
pub T,
pub &'static str,
pub &'a MultiProgress,
);
impl<T: ExactSizeIterator, const COLOR: char> IntoIterator for NiceProgressBar<'_, T, COLOR> {
type IntoIter = ProgressBarIter<T>;
type Item = T::Item;
fn into_iter(self) -> Self::IntoIter {
let color = match COLOR {
'b' => "blue",
'g' => "green",
'r' => "red",
other => panic!("Color char `{other}` not supported"),
};
let bar = ProgressBar::new(self.0.len() as u64);
configure_progress_bar(&bar);
bar.set_style(
ProgressStyle::default_bar()
.template(&format!(
"{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
self.1
))
.unwrap()
.progress_chars("#>-"),
);
self.2.add(bar.clone());
self.0.progress_with(bar)
}
}
pub struct ParProgress<I> {
iter: I,
bar: ProgressBar,
}
impl<I> ParallelIterator for ParProgress<I>
where
I: ParallelIterator,
I::Item: Send,
{
type Item = I::Item;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
{
let bar = self.bar.clone();
let iter = self.iter.map(move |item| {
bar.inc(1);
item
});
iter.drive_unindexed(consumer)
}
}
impl<I> IndexedParallelIterator for ParProgress<I>
where
I: IndexedParallelIterator,
I::Item: Send,
{
fn len(&self) -> usize {
self.iter.len()
}
fn drive<C>(self, consumer: C) -> C::Result
where
C: rayon::iter::plumbing::Consumer<Self::Item>,
{
let bar = self.bar.clone();
let iter = self.iter.map(move |item| {
bar.inc(1);
item
});
iter.drive(consumer)
}
fn with_producer<CB>(self, callback: CB) -> CB::Output
where
CB: rayon::iter::plumbing::ProducerCallback<Self::Item>,
{
let bar = self.bar.clone();
let iter = self.iter.map(move |item| {
bar.inc(1);
item
});
iter.with_producer(callback)
}
}
impl<'a, T, const COLOR: char> IntoParallelIterator for NiceProgressBar<'a, T, COLOR>
where
T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
<T as IntoParallelIterator>::Item: Send + 'a,
T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
+ IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
+ Send,
{
type Iter = ParProgress<T::Iter>;
type Item = <T as IntoParallelIterator>::Item;
fn into_par_iter(self) -> Self::Iter {
let color = match COLOR {
'b' => "blue",
'g' => "green",
'r' => "red",
other => panic!("Color char `{other}` not supported"),
};
let bar = ProgressBar::new(self.0.len() as u64);
configure_progress_bar(&bar);
bar.set_style(
ProgressStyle::default_bar()
.template(&format!(
"{}: [{{elapsed_precise}}] [{{bar:40.{color}/{color}}}] {{pos}}/{{len}} ({{eta}})",
self.1
))
.unwrap()
.progress_chars("#>-"),
);
self.2.add(bar.clone());
ParProgress {
iter: self.0.into_par_iter(),
bar,
}
}
}
impl<'a, T, const COLOR: char> NiceProgressBar<'a, T, COLOR>
where
T: ExactSizeIterator + IntoParallelIterator + Send + Sync + 'a,
<T as IntoParallelIterator>::Item: Send + 'a,
T::Iter: ParallelIterator<Item = <T as IntoParallelIterator>::Item>
+ IndexedParallelIterator<Item = <T as IntoParallelIterator>::Item>
+ Send,
T: IntoParallelIterator<Item = <T as Iterator>::Item>,
{
pub fn run<F, U>(self, _is_parallel: bool, f: F) -> candle_core::Result<Vec<U>>
where
F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
U: Send,
{
self.into_iter().map(f).collect()
}
pub fn par_iter_if_isq<F, U>(self, f: F) -> candle_core::Result<Vec<U>>
where
F: Fn(<T as IntoParallelIterator>::Item) -> candle_core::Result<U> + Sync + Send,
U: Send,
{
self.run(get_immediate_isq().is_some_and(|x| x.ty.is_some()), f)
}
}