use std::collections::{HashMap, HashSet};
use crate::rand::seq::index::sample as choose_range;
use crate::rand::Rng;
pub trait HasLen {
fn len(&self) -> usize;
}
pub trait HasIter {
type Item<'a>
where
Self: 'a;
type Iter<'a>: Iterator<Item = Self::Item<'a>>
where
Self: 'a;
fn iter(&self) -> Self::Iter<'_>;
}
macro_rules! impl_has_len {
($ty:ident < $($gen:ident),* >) => {
impl<$($gen),*> HasLen for $ty<$($gen),*> {
fn len(&self) -> usize {
<$ty<$($gen),*>>::len(self)
}
}
};
}
macro_rules! impl_has_iter {
($ty:ident < $($gen:ident),* >, $iter:ty, $item:ty) => {
impl<$($gen),*> HasIter for $ty<$($gen),*> {
type Item<'a> = $item where Self: 'a;
type Iter<'a> = $iter where Self: 'a;
fn iter(&self) -> Self::Iter<'_> {
<$ty<$($gen),*>>::iter(self)
}
}
};
}
impl_has_len!(Vec<T>);
impl<T> HasIter for Vec<T> {
type Item<'a>
= &'a T
where
Self: 'a;
type Iter<'a>
= std::slice::Iter<'a, T>
where
Self: 'a;
fn iter(&self) -> Self::Iter<'_> {
<[T]>::iter(self)
}
}
impl_has_len!(HashSet<T, H>);
impl_has_iter!(HashSet<T, H>, std::collections::hash_set::Iter<'a, T>, &'a T);
impl_has_len!(HashMap<K, V, H>);
impl_has_iter!(HashMap<K, V, H>, std::collections::hash_map::Iter<'a, K, V>, (&'a K, &'a V));
pub fn sample_single_from_known_length<'a, Container, R, T>(
rng: &mut R,
set: &'a Container,
) -> Option<T>
where
R: Rng,
Container: HasLen + HasIter<Item<'a> = &'a T>,
T: Clone + 'static,
{
let len = set.len();
if len == 0 {
return None;
}
let index = rng.random_range(0..len as u32) as usize;
set.iter().nth(index).cloned()
}
pub fn sample_single_l_reservoir<'a, Container, R, T>(rng: &mut R, set: &'a Container) -> Option<T>
where
R: Rng,
Container: HasIter<Item<'a> = &'a T>,
T: Clone + 'static,
{
let mut chosen_item: Option<T> = None; let mut weight: f64 = rng.random_range(0.0..1.0); let mut position: usize = 0; let mut next_pick_position: usize = 1;
set.iter().for_each(|item| {
position += 1;
if position == next_pick_position {
chosen_item = Some(item.clone());
next_pick_position +=
(f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight)).floor() as usize + 1;
weight *= rng.random_range(0.0..1.0);
}
});
chosen_item
}
pub fn sample_multiple_from_known_length<'a, Container, R, T>(
rng: &mut R,
set: &'a Container,
requested: usize,
) -> Vec<T>
where
R: Rng,
Container: HasLen + HasIter<Item<'a> = &'a T>,
T: Clone + 'static,
{
let mut indexes = Vec::with_capacity(requested);
indexes.extend(choose_range(rng, set.len(), requested));
indexes.sort_unstable();
let mut index_iterator = indexes.into_iter();
let mut next_idx = index_iterator.next().unwrap();
let mut selected = Vec::with_capacity(requested);
for (idx, item) in set.iter().enumerate() {
if idx == next_idx {
selected.push(item.clone());
if let Some(i) = index_iterator.next() {
next_idx = i;
} else {
break;
}
}
}
selected
}
pub fn sample_multiple_l_reservoir<'a, Container, R, T>(
rng: &mut R,
set: &'a Container,
requested: usize,
) -> Vec<T>
where
R: Rng,
Container: HasIter<Item<'a> = &'a T>,
T: Clone + 'static,
{
if requested == 0 {
return Vec::new();
}
let mut weight: f64 = rng.random_range(0.0..1.0); weight = weight.powf(1.0 / requested as f64);
let mut position: usize = 0; let mut next_pick_position: usize = 1; let mut reservoir = Vec::with_capacity(requested);
set.iter().for_each(|item| {
position += 1;
if position == next_pick_position {
if reservoir.len() == requested {
let to_remove = rng.random_range(0..reservoir.len());
reservoir.swap_remove(to_remove);
}
reservoir.push(item.clone());
if reservoir.len() == requested {
next_pick_position += (f64::ln(rng.random_range(0.0..1.0)) / f64::ln(1.0 - weight))
.floor() as usize
+ 1;
let uniform_random: f64 = rng.random_range(0.0..1.0);
weight *= uniform_random.powf(1.0 / requested as f64);
} else {
next_pick_position += 1;
}
}
});
reservoir
}
#[cfg(test)]
mod tests {
use rand::rngs::StdRng;
use rand::SeedableRng;
use super::*;
use crate::hashing::{HashSet, HashSetExt};
#[test]
fn test_sample_single_l_reservoir_basic() {
let data: Vec<u32> = (0..1000).collect();
let seed: u64 = 42;
let mut rng = StdRng::seed_from_u64(seed);
let sample = sample_single_l_reservoir(&mut rng, &data);
assert!(sample.is_some());
let value = sample.unwrap();
assert!(value < 1000);
}
#[test]
fn test_sample_single_l_reservoir_empty() {
let data: Vec<u32> = Vec::new();
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_single_l_reservoir(&mut rng, &data);
assert!(sample.is_none());
}
#[test]
fn test_sample_single_l_reservoir_single_element() {
let data: Vec<u32> = vec![42];
let mut rng = StdRng::seed_from_u64(1);
let sample = sample_single_l_reservoir(&mut rng, &data);
assert_eq!(sample, Some(42));
}
#[test]
fn test_sample_single_l_reservoir_uniformity() {
let population: u32 = 1000;
let data: Vec<u32> = (0..population).collect();
let num_runs = 10000;
let num_bins = 10;
let mut counts = vec![0usize; num_bins];
for run in 0..num_runs {
let mut rng = StdRng::seed_from_u64(42 + run as u64);
let sample = sample_single_l_reservoir(&mut rng, &data);
if let Some(value) = sample {
let bin = (value as usize) / (population as usize / num_bins);
counts[bin] += 1;
}
}
let expected = num_runs as f64 / num_bins as f64;
let chi_square: f64 = counts
.iter()
.map(|&obs| {
let diff = (obs as f64) - expected;
diff * diff / expected
})
.sum();
let critical = 27.877;
println!("χ² = {}, counts = {:?}", chi_square, counts);
assert!(
chi_square < critical,
"Single sample fails uniformity test: χ² = {}, counts = {:?}",
chi_square,
counts
);
}
#[test]
fn test_sample_single_l_reservoir_hashset() {
let mut data = HashSet::new();
for i in 0..100 {
data.insert(i);
}
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_single_l_reservoir(&mut rng, &data);
assert!(sample.is_some());
let value = sample.unwrap();
assert!(data.contains(&value));
}
#[test]
fn test_sample_multiple_l_reservoir_basic() {
let data: Vec<u32> = (0..1000).collect();
let requested = 100;
let seed: u64 = 42;
let mut rng = StdRng::seed_from_u64(seed);
let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
assert_eq!(sample.len(), requested);
assert!(sample.iter().all(|v| *v < 1000));
let unique: HashSet<_> = sample.iter().collect();
assert_eq!(unique.len(), sample.len());
}
#[test]
fn test_sample_multiple_l_reservoir_empty() {
let data: Vec<u32> = Vec::new();
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_multiple_l_reservoir(&mut rng, &data, 10);
assert_eq!(sample.len(), 0);
}
#[test]
fn test_sample_multiple_l_reservoir_zero_requested() {
let data: Vec<u32> = (0..100).collect();
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_multiple_l_reservoir(&mut rng, &data, 0);
assert_eq!(sample.len(), 0);
}
#[test]
fn test_sample_multiple_l_reservoir_requested_exceeds_population() {
let data: Vec<u32> = (0..50).collect();
let requested = 100;
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
assert_eq!(sample.len(), 50);
let unique: HashSet<_> = sample.iter().collect();
assert_eq!(unique.len(), 50);
assert!(sample.iter().all(|v| *v < 50));
}
#[test]
fn test_sample_multiple_l_reservoir_exact_population() {
let data: Vec<u32> = (0..100).collect();
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_multiple_l_reservoir(&mut rng, &data, 100);
assert_eq!(sample.len(), 100);
let unique: HashSet<_> = sample.iter().collect();
assert_eq!(unique.len(), 100);
}
#[test]
fn test_sample_multiple_l_reservoir_single_element() {
let data: Vec<u32> = vec![42];
let mut rng = StdRng::seed_from_u64(1);
let sample = sample_multiple_l_reservoir(&mut rng, &data, 1);
assert_eq!(sample.len(), 1);
assert_eq!(sample[0], 42);
}
#[test]
fn test_sample_multiple_l_reservoir_hashset() {
let mut data = HashSet::new();
for i in 0..100 {
data.insert(i);
}
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_multiple_l_reservoir(&mut rng, &data, 10);
assert_eq!(sample.len(), 10);
assert!(sample.iter().all(|v| data.contains(v)));
let unique: HashSet<_> = sample.iter().collect();
assert_eq!(unique.len(), 10);
}
#[test]
fn test_sample_multiple_l_reservoir_small_sample() {
let data: Vec<u32> = (0..1000).collect();
let requested = 5;
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
assert_eq!(sample.len(), requested);
let unique: HashSet<_> = sample.iter().collect();
assert_eq!(unique.len(), requested);
}
#[test]
fn test_sample_multiple_l_reservoir_large_sample() {
let data: Vec<u32> = (0..1000).collect();
let requested = 900;
let mut rng = StdRng::seed_from_u64(42);
let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
assert_eq!(sample.len(), requested);
let unique: HashSet<_> = sample.iter().collect();
assert_eq!(unique.len(), requested);
}
#[test]
fn test_sample_multiple_l_reservoir_uniformity() {
let population: u32 = 10000;
let data: Vec<u32> = (0..population).collect();
let requested = 100;
let num_runs = 1000;
let mut chi_squares = Vec::with_capacity(num_runs);
for run in 0..num_runs {
let mut rng = StdRng::seed_from_u64(42 + run as u64);
let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
let mut counts = [0usize; 10];
for &value in &sample {
let bin = (value as usize) / (population as usize / 10);
counts[bin] += 1;
}
let expected = requested as f64 / 10.0;
let chi_square: f64 = counts
.iter()
.map(|&obs| {
let diff = (obs as f64) - expected;
diff * diff / expected
})
.sum();
chi_squares.push(chi_square);
}
let quantiles = [
0.0, 4.16816, 5.38005, 6.39331, 7.35703, 8.34283, 9.41364, 10.6564, 12.2421, 14.6837, f64::INFINITY, ];
let num_bins = quantiles.len() - 1;
let mut chi_square_counts = vec![0usize; num_bins];
for &chi_sq in &chi_squares {
for i in 0..num_bins {
if chi_sq >= quantiles[i] && chi_sq < quantiles[i + 1] {
chi_square_counts[i] += 1;
break;
}
}
}
let expected_per_bin = num_runs as f64 / num_bins as f64;
let chi_square_of_chi_squares: f64 = chi_square_counts
.iter()
.map(|&obs| {
let diff = (obs as f64) - expected_per_bin;
diff * diff / expected_per_bin
})
.sum();
let critical = 27.877;
println!(
"χ² = {}, counts = {:?}",
chi_square_of_chi_squares, chi_square_counts
);
assert!(
chi_square_of_chi_squares < critical,
"Chi-square statistics fail to follow chi-square(9) distribution: χ² = {}, counts = {:?}",
chi_square_of_chi_squares,
chi_square_counts
);
}
#[test]
fn test_sample_multiple_l_reservoir_element_probability() {
let population: u32 = 100;
let data: Vec<u32> = (0..population).collect();
let requested = 10;
let num_runs = 10000;
let mut selection_counts = vec![0usize; population as usize];
for run in 0..num_runs {
let mut rng = StdRng::seed_from_u64(42 + run as u64);
let sample = sample_multiple_l_reservoir(&mut rng, &data, requested);
for &value in &sample {
selection_counts[value as usize] += 1;
}
}
let expected = (num_runs * requested) as f64 / population as f64;
let chi_square: f64 = selection_counts
.iter()
.map(|&obs| {
let diff = (obs as f64) - expected;
diff * diff / expected
})
.sum();
let critical = 148.23;
println!(
"χ² = {}, expected = {}, min = {}, max = {}",
chi_square,
expected,
selection_counts.iter().min().unwrap(),
selection_counts.iter().max().unwrap()
);
assert!(
chi_square < critical,
"Element selection probabilities are not uniform: χ² = {}",
chi_square
);
}
#[test]
fn test_sample_multiple_l_reservoir_reproducibility() {
let data: Vec<u32> = (0..1000).collect();
let test_sizes = [1, 2, 5, 10, 100, 500];
for &requested in &test_sizes {
let seed: u64 = 12345;
let mut rng1 = StdRng::seed_from_u64(seed);
let sample1 = sample_multiple_l_reservoir(&mut rng1, &data, requested);
let mut rng2 = StdRng::seed_from_u64(seed);
let sample2 = sample_multiple_l_reservoir(&mut rng2, &data, requested);
assert_eq!(
sample1.len(),
requested,
"Sample size {} doesn't match requested size {}",
sample1.len(),
requested
);
assert_eq!(
sample2.len(),
requested,
"Sample size {} doesn't match requested size {}",
sample2.len(),
requested
);
assert_eq!(
sample1, sample2,
"Reproducibility failed for requested={}",
requested
);
}
}
#[test]
fn test_sample_single_l_reservoir_reproducibility() {
let data: Vec<u32> = (0..1000).collect();
let seed: u64 = 12345;
let mut rng1 = StdRng::seed_from_u64(seed);
let sample1 = sample_single_l_reservoir(&mut rng1, &data);
let mut rng2 = StdRng::seed_from_u64(seed);
let sample2 = sample_single_l_reservoir(&mut rng2, &data);
assert_eq!(sample1, sample2);
}
}