use crate::Dataset;
use crate::transform::{RngSource, SelectionDataset};
pub struct ShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
wrapped: SelectionDataset<D, I>,
}
impl<D, I> ShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
pub fn new<R>(dataset: D, rng_source: R) -> Self
where
R: Into<RngSource>,
{
Self {
wrapped: SelectionDataset::new_shuffled(dataset, rng_source),
}
}
#[deprecated(since = "0.19.0", note = "Use `new(dataset, seed)` instead`")]
pub fn with_seed(dataset: D, seed: u64) -> Self {
Self::new(dataset, seed)
}
}
impl<D, I> Dataset<I> for ShuffledDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
self.wrapped.get(index)
}
fn len(&self) -> usize {
self.wrapped.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FakeDataset;
use crate::transform::selection::shuffled_indices;
use rand::SeedableRng;
use rand::prelude::StdRng;
#[test]
fn test_shuffled_dataset() {
let dataset = FakeDataset::<String>::new(27);
let source_items = dataset.iter().collect::<Vec<_>>();
let seed = 42;
#[allow(deprecated)]
let shuffled = ShuffledDataset::with_seed(dataset, seed);
let mut rng = StdRng::seed_from_u64(seed);
let indices = shuffled_indices(source_items.len(), &mut rng);
assert_eq!(shuffled.len(), source_items.len());
let expected_items: Vec<_> = indices
.iter()
.map(|&i| source_items[i].to_string())
.collect();
assert_eq!(&shuffled.iter().collect::<Vec<_>>(), &expected_items);
}
}