1use crate::collate::{stack_tensors, Collate};
10use crate::dataset::Dataset;
11use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
12use axonml_tensor::Tensor;
13use rayon::prelude::*;
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
22pub struct Batch {
23 pub data: Tensor<f32>,
25 pub targets: Tensor<f32>,
27 pub size: usize,
29}
30
31impl Batch {
32 #[must_use]
34 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
35 let size = data.shape()[0];
36 Self {
37 data,
38 targets,
39 size,
40 }
41 }
42
43 #[must_use]
45 pub fn len(&self) -> usize {
46 self.size
47 }
48
49 #[must_use]
51 pub fn is_empty(&self) -> bool {
52 self.size == 0
53 }
54}
55
56pub struct DataLoader<D>
64where
65 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
66{
67 dataset: D,
69 batch_size: usize,
71 shuffle: bool,
73 drop_last: bool,
75 num_workers: usize,
77}
78
79impl<D> DataLoader<D>
80where
81 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
82{
83 pub fn new(dataset: D, batch_size: usize) -> Self {
85 Self {
86 dataset,
87 batch_size,
88 shuffle: false,
89 drop_last: false,
90 num_workers: 0,
91 }
92 }
93
94 pub fn shuffle(mut self, shuffle: bool) -> Self {
96 self.shuffle = shuffle;
97 self
98 }
99
100 pub fn drop_last(mut self, drop_last: bool) -> Self {
102 self.drop_last = drop_last;
103 self
104 }
105
106 pub fn num_workers(mut self, num_workers: usize) -> Self {
108 self.num_workers = num_workers;
109 self
110 }
111
112 pub fn batch_size(&self) -> usize {
114 self.batch_size
115 }
116
117 pub fn len(&self) -> usize {
119 let total = self.dataset.len();
120 if self.drop_last {
121 total / self.batch_size
122 } else {
123 total.div_ceil(self.batch_size)
124 }
125 }
126
127 pub fn is_empty(&self) -> bool {
129 self.dataset.is_empty()
130 }
131
132 pub fn dataset_len(&self) -> usize {
134 self.dataset.len()
135 }
136
137 pub fn iter(&self) -> DataLoaderIter<'_, D> {
139 let indices: Vec<usize> = if self.shuffle {
140 let sampler = RandomSampler::new(self.dataset.len());
141 sampler.iter().collect()
142 } else {
143 let sampler = SequentialSampler::new(self.dataset.len());
144 sampler.iter().collect()
145 };
146
147 DataLoaderIter {
148 dataset: &self.dataset,
149 indices,
150 batch_size: self.batch_size,
151 drop_last: self.drop_last,
152 position: 0,
153 num_workers: self.num_workers,
154 }
155 }
156}
157
158pub struct DataLoaderIter<'a, D>
164where
165 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
166{
167 dataset: &'a D,
168 indices: Vec<usize>,
169 batch_size: usize,
170 drop_last: bool,
171 position: usize,
172 num_workers: usize,
173}
174
175impl<D> Iterator for DataLoaderIter<'_, D>
176where
177 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
178{
179 type Item = Batch;
180
181 fn next(&mut self) -> Option<Self::Item> {
182 if self.position >= self.indices.len() {
183 return None;
184 }
185
186 let end = (self.position + self.batch_size).min(self.indices.len());
187 let batch_indices = &self.indices[self.position..end];
188
189 if batch_indices.len() < self.batch_size && self.drop_last {
191 return None;
192 }
193
194 let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
196 batch_indices
198 .par_iter()
199 .filter_map(|&idx| self.dataset.get(idx))
200 .collect()
201 } else {
202 batch_indices
204 .iter()
205 .filter_map(|&idx| self.dataset.get(idx))
206 .collect()
207 };
208
209 if samples.is_empty() {
210 return None;
211 }
212
213 let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
215 let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
216
217 let data = stack_tensors(&data_samples);
219 let targets = stack_tensors(&target_samples);
220
221 self.position = end;
222
223 Some(Batch::new(data, targets))
224 }
225}
226
227impl<D> DataLoaderIter<'_, D>
228where
229 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
230{
231 #[must_use]
233 pub fn remaining(&self) -> usize {
234 let remaining_samples = self.indices.len().saturating_sub(self.position);
235 if self.drop_last {
236 remaining_samples / self.batch_size
237 } else {
238 remaining_samples.div_ceil(self.batch_size)
239 }
240 }
241}
242
243pub struct GenericDataLoader<D, C, T>
249where
250 D: Dataset<Item = T>,
251 C: Collate<T>,
252 T: Send,
253{
254 dataset: D,
255 collate_fn: C,
256 batch_size: usize,
257 shuffle: bool,
258 drop_last: bool,
259 num_workers: usize,
260 _phantom: PhantomData<T>,
261}
262
263impl<D, C, T> GenericDataLoader<D, C, T>
264where
265 D: Dataset<Item = T>,
266 C: Collate<T>,
267 T: Send,
268{
269 pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
271 Self {
272 dataset,
273 collate_fn,
274 batch_size,
275 shuffle: false,
276 drop_last: false,
277 num_workers: 0,
278 _phantom: PhantomData,
279 }
280 }
281
282 pub fn num_workers(mut self, num_workers: usize) -> Self {
284 self.num_workers = num_workers;
285 self
286 }
287
288 pub fn shuffle(mut self, shuffle: bool) -> Self {
290 self.shuffle = shuffle;
291 self
292 }
293
294 pub fn drop_last(mut self, drop_last: bool) -> Self {
296 self.drop_last = drop_last;
297 self
298 }
299
300 pub fn len(&self) -> usize {
302 let total = self.dataset.len();
303 if self.drop_last {
304 total / self.batch_size
305 } else {
306 total.div_ceil(self.batch_size)
307 }
308 }
309
310 pub fn is_empty(&self) -> bool {
312 self.dataset.is_empty()
313 }
314
315 pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
317 let indices: Vec<usize> = if self.shuffle {
318 let sampler = RandomSampler::new(self.dataset.len());
319 sampler.iter().collect()
320 } else {
321 (0..self.dataset.len()).collect()
322 };
323
324 GenericDataLoaderIter {
325 dataset: &self.dataset,
326 collate_fn: &self.collate_fn,
327 indices,
328 batch_size: self.batch_size,
329 drop_last: self.drop_last,
330 position: 0,
331 num_workers: self.num_workers,
332 _phantom: PhantomData,
333 }
334 }
335}
336
337pub struct GenericDataLoaderIter<'a, D, C, T>
339where
340 D: Dataset<Item = T>,
341 C: Collate<T>,
342 T: Send,
343{
344 dataset: &'a D,
345 collate_fn: &'a C,
346 indices: Vec<usize>,
347 batch_size: usize,
348 drop_last: bool,
349 position: usize,
350 num_workers: usize,
351 _phantom: PhantomData<T>,
352}
353
354impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
355where
356 D: Dataset<Item = T>,
357 C: Collate<T>,
358 T: Send + Sync,
359{
360 type Item = C::Output;
361
362 fn next(&mut self) -> Option<Self::Item> {
363 if self.position >= self.indices.len() {
364 return None;
365 }
366
367 let end = (self.position + self.batch_size).min(self.indices.len());
368 let batch_indices = &self.indices[self.position..end];
369
370 if batch_indices.len() < self.batch_size && self.drop_last {
371 return None;
372 }
373
374 let samples: Vec<T> = if self.num_workers > 0 {
376 batch_indices
377 .par_iter()
378 .filter_map(|&idx| self.dataset.get(idx))
379 .collect()
380 } else {
381 batch_indices
382 .iter()
383 .filter_map(|&idx| self.dataset.get(idx))
384 .collect()
385 };
386
387 if samples.is_empty() {
388 return None;
389 }
390
391 self.position = end;
392
393 Some(self.collate_fn.collate(samples))
394 }
395}
396
397#[cfg(test)]
402mod tests {
403 use super::*;
404 use crate::collate::DefaultCollate;
405 use crate::dataset::TensorDataset;
406
407 fn create_test_dataset(size: usize) -> TensorDataset {
408 let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
409 let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
410
411 let x = Tensor::from_vec(data, &[size, 2]).unwrap();
412 let y = Tensor::from_vec(targets, &[size]).unwrap();
413
414 TensorDataset::new(x, y)
415 }
416
417 #[test]
418 fn test_dataloader_basic() {
419 let dataset = create_test_dataset(10);
420 let loader = DataLoader::new(dataset, 3);
421
422 assert_eq!(loader.batch_size(), 3);
423 assert_eq!(loader.len(), 4); let batches: Vec<Batch> = loader.iter().collect();
426 assert_eq!(batches.len(), 4);
427
428 assert_eq!(batches[0].len(), 3);
430 assert_eq!(batches[1].len(), 3);
431 assert_eq!(batches[2].len(), 3);
432 assert_eq!(batches[3].len(), 1);
433 }
434
435 #[test]
436 fn test_dataloader_drop_last() {
437 let dataset = create_test_dataset(10);
438 let loader = DataLoader::new(dataset, 3).drop_last(true);
439
440 assert_eq!(loader.len(), 3); let batches: Vec<Batch> = loader.iter().collect();
443 assert_eq!(batches.len(), 3);
444
445 for batch in &batches {
447 assert_eq!(batch.len(), 3);
448 }
449 }
450
451 #[test]
452 fn test_dataloader_shuffle() {
453 let dataset = create_test_dataset(100);
454 let loader = DataLoader::new(dataset, 10).shuffle(true);
455
456 let batch1: Vec<Batch> = loader.iter().take(1).collect();
458 let batch2: Vec<Batch> = loader.iter().take(1).collect();
459
460 assert!(!batch1.is_empty());
463 assert!(!batch2.is_empty());
464 }
465
466 #[test]
467 fn test_dataloader_exact_batches() {
468 let dataset = create_test_dataset(9);
469 let loader = DataLoader::new(dataset, 3);
470
471 let batches: Vec<Batch> = loader.iter().collect();
472 assert_eq!(batches.len(), 3);
473
474 for batch in &batches {
475 assert_eq!(batch.len(), 3);
476 }
477 }
478
479 #[test]
480 fn test_batch_struct() {
481 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
482 let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
483
484 let batch = Batch::new(data, targets);
485 assert_eq!(batch.len(), 2);
486 assert!(!batch.is_empty());
487 }
488
489 #[test]
490 fn test_dataloader_empty() {
491 let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
492 let y = Tensor::from_vec(vec![], &[0]).unwrap();
493 let dataset = TensorDataset::new(x, y);
494 let loader = DataLoader::new(dataset, 3);
495
496 assert!(loader.is_empty());
497 let batches: Vec<Batch> = loader.iter().collect();
498 assert!(batches.is_empty());
499 }
500
501 #[test]
502 fn test_dataloader_single_item() {
503 let dataset = create_test_dataset(1);
504 let loader = DataLoader::new(dataset, 3);
505
506 let batches: Vec<Batch> = loader.iter().collect();
507 assert_eq!(batches.len(), 1);
508 assert_eq!(batches[0].len(), 1);
509 }
510
511 #[test]
512 fn test_dataloader_iteration_order() {
513 let dataset = create_test_dataset(6);
514 let loader = DataLoader::new(dataset, 2).shuffle(false);
515
516 let batches: Vec<Batch> = loader.iter().collect();
517
518 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
520 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
521 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
522 }
523
524 #[test]
525 fn test_generic_dataloader() {
526 let dataset = create_test_dataset(6);
527 let collate = DefaultCollate::new();
528 let loader = GenericDataLoader::new(dataset, collate, 2);
529
530 let batches: Vec<_> = loader.iter().collect();
531 assert_eq!(batches.len(), 3);
532 }
533
534 #[test]
535 fn test_dataloader_remaining() {
536 let dataset = create_test_dataset(10);
537 let loader = DataLoader::new(dataset, 3);
538
539 let mut iter = loader.iter();
540 assert_eq!(iter.remaining(), 4);
541
542 iter.next();
543 assert_eq!(iter.remaining(), 3);
544
545 iter.next();
546 assert_eq!(iter.remaining(), 2);
547 }
548
549 #[test]
550 fn test_parallel_dataloader() {
551 let dataset = create_test_dataset(100);
552 let loader = DataLoader::new(dataset, 10).num_workers(4);
553
554 let batches: Vec<Batch> = loader.iter().collect();
555 assert_eq!(batches.len(), 10);
556
557 let total_samples: usize = batches.iter().map(|b| b.len()).sum();
559 assert_eq!(total_samples, 100);
560 }
561
562 #[test]
563 fn test_parallel_vs_sequential_equivalence() {
564 let dataset_seq = create_test_dataset(50);
566 let dataset_par = create_test_dataset(50);
567
568 let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
570 let batches_seq: Vec<Batch> = loader_seq.iter().collect();
571
572 let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
574 let batches_par: Vec<Batch> = loader_par.iter().collect();
575
576 assert_eq!(batches_seq.len(), batches_par.len());
578
579 for i in 0..batches_seq.len() {
581 assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
582 assert_eq!(
583 batches_seq[i].targets.to_vec(),
584 batches_par[i].targets.to_vec()
585 );
586 }
587 }
588
589 #[test]
590 fn test_parallel_dataloader_drop_last() {
591 let dataset = create_test_dataset(95);
592 let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
593
594 let batches: Vec<Batch> = loader.iter().collect();
595 assert_eq!(batches.len(), 9); for batch in &batches {
598 assert_eq!(batch.len(), 10);
599 }
600 }
601
602 #[test]
603 fn test_parallel_generic_dataloader() {
604 let dataset = create_test_dataset(60);
605 let collate = DefaultCollate::new();
606 let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
607
608 let batches: Vec<_> = loader.iter().collect();
609 assert_eq!(batches.len(), 6);
610 }
611}