Skip to main content

burn_dataset/transform/
selection.rs

1use crate::Dataset;
2use crate::transform::RngSource;
3use rand::prelude::SliceRandom;
4use rand::rngs::StdRng;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8/// Generates a vector of indices from 0 to size - 1.
9///
10/// # Arguments
11///
12/// * `size` - The size of the dataset.
13///
14/// # Returns
15///
16/// A vector containing indices from 0 to size - 1.
17#[inline(always)]
18pub fn iota(size: usize) -> Vec<usize> {
19    (0..size).collect()
20}
21
22/// Generates a shuffled vector of indices up to a size.
23///
24/// # Arguments
25///
26/// * `size` - The size of the dataset to shuffle.
27///
28/// # Returns
29///
30/// A vector of shuffled indices.
31#[inline(always)]
32pub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec<usize> {
33    let mut indices = iota(size);
34    indices.shuffle(rng);
35    indices
36}
37
38/// A dataset that selects a subset of indices from an existing dataset.
39///
40/// Indices may appear multiple times, but they must be within the bounds of the original dataset.
41#[derive(Clone)]
42pub struct SelectionDataset<D, I>
43where
44    D: Dataset<I>,
45    I: Clone + Send + Sync,
46{
47    /// The wrapped dataset from which to select indices.
48    pub wrapped: Arc<D>,
49
50    /// The indices to select from the wrapped dataset.
51    pub indices: Vec<usize>,
52
53    input: PhantomData<I>,
54}
55
56impl<D, I> SelectionDataset<D, I>
57where
58    D: Dataset<I>,
59    I: Clone + Send + Sync,
60{
61    /// Creates a new selection dataset with the given dataset and indices.
62    ///
63    /// Checks that all indices are within the bounds of the dataset.
64    ///
65    /// # Arguments
66    ///
67    /// * `dataset` - The original dataset to select from.
68    /// * `indices` - A slice of indices to select from the dataset.
69    ///   These indices must be within the bounds of the dataset.
70    ///
71    /// # Panics
72    ///
73    /// Panics if any index is out of bounds for the dataset.
74    pub fn from_indices_checked<S>(dataset: S, indices: Vec<usize>) -> Self
75    where
76        S: Into<Arc<D>>,
77    {
78        let dataset = dataset.into();
79
80        let size = dataset.len();
81        if let Some(idx) = indices.iter().find(|&i| *i >= size) {
82            panic!("Index out of bounds for wrapped dataset size: {idx} >= {size}");
83        }
84
85        Self::from_indices_unchecked(dataset, indices)
86    }
87
88    /// Creates a new selection dataset with the given dataset and indices without checking bounds.
89    ///
90    /// # Arguments
91    ///
92    /// * `dataset` - The original dataset to select from.
93    /// * `indices` - A vector of indices to select from the dataset.
94    ///
95    /// # Safety
96    ///
97    /// This function does not check if the indices are within the bounds of the dataset.
98    pub fn from_indices_unchecked<S>(dataset: S, indices: Vec<usize>) -> Self
99    where
100        S: Into<Arc<D>>,
101    {
102        Self {
103            wrapped: dataset.into(),
104            indices,
105            input: PhantomData,
106        }
107    }
108
109    /// Creates a new selection dataset that selects all indices from the dataset.
110    ///
111    /// This allocates a 1-to-1 mapping of indices to the dataset size,
112    /// essentially functioning as a no-op selection. This is only useful
113    /// when the dataset will later be shuffled or transformed in place.
114    ///
115    /// # Arguments
116    ///
117    /// * `dataset` - The original dataset to select from.
118    ///
119    /// # Returns
120    ///
121    /// A new `SelectionDataset` that selects all indices from the dataset.
122    pub fn new_select_all<S>(dataset: S) -> Self
123    where
124        S: Into<Arc<D>>,
125    {
126        let dataset = dataset.into();
127        let size = dataset.len();
128        Self::from_indices_unchecked(dataset, iota(size))
129    }
130
131    /// Creates a new selection dataset with shuffled indices.
132    ///
133    /// Selects every index of the dataset and shuffles them
134    /// with randomness from the provided random number generator.
135    ///
136    /// # Arguments
137    ///
138    /// * `dataset` - The original dataset to select from.
139    /// * `rng` - A mutable reference to a random number generator.
140    ///
141    /// # Returns
142    ///
143    /// A new `SelectionDataset` with shuffled indices.
144    pub fn new_shuffled<S, R>(dataset: S, rng_source: R) -> Self
145    where
146        S: Into<Arc<D>>,
147        R: Into<RngSource>,
148    {
149        let mut this = Self::new_select_all(dataset);
150        this.shuffle(rng_source);
151        this
152    }
153
154    /// Shuffles the indices of the dataset using a mutable random number generator.
155    ///
156    /// This method modifies the dataset in place, shuffling the indices.
157    ///
158    /// # Arguments
159    ///
160    /// * `rng` - A mutable reference to a random number generator.
161    pub fn shuffle<R>(&mut self, rng_source: R)
162    where
163        R: Into<RngSource>,
164    {
165        let mut rng: StdRng = rng_source.into().into();
166        self.indices.shuffle(&mut rng)
167    }
168
169    /// Creates a new dataset that is a slice of the current selection dataset.
170    ///
171    /// Slices the *selection indices* from ``[start..end]``.
172    ///
173    /// Independent of future shuffles on the parent, but shares the same wrapped dataset.
174    ///
175    ///
176    /// # Arguments
177    ///
178    /// * `start` - The start of the range.
179    /// * `end` - The end of the range (exclusive).
180    // TODO: SliceArg in burn-tensor should be lifted to burn-std; this should use SliceArg.
181    pub fn slice(&self, start: usize, end: usize) -> Self {
182        Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec())
183    }
184
185    /// Split into `num` datasets by slicing the selection indices evenly.
186    ///
187    /// Split is done via `slice`, so the datasets share the same wrapped dataset.
188    ///
189    /// Independent of future shuffles on the parent, but shares the same wrapped dataset.
190    ///
191    /// # Arguments
192    ///
193    /// * `num` - The number of datasets to split into.
194    ///
195    /// # Returns
196    ///
197    /// A vector of `SelectionDataset` instances, each containing a subset of the indices.
198    pub fn split(&self, num: usize) -> Vec<Self> {
199        let n = self.indices.len();
200
201        let mut current = 0;
202        let mut datasets = Vec::with_capacity(num);
203
204        let batch_size = n / num;
205        for i in 0..num {
206            let start = current;
207            let mut end = current + batch_size;
208
209            if i == (num - 1) {
210                end = n;
211            }
212
213            let dataset = self.slice(start, end);
214
215            current += batch_size;
216            datasets.push(dataset);
217        }
218
219        datasets
220    }
221}
222
223impl<D, I> Dataset<I> for SelectionDataset<D, I>
224where
225    D: Dataset<I>,
226    I: Clone + Send + Sync,
227{
228    fn get(&self, index: usize) -> Option<I> {
229        let index = self.indices.get(index)?;
230        self.wrapped.get(*index)
231    }
232
233    fn len(&self) -> usize {
234        self.indices.len()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::FakeDataset;
242    use rand::SeedableRng;
243
244    #[test]
245    fn test_iota() {
246        let size = 10;
247        let indices = iota(size);
248        assert_eq!(indices.len(), size);
249        assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
250    }
251
252    #[test]
253    fn test_shuffled_indices_same_seed_is_deterministic() {
254        let size = 10;
255
256        let mut rng1 = StdRng::seed_from_u64(10);
257        // `StdRng` is no longer `Clone`, so its internal state cannot be duplicated.
258        // To test determinism, we must explicitly create a second RNG from the same seed.
259        let mut rng2 = StdRng::seed_from_u64(10);
260
261        let mut expected = iota(size);
262        expected.shuffle(&mut rng1);
263
264        let indices = shuffled_indices(size, &mut rng2);
265
266        assert_eq!(indices, expected);
267    }
268
269    #[test]
270    fn test_shuffled_indices_forked_rngs_differ() {
271        let size = 10;
272
273        let mut rng1 = StdRng::seed_from_u64(10);
274        let mut rng2 = rng1.fork();
275
276        let mut a = iota(size);
277        let mut b = iota(size);
278
279        a.shuffle(&mut rng1);
280        b.shuffle(&mut rng2);
281
282        assert_ne!(a, b);
283    }
284
285    #[should_panic(expected = "Index out of bounds for wrapped dataset size: 300 >= 27")]
286    #[test]
287    fn test_from_indices_checked_panics() {
288        let source_dataset = FakeDataset::<String>::new(27);
289        let indices: Vec<usize> = vec![15, 1, 12, 300];
290        SelectionDataset::from_indices_checked(source_dataset, indices);
291    }
292
293    #[test]
294    fn test_checked_selection_dataset() {
295        let source_dataset = FakeDataset::<String>::new(27);
296
297        let indices: Vec<usize> = vec![15, 1, 12, 12];
298        let expected: Vec<String> = indices
299            .iter()
300            .map(|i| source_dataset.get(*i).unwrap())
301            .collect();
302
303        let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone());
304
305        assert_eq!(&selection.indices, &indices);
306
307        let items = selection.iter().collect::<Vec<_>>();
308
309        assert_eq!(items, expected);
310    }
311
312    #[test]
313    fn test_shuffled_dataset() {
314        let dataset = FakeDataset::<String>::new(27);
315        let source_items = dataset.iter().collect::<Vec<_>>();
316
317        let selection = SelectionDataset::new_shuffled(dataset, 42);
318
319        let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42));
320
321        assert_eq!(&selection.indices, &indices);
322        assert_eq!(selection.len(), source_items.len());
323
324        let expected_items: Vec<_> = indices
325            .iter()
326            .map(|&i| source_items[i].to_string())
327            .collect();
328        assert_eq!(&selection.iter().collect::<Vec<_>>(), &expected_items);
329    }
330
331    #[test]
332    fn test_slice() {
333        let dataset = FakeDataset::<String>::new(27);
334        let source_items = dataset.iter().collect::<Vec<_>>();
335
336        let selection = SelectionDataset::new_select_all(dataset);
337
338        let start = 5;
339        let end = 15;
340        let sliced_selection = selection.slice(start, end);
341
342        assert_eq!(sliced_selection.len(), end - start);
343
344        #[allow(clippy::needless_range_loop)]
345        for i in start..end {
346            assert_eq!(
347                sliced_selection.get(i - start),
348                Some(source_items[i].to_string())
349            );
350        }
351    }
352
353    #[test]
354    fn test_split() {
355        let dataset = FakeDataset::<String>::new(28);
356        let source_items = dataset.iter().collect::<Vec<_>>();
357
358        let selection = SelectionDataset::new_select_all(dataset);
359
360        let split_contents: Vec<Vec<_>> = selection
361            .split(3)
362            .iter()
363            .map(|d| d.iter().collect::<Vec<_>>())
364            .collect();
365        assert_eq!(
366            split_contents,
367            vec![
368                source_items[0..9].to_vec(),
369                source_items[9..18].to_vec(),
370                source_items[18..28].to_vec(),
371            ]
372        );
373    }
374}