#[cfg(test)]
#[path = "../../tests/unit/utils/iterators_test.rs"]
mod iterators_test;
use crate::utils::*;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::hash::Hash;
use std::sync::Arc;
pub trait CollectGroupBy: Iterator {
fn collect_group_by_key<K, V, FA>(self, f: FA) -> HashMap<K, Vec<V>>
where
Self: Sized + Iterator<Item = V>,
K: Hash + Eq,
FA: Fn(&V) -> K,
{
self.map(|v| (f(&v), v)).collect_group_by()
}
fn collect_group_by<K, V>(self) -> HashMap<K, Vec<V>>
where
Self: Sized + Iterator<Item = (K, V)>,
K: Hash + Eq,
{
let mut map = HashMap::new();
for (key, val) in self {
let vec: &mut Vec<_> = map.entry(key).or_default();
vec.push(val);
}
map
}
}
impl<T: Iterator> CollectGroupBy for T {}
pub struct SelectionSamplingIterator<I: Iterator> {
processed: usize,
needed: usize,
size: usize,
iterator: I,
random: Arc<dyn Random>,
}
impl<I: Iterator> SelectionSamplingIterator<I> {
pub fn new(iterator: I, amount: usize, random: Arc<dyn Random>) -> Self {
assert!(amount > 0);
Self {
size: iterator.size_hint().0,
processed: 0,
needed: amount,
iterator,
random,
}
}
}
impl<I: Iterator> Iterator for SelectionSamplingIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
loop {
let left = if self.needed != 0 && self.size > self.processed {
self.size - self.processed
} else {
return None;
};
let probability = self.needed as Float / left as Float;
self.processed += 1;
let next = self.iterator.next();
if next.is_none() || self.random.is_hit(probability) {
self.needed -= 1;
return next;
}
}
}
}
pub fn create_range_sampling_iter<I: Iterator>(
iterator: I,
sample_size: usize,
random: &(dyn Random),
) -> impl Iterator<Item = I::Item> {
let iterator_size = iterator.size_hint().0 as Float;
let sample_count = (iterator_size / sample_size as Float).max(1.) - 1.;
let offset = random.uniform_int(0, sample_count as i32) as usize * sample_size;
iterator.skip(offset).take(sample_size)
}
pub trait SelectionSamplingSearch: Iterator {
fn sample_search<'a, T, R, FM, FI, FC>(
self,
sample_size: usize,
random: Arc<dyn Random>,
mut map_fn: FM,
index_fn: FI,
compare_fn: FC,
) -> Option<R>
where
Self: Sized + Clone + Iterator<Item = T> + 'a,
T: 'a,
R: 'a,
FM: FnMut(T) -> R,
FI: Fn(&T) -> usize,
FC: Fn(&R, &R) -> bool,
{
const N: usize = 32;
let size = self.size_hint().0;
if size == 0 || sample_size == 0 {
return None;
}
let mut state = SearchState::<N, R>::new(sample_size, size);
loop {
let (skip, take) = (state.left, state.right - state.left + 1);
let iterator = self.clone().skip(skip).take(take);
let (orig_right, last_probe_idx) = (state.right, take.min(sample_size - 1));
state = SelectionSamplingIterator::new(iterator, sample_size, random.clone())
.enumerate()
.fold(state, |mut acc, (probe_idx, item)| {
let item_idx = index_fn(&item);
let is_new_item = acc.probe(item_idx);
assert!(
item_idx >= skip && item_idx <= orig_right,
"caller's index_fn returns an index outside of expected range"
);
match &acc.best {
BestItem::Unknown => acc.best = BestItem::Fresh((item_idx, map_fn(item))),
BestItem::Fresh((best_idx, best_value)) | BestItem::Stale((best_idx, best_value)) => {
if matches!(acc.best, BestItem::Stale(_)) {
acc.left = ((item_idx + 1).min(*best_idx)).max(acc.left);
acc.right = ((item_idx.max(1) - 1).max(*best_idx)).min(acc.right);
} else {
if acc.last == *best_idx {
acc.right = item_idx.max(1) - 1
}
}
if is_new_item {
let item_value = map_fn(item);
if compare_fn(&item_value, best_value) {
acc.best = BestItem::Fresh((item_idx, item_value));
acc.left = if probe_idx == 0 { acc.left } else { acc.last + 1 };
acc.right = if probe_idx == last_probe_idx { orig_right } else { item_idx };
}
}
}
}
acc.last = item_idx;
acc
})
.next_range();
if state.is_terminal() {
break;
}
}
state.best.get_value()
}
}
impl<T: Iterator> SelectionSamplingSearch for T {}
enum BestItem<T> {
Unknown,
Stale((usize, T)),
Fresh((usize, T)),
}
impl<T> BestItem<T> {
fn get_value(self) -> Option<T> {
match self {
BestItem::Unknown => None,
BestItem::Stale((_, value)) | BestItem::Fresh((_, value)) => Some(value),
}
}
}
impl<T> Debug for BestItem<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
BestItem::Unknown => write!(f, "X"),
BestItem::Stale((idx, _)) | BestItem::Fresh((idx, _)) => write!(f, "{idx}"),
}
}
}
struct SearchState<const N: usize, T> {
left: usize,
right: usize,
last: usize,
best: BestItem<T>,
bit_array: FixedBitArray<N>,
collisions_limit: i32,
}
impl<const N: usize, T> SearchState<N, T> {
pub fn new(collisions_limit: usize, size: usize) -> Self {
Self {
left: 0,
right: size - 1,
last: 0,
best: BestItem::<T>::Unknown,
bit_array: FixedBitArray::<N>::default(),
collisions_limit: collisions_limit as i32,
}
}
pub fn probe(&mut self, index: usize) -> bool {
if self.bit_array.replace(index, true) {
self.collisions_limit -= 1;
false
} else {
true
}
}
pub fn next_range(self) -> Self {
Self {
best: match self.best {
BestItem::Unknown => BestItem::Unknown,
BestItem::Stale(item) | BestItem::Fresh(item) => BestItem::Stale(item),
},
..self
}
}
pub fn is_terminal(&self) -> bool {
self.left >= self.right || self.collisions_limit <= 0
}
}
impl<const N: usize, T> Debug for SearchState<N, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct(short_type_name::<Self>())
.field("range", &(self.left, self.right))
.field("col_lim", &self.collisions_limit)
.field("best_idx", &self.best)
.field("bits", &format!("{:b}", self.bit_array))
.finish()
}
}