use super::Rng;
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::vec::Vec;
#[cfg(feature="std")] use std::collections::HashMap;
#[cfg(all(feature="alloc", not(feature="std")))] use alloc::collections::BTreeMap;
pub fn sample_iter<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Result<Vec<T>, Vec<T>>
where I: IntoIterator<Item=T>,
R: Rng + ?Sized,
{
let mut iter = iterable.into_iter();
let mut reservoir = Vec::with_capacity(amount);
reservoir.extend(iter.by_ref().take(amount));
if reservoir.len() == amount {
for (i, elem) in iter.enumerate() {
let k = rng.gen_range(0, i + 1 + amount);
if let Some(spot) = reservoir.get_mut(k) {
*spot = elem;
}
}
Ok(reservoir)
} else {
reservoir.shrink_to_fit();
Err(reservoir)
}
}
pub fn sample_slice<R, T>(rng: &mut R, slice: &[T], amount: usize) -> Vec<T>
where R: Rng + ?Sized,
T: Clone
{
let indices = sample_indices(rng, slice.len(), amount);
let mut out = Vec::with_capacity(amount);
out.extend(indices.iter().map(|i| slice[*i].clone()));
out
}
pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize) -> Vec<&'a T>
where R: Rng + ?Sized
{
let indices = sample_indices(rng, slice.len(), amount);
let mut out = Vec::with_capacity(amount);
out.extend(indices.iter().map(|i| &slice[*i]));
out
}
pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
where R: Rng + ?Sized,
{
if amount > length {
panic!("`amount` must be less than or equal to `slice.len()`");
}
if amount >= length / 20 {
sample_indices_inplace(rng, length, amount)
} else {
sample_indices_cache(rng, length, amount)
}
}
fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize) -> Vec<usize>
where R: Rng + ?Sized,
{
debug_assert!(amount <= length);
let mut indices: Vec<usize> = Vec::with_capacity(length);
indices.extend(0..length);
for i in 0..amount {
let j: usize = rng.gen_range(i, length);
indices.swap(i, j);
}
indices.truncate(amount);
debug_assert_eq!(indices.len(), amount);
indices
}
fn sample_indices_cache<R>(
rng: &mut R,
length: usize,
amount: usize,
) -> Vec<usize>
where R: Rng + ?Sized,
{
debug_assert!(amount <= length);
#[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount);
#[cfg(not(feature="std"))] let mut cache = BTreeMap::new();
let mut out = Vec::with_capacity(amount);
for i in 0..amount {
let j: usize = rng.gen_range(i, length);
let tmp = match cache.get(&i) {
Some(e) => *e,
None => i,
};
let x = match cache.get(&j) {
Some(x) => *x,
None => j,
};
cache.insert(j, tmp);
out.push(x);
}
debug_assert_eq!(out.len(), amount);
out
}
#[cfg(test)]
mod test {
use super::*;
use {XorShiftRng, Rng, SeedableRng};
#[cfg(not(feature="std"))]
use alloc::vec::Vec;
#[test]
fn test_sample_iter() {
let min_val = 1;
let max_val = 100;
let mut r = ::test::rng(401);
let vals = (min_val..max_val).collect::<Vec<i32>>();
let small_sample = sample_iter(&mut r, vals.iter(), 5).unwrap();
let large_sample = sample_iter(&mut r, vals.iter(), vals.len() + 5).unwrap_err();
assert_eq!(small_sample.len(), 5);
assert_eq!(large_sample.len(), vals.len());
assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
assert!(small_sample.iter().all(|e| {
**e >= min_val && **e <= max_val
}));
}
#[test]
fn test_sample_slice_boundaries() {
let empty: &[u8] = &[];
let mut r = ::test::rng(402);
assert_eq!(&sample_slice(&mut r, empty, 0)[..], [0u8; 0]);
assert_eq!(&sample_slice(&mut r, &[42, 2, 42], 0)[..], [0u8; 0]);
assert_eq!(&sample_slice(&mut r, &[42], 1)[..], [42]);
let v = sample_slice(&mut r, &[1, 42], 1)[0];
assert!(v == 1 || v == 42);
let v = sample_slice(&mut r, &[42, 133], 2);
assert!(&v[..] == [42, 133] || v[..] == [133, 42]);
assert_eq!(&sample_indices_inplace(&mut r, 0, 0)[..], [0usize; 0]);
assert_eq!(&sample_indices_inplace(&mut r, 1, 0)[..], [0usize; 0]);
assert_eq!(&sample_indices_inplace(&mut r, 1, 1)[..], [0]);
assert_eq!(&sample_indices_cache(&mut r, 0, 0)[..], [0usize; 0]);
assert_eq!(&sample_indices_cache(&mut r, 1, 0)[..], [0usize; 0]);
assert_eq!(&sample_indices_cache(&mut r, 1, 1)[..], [0]);
let slice = &[42, 777];
let mut num_42 = 0;
let total = 1000;
for _ in 0..total {
let v = sample_slice(&mut r, slice, 1);
assert_eq!(v.len(), 1);
let v = v[0];
assert!(v == 42 || v == 777);
if v == 42 {
num_42 += 1;
}
}
let ratio_42 = num_42 as f64 / 1000 as f64;
assert!(0.4 <= ratio_42 || ratio_42 <= 0.6, "{}", ratio_42);
}
#[test]
fn test_sample_slice() {
let xor_rng = XorShiftRng::from_seed;
let max_range = 100;
let mut r = ::test::rng(403);
for length in 1usize..max_range {
let amount = r.gen_range(0, length);
let mut seed = [0u8; 16];
r.fill(&mut seed);
let inplace = sample_indices_inplace(
&mut xor_rng(seed), length, amount);
let cache = sample_indices_cache(
&mut xor_rng(seed), length, amount);
assert_eq!(inplace, cache);
let regular = sample_indices(
&mut xor_rng(seed), length, amount);
assert_eq!(regular.len(), amount);
assert!(regular.iter().all(|e| *e < length));
assert_eq!(regular, inplace);
let vec: Vec<usize> = (0..length).collect();
{
let result = sample_slice(&mut xor_rng(seed), &vec, amount);
assert_eq!(result, regular);
}
{
let result = sample_slice_ref(&mut xor_rng(seed), &vec, amount);
let expected = regular.iter().map(|v| v).collect::<Vec<_>>();
assert_eq!(result, expected);
}
}
}
}