pub trait IteratorExt: Iterator {
fn filter_cnt<P>(self, count: &mut FilterCount, pred: P) -> CountingFilter<'_, P, Self>
where
Self: Sized,
P: FnMut(&Self::Item) -> bool,
{
*count = FilterCount::default();
CountingFilter {
inner: self,
pred,
count,
}
}
}
impl<I> IteratorExt for I where I: Iterator {}
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
#[allow(clippy::exhaustive_structs)]
pub struct FilterCount {
pub n_accepted: usize,
pub n_rejected: usize,
}
pub struct CountingFilter<'a, P, I> {
inner: I,
pred: P,
count: &'a mut FilterCount,
}
impl<'a, P, I> Iterator for CountingFilter<'a, P, I>
where
P: FnMut(&I::Item) -> bool,
I: Iterator,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
for item in &mut self.inner {
if (self.pred)(&item) {
self.count.n_accepted += 1;
return Some(item);
} else {
self.count.n_rejected += 1;
}
}
None
}
}
impl FilterCount {
pub fn display_frac_rejected(&self) -> DisplayFracRejected<'_> {
DisplayFracRejected(self)
}
pub fn count(&mut self, accept: bool) -> bool {
if accept {
self.n_accepted += 1;
} else {
self.n_rejected += 1;
}
accept
}
}
#[derive(Debug, Clone)]
pub struct DisplayFracRejected<'a>(&'a FilterCount);
impl<'a> std::fmt::Display for DisplayFracRejected<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}/{}",
self.0.n_rejected,
self.0.n_accepted + self.0.n_rejected
)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod test {
use super::*;
#[test]
fn counting_filter() {
let mut count = FilterCount::default();
let v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let first_even = v
.iter()
.filter_cnt(&mut count, |val| **val % 2 == 0)
.next()
.unwrap();
assert_eq!(*first_even, 2);
assert_eq!(count.n_accepted, 1);
assert_eq!(count.n_rejected, 1);
let sum_even: usize = v.iter().filter_cnt(&mut count, |val| **val % 2 == 0).sum();
assert_eq!(sum_even, 20);
assert_eq!(count.n_accepted, 4);
assert_eq!(count.n_rejected, 5);
}
#[test]
fn counting_with_predicates() {
let mut count = FilterCount::default();
let v = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let first_even = v.iter().find(|val| count.count(**val % 2 == 0)).unwrap();
assert_eq!(*first_even, 2);
assert_eq!(count.n_accepted, 1);
assert_eq!(count.n_rejected, 1);
let mut count = FilterCount::default();
let sum_even: usize = v.iter().filter(|val| count.count(**val % 2 == 0)).sum();
assert_eq!(sum_even, 20);
assert_eq!(count.n_accepted, 4);
assert_eq!(count.n_rejected, 5);
}
#[test]
fn fooz() {}
}