1use crate::collate::{stack_tensors, Collate};
10use crate::dataset::Dataset;
11use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
12use axonml_tensor::Tensor;
13use std::marker::PhantomData;
14
15#[derive(Debug, Clone)]
21pub struct Batch {
22 pub data: Tensor<f32>,
24 pub targets: Tensor<f32>,
26 pub size: usize,
28}
29
30impl Batch {
31 #[must_use] pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
33 let size = data.shape()[0];
34 Self {
35 data,
36 targets,
37 size,
38 }
39 }
40
41 #[must_use] pub fn len(&self) -> usize {
43 self.size
44 }
45
46 #[must_use] pub fn is_empty(&self) -> bool {
48 self.size == 0
49 }
50}
51
52pub struct DataLoader<D>
60where
61 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
62{
63 dataset: D,
65 batch_size: usize,
67 shuffle: bool,
69 drop_last: bool,
71 num_workers: usize,
73}
74
75impl<D> DataLoader<D>
76where
77 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
78{
79 pub fn new(dataset: D, batch_size: usize) -> Self {
81 Self {
82 dataset,
83 batch_size,
84 shuffle: false,
85 drop_last: false,
86 num_workers: 0,
87 }
88 }
89
90 pub fn shuffle(mut self, shuffle: bool) -> Self {
92 self.shuffle = shuffle;
93 self
94 }
95
96 pub fn drop_last(mut self, drop_last: bool) -> Self {
98 self.drop_last = drop_last;
99 self
100 }
101
102 pub fn num_workers(mut self, num_workers: usize) -> Self {
104 self.num_workers = num_workers;
105 self
106 }
107
108 pub fn batch_size(&self) -> usize {
110 self.batch_size
111 }
112
113 pub fn len(&self) -> usize {
115 let total = self.dataset.len();
116 if self.drop_last {
117 total / self.batch_size
118 } else {
119 total.div_ceil(self.batch_size)
120 }
121 }
122
123 pub fn is_empty(&self) -> bool {
125 self.dataset.is_empty()
126 }
127
128 pub fn dataset_len(&self) -> usize {
130 self.dataset.len()
131 }
132
133 pub fn iter(&self) -> DataLoaderIter<'_, D> {
135 let indices: Vec<usize> = if self.shuffle {
136 let sampler = RandomSampler::new(self.dataset.len());
137 sampler.iter().collect()
138 } else {
139 let sampler = SequentialSampler::new(self.dataset.len());
140 sampler.iter().collect()
141 };
142
143 DataLoaderIter {
144 dataset: &self.dataset,
145 indices,
146 batch_size: self.batch_size,
147 drop_last: self.drop_last,
148 position: 0,
149 }
150 }
151}
152
153pub struct DataLoaderIter<'a, D>
159where
160 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
161{
162 dataset: &'a D,
163 indices: Vec<usize>,
164 batch_size: usize,
165 drop_last: bool,
166 position: usize,
167}
168
169impl<D> Iterator for DataLoaderIter<'_, D>
170where
171 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
172{
173 type Item = Batch;
174
175 fn next(&mut self) -> Option<Self::Item> {
176 if self.position >= self.indices.len() {
177 return None;
178 }
179
180 let end = (self.position + self.batch_size).min(self.indices.len());
181 let batch_indices = &self.indices[self.position..end];
182
183 if batch_indices.len() < self.batch_size && self.drop_last {
185 return None;
186 }
187
188 let mut data_samples = Vec::with_capacity(batch_indices.len());
190 let mut target_samples = Vec::with_capacity(batch_indices.len());
191
192 for &idx in batch_indices {
193 if let Some((x, y)) = self.dataset.get(idx) {
194 data_samples.push(x);
195 target_samples.push(y);
196 }
197 }
198
199 if data_samples.is_empty() {
200 return None;
201 }
202
203 let data = stack_tensors(&data_samples);
205 let targets = stack_tensors(&target_samples);
206
207 self.position = end;
208
209 Some(Batch::new(data, targets))
210 }
211}
212
213impl<D> DataLoaderIter<'_, D>
214where
215 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
216{
217 #[must_use] pub fn remaining(&self) -> usize {
219 let remaining_samples = self.indices.len().saturating_sub(self.position);
220 if self.drop_last {
221 remaining_samples / self.batch_size
222 } else {
223 remaining_samples.div_ceil(self.batch_size)
224 }
225 }
226}
227
228pub struct GenericDataLoader<D, C, T>
234where
235 D: Dataset<Item = T>,
236 C: Collate<T>,
237 T: Send,
238{
239 dataset: D,
240 collate_fn: C,
241 batch_size: usize,
242 shuffle: bool,
243 drop_last: bool,
244 _phantom: PhantomData<T>,
245}
246
247impl<D, C, T> GenericDataLoader<D, C, T>
248where
249 D: Dataset<Item = T>,
250 C: Collate<T>,
251 T: Send,
252{
253 pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
255 Self {
256 dataset,
257 collate_fn,
258 batch_size,
259 shuffle: false,
260 drop_last: false,
261 _phantom: PhantomData,
262 }
263 }
264
265 pub fn shuffle(mut self, shuffle: bool) -> Self {
267 self.shuffle = shuffle;
268 self
269 }
270
271 pub fn drop_last(mut self, drop_last: bool) -> Self {
273 self.drop_last = drop_last;
274 self
275 }
276
277 pub fn len(&self) -> usize {
279 let total = self.dataset.len();
280 if self.drop_last {
281 total / self.batch_size
282 } else {
283 total.div_ceil(self.batch_size)
284 }
285 }
286
287 pub fn is_empty(&self) -> bool {
289 self.dataset.is_empty()
290 }
291
292 pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
294 let indices: Vec<usize> = if self.shuffle {
295 let sampler = RandomSampler::new(self.dataset.len());
296 sampler.iter().collect()
297 } else {
298 (0..self.dataset.len()).collect()
299 };
300
301 GenericDataLoaderIter {
302 dataset: &self.dataset,
303 collate_fn: &self.collate_fn,
304 indices,
305 batch_size: self.batch_size,
306 drop_last: self.drop_last,
307 position: 0,
308 _phantom: PhantomData,
309 }
310 }
311}
312
313pub struct GenericDataLoaderIter<'a, D, C, T>
315where
316 D: Dataset<Item = T>,
317 C: Collate<T>,
318 T: Send,
319{
320 dataset: &'a D,
321 collate_fn: &'a C,
322 indices: Vec<usize>,
323 batch_size: usize,
324 drop_last: bool,
325 position: usize,
326 _phantom: PhantomData<T>,
327}
328
329impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
330where
331 D: Dataset<Item = T>,
332 C: Collate<T>,
333 T: Send,
334{
335 type Item = C::Output;
336
337 fn next(&mut self) -> Option<Self::Item> {
338 if self.position >= self.indices.len() {
339 return None;
340 }
341
342 let end = (self.position + self.batch_size).min(self.indices.len());
343 let batch_indices = &self.indices[self.position..end];
344
345 if batch_indices.len() < self.batch_size && self.drop_last {
346 return None;
347 }
348
349 let samples: Vec<T> = batch_indices
351 .iter()
352 .filter_map(|&idx| self.dataset.get(idx))
353 .collect();
354
355 if samples.is_empty() {
356 return None;
357 }
358
359 self.position = end;
360
361 Some(self.collate_fn.collate(samples))
362 }
363}
364
365#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::collate::DefaultCollate;
373 use crate::dataset::TensorDataset;
374
375 fn create_test_dataset(size: usize) -> TensorDataset {
376 let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
377 let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
378
379 let x = Tensor::from_vec(data, &[size, 2]).unwrap();
380 let y = Tensor::from_vec(targets, &[size]).unwrap();
381
382 TensorDataset::new(x, y)
383 }
384
385 #[test]
386 fn test_dataloader_basic() {
387 let dataset = create_test_dataset(10);
388 let loader = DataLoader::new(dataset, 3);
389
390 assert_eq!(loader.batch_size(), 3);
391 assert_eq!(loader.len(), 4); let batches: Vec<Batch> = loader.iter().collect();
394 assert_eq!(batches.len(), 4);
395
396 assert_eq!(batches[0].len(), 3);
398 assert_eq!(batches[1].len(), 3);
399 assert_eq!(batches[2].len(), 3);
400 assert_eq!(batches[3].len(), 1);
401 }
402
403 #[test]
404 fn test_dataloader_drop_last() {
405 let dataset = create_test_dataset(10);
406 let loader = DataLoader::new(dataset, 3).drop_last(true);
407
408 assert_eq!(loader.len(), 3); let batches: Vec<Batch> = loader.iter().collect();
411 assert_eq!(batches.len(), 3);
412
413 for batch in &batches {
415 assert_eq!(batch.len(), 3);
416 }
417 }
418
419 #[test]
420 fn test_dataloader_shuffle() {
421 let dataset = create_test_dataset(100);
422 let loader = DataLoader::new(dataset, 10).shuffle(true);
423
424 let batch1: Vec<Batch> = loader.iter().take(1).collect();
426 let batch2: Vec<Batch> = loader.iter().take(1).collect();
427
428 assert!(!batch1.is_empty());
431 assert!(!batch2.is_empty());
432 }
433
434 #[test]
435 fn test_dataloader_exact_batches() {
436 let dataset = create_test_dataset(9);
437 let loader = DataLoader::new(dataset, 3);
438
439 let batches: Vec<Batch> = loader.iter().collect();
440 assert_eq!(batches.len(), 3);
441
442 for batch in &batches {
443 assert_eq!(batch.len(), 3);
444 }
445 }
446
447 #[test]
448 fn test_batch_struct() {
449 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
450 let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
451
452 let batch = Batch::new(data, targets);
453 assert_eq!(batch.len(), 2);
454 assert!(!batch.is_empty());
455 }
456
457 #[test]
458 fn test_dataloader_empty() {
459 let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
460 let y = Tensor::from_vec(vec![], &[0]).unwrap();
461 let dataset = TensorDataset::new(x, y);
462 let loader = DataLoader::new(dataset, 3);
463
464 assert!(loader.is_empty());
465 let batches: Vec<Batch> = loader.iter().collect();
466 assert!(batches.is_empty());
467 }
468
469 #[test]
470 fn test_dataloader_single_item() {
471 let dataset = create_test_dataset(1);
472 let loader = DataLoader::new(dataset, 3);
473
474 let batches: Vec<Batch> = loader.iter().collect();
475 assert_eq!(batches.len(), 1);
476 assert_eq!(batches[0].len(), 1);
477 }
478
479 #[test]
480 fn test_dataloader_iteration_order() {
481 let dataset = create_test_dataset(6);
482 let loader = DataLoader::new(dataset, 2).shuffle(false);
483
484 let batches: Vec<Batch> = loader.iter().collect();
485
486 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
488 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
489 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
490 }
491
492 #[test]
493 fn test_generic_dataloader() {
494 let dataset = create_test_dataset(6);
495 let collate = DefaultCollate::new();
496 let loader = GenericDataLoader::new(dataset, collate, 2);
497
498 let batches: Vec<_> = loader.iter().collect();
499 assert_eq!(batches.len(), 3);
500 }
501
502 #[test]
503 fn test_dataloader_remaining() {
504 let dataset = create_test_dataset(10);
505 let loader = DataLoader::new(dataset, 3);
506
507 let mut iter = loader.iter();
508 assert_eq!(iter.remaining(), 4);
509
510 iter.next();
511 assert_eq!(iter.remaining(), 3);
512
513 iter.next();
514 assert_eq!(iter.remaining(), 2);
515 }
516}