use compact_str::CompactString;
use crate::{
progress::{Progress, ProgressType},
stack::ProgressStack,
};
pub struct ProgressIter<I> {
iter: I,
progress: Progress,
}
impl<I> ProgressIter<I> {
pub const fn new(iter: I, progress: Progress) -> Self {
Self { iter, progress }
}
}
impl<I: Iterator> Iterator for ProgressIter<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
let item = self.iter.next();
if item.is_some() {
self.progress.inc(1u64);
} else {
self.progress.finish();
}
item
}
}
pub trait ProgressIteratorExt: Sized {
fn progress(self) -> ProgressIter<Self>;
fn progress_with_name(self, name: impl Into<CompactString>) -> ProgressIter<Self>;
fn progress_with(self, progress: Progress) -> ProgressIter<Self>;
fn progress_in(self, stack: &ProgressStack) -> ProgressIter<Self>;
fn type_from_size_hint(&self) -> (ProgressType, u64);
}
impl<I: Iterator> ProgressIteratorExt for I {
fn progress(self) -> ProgressIter<Self> {
self.progress_with_name(CompactString::default())
}
fn progress_with_name(self, name: impl Into<CompactString>) -> ProgressIter<Self> {
let (kind, total) = self.type_from_size_hint();
let progress = Progress::new(kind, name, total);
ProgressIter::new(self, progress)
}
fn progress_with(self, progress: Progress) -> ProgressIter<Self> {
ProgressIter::new(self, progress)
}
fn progress_in(self, stack: &ProgressStack) -> ProgressIter<Self> {
let (kind, total) = self.type_from_size_hint();
let progress = match kind {
ProgressType::Bar => stack.add_pb(CompactString::default(), total),
ProgressType::Spinner => stack.add_spinner(CompactString::default()),
};
ProgressIter::new(self, progress)
}
fn type_from_size_hint(&self) -> (ProgressType, u64) {
let (lower, upper) = self.size_hint();
match upper {
Some(u) if u == lower => (ProgressType::Bar, u as u64),
_ => (ProgressType::Spinner, 0),
}
}
}
#[cfg(test)]
mod tests {
use super::ProgressIteratorExt as _;
#[test]
fn test_iterator_adapter() {
let data = [1, 2, 3, 4, 5];
let mut count = 0;
let iter = data.iter().progress_with_name("iter_test");
let progress_handle = iter.progress.clone();
for _ in iter {
count += 1;
}
assert_eq!(count, 5);
assert_eq!(progress_handle.get_pos(), 5);
assert!(
progress_handle.is_finished(),
"Iterator exhaustion should finish progress"
);
assert_eq!(
progress_handle.get_total(),
5,
"Total should be inferred from Vec len"
);
}
}