burn_dataset/transform/
shuffle.rs1use crate::Dataset;
2use crate::transform::{RngSource, SelectionDataset};
3
4pub struct ShuffledDataset<D, I>
15where
16 D: Dataset<I>,
17 I: Clone + Send + Sync,
18{
19 wrapped: SelectionDataset<D, I>,
20}
21
22impl<D, I> ShuffledDataset<D, I>
23where
24 D: Dataset<I>,
25 I: Clone + Send + Sync,
26{
27 pub fn new<R>(dataset: D, rng_source: R) -> Self
40 where
41 R: Into<RngSource>,
42 {
43 Self {
44 wrapped: SelectionDataset::new_shuffled(dataset, rng_source),
45 }
46 }
47
48 #[deprecated(since = "0.19.0", note = "Use `new(dataset, seed)` instead`")]
61 pub fn with_seed(dataset: D, seed: u64) -> Self {
62 Self::new(dataset, seed)
63 }
64}
65
66impl<D, I> Dataset<I> for ShuffledDataset<D, I>
67where
68 D: Dataset<I>,
69 I: Clone + Send + Sync,
70{
71 fn get(&self, index: usize) -> Option<I> {
72 self.wrapped.get(index)
73 }
74
75 fn len(&self) -> usize {
76 self.wrapped.len()
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::FakeDataset;
84 use crate::transform::selection::shuffled_indices;
85 use rand::SeedableRng;
86 use rand::prelude::StdRng;
87
88 #[test]
89 fn test_shuffled_dataset() {
90 let dataset = FakeDataset::<String>::new(27);
91 let source_items = dataset.iter().collect::<Vec<_>>();
92
93 let seed = 42;
94
95 #[allow(deprecated)]
96 let shuffled = ShuffledDataset::with_seed(dataset, seed);
97
98 let mut rng = StdRng::seed_from_u64(seed);
99 let indices = shuffled_indices(source_items.len(), &mut rng);
100
101 assert_eq!(shuffled.len(), source_items.len());
102
103 let expected_items: Vec<_> = indices
104 .iter()
105 .map(|&i| source_items[i].to_string())
106 .collect();
107 assert_eq!(&shuffled.iter().collect::<Vec<_>>(), &expected_items);
108 }
109}