burn_dataset/transform/
random.rs

1use crate::Dataset;
2use rand::{SeedableRng, prelude::SliceRandom, rngs::StdRng};
3use std::marker::PhantomData;
4
5/// Shuffled a dataset, consider using [sampler dataset](crate::transform::SamplerDataset) is you
6/// want a probability distribution that is computed lazily.
7pub struct ShuffledDataset<D, I> {
8    dataset: D,
9    indices: Vec<usize>,
10    input: PhantomData<I>,
11}
12
13impl<D, I> ShuffledDataset<D, I>
14where
15    D: Dataset<I>,
16{
17    /// Creates a new shuffled dataset.
18    pub fn new(dataset: D, rng: &mut StdRng) -> Self {
19        let mut indices = Vec::with_capacity(dataset.len());
20        for i in 0..dataset.len() {
21            indices.push(i);
22        }
23        indices.shuffle(rng);
24
25        Self {
26            dataset,
27            indices,
28            input: PhantomData,
29        }
30    }
31
32    /// Creates a new shuffled dataset with a fixed seed.
33    pub fn with_seed(dataset: D, seed: u64) -> Self {
34        let mut rng = StdRng::seed_from_u64(seed);
35        Self::new(dataset, &mut rng)
36    }
37}
38
39impl<D, I> Dataset<I> for ShuffledDataset<D, I>
40where
41    D: Dataset<I>,
42    I: Clone + Send + Sync,
43{
44    fn get(&self, index: usize) -> Option<I> {
45        let index = self.indices.get(index)?;
46        self.dataset.get(*index)
47    }
48
49    fn len(&self) -> usize {
50        self.dataset.len()
51    }
52}