burn_core/data/dataloader/
batch.rs1use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
2use burn_dataset::{
3 Dataset,
4 transform::{PartialDataset, ShuffledDataset},
5};
6use burn_tensor::backend::Backend;
7use std::ops::DerefMut;
8use std::sync::Arc;
9
10pub struct BatchDataLoader<B: Backend, I, O> {
12 strategy: Box<dyn BatchStrategy<I>>,
13 dataset: Arc<dyn Dataset<I>>,
14 batcher: Arc<dyn Batcher<B, I, O>>,
15 device: B::Device,
16 rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
17}
18
19impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
20 fn clone(&self) -> Self {
21 Self {
22 strategy: self.strategy.clone_dyn(),
23 dataset: self.dataset.clone(),
24 batcher: self.batcher.clone(),
25 device: self.device.clone(),
26 rng: self.rng.clone(),
27 }
28 }
29}
30
31impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
32 pub fn new(
47 strategy: Box<dyn BatchStrategy<I>>,
48 dataset: Arc<dyn Dataset<I>>,
49 batcher: Arc<dyn Batcher<B, I, O>>,
50 device: B::Device,
51 rng: Option<rand::rngs::StdRng>,
52 ) -> Self {
53 Self {
54 strategy,
55 dataset,
56 batcher,
57 device,
58 rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
59 }
60 }
61}
62
63struct BatchDataloaderIterator<B: Backend, I, O> {
65 current_index: usize,
66 strategy: Box<dyn BatchStrategy<I>>,
67 dataset: Arc<dyn Dataset<I>>,
68 batcher: Arc<dyn Batcher<B, I, O>>,
69 device: B::Device,
70}
71
72impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
73where
74 B: Backend,
75 I: Send + Sync + Clone + 'static,
76 O: Send + 'static,
77{
78 fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
79 let dataset = match &self.rng {
83 Some(rng) => Arc::new(ShuffledDataset::new(
84 self.dataset.clone(),
85 rng.lock().deref_mut(),
86 )),
87 None => self.dataset.clone(),
88 };
89 Box::new(BatchDataloaderIterator::new(
90 self.strategy.clone_dyn(),
91 dataset,
92 self.batcher.clone(),
93 self.device.clone(),
94 ))
95 }
96
97 fn num_items(&self) -> usize {
98 self.dataset.len()
99 }
100
101 fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
102 let rng = self.rng.as_ref().map(|rng| {
103 let rng = rng.lock();
104 rng.clone()
105 });
106 Arc::new(Self::new(
107 self.strategy.clone_dyn(),
108 self.dataset.clone(),
109 self.batcher.clone(),
110 device.clone(),
111 rng,
112 ))
113 }
114
115 fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
116 let rng = self.rng.as_ref().map(|rng| {
117 let rng = rng.lock();
118 rng.clone()
119 });
120 let dataloader = Self::new(
121 self.strategy.clone_dyn(),
122 Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
123 self.batcher.clone(),
124 self.device.clone(),
125 rng,
126 );
127 Arc::new(dataloader)
128 }
129}
130
131impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
132 pub fn new(
145 strategy: Box<dyn BatchStrategy<I>>,
146 dataset: Arc<dyn Dataset<I>>,
147 batcher: Arc<dyn Batcher<B, I, O>>,
148 device: B::Device,
149 ) -> Self {
150 BatchDataloaderIterator {
151 current_index: 0,
152 strategy,
153 dataset,
154 batcher,
155 device,
156 }
157 }
158}
159
160impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
161 type Item = O;
162
163 fn next(&mut self) -> Option<O> {
164 while let Some(item) = self.dataset.get(self.current_index) {
165 self.current_index += 1;
166 self.strategy.add(item);
167
168 if let Some(items) = self.strategy.batch(false) {
169 return Some(self.batcher.batch(items, &self.device));
170 }
171 }
172
173 if let Some(items) = self.strategy.batch(true) {
174 return Some(self.batcher.batch(items, &self.device));
175 }
176
177 None
178 }
179}
180
181impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
182 fn progress(&self) -> Progress {
183 Progress::new(self.current_index, self.dataset.len())
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use std::collections::HashSet;
190
191 use super::*;
192 use crate::data::dataloader::FixBatchStrategy;
193 use crate::data::dataloader::batcher::TestBatcher;
194 use crate::data::dataset::FakeDataset;
195
196 #[test]
197 fn test_batch_dataloader() {
198 let batcher = Arc::new(TestBatcher::new());
199 let dataset = Arc::new(FakeDataset::<String>::new(27));
200 let dataloader = BatchDataLoader::new(
201 Box::new(FixBatchStrategy::new(5)),
202 dataset.clone(),
203 batcher,
204 Default::default(),
205 None,
206 );
207
208 let mut items_dataset = HashSet::new();
209 let mut items_dataloader = HashSet::new();
210
211 for item in dataset.iter() {
212 items_dataset.insert(item);
213 }
214
215 for items in dataloader.iter() {
216 for item in items {
217 items_dataloader.insert(item);
218 }
219 }
220
221 assert_eq!(items_dataset, items_dataloader);
222 }
223
224 #[test]
225 fn test_batch_dataloader_slice() {
226 let batcher = Arc::new(TestBatcher::new());
227 let dataset = Arc::new(FakeDataset::<String>::new(27));
228 let dataloader = BatchDataLoader::new(
229 Box::new(FixBatchStrategy::new(5)),
230 dataset.clone(),
231 batcher,
232 Default::default(),
233 None,
234 );
235 let dataloader_slice = dataloader.slice(5, 15);
236
237 let mut items_dataloader = HashSet::new();
238 let mut items_dataloader_slice = HashSet::new();
239
240 let mut idx = 0;
241 for items in dataloader.iter() {
242 for item in items {
243 if (5..15).contains(&idx) {
244 items_dataloader.insert(item);
245 }
246 idx += 1;
247 }
248 }
249
250 for items in dataloader_slice.iter() {
251 for item in items {
252 items_dataloader_slice.insert(item);
253 }
254 }
255
256 assert_eq!(items_dataloader, items_dataloader_slice);
257 }
258}