Skip to main content

burn_dataset/transform/
selection.rs

1use crate::Dataset;
2use crate::transform::RngSource;
3use rand::prelude::SliceRandom;
4use rand::rngs::StdRng;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8/// Generates a vector of indices from 0 to size - 1.
9///
10/// # Arguments
11///
12/// * `size` - The size of the dataset.
13///
14/// # Returns
15///
16/// A vector containing indices from 0 to size - 1.
17#[inline(always)]
18pub fn iota(size: usize) -> Vec<usize> {
19    (0..size).collect()
20}
21
22/// Generates a shuffled vector of indices up to a size.
23///
24/// # Arguments
25///
26/// * `size` - The size of the dataset to shuffle.
27///
28/// # Returns
29///
30/// A vector of shuffled indices.
31#[inline(always)]
32pub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec<usize> {
33    let mut indices = iota(size);
34    indices.shuffle(rng);
35    indices
36}
37
38/// A dataset that selects a subset of indices from an existing dataset.
39///
40/// Indices may appear multiple times, but they must be within the bounds of the original dataset.
41#[derive(Clone)]
42pub struct SelectionDataset<D, I>
43where
44    D: Dataset<I>,
45    I: Clone + Send + Sync,
46{
47    /// The wrapped dataset from which to select indices.
48    pub wrapped: Arc<D>,
49
50    /// The indices to select from the wrapped dataset.
51    pub indices: Vec<usize>,
52
53    input: PhantomData<I>,
54}
55
56impl<D, I> SelectionDataset<D, I>
57where
58    D: Dataset<I>,
59    I: Clone + Send + Sync,
60{
61    /// Creates a new selection dataset with the given dataset and indices.
62    ///
63    /// Checks that all indices are within the bounds of the dataset.
64    ///
65    /// # Arguments
66    ///
67    /// * `dataset` - The original dataset to select from.
68    /// * `indices` - A slice of indices to select from the dataset.
69    ///   These indices must be within the bounds of the dataset.
70    ///
71    /// # Panics
72    ///
73    /// Panics if any index is out of bounds for the dataset.
74    pub fn from_indices_checked<S>(dataset: S, indices: Vec<usize>) -> Self
75    where
76        S: Into<Arc<D>>,
77    {
78        let dataset = dataset.into();
79
80        let size = dataset.len();
81        if let Some(idx) = indices.iter().find(|&i| *i >= size) {
82            panic!("Index out of bounds for wrapped dataset size: {idx} >= {size}");
83        }
84
85        Self::from_indices_unchecked(dataset, indices)
86    }
87
88    /// Creates a new selection dataset with the given dataset and indices without checking bounds.
89    ///
90    /// # Arguments
91    ///
92    /// * `dataset` - The original dataset to select from.
93    /// * `indices` - A vector of indices to select from the dataset.
94    ///
95    /// # Safety
96    ///
97    /// This function does not check if the indices are within the bounds of the dataset.
98    pub fn from_indices_unchecked<S>(dataset: S, indices: Vec<usize>) -> Self
99    where
100        S: Into<Arc<D>>,
101    {
102        Self {
103            wrapped: dataset.into(),
104            indices,
105            input: PhantomData,
106        }
107    }
108
109    /// Creates a new selection dataset that selects all indices from the dataset.
110    ///
111    /// This allocates a 1-to-1 mapping of indices to the dataset size,
112    /// essentially functioning as a no-op selection. This is only useful
113    /// when the dataset will later be shuffled or transformed in place.
114    ///
115    /// # Arguments
116    ///
117    /// * `dataset` - The original dataset to select from.
118    ///
119    /// # Returns
120    ///
121    /// A new `SelectionDataset` that selects all indices from the dataset.
122    pub fn new_select_all<S>(dataset: S) -> Self
123    where
124        S: Into<Arc<D>>,
125    {
126        let dataset = dataset.into();
127        let size = dataset.len();
128        Self::from_indices_unchecked(dataset, iota(size))
129    }
130
131    /// Creates a new selection dataset with shuffled indices.
132    ///
133    /// Selects every index of the dataset and shuffles them
134    /// with randomness from the provided random number generator.
135    ///
136    /// # Arguments
137    ///
138    /// * `dataset` - The original dataset to select from.
139    /// * `rng` - A mutable reference to a random number generator.
140    ///
141    /// # Returns
142    ///
143    /// A new `SelectionDataset` with shuffled indices.
144    pub fn new_shuffled<S, R>(dataset: S, rng_source: R) -> Self
145    where
146        S: Into<Arc<D>>,
147        R: Into<RngSource>,
148    {
149        let mut this = Self::new_select_all(dataset);
150        this.shuffle(rng_source);
151        this
152    }
153
154    /// Shuffles the indices of the dataset using a mutable random number generator.
155    ///
156    /// This method modifies the dataset in place, shuffling the indices.
157    ///
158    /// # Arguments
159    ///
160    /// * `rng` - A mutable reference to a random number generator.
161    pub fn shuffle<R>(&mut self, rng_source: R)
162    where
163        R: Into<RngSource>,
164    {
165        let mut rng: StdRng = rng_source.into().into();
166        self.indices.shuffle(&mut rng)
167    }
168
169    /// Creates a new dataset that is a slice of the current selection dataset.
170    ///
171    /// Slices the *selection indices* from ``[start..end]``.
172    ///
173    /// Independent of future shuffles on the parent, but shares the same wrapped dataset.
174    ///
175    ///
176    /// # Arguments
177    ///
178    /// * `start` - The start of the range.
179    /// * `end` - The end of the range (exclusive).
180    // TODO: SliceArg in burn-tensor should be lifted to burn-std; this should use SliceArg.
181    pub fn slice(&self, start: usize, end: usize) -> Self {
182        Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec())
183    }
184
185    /// Split into `num` datasets by slicing the selection indices evenly.
186    ///
187    /// Split is done via `slice`, so the datasets share the same wrapped dataset.
188    ///
189    /// Independent of future shuffles on the parent, but shares the same wrapped dataset.
190    ///
191    /// # Arguments
192    ///
193    /// * `num` - The number of datasets to split into.
194    ///
195    /// # Returns
196    ///
197    /// A vector of `SelectionDataset` instances, each containing a subset of the indices.
198    pub fn split(&self, num: usize) -> Vec<Self> {
199        let n = self.indices.len();
200
201        let mut current = 0;
202        let mut datasets = Vec::with_capacity(num);
203
204        let batch_size = n / num;
205        for i in 0..num {
206            let start = current;
207            let mut end = current + batch_size;
208
209            if i == (num - 1) {
210                end = n;
211            }
212
213            let dataset = self.slice(start, end);
214
215            current += batch_size;
216            datasets.push(dataset);
217        }
218
219        datasets
220    }
221}
222
223impl<D, I> Dataset<I> for SelectionDataset<D, I>
224where
225    D: Dataset<I>,
226    I: Clone + Send + Sync,
227{
228    fn get(&self, index: usize) -> Option<I> {
229        let index = self.indices.get(index)?;
230        self.wrapped.get(*index)
231    }
232
233    fn len(&self) -> usize {
234        self.indices.len()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::FakeDataset;
242    use rand::SeedableRng;
243
244    #[test]
245    fn test_iota() {
246        let size = 10;
247        let indices = iota(size);
248        assert_eq!(indices.len(), size);
249        assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
250    }
251
252    #[test]
253    fn test_shuffled_indices() {
254        let size = 10;
255
256        let mut rng1 = StdRng::seed_from_u64(10);
257        let mut rng2 = rng1.clone();
258
259        let mut expected = iota(size);
260        expected.shuffle(&mut rng1);
261
262        let indices = shuffled_indices(size, &mut rng2);
263
264        assert_eq!(indices, expected);
265    }
266
267    #[should_panic(expected = "Index out of bounds for wrapped dataset size: 300 >= 27")]
268    #[test]
269    fn test_from_indices_checked_panics() {
270        let source_dataset = FakeDataset::<String>::new(27);
271        let indices: Vec<usize> = vec![15, 1, 12, 300];
272        SelectionDataset::from_indices_checked(source_dataset, indices);
273    }
274
275    #[test]
276    fn test_checked_selection_dataset() {
277        let source_dataset = FakeDataset::<String>::new(27);
278
279        let indices: Vec<usize> = vec![15, 1, 12, 12];
280        let expected: Vec<String> = indices
281            .iter()
282            .map(|i| source_dataset.get(*i).unwrap())
283            .collect();
284
285        let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone());
286
287        assert_eq!(&selection.indices, &indices);
288
289        let items = selection.iter().collect::<Vec<_>>();
290
291        assert_eq!(items, expected);
292    }
293
294    #[test]
295    fn test_shuffled_dataset() {
296        let dataset = FakeDataset::<String>::new(27);
297        let source_items = dataset.iter().collect::<Vec<_>>();
298
299        let selection = SelectionDataset::new_shuffled(dataset, 42);
300
301        let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42));
302
303        assert_eq!(&selection.indices, &indices);
304        assert_eq!(selection.len(), source_items.len());
305
306        let expected_items: Vec<_> = indices
307            .iter()
308            .map(|&i| source_items[i].to_string())
309            .collect();
310        assert_eq!(&selection.iter().collect::<Vec<_>>(), &expected_items);
311    }
312
313    #[test]
314    fn test_slice() {
315        let dataset = FakeDataset::<String>::new(27);
316        let source_items = dataset.iter().collect::<Vec<_>>();
317
318        let selection = SelectionDataset::new_select_all(dataset);
319
320        let start = 5;
321        let end = 15;
322        let sliced_selection = selection.slice(start, end);
323
324        assert_eq!(sliced_selection.len(), end - start);
325
326        #[allow(clippy::needless_range_loop)]
327        for i in start..end {
328            assert_eq!(
329                sliced_selection.get(i - start),
330                Some(source_items[i].to_string())
331            );
332        }
333    }
334
335    #[test]
336    fn test_split() {
337        let dataset = FakeDataset::<String>::new(28);
338        let source_items = dataset.iter().collect::<Vec<_>>();
339
340        let selection = SelectionDataset::new_select_all(dataset);
341
342        let split_contents: Vec<Vec<_>> = selection
343            .split(3)
344            .iter()
345            .map(|d| d.iter().collect::<Vec<_>>())
346            .collect();
347        assert_eq!(
348            split_contents,
349            vec![
350                source_items[0..9].to_vec(),
351                source_items[9..18].to_vec(),
352                source_items[18..28].to_vec(),
353            ]
354        );
355    }
356}