use rand::Rng;
use rand::seq::SliceRandom;
pub trait Sampler: Send + Sync {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
}
pub struct SequentialSampler {
len: usize,
}
impl SequentialSampler {
#[must_use]
pub fn new(len: usize) -> Self {
Self { len }
}
}
impl Sampler for SequentialSampler {
fn len(&self) -> usize {
self.len
}
fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
Box::new(0..self.len)
}
}
pub struct RandomSampler {
len: usize,
replacement: bool,
num_samples: Option<usize>,
}
impl RandomSampler {
#[must_use]
pub fn new(len: usize) -> Self {
Self {
len,
replacement: false,
num_samples: None,
}
}
#[must_use]
pub fn with_replacement(len: usize, num_samples: usize) -> Self {
Self {
len,
replacement: true,
num_samples: Some(num_samples),
}
}
}
impl Sampler for RandomSampler {
fn len(&self) -> usize {
self.num_samples.unwrap_or(self.len)
}
fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
if self.replacement {
let len = self.len;
let num = self.num_samples.unwrap_or(len);
Box::new(RandomWithReplacementIter::new(len, num))
} else {
let mut indices: Vec<usize> = (0..self.len).collect();
indices.shuffle(&mut rand::thread_rng());
Box::new(indices.into_iter())
}
}
}
struct RandomWithReplacementIter {
len: usize,
remaining: usize,
}
impl RandomWithReplacementIter {
fn new(len: usize, num_samples: usize) -> Self {
Self {
len,
remaining: num_samples,
}
}
}
impl Iterator for RandomWithReplacementIter {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
self.remaining -= 1;
Some(rand::thread_rng().gen_range(0..self.len))
}
}
pub struct SubsetRandomSampler {
indices: Vec<usize>,
}
impl SubsetRandomSampler {
#[must_use]
pub fn new(indices: Vec<usize>) -> Self {
Self { indices }
}
}
impl Sampler for SubsetRandomSampler {
fn len(&self) -> usize {
self.indices.len()
}
fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
let mut shuffled = self.indices.clone();
shuffled.shuffle(&mut rand::thread_rng());
Box::new(shuffled.into_iter())
}
}
pub struct WeightedRandomSampler {
weights: Vec<f64>,
cumulative: Vec<f64>,
num_samples: usize,
replacement: bool,
}
impl WeightedRandomSampler {
#[must_use]
pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
let cumulative = Self::build_cumulative(&weights);
Self {
weights,
cumulative,
num_samples,
replacement,
}
}
fn build_cumulative(weights: &[f64]) -> Vec<f64> {
let mut cum = Vec::with_capacity(weights.len());
let mut total = 0.0;
for &w in weights {
total += w;
cum.push(total);
}
cum
}
fn sample_index(&self) -> usize {
let total = *self.cumulative.last().unwrap_or(&1.0);
let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
match self
.cumulative
.binary_search_by(|c| c.partial_cmp(&threshold).unwrap())
{
Ok(i) => i,
Err(i) => i.min(self.cumulative.len() - 1),
}
}
}
impl Sampler for WeightedRandomSampler {
fn len(&self) -> usize {
self.num_samples
}
fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
if self.replacement {
Box::new(WeightedIter::new(self))
} else {
let mut indices = Vec::with_capacity(self.num_samples);
let mut available: Vec<usize> = (0..self.weights.len()).collect();
let mut weights = self.weights.clone();
let mut cumulative = self.cumulative.clone();
while indices.len() < self.num_samples && !available.is_empty() {
let total = *cumulative.last().unwrap_or(&0.0);
if total <= 0.0 {
break;
}
let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
let selected =
match cumulative.binary_search_by(|c| c.partial_cmp(&threshold).unwrap()) {
Ok(i) => i,
Err(i) => i.min(cumulative.len() - 1),
};
indices.push(available[selected]);
available.swap_remove(selected);
weights.swap_remove(selected);
cumulative = Self::build_cumulative(&weights);
}
Box::new(indices.into_iter())
}
}
}
struct WeightedIter<'a> {
sampler: &'a WeightedRandomSampler,
remaining: usize,
}
impl<'a> WeightedIter<'a> {
fn new(sampler: &'a WeightedRandomSampler) -> Self {
Self {
sampler,
remaining: sampler.num_samples,
}
}
}
impl Iterator for WeightedIter<'_> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining == 0 {
return None;
}
self.remaining -= 1;
Some(self.sampler.sample_index())
}
}
pub struct BatchSampler<S: Sampler> {
sampler: S,
batch_size: usize,
drop_last: bool,
}
impl<S: Sampler> BatchSampler<S> {
pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
Self {
sampler,
batch_size,
drop_last,
}
}
pub fn iter(&self) -> BatchIter {
let indices: Vec<usize> = self.sampler.iter().collect();
BatchIter {
indices,
batch_size: self.batch_size,
drop_last: self.drop_last,
position: 0,
}
}
pub fn len(&self) -> usize {
let total = self.sampler.len();
if self.drop_last {
total / self.batch_size
} else {
total.div_ceil(self.batch_size)
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct BatchIter {
indices: Vec<usize>,
batch_size: usize,
drop_last: bool,
position: usize,
}
impl Iterator for BatchIter {
type Item = Vec<usize>;
fn next(&mut self) -> Option<Self::Item> {
if self.position >= self.indices.len() {
return None;
}
let end = (self.position + self.batch_size).min(self.indices.len());
let batch: Vec<usize> = self.indices[self.position..end].to_vec();
if batch.len() < self.batch_size && self.drop_last {
return None;
}
self.position = end;
Some(batch)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sequential_sampler() {
let sampler = SequentialSampler::new(5);
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_random_sampler() {
let sampler = RandomSampler::new(10);
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 10);
let mut sorted = indices.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(sorted.len(), 10);
}
#[test]
fn test_random_sampler_with_replacement() {
let sampler = RandomSampler::with_replacement(5, 20);
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 20);
assert!(indices.iter().all(|&i| i < 5));
}
#[test]
fn test_subset_random_sampler() {
let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 4);
let mut sorted = indices.clone();
sorted.sort_unstable();
assert_eq!(sorted, vec![0, 5, 10, 15]);
}
#[test]
fn test_weighted_random_sampler() {
let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 100);
let zeros = indices.iter().filter(|&&i| i == 0).count();
assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
}
#[test]
fn test_batch_sampler() {
let base = SequentialSampler::new(10);
let sampler = BatchSampler::new(base, 3, false);
let batches: Vec<Vec<usize>> = sampler.iter().collect();
assert_eq!(batches.len(), 4);
assert_eq!(batches[0], vec![0, 1, 2]);
assert_eq!(batches[1], vec![3, 4, 5]);
assert_eq!(batches[2], vec![6, 7, 8]);
assert_eq!(batches[3], vec![9]); }
#[test]
fn test_batch_sampler_drop_last() {
let base = SequentialSampler::new(10);
let sampler = BatchSampler::new(base, 3, true);
let batches: Vec<Vec<usize>> = sampler.iter().collect();
assert_eq!(batches.len(), 3);
assert_eq!(batches[0], vec![0, 1, 2]);
assert_eq!(batches[1], vec![3, 4, 5]);
assert_eq!(batches[2], vec![6, 7, 8]);
}
}