use crate::Dataset;
use crate::transform::RngSource;
use rand::prelude::SliceRandom;
use rand::rngs::StdRng;
use std::marker::PhantomData;
use std::sync::Arc;
#[inline(always)]
pub fn iota(size: usize) -> Vec<usize> {
(0..size).collect()
}
#[inline(always)]
pub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec<usize> {
let mut indices = iota(size);
indices.shuffle(rng);
indices
}
#[derive(Clone)]
pub struct SelectionDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
pub wrapped: Arc<D>,
pub indices: Vec<usize>,
input: PhantomData<I>,
}
impl<D, I> SelectionDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
pub fn from_indices_checked<S>(dataset: S, indices: Vec<usize>) -> Self
where
S: Into<Arc<D>>,
{
let dataset = dataset.into();
let size = dataset.len();
if let Some(idx) = indices.iter().find(|&i| *i >= size) {
panic!("Index out of bounds for wrapped dataset size: {idx} >= {size}");
}
Self::from_indices_unchecked(dataset, indices)
}
pub fn from_indices_unchecked<S>(dataset: S, indices: Vec<usize>) -> Self
where
S: Into<Arc<D>>,
{
Self {
wrapped: dataset.into(),
indices,
input: PhantomData,
}
}
pub fn new_select_all<S>(dataset: S) -> Self
where
S: Into<Arc<D>>,
{
let dataset = dataset.into();
let size = dataset.len();
Self::from_indices_unchecked(dataset, iota(size))
}
pub fn new_shuffled<S, R>(dataset: S, rng_source: R) -> Self
where
S: Into<Arc<D>>,
R: Into<RngSource>,
{
let mut this = Self::new_select_all(dataset);
this.shuffle(rng_source);
this
}
pub fn shuffle<R>(&mut self, rng_source: R)
where
R: Into<RngSource>,
{
let mut rng: StdRng = rng_source.into().into();
self.indices.shuffle(&mut rng)
}
pub fn slice(&self, start: usize, end: usize) -> Self {
Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec())
}
pub fn split(&self, num: usize) -> Vec<Self> {
let n = self.indices.len();
let mut current = 0;
let mut datasets = Vec::with_capacity(num);
let batch_size = n / num;
for i in 0..num {
let start = current;
let mut end = current + batch_size;
if i == (num - 1) {
end = n;
}
let dataset = self.slice(start, end);
current += batch_size;
datasets.push(dataset);
}
datasets
}
}
impl<D, I> Dataset<I> for SelectionDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let index = self.indices.get(index)?;
self.wrapped.get(*index)
}
fn len(&self) -> usize {
self.indices.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FakeDataset;
use rand::SeedableRng;
#[test]
fn test_iota() {
let size = 10;
let indices = iota(size);
assert_eq!(indices.len(), size);
assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_shuffled_indices_same_seed_is_deterministic() {
let size = 10;
let mut rng1 = StdRng::seed_from_u64(10);
let mut rng2 = StdRng::seed_from_u64(10);
let mut expected = iota(size);
expected.shuffle(&mut rng1);
let indices = shuffled_indices(size, &mut rng2);
assert_eq!(indices, expected);
}
#[test]
fn test_shuffled_indices_forked_rngs_differ() {
let size = 10;
let mut rng1 = StdRng::seed_from_u64(10);
let mut rng2 = rng1.fork();
let mut a = iota(size);
let mut b = iota(size);
a.shuffle(&mut rng1);
b.shuffle(&mut rng2);
assert_ne!(a, b);
}
#[should_panic(expected = "Index out of bounds for wrapped dataset size: 300 >= 27")]
#[test]
fn test_from_indices_checked_panics() {
let source_dataset = FakeDataset::<String>::new(27);
let indices: Vec<usize> = vec![15, 1, 12, 300];
SelectionDataset::from_indices_checked(source_dataset, indices);
}
#[test]
fn test_checked_selection_dataset() {
let source_dataset = FakeDataset::<String>::new(27);
let indices: Vec<usize> = vec![15, 1, 12, 12];
let expected: Vec<String> = indices
.iter()
.map(|i| source_dataset.get(*i).unwrap())
.collect();
let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone());
assert_eq!(&selection.indices, &indices);
let items = selection.iter().collect::<Vec<_>>();
assert_eq!(items, expected);
}
#[test]
fn test_shuffled_dataset() {
let dataset = FakeDataset::<String>::new(27);
let source_items = dataset.iter().collect::<Vec<_>>();
let selection = SelectionDataset::new_shuffled(dataset, 42);
let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42));
assert_eq!(&selection.indices, &indices);
assert_eq!(selection.len(), source_items.len());
let expected_items: Vec<_> = indices
.iter()
.map(|&i| source_items[i].to_string())
.collect();
assert_eq!(&selection.iter().collect::<Vec<_>>(), &expected_items);
}
#[test]
fn test_slice() {
let dataset = FakeDataset::<String>::new(27);
let source_items = dataset.iter().collect::<Vec<_>>();
let selection = SelectionDataset::new_select_all(dataset);
let start = 5;
let end = 15;
let sliced_selection = selection.slice(start, end);
assert_eq!(sliced_selection.len(), end - start);
#[allow(clippy::needless_range_loop)]
for i in start..end {
assert_eq!(
sliced_selection.get(i - start),
Some(source_items[i].to_string())
);
}
}
#[test]
fn test_split() {
let dataset = FakeDataset::<String>::new(28);
let source_items = dataset.iter().collect::<Vec<_>>();
let selection = SelectionDataset::new_select_all(dataset);
let split_contents: Vec<Vec<_>> = selection
.split(3)
.iter()
.map(|d| d.iter().collect::<Vec<_>>())
.collect();
assert_eq!(
split_contents,
vec![
source_items[0..9].to_vec(),
source_items[9..18].to_vec(),
source_items[18..28].to_vec(),
]
);
}
}