burn_dataset/transform/
selection.rs1use crate::Dataset;
2use crate::transform::RngSource;
3use rand::prelude::SliceRandom;
4use rand::rngs::StdRng;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8#[inline(always)]
18pub fn iota(size: usize) -> Vec<usize> {
19 (0..size).collect()
20}
21
22#[inline(always)]
32pub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec<usize> {
33 let mut indices = iota(size);
34 indices.shuffle(rng);
35 indices
36}
37
38#[derive(Clone)]
42pub struct SelectionDataset<D, I>
43where
44 D: Dataset<I>,
45 I: Clone + Send + Sync,
46{
47 pub wrapped: Arc<D>,
49
50 pub indices: Vec<usize>,
52
53 input: PhantomData<I>,
54}
55
56impl<D, I> SelectionDataset<D, I>
57where
58 D: Dataset<I>,
59 I: Clone + Send + Sync,
60{
61 pub fn from_indices_checked<S>(dataset: S, indices: Vec<usize>) -> Self
75 where
76 S: Into<Arc<D>>,
77 {
78 let dataset = dataset.into();
79
80 let size = dataset.len();
81 if let Some(idx) = indices.iter().find(|&i| *i >= size) {
82 panic!("Index out of bounds for wrapped dataset size: {idx} >= {size}");
83 }
84
85 Self::from_indices_unchecked(dataset, indices)
86 }
87
88 pub fn from_indices_unchecked<S>(dataset: S, indices: Vec<usize>) -> Self
99 where
100 S: Into<Arc<D>>,
101 {
102 Self {
103 wrapped: dataset.into(),
104 indices,
105 input: PhantomData,
106 }
107 }
108
109 pub fn new_select_all<S>(dataset: S) -> Self
123 where
124 S: Into<Arc<D>>,
125 {
126 let dataset = dataset.into();
127 let size = dataset.len();
128 Self::from_indices_unchecked(dataset, iota(size))
129 }
130
131 pub fn new_shuffled<S, R>(dataset: S, rng_source: R) -> Self
145 where
146 S: Into<Arc<D>>,
147 R: Into<RngSource>,
148 {
149 let mut this = Self::new_select_all(dataset);
150 this.shuffle(rng_source);
151 this
152 }
153
154 pub fn shuffle<R>(&mut self, rng_source: R)
162 where
163 R: Into<RngSource>,
164 {
165 let mut rng: StdRng = rng_source.into().into();
166 self.indices.shuffle(&mut rng)
167 }
168
169 pub fn slice(&self, start: usize, end: usize) -> Self {
182 Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec())
183 }
184
185 pub fn split(&self, num: usize) -> Vec<Self> {
199 let n = self.indices.len();
200
201 let mut current = 0;
202 let mut datasets = Vec::with_capacity(num);
203
204 let batch_size = n / num;
205 for i in 0..num {
206 let start = current;
207 let mut end = current + batch_size;
208
209 if i == (num - 1) {
210 end = n;
211 }
212
213 let dataset = self.slice(start, end);
214
215 current += batch_size;
216 datasets.push(dataset);
217 }
218
219 datasets
220 }
221}
222
223impl<D, I> Dataset<I> for SelectionDataset<D, I>
224where
225 D: Dataset<I>,
226 I: Clone + Send + Sync,
227{
228 fn get(&self, index: usize) -> Option<I> {
229 let index = self.indices.get(index)?;
230 self.wrapped.get(*index)
231 }
232
233 fn len(&self) -> usize {
234 self.indices.len()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::FakeDataset;
242 use rand::SeedableRng;
243
244 #[test]
245 fn test_iota() {
246 let size = 10;
247 let indices = iota(size);
248 assert_eq!(indices.len(), size);
249 assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
250 }
251
252 #[test]
253 fn test_shuffled_indices() {
254 let size = 10;
255
256 let mut rng1 = StdRng::seed_from_u64(10);
257 let mut rng2 = rng1.clone();
258
259 let mut expected = iota(size);
260 expected.shuffle(&mut rng1);
261
262 let indices = shuffled_indices(size, &mut rng2);
263
264 assert_eq!(indices, expected);
265 }
266
267 #[should_panic(expected = "Index out of bounds for wrapped dataset size: 300 >= 27")]
268 #[test]
269 fn test_from_indices_checked_panics() {
270 let source_dataset = FakeDataset::<String>::new(27);
271 let indices: Vec<usize> = vec![15, 1, 12, 300];
272 SelectionDataset::from_indices_checked(source_dataset, indices);
273 }
274
275 #[test]
276 fn test_checked_selection_dataset() {
277 let source_dataset = FakeDataset::<String>::new(27);
278
279 let indices: Vec<usize> = vec![15, 1, 12, 12];
280 let expected: Vec<String> = indices
281 .iter()
282 .map(|i| source_dataset.get(*i).unwrap())
283 .collect();
284
285 let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone());
286
287 assert_eq!(&selection.indices, &indices);
288
289 let items = selection.iter().collect::<Vec<_>>();
290
291 assert_eq!(items, expected);
292 }
293
294 #[test]
295 fn test_shuffled_dataset() {
296 let dataset = FakeDataset::<String>::new(27);
297 let source_items = dataset.iter().collect::<Vec<_>>();
298
299 let selection = SelectionDataset::new_shuffled(dataset, 42);
300
301 let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42));
302
303 assert_eq!(&selection.indices, &indices);
304 assert_eq!(selection.len(), source_items.len());
305
306 let expected_items: Vec<_> = indices
307 .iter()
308 .map(|&i| source_items[i].to_string())
309 .collect();
310 assert_eq!(&selection.iter().collect::<Vec<_>>(), &expected_items);
311 }
312
313 #[test]
314 fn test_slice() {
315 let dataset = FakeDataset::<String>::new(27);
316 let source_items = dataset.iter().collect::<Vec<_>>();
317
318 let selection = SelectionDataset::new_select_all(dataset);
319
320 let start = 5;
321 let end = 15;
322 let sliced_selection = selection.slice(start, end);
323
324 assert_eq!(sliced_selection.len(), end - start);
325
326 #[allow(clippy::needless_range_loop)]
327 for i in start..end {
328 assert_eq!(
329 sliced_selection.get(i - start),
330 Some(source_items[i].to_string())
331 );
332 }
333 }
334
335 #[test]
336 fn test_split() {
337 let dataset = FakeDataset::<String>::new(28);
338 let source_items = dataset.iter().collect::<Vec<_>>();
339
340 let selection = SelectionDataset::new_select_all(dataset);
341
342 let split_contents: Vec<Vec<_>> = selection
343 .split(3)
344 .iter()
345 .map(|d| d.iter().collect::<Vec<_>>())
346 .collect();
347 assert_eq!(
348 split_contents,
349 vec![
350 source_items[0..9].to_vec(),
351 source_items[9..18].to_vec(),
352 source_items[18..28].to_vec(),
353 ]
354 );
355 }
356}