burn_dataset/transform/
partial.rs

1use crate::Dataset;
2use std::{marker::PhantomData, sync::Arc};
3
4/// Only use a fraction of an existing dataset lazily.
5#[derive(new)]
6pub struct PartialDataset<D, I> {
7    dataset: D,
8    start_index: usize,
9    end_index: usize,
10    input: PhantomData<I>,
11}
12
13impl<D, I> PartialDataset<D, I>
14where
15    D: Dataset<I>,
16{
17    /// Splits a dataset into multiple partial datasets.
18    pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
19        let dataset = Arc::new(dataset); // cheap cloning.
20
21        let mut current = 0;
22        let mut datasets = Vec::with_capacity(num);
23
24        let batch_size = dataset.len() / num;
25
26        for i in 0..num {
27            let start = current;
28            let mut end = current + batch_size;
29
30            if i == (num - 1) {
31                end = dataset.len();
32            }
33
34            let dataset = PartialDataset::new(dataset.clone(), start, end);
35
36            current += batch_size;
37            datasets.push(dataset);
38        }
39
40        datasets
41    }
42}
43
44impl<D, I> Dataset<I> for PartialDataset<D, I>
45where
46    D: Dataset<I>,
47    I: Clone + Send + Sync,
48{
49    fn get(&self, index: usize) -> Option<I> {
50        let index = index + self.start_index;
51        if index < self.start_index || index >= self.end_index {
52            return None;
53        }
54        self.dataset.get(index)
55    }
56
57    fn len(&self) -> usize {
58        usize::min(self.end_index - self.start_index, self.dataset.len())
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    use crate::FakeDataset;
66    use std::collections::HashSet;
67
68    #[test]
69    fn test_start_from_beginning() {
70        let dataset_original = FakeDataset::<String>::new(27);
71        let mut items_original_1 = HashSet::new();
72        let mut items_original_2 = HashSet::new();
73        let mut items_partial = HashSet::new();
74        dataset_original.iter().enumerate().for_each(|(i, item)| {
75            match i >= 10 {
76                true => items_original_2.insert(item),
77                false => items_original_1.insert(item),
78            };
79        });
80
81        let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
82
83        for item in dataset_partial.iter() {
84            items_partial.insert(item);
85        }
86
87        assert_eq!(dataset_partial.len(), 10);
88        assert_eq!(items_original_1, items_partial);
89        for item in items_original_2 {
90            assert!(!items_partial.contains(&item));
91        }
92    }
93
94    #[test]
95    fn test_start_inside() {
96        let dataset_original = FakeDataset::<String>::new(27);
97        let mut items_original_1 = HashSet::new();
98        let mut items_original_2 = HashSet::new();
99        let mut items_partial = HashSet::new();
100
101        dataset_original.iter().enumerate().for_each(|(i, item)| {
102            match !(10..20).contains(&i) {
103                true => items_original_2.insert(item),
104                false => items_original_1.insert(item),
105            };
106        });
107
108        let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
109        for item in dataset_partial.iter() {
110            items_partial.insert(item);
111        }
112
113        assert_eq!(dataset_partial.len(), 10);
114        assert_eq!(items_original_1, items_partial);
115        for item in items_original_2 {
116            assert!(!items_partial.contains(&item));
117        }
118    }
119
120    #[test]
121    fn test_split_contains_all_items_without_duplicates() {
122        let dataset_original = FakeDataset::<String>::new(27);
123        let mut items_original = Vec::new();
124        let mut items_partial = Vec::new();
125        for item in dataset_original.iter() {
126            items_original.push(item);
127        }
128
129        let dataset_partials = PartialDataset::split(dataset_original, 4);
130
131        for dataset in dataset_partials {
132            for item in dataset.iter() {
133                items_partial.push(item);
134            }
135        }
136
137        assert_eq!(items_original, items_partial);
138    }
139}