use rand::RngExt;
use rand::rngs::StdRng;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub enum SampleMethod {
None,
Random,
}
pub trait Sampler {
fn sample(&mut self, rng: &mut StdRng, index: &[usize]) -> (Vec<usize>, Vec<usize>);
}
pub struct RandomSampler {
subsample: f32,
}
impl RandomSampler {
#[allow(dead_code)]
pub fn new(subsample: f32) -> Self {
RandomSampler { subsample }
}
}
impl Sampler for RandomSampler {
fn sample(&mut self, rng: &mut StdRng, index: &[usize]) -> (Vec<usize>, Vec<usize>) {
let subsample = self.subsample;
let mut chosen = Vec::new();
let mut excluded = Vec::new();
for i in index {
if rng.random::<f32>() < subsample {
chosen.push(*i);
} else {
excluded.push(*i)
}
}
(chosen, excluded)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
#[test]
fn test_random_sampler() {
let mut rng = StdRng::seed_from_u64(42);
let index = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut sampler = RandomSampler::new(0.5);
let (chosen, excluded) = sampler.sample(&mut rng, &index);
assert!(!chosen.is_empty());
assert!(!excluded.is_empty());
assert_eq!(chosen.len() + excluded.len(), index.len());
let mut sampler_all = RandomSampler::new(1.0);
let (chosen_all, excluded_all) = sampler_all.sample(&mut rng, &index);
assert_eq!(chosen_all.len(), index.len());
assert!(excluded_all.is_empty());
let mut sampler_none = RandomSampler::new(0.0);
let (chosen_none, excluded_none) = sampler_none.sample(&mut rng, &index);
assert!(chosen_none.is_empty());
assert_eq!(excluded_none.len(), index.len());
}
}