use crate::rng::Rng;
pub trait Sampler: Send {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn indices(&mut self, epoch: usize) -> Vec<usize>;
}
pub struct RandomSampler {
n: usize,
seed: u64,
}
impl RandomSampler {
pub fn new(n: usize, seed: u64) -> Self {
RandomSampler { n, seed }
}
}
impl Sampler for RandomSampler {
fn len(&self) -> usize {
self.n
}
fn indices(&mut self, epoch: usize) -> Vec<usize> {
let mut rng = Rng::seed(self.seed.wrapping_add(epoch as u64));
let mut idx: Vec<usize> = (0..self.n).collect();
rng.shuffle(&mut idx);
idx
}
}
pub struct SequentialSampler {
n: usize,
}
impl SequentialSampler {
pub fn new(n: usize) -> Self {
SequentialSampler { n }
}
}
impl Sampler for SequentialSampler {
fn len(&self) -> usize {
self.n
}
fn indices(&mut self, _epoch: usize) -> Vec<usize> {
(0..self.n).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_random_sampler_permutation() {
let mut sampler = RandomSampler::new(10, 42);
let idx = sampler.indices(0);
assert_eq!(idx.len(), 10);
let mut sorted = idx.clone();
sorted.sort();
assert_eq!(sorted, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_random_sampler_different_epochs() {
let mut sampler = RandomSampler::new(100, 42);
let epoch0 = sampler.indices(0);
let epoch1 = sampler.indices(1);
assert_ne!(epoch0, epoch1);
}
#[test]
fn test_random_sampler_reproducible() {
let mut s1 = RandomSampler::new(100, 42);
let mut s2 = RandomSampler::new(100, 42);
assert_eq!(s1.indices(5), s2.indices(5));
}
#[test]
fn test_random_sampler_different_seeds() {
let mut s1 = RandomSampler::new(100, 42);
let mut s2 = RandomSampler::new(100, 99);
assert_ne!(s1.indices(0), s2.indices(0));
}
#[test]
fn test_sequential_sampler() {
let mut sampler = SequentialSampler::new(5);
assert_eq!(sampler.indices(0), vec![0, 1, 2, 3, 4]);
assert_eq!(sampler.indices(10), vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_sequential_sampler_stable() {
let mut sampler = SequentialSampler::new(20);
let a = sampler.indices(0);
let b = sampler.indices(1);
assert_eq!(a, b);
}
#[test]
fn test_sampler_len() {
let s1 = RandomSampler::new(50, 0);
assert_eq!(s1.len(), 50);
let s2 = SequentialSampler::new(30);
assert_eq!(s2.len(), 30);
}
}