burn_dataset/transform/
random.rs1use crate::Dataset;
2use rand::{SeedableRng, prelude::SliceRandom, rngs::StdRng};
3use std::marker::PhantomData;
4
5pub struct ShuffledDataset<D, I> {
8 dataset: D,
9 indices: Vec<usize>,
10 input: PhantomData<I>,
11}
12
13impl<D, I> ShuffledDataset<D, I>
14where
15 D: Dataset<I>,
16{
17 pub fn new(dataset: D, rng: &mut StdRng) -> Self {
19 let mut indices = Vec::with_capacity(dataset.len());
20 for i in 0..dataset.len() {
21 indices.push(i);
22 }
23 indices.shuffle(rng);
24
25 Self {
26 dataset,
27 indices,
28 input: PhantomData,
29 }
30 }
31
32 pub fn with_seed(dataset: D, seed: u64) -> Self {
34 let mut rng = StdRng::seed_from_u64(seed);
35 Self::new(dataset, &mut rng)
36 }
37}
38
39impl<D, I> Dataset<I> for ShuffledDataset<D, I>
40where
41 D: Dataset<I>,
42 I: Clone + Send + Sync,
43{
44 fn get(&self, index: usize) -> Option<I> {
45 let index = self.indices.get(index)?;
46 self.dataset.get(*index)
47 }
48
49 fn len(&self) -> usize {
50 self.dataset.len()
51 }
52}