Skip to main content

burn_dataset/transform/
partial.rs

1use crate::Dataset;
2use std::{marker::PhantomData, sync::Arc};
3
4/// Only use a fraction of an existing dataset lazily.
5#[derive(new, Clone)]
6pub struct PartialDataset<D, I> {
7    dataset: D,
8    start_index: usize,
9    end_index: usize,
10    input: PhantomData<I>,
11}
12
13impl<D, I> PartialDataset<D, I>
14where
15    D: Dataset<I>,
16{
17    /// Splits a dataset into multiple partial datasets.
18    pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
19        let dataset = Arc::new(dataset); // cheap cloning.
20
21        let mut current = 0;
22        let mut datasets = Vec::with_capacity(num);
23
24        let batch_size = dataset.len() / num;
25
26        for i in 0..num {
27            let start = current;
28            let mut end = current + batch_size;
29
30            if i == (num - 1) {
31                end = dataset.len();
32            }
33
34            let dataset = PartialDataset::new(dataset.clone(), start, end);
35
36            current += batch_size;
37            datasets.push(dataset);
38        }
39
40        datasets
41    }
42
43    /// Splits a dataset by distributing complete chunks/batches across multiple partial datasets.
44    pub fn split_chunks(
45        dataset: D,
46        num: usize,
47        batch_size: usize,
48    ) -> Vec<PartialDataset<Arc<D>, I>> {
49        let dataset = Arc::new(dataset); // cheap cloning.
50        let total_items = dataset.len();
51
52        // Total number of complete batches
53        let total_batches = total_items.div_ceil(batch_size);
54        let batches_per_split = total_batches / num;
55        let extra_batches = total_batches % num;
56
57        let mut datasets = Vec::with_capacity(num);
58        let mut current_batch = 0;
59
60        for i in 0..num {
61            // Extra batches distributed across first splits
62            let split_batches = if i < extra_batches {
63                batches_per_split + 1
64            } else {
65                batches_per_split
66            };
67
68            let start_batch = current_batch;
69            let end_batch = start_batch + split_batches;
70
71            let start_index = start_batch * batch_size;
72            let end_index = core::cmp::min(end_batch * batch_size, total_items);
73
74            if start_index < total_items {
75                datasets.push(PartialDataset::new(dataset.clone(), start_index, end_index));
76            }
77
78            current_batch = end_batch;
79        }
80
81        datasets
82    }
83}
84
85impl<D, I> Dataset<I> for PartialDataset<D, I>
86where
87    D: Dataset<I>,
88    I: Clone + Send + Sync,
89{
90    fn get(&self, index: usize) -> Option<I> {
91        let index = index + self.start_index;
92        if index < self.start_index || index >= self.end_index {
93            return None;
94        }
95        self.dataset.get(index)
96    }
97
98    fn len(&self) -> usize {
99        usize::min(self.end_index - self.start_index, self.dataset.len())
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::FakeDataset;
107    use std::collections::HashSet;
108
109    #[test]
110    fn test_start_from_beginning() {
111        let dataset_original = FakeDataset::<String>::new(27);
112        let mut items_original_1 = HashSet::new();
113        let mut items_original_2 = HashSet::new();
114        let mut items_partial = HashSet::new();
115        dataset_original.iter().enumerate().for_each(|(i, item)| {
116            match i >= 10 {
117                true => items_original_2.insert(item),
118                false => items_original_1.insert(item),
119            };
120        });
121
122        let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
123
124        for item in dataset_partial.iter() {
125            items_partial.insert(item);
126        }
127
128        assert_eq!(dataset_partial.len(), 10);
129        assert_eq!(items_original_1, items_partial);
130        for item in items_original_2 {
131            assert!(!items_partial.contains(&item));
132        }
133    }
134
135    #[test]
136    fn test_start_inside() {
137        let dataset_original = FakeDataset::<String>::new(27);
138        let mut items_original_1 = HashSet::new();
139        let mut items_original_2 = HashSet::new();
140        let mut items_partial = HashSet::new();
141
142        dataset_original.iter().enumerate().for_each(|(i, item)| {
143            match !(10..20).contains(&i) {
144                true => items_original_2.insert(item),
145                false => items_original_1.insert(item),
146            };
147        });
148
149        let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
150        for item in dataset_partial.iter() {
151            items_partial.insert(item);
152        }
153
154        assert_eq!(dataset_partial.len(), 10);
155        assert_eq!(items_original_1, items_partial);
156        for item in items_original_2 {
157            assert!(!items_partial.contains(&item));
158        }
159    }
160
161    #[test]
162    fn test_split_contains_all_items_without_duplicates() {
163        let dataset_original = FakeDataset::<String>::new(27);
164        let mut items_original = Vec::new();
165        let mut items_partial = Vec::new();
166        for item in dataset_original.iter() {
167            items_original.push(item);
168        }
169
170        let dataset_partials = PartialDataset::split(dataset_original, 4);
171        let expected_len = [6, 6, 6, 9];
172
173        for (i, dataset) in dataset_partials.iter().enumerate() {
174            assert_eq!(dataset.len(), expected_len[i]);
175            for item in dataset.iter() {
176                items_partial.push(item);
177            }
178        }
179
180        assert_eq!(items_original, items_partial);
181    }
182
183    #[test]
184    fn test_split_chunks_contains_all_items_without_duplicates() {
185        let dataset_original = FakeDataset::<String>::new(27);
186        let mut items_original = Vec::new();
187        let mut items_partial = Vec::new();
188        for item in dataset_original.iter() {
189            items_original.push(item);
190        }
191
192        let dataset_partials = PartialDataset::split_chunks(dataset_original, 4, 5);
193        // [(2 * 5), (2 * 5), 5, 2] -> 5 complete chunks + 1 incomplete with 2 remaining items
194        // OTOH, `split(dataset, 4)` would yield [6, 6, 6, 9] -> 4 incomplete chunks + 4 incomplete with [1, 1, 1, 4]
195        let expected_len = [10, 10, 5, 2];
196
197        for (i, dataset) in dataset_partials.iter().enumerate() {
198            assert_eq!(dataset.len(), expected_len[i]);
199            for item in dataset.iter() {
200                items_partial.push(item);
201            }
202        }
203
204        assert_eq!(items_original, items_partial);
205    }
206}