Skip to main content

burn_dataset/transform/
shuffle.rs

1use crate::Dataset;
2use crate::transform::{RngSource, SelectionDataset};
3
4/// A Shuffled a dataset.
5///
6/// This is a thin wrapper around a [SelectionDataset] which selects and shuffles
7/// the full indices of the original dataset.
8///
9/// Consider using [SelectionDataset] if you are only interested in
10/// shuffling mechanisms.
11///
12/// Consider using [sampler dataset](crate::transform::SamplerDataset) if you
13/// want a probability distribution which is computed lazily.
14pub struct ShuffledDataset<D, I>
15where
16    D: Dataset<I>,
17    I: Clone + Send + Sync,
18{
19    wrapped: SelectionDataset<D, I>,
20}
21
22impl<D, I> ShuffledDataset<D, I>
23where
24    D: Dataset<I>,
25    I: Clone + Send + Sync,
26{
27    /// Creates a new selection dataset with shuffled indices.
28    ///
29    /// This is a thin wrapper around `SelectionDataset::new_shuffled`.
30    ///
31    /// # Arguments
32    ///
33    /// * `dataset` - The original dataset to select from.
34    /// * `rng_source` - The source of the random number generator.
35    ///
36    /// # Returns
37    ///
38    /// A new `ShuffledDataset`.
39    pub fn new<R>(dataset: D, rng_source: R) -> Self
40    where
41        R: Into<RngSource>,
42    {
43        Self {
44            wrapped: SelectionDataset::new_shuffled(dataset, rng_source),
45        }
46    }
47
48    /// Creates a new selection dataset with shuffled indices using a fixed seed.
49    ///
50    /// This is a thin wrapper around `SelectionDataset::new_shuffled_with_seed`.
51    ///
52    /// # Arguments
53    ///
54    /// * `dataset` - The original dataset to select from.
55    /// * `seed` - A fixed seed for the random number generator.
56    ///
57    /// # Returns
58    ///
59    /// A new `ShuffledDataset`.
60    #[deprecated(since = "0.19.0", note = "Use `new(dataset, seed)` instead`")]
61    pub fn with_seed(dataset: D, seed: u64) -> Self {
62        Self::new(dataset, seed)
63    }
64}
65
66impl<D, I> Dataset<I> for ShuffledDataset<D, I>
67where
68    D: Dataset<I>,
69    I: Clone + Send + Sync,
70{
71    fn get(&self, index: usize) -> Option<I> {
72        self.wrapped.get(index)
73    }
74
75    fn len(&self) -> usize {
76        self.wrapped.len()
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::FakeDataset;
84    use crate::transform::selection::shuffled_indices;
85    use rand::SeedableRng;
86    use rand::prelude::StdRng;
87
88    #[test]
89    fn test_shuffled_dataset() {
90        let dataset = FakeDataset::<String>::new(27);
91        let source_items = dataset.iter().collect::<Vec<_>>();
92
93        let seed = 42;
94
95        #[allow(deprecated)]
96        let shuffled = ShuffledDataset::with_seed(dataset, seed);
97
98        let mut rng = StdRng::seed_from_u64(seed);
99        let indices = shuffled_indices(source_items.len(), &mut rng);
100
101        assert_eq!(shuffled.len(), source_items.len());
102
103        let expected_items: Vec<_> = indices
104            .iter()
105            .map(|&i| source_items[i].to_string())
106            .collect();
107        assert_eq!(&shuffled.iter().collect::<Vec<_>>(), &expected_items);
108    }
109}