burn_dataset/transform/
partial.rs1use crate::Dataset;
2use std::{marker::PhantomData, sync::Arc};
3
4#[derive(new, Clone)]
6pub struct PartialDataset<D, I> {
7 dataset: D,
8 start_index: usize,
9 end_index: usize,
10 input: PhantomData<I>,
11}
12
13impl<D, I> PartialDataset<D, I>
14where
15 D: Dataset<I>,
16{
17 pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
19 let dataset = Arc::new(dataset); let mut current = 0;
22 let mut datasets = Vec::with_capacity(num);
23
24 let batch_size = dataset.len() / num;
25
26 for i in 0..num {
27 let start = current;
28 let mut end = current + batch_size;
29
30 if i == (num - 1) {
31 end = dataset.len();
32 }
33
34 let dataset = PartialDataset::new(dataset.clone(), start, end);
35
36 current += batch_size;
37 datasets.push(dataset);
38 }
39
40 datasets
41 }
42
43 pub fn split_chunks(
45 dataset: D,
46 num: usize,
47 batch_size: usize,
48 ) -> Vec<PartialDataset<Arc<D>, I>> {
49 let dataset = Arc::new(dataset); let total_items = dataset.len();
51
52 let total_batches = total_items.div_ceil(batch_size);
54 let batches_per_split = total_batches / num;
55 let extra_batches = total_batches % num;
56
57 let mut datasets = Vec::with_capacity(num);
58 let mut current_batch = 0;
59
60 for i in 0..num {
61 let split_batches = if i < extra_batches {
63 batches_per_split + 1
64 } else {
65 batches_per_split
66 };
67
68 let start_batch = current_batch;
69 let end_batch = start_batch + split_batches;
70
71 let start_index = start_batch * batch_size;
72 let end_index = core::cmp::min(end_batch * batch_size, total_items);
73
74 if start_index < total_items {
75 datasets.push(PartialDataset::new(dataset.clone(), start_index, end_index));
76 }
77
78 current_batch = end_batch;
79 }
80
81 datasets
82 }
83}
84
85impl<D, I> Dataset<I> for PartialDataset<D, I>
86where
87 D: Dataset<I>,
88 I: Clone + Send + Sync,
89{
90 fn get(&self, index: usize) -> Option<I> {
91 let index = index + self.start_index;
92 if index < self.start_index || index >= self.end_index {
93 return None;
94 }
95 self.dataset.get(index)
96 }
97
98 fn len(&self) -> usize {
99 usize::min(self.end_index - self.start_index, self.dataset.len())
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::FakeDataset;
107 use std::collections::HashSet;
108
109 #[test]
110 fn test_start_from_beginning() {
111 let dataset_original = FakeDataset::<String>::new(27);
112 let mut items_original_1 = HashSet::new();
113 let mut items_original_2 = HashSet::new();
114 let mut items_partial = HashSet::new();
115 dataset_original.iter().enumerate().for_each(|(i, item)| {
116 match i >= 10 {
117 true => items_original_2.insert(item),
118 false => items_original_1.insert(item),
119 };
120 });
121
122 let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
123
124 for item in dataset_partial.iter() {
125 items_partial.insert(item);
126 }
127
128 assert_eq!(dataset_partial.len(), 10);
129 assert_eq!(items_original_1, items_partial);
130 for item in items_original_2 {
131 assert!(!items_partial.contains(&item));
132 }
133 }
134
135 #[test]
136 fn test_start_inside() {
137 let dataset_original = FakeDataset::<String>::new(27);
138 let mut items_original_1 = HashSet::new();
139 let mut items_original_2 = HashSet::new();
140 let mut items_partial = HashSet::new();
141
142 dataset_original.iter().enumerate().for_each(|(i, item)| {
143 match !(10..20).contains(&i) {
144 true => items_original_2.insert(item),
145 false => items_original_1.insert(item),
146 };
147 });
148
149 let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
150 for item in dataset_partial.iter() {
151 items_partial.insert(item);
152 }
153
154 assert_eq!(dataset_partial.len(), 10);
155 assert_eq!(items_original_1, items_partial);
156 for item in items_original_2 {
157 assert!(!items_partial.contains(&item));
158 }
159 }
160
161 #[test]
162 fn test_split_contains_all_items_without_duplicates() {
163 let dataset_original = FakeDataset::<String>::new(27);
164 let mut items_original = Vec::new();
165 let mut items_partial = Vec::new();
166 for item in dataset_original.iter() {
167 items_original.push(item);
168 }
169
170 let dataset_partials = PartialDataset::split(dataset_original, 4);
171 let expected_len = [6, 6, 6, 9];
172
173 for (i, dataset) in dataset_partials.iter().enumerate() {
174 assert_eq!(dataset.len(), expected_len[i]);
175 for item in dataset.iter() {
176 items_partial.push(item);
177 }
178 }
179
180 assert_eq!(items_original, items_partial);
181 }
182
183 #[test]
184 fn test_split_chunks_contains_all_items_without_duplicates() {
185 let dataset_original = FakeDataset::<String>::new(27);
186 let mut items_original = Vec::new();
187 let mut items_partial = Vec::new();
188 for item in dataset_original.iter() {
189 items_original.push(item);
190 }
191
192 let dataset_partials = PartialDataset::split_chunks(dataset_original, 4, 5);
193 let expected_len = [10, 10, 5, 2];
196
197 for (i, dataset) in dataset_partials.iter().enumerate() {
198 assert_eq!(dataset.len(), expected_len[i]);
199 for item in dataset.iter() {
200 items_partial.push(item);
201 }
202 }
203
204 assert_eq!(items_original, items_partial);
205 }
206}