use crate::Dataset;
use crate::transform::{RngSource, SizeConfig};
use rand::prelude::SliceRandom;
use rand::{RngExt, distr::Uniform, rngs::StdRng, seq::IteratorRandom};
use std::{marker::PhantomData, ops::DerefMut, sync::Mutex};
#[derive(Debug, PartialEq)]
pub struct SamplerDatasetOptions {
pub replace_samples: bool,
pub size_config: SizeConfig,
pub rng_source: RngSource,
}
impl Default for SamplerDatasetOptions {
fn default() -> Self {
Self {
replace_samples: true,
size_config: SizeConfig::Default,
rng_source: RngSource::Default,
}
}
}
impl<T> From<Option<T>> for SamplerDatasetOptions
where
T: Into<SamplerDatasetOptions>,
{
fn from(option: Option<T>) -> Self {
match option {
Some(option) => option.into(),
None => Self::default(),
}
}
}
impl From<usize> for SamplerDatasetOptions {
fn from(size: usize) -> Self {
Self::default().with_replacement().with_fixed_size(size)
}
}
impl SamplerDatasetOptions {
pub fn with_replace_samples(self, replace_samples: bool) -> Self {
Self {
replace_samples,
..self
}
}
pub fn with_replacement(self) -> Self {
self.with_replace_samples(true)
}
pub fn without_replacement(self) -> Self {
self.with_replace_samples(false)
}
pub fn with_size<S>(self, source: S) -> Self
where
S: Into<SizeConfig>,
{
Self {
size_config: source.into(),
..self
}
}
pub fn with_source_size(self) -> Self {
self.with_size(SizeConfig::Default)
}
pub fn with_fixed_size(self, size: usize) -> Self {
self.with_size(size)
}
pub fn with_size_ratio(self, size_ratio: f64) -> Self {
self.with_size(size_ratio)
}
pub fn with_rng<R>(self, rng: R) -> Self
where
R: Into<RngSource>,
{
Self {
rng_source: rng.into(),
..self
}
}
pub fn with_system_rng(self) -> Self {
self.with_rng(RngSource::Default)
}
pub fn with_seed(self, seed: u64) -> Self {
self.with_rng(seed)
}
}
pub struct SamplerDataset<D, I> {
dataset: D,
size: usize,
state: Mutex<SamplerState>,
input: PhantomData<I>,
}
enum SamplerState {
WithReplacement(StdRng),
WithoutReplacement(StdRng, Vec<usize>),
}
impl<D, I> SamplerDataset<D, I>
where
D: Dataset<I>,
I: Send + Sync,
{
pub fn new<O>(dataset: D, options: O) -> Self
where
O: Into<SamplerDatasetOptions>,
{
let options = options.into();
let size = options.size_config.resolve(dataset.len());
let rng = options.rng_source.into();
Self {
dataset,
size,
state: Mutex::new(match options.replace_samples {
true => SamplerState::WithReplacement(rng),
false => SamplerState::WithoutReplacement(rng, Vec::with_capacity(size)),
}),
input: PhantomData,
}
}
pub fn with_replacement(dataset: D, size: usize) -> Self {
Self::new(
dataset,
SamplerDatasetOptions::default()
.with_replacement()
.with_fixed_size(size),
)
}
pub fn without_replacement(dataset: D, size: usize) -> Self {
Self::new(
dataset,
SamplerDatasetOptions::default()
.without_replacement()
.with_fixed_size(size),
)
}
pub fn is_with_replacement(&self) -> bool {
match self.state.lock().unwrap().deref_mut() {
SamplerState::WithReplacement(_) => true,
SamplerState::WithoutReplacement(_, _) => false,
}
}
fn index(&self) -> usize {
match self.state.lock().unwrap().deref_mut() {
SamplerState::WithReplacement(rng) => {
rng.sample(Uniform::new(0, self.dataset.len()).unwrap())
}
SamplerState::WithoutReplacement(rng, indices) => {
if indices.is_empty() {
let idx_range = 0..self.dataset.len();
for _ in 0..(self.size / self.dataset.len()) {
indices.extend(idx_range.clone())
}
indices.extend(idx_range.sample(rng, self.size - indices.len()));
indices.shuffle(rng);
}
indices.pop().expect("Indices are refilled when empty.")
}
}
}
}
impl<D, I> Dataset<I> for SamplerDataset<D, I>
where
D: Dataset<I>,
I: Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
if index >= self.size {
return None;
}
self.dataset.get(self.index())
}
fn len(&self) -> usize {
self.size
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::bool_assert_comparison)]
use super::*;
use crate::FakeDataset;
use rand::SeedableRng;
use std::collections::HashMap;
#[test]
fn test_samplerdataset_options() {
let options = SamplerDatasetOptions::default();
assert_eq!(options.replace_samples, true);
assert_eq!(options.size_config, SizeConfig::Default);
assert_eq!(options.rng_source, RngSource::Default);
let options = options.with_replace_samples(false);
assert_eq!(options.replace_samples, false);
let options = options.with_replacement();
assert_eq!(options.replace_samples, true);
let options = options.without_replacement();
assert_eq!(options.replace_samples, false);
let options = options.with_size(SizeConfig::Default);
assert_eq!(options.size_config, SizeConfig::Default);
let options = options.with_source_size();
assert_eq!(options.size_config, SizeConfig::Default);
let options = options.with_fixed_size(10);
assert_eq!(options.size_config, SizeConfig::Fixed(10));
let options = options.with_size_ratio(1.5);
assert_eq!(options.size_config, SizeConfig::Ratio(1.5));
let options = options.with_system_rng();
assert_eq!(options.rng_source, RngSource::Default);
let options = options.with_seed(42);
assert_eq!(options.rng_source, RngSource::Seed(42));
let rng = StdRng::seed_from_u64(9);
let options = options.with_rng(rng);
assert!(matches!(options.rng_source, RngSource::Rng(_)));
}
#[test]
fn sampler_dataset_constructors_test() {
let ds = SamplerDataset::new(FakeDataset::<u32>::new(10), 15);
assert_eq!(ds.len(), 15);
assert_eq!(ds.dataset.len(), 10);
assert!(ds.is_with_replacement());
let ds = SamplerDataset::with_replacement(FakeDataset::<u32>::new(10), 15);
assert_eq!(ds.len(), 15);
assert_eq!(ds.dataset.len(), 10);
assert!(ds.is_with_replacement());
let ds = SamplerDataset::without_replacement(FakeDataset::<u32>::new(10), 15);
assert_eq!(ds.len(), 15);
assert_eq!(ds.dataset.len(), 10);
assert!(!ds.is_with_replacement());
}
#[test]
fn sampler_dataset_with_replacement_iter() {
let factor = 3;
let len_original = 10;
let dataset_sampler = SamplerDataset::with_replacement(
FakeDataset::<String>::new(len_original),
len_original * factor,
);
let mut total = 0;
for _item in dataset_sampler.iter() {
total += 1;
}
assert_eq!(total, factor * len_original);
}
#[test]
fn sampler_dataset_without_replacement_bucket_test() {
let factor = 3;
let len_original = 10;
let dataset_sampler = SamplerDataset::new(
FakeDataset::<String>::new(len_original),
SamplerDatasetOptions::default()
.without_replacement()
.with_size_ratio(factor as f64),
);
let mut buckets = HashMap::new();
for item in dataset_sampler.iter() {
let count = match buckets.get(&item) {
Some(count) => count + 1,
None => 1,
};
buckets.insert(item, count);
}
let mut total = 0;
for count in buckets.into_values() {
assert_eq!(count, factor);
total += count;
}
assert_eq!(total, factor * len_original);
}
#[test]
fn sampler_dataset_without_replacement_uniform_order_test() {
let size = 1000;
let dataset_sampler =
SamplerDataset::without_replacement(FakeDataset::<i32>::new(size), size);
let indices: Vec<_> = (0..size).map(|_| dataset_sampler.index()).collect();
let mean_delta = indices
.windows(2)
.map(|pair| pair[1].abs_diff(pair[0]))
.sum::<usize>() as f64
/ (size - 1) as f64;
let expected = (size + 2) as f64 / 3.0;
assert!(
(mean_delta - expected).abs() <= 0.25 * expected,
"Sampled indices are not uniformly distributed: mean_delta: {mean_delta}, expected: {expected}"
);
}
}