burn_core/data/dataloader/
batch.rs1use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};
2use burn_dataset::{
3 Dataset,
4 transform::{PartialDataset, ShuffledDataset},
5};
6use burn_tensor::backend::Backend;
7use rand::SeedableRng;
8use std::ops::DerefMut;
9use std::sync::Arc;
10
11pub struct BatchDataLoader<B: Backend, I, O> {
13 strategy: Box<dyn BatchStrategy<I>>,
14 dataset: Arc<dyn Dataset<I>>,
15 batcher: Arc<dyn Batcher<B, I, O>>,
16 device: B::Device,
17 rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,
18}
19
20impl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {
21 fn clone(&self) -> Self {
22 Self {
23 strategy: self.strategy.clone_dyn(),
24 dataset: self.dataset.clone(),
25 batcher: self.batcher.clone(),
26 device: self.device.clone(),
27 rng: self.rng.clone(),
28 }
29 }
30}
31
32impl<B: Backend, I, O> BatchDataLoader<B, I, O> {
33 pub fn new(
48 strategy: Box<dyn BatchStrategy<I>>,
49 dataset: Arc<dyn Dataset<I>>,
50 batcher: Arc<dyn Batcher<B, I, O>>,
51 device: B::Device,
52 rng: Option<rand::rngs::StdRng>,
53 ) -> Self {
54 Self {
55 strategy,
56 dataset,
57 batcher,
58 device,
59 rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),
60 }
61 }
62}
63
64struct BatchDataloaderIterator<B: Backend, I, O> {
66 current_index: usize,
67 strategy: Box<dyn BatchStrategy<I>>,
68 dataset: Arc<dyn Dataset<I>>,
69 batcher: Arc<dyn Batcher<B, I, O>>,
70 device: B::Device,
71}
72
73impl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>
74where
75 B: Backend,
76 I: Send + Sync + Clone + 'static,
77 O: Send + 'static,
78{
79 fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {
80 let dataset = match &self.rng {
84 Some(rng) => Arc::new(ShuffledDataset::new(
85 self.dataset.clone(),
86 rng.lock().deref_mut(),
87 )),
88 None => self.dataset.clone(),
89 };
90 Box::new(BatchDataloaderIterator::new(
91 self.strategy.clone_dyn(),
92 dataset,
93 self.batcher.clone(),
94 self.device.clone(),
95 ))
96 }
97
98 fn num_items(&self) -> usize {
99 self.dataset.len()
100 }
101
102 fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {
103 let rng = self.rng.as_ref().map(|rng| {
104 let mut rng = rng.lock();
105 rng.fork()
106 });
107 Arc::new(Self::new(
108 self.strategy.clone_dyn(),
109 self.dataset.clone(),
110 self.batcher.clone(),
111 device.clone(),
112 rng,
113 ))
114 }
115
116 fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {
117 let rng = self.rng.as_ref().map(|rng| {
118 let mut rng = rng.lock();
119 rng.fork()
120 });
121 let dataloader = Self::new(
122 self.strategy.clone_dyn(),
123 Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),
124 self.batcher.clone(),
125 self.device.clone(),
126 rng,
127 );
128 Arc::new(dataloader)
129 }
130}
131
132impl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {
133 pub fn new(
146 strategy: Box<dyn BatchStrategy<I>>,
147 dataset: Arc<dyn Dataset<I>>,
148 batcher: Arc<dyn Batcher<B, I, O>>,
149 device: B::Device,
150 ) -> Self {
151 BatchDataloaderIterator {
152 current_index: 0,
153 strategy,
154 dataset,
155 batcher,
156 device,
157 }
158 }
159}
160
161impl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {
162 type Item = O;
163
164 fn next(&mut self) -> Option<O> {
165 while let Some(item) = self.dataset.get(self.current_index) {
166 self.current_index += 1;
167 self.strategy.add(item);
168
169 if let Some(items) = self.strategy.batch(false) {
170 return Some(self.batcher.batch(items, &self.device));
171 }
172 }
173
174 if let Some(items) = self.strategy.batch(true) {
175 return Some(self.batcher.batch(items, &self.device));
176 }
177
178 None
179 }
180}
181
182impl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {
183 fn progress(&self) -> Progress {
184 Progress::new(self.current_index, self.dataset.len())
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use std::collections::HashSet;
191
192 use super::*;
193 use crate::data::dataloader::FixBatchStrategy;
194 use crate::data::dataloader::batcher::TestBatcher;
195 use crate::data::dataset::FakeDataset;
196
197 #[test]
198 fn test_batch_dataloader() {
199 let batcher = Arc::new(TestBatcher::new());
200 let dataset = Arc::new(FakeDataset::<String>::new(27));
201 let dataloader = BatchDataLoader::new(
202 Box::new(FixBatchStrategy::new(5)),
203 dataset.clone(),
204 batcher,
205 Default::default(),
206 None,
207 );
208
209 let mut items_dataset = HashSet::new();
210 let mut items_dataloader = HashSet::new();
211
212 for item in dataset.iter() {
213 items_dataset.insert(item);
214 }
215
216 for items in dataloader.iter() {
217 for item in items {
218 items_dataloader.insert(item);
219 }
220 }
221
222 assert_eq!(items_dataset, items_dataloader);
223 }
224
225 #[test]
226 fn test_batch_dataloader_slice() {
227 let batcher = Arc::new(TestBatcher::new());
228 let dataset = Arc::new(FakeDataset::<String>::new(27));
229 let dataloader = BatchDataLoader::new(
230 Box::new(FixBatchStrategy::new(5)),
231 dataset.clone(),
232 batcher,
233 Default::default(),
234 None,
235 );
236 let dataloader_slice = dataloader.slice(5, 15);
237
238 let mut items_dataloader = HashSet::new();
239 let mut items_dataloader_slice = HashSet::new();
240
241 let mut idx = 0;
242 for items in dataloader.iter() {
243 for item in items {
244 if (5..15).contains(&idx) {
245 items_dataloader.insert(item);
246 }
247 idx += 1;
248 }
249 }
250
251 for items in dataloader_slice.iter() {
252 for item in items {
253 items_dataloader_slice.insert(item);
254 }
255 }
256
257 assert_eq!(items_dataloader, items_dataloader_slice);
258 }
259}