1use std::sync::Arc;
7
8use arrow::{array::RecordBatch, compute::concat_batches};
9#[cfg(feature = "shuffle")]
10use rand::{seq::SliceRandom, SeedableRng};
11
12use crate::{dataset::Dataset, error::Result};
13
14#[derive(Debug)]
37pub struct DataLoader<D: Dataset> {
38 dataset: Arc<D>,
39 batch_size: usize,
40 #[allow(dead_code)] shuffle: bool,
42 drop_last: bool,
43 #[allow(dead_code)] seed: Option<u64>,
45}
46
47impl<D: Dataset> DataLoader<D> {
48 pub fn new(dataset: D) -> Self {
56 Self {
57 dataset: Arc::new(dataset),
58 batch_size: 1,
59 shuffle: false,
60 drop_last: false,
61 seed: None,
62 }
63 }
64
65 #[must_use]
75 pub fn batch_size(mut self, size: usize) -> Self {
76 self.batch_size = size.max(1);
77 self
78 }
79
80 #[cfg(feature = "shuffle")]
85 #[must_use]
86 pub fn shuffle(mut self, shuffle: bool) -> Self {
87 self.shuffle = shuffle;
88 self
89 }
90
91 #[must_use]
96 pub fn drop_last(mut self, drop_last: bool) -> Self {
97 self.drop_last = drop_last;
98 self
99 }
100
101 #[cfg(feature = "shuffle")]
106 #[must_use]
107 pub fn seed(mut self, seed: u64) -> Self {
108 self.seed = Some(seed);
109 self
110 }
111
112 pub fn get_batch_size(&self) -> usize {
114 self.batch_size
115 }
116
117 pub fn is_shuffle(&self) -> bool {
119 self.shuffle
120 }
121
122 pub fn is_drop_last(&self) -> bool {
124 self.drop_last
125 }
126
127 pub fn num_batches(&self) -> usize {
129 let len = self.dataset.len();
130 if self.drop_last {
131 len / self.batch_size
132 } else {
133 len.div_ceil(self.batch_size)
134 }
135 }
136
137 pub fn len(&self) -> usize {
139 self.dataset.len()
140 }
141
142 pub fn is_empty(&self) -> bool {
144 self.dataset.is_empty()
145 }
146}
147
148impl<D: Dataset> IntoIterator for DataLoader<D> {
149 type Item = RecordBatch;
150 type IntoIter = DataLoaderIterator<D>;
151
152 fn into_iter(self) -> Self::IntoIter {
153 let indices: Vec<usize> = (0..self.dataset.len()).collect();
154
155 #[cfg(feature = "shuffle")]
156 let shuffled_indices = if self.shuffle {
157 let mut indices = indices;
158 let mut rng = match self.seed {
159 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
160 None => rand::rngs::StdRng::from_entropy(),
161 };
162 indices.shuffle(&mut rng);
163 indices
164 } else {
165 indices
166 };
167
168 #[cfg(not(feature = "shuffle"))]
169 let shuffled_indices = indices;
170
171 DataLoaderIterator {
172 dataset: self.dataset,
173 batch_size: self.batch_size,
174 drop_last: self.drop_last,
175 indices: shuffled_indices,
176 position: 0,
177 }
178 }
179}
180
181pub struct DataLoaderIterator<D: Dataset> {
183 dataset: Arc<D>,
184 batch_size: usize,
185 drop_last: bool,
186 indices: Vec<usize>,
187 position: usize,
188}
189
190impl<D: Dataset> Iterator for DataLoaderIterator<D> {
191 type Item = RecordBatch;
192
193 fn next(&mut self) -> Option<Self::Item> {
194 if self.position >= self.indices.len() {
195 return None;
196 }
197
198 let remaining = self.indices.len() - self.position;
199 let batch_size = remaining.min(self.batch_size);
200
201 if self.drop_last && batch_size < self.batch_size {
203 return None;
204 }
205
206 let batch_indices = &self.indices[self.position..self.position + batch_size];
208 self.position += batch_size;
209
210 let rows: Vec<RecordBatch> = batch_indices
212 .iter()
213 .filter_map(|&idx| self.dataset.get(idx))
214 .collect();
215
216 if rows.is_empty() {
217 return None;
218 }
219
220 concat_batches(&self.dataset.schema(), &rows).ok()
222 }
223
224 fn size_hint(&self) -> (usize, Option<usize>) {
225 let remaining = self.indices.len().saturating_sub(self.position);
226 let batches = if self.drop_last {
227 remaining / self.batch_size
228 } else if remaining > 0 {
229 remaining.div_ceil(self.batch_size)
230 } else {
231 0
232 };
233 (batches, Some(batches))
234 }
235}
236
237#[derive(Debug, Default)]
239pub struct DataLoaderBuilder {
240 batch_size: Option<usize>,
241 shuffle: Option<bool>,
242 drop_last: Option<bool>,
243 seed: Option<u64>,
244}
245
246impl DataLoaderBuilder {
247 pub fn new() -> Self {
249 Self::default()
250 }
251
252 #[must_use]
254 pub fn batch_size(mut self, size: usize) -> Self {
255 self.batch_size = Some(size);
256 self
257 }
258
259 #[must_use]
261 pub fn shuffle(mut self, shuffle: bool) -> Self {
262 self.shuffle = Some(shuffle);
263 self
264 }
265
266 #[must_use]
268 pub fn drop_last(mut self, drop_last: bool) -> Self {
269 self.drop_last = Some(drop_last);
270 self
271 }
272
273 #[must_use]
275 pub fn seed(mut self, seed: u64) -> Self {
276 self.seed = Some(seed);
277 self
278 }
279
280 pub fn build<D: Dataset>(self, dataset: D) -> Result<DataLoader<D>> {
286 let batch_size = self.batch_size.unwrap_or(1);
287 if batch_size == 0 {
288 return Err(crate::error::Error::invalid_config(
289 "batch_size must be greater than 0",
290 ));
291 }
292
293 let mut loader = DataLoader::new(dataset).batch_size(batch_size);
294
295 #[cfg(feature = "shuffle")]
296 if let Some(shuffle) = self.shuffle {
297 loader = loader.shuffle(shuffle);
298 }
299 if let Some(drop_last) = self.drop_last {
300 loader = loader.drop_last(drop_last);
301 }
302 #[cfg(feature = "shuffle")]
303 if let Some(seed) = self.seed {
304 loader = loader.seed(seed);
305 }
306
307 Ok(loader)
308 }
309}
310
311#[cfg(test)]
312#[allow(
313 clippy::cast_possible_truncation,
314 clippy::cast_possible_wrap,
315 clippy::uninlined_format_args
316)]
317mod tests {
318 use std::collections::HashSet;
319
320 use arrow::{
321 array::{Int32Array, StringArray},
322 datatypes::{DataType, Field, Schema},
323 };
324
325 use super::*;
326 use crate::ArrowDataset;
327
328 fn create_test_dataset(rows: usize) -> ArrowDataset {
329 let schema = Arc::new(Schema::new(vec![
330 Field::new("id", DataType::Int32, false),
331 Field::new("value", DataType::Utf8, false),
332 ]));
333
334 let ids: Vec<i32> = (0..rows as i32).collect();
335 let values: Vec<String> = ids.iter().map(|i| format!("val_{}", i)).collect();
336
337 let batch = RecordBatch::try_new(
338 schema,
339 vec![
340 Arc::new(Int32Array::from(ids)),
341 Arc::new(StringArray::from(values)),
342 ],
343 )
344 .ok()
345 .unwrap_or_else(|| panic!("Should create batch"));
346
347 ArrowDataset::from_batch(batch)
348 .ok()
349 .unwrap_or_else(|| panic!("Should create dataset"))
350 }
351
352 #[test]
353 fn test_basic_iteration() {
354 let dataset = create_test_dataset(10);
355 let loader = DataLoader::new(dataset).batch_size(3);
356
357 let batches: Vec<RecordBatch> = loader.into_iter().collect();
358 assert_eq!(batches.len(), 4); let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
361 assert_eq!(total_rows, 10);
362 }
363
364 #[test]
365 fn test_drop_last() {
366 let dataset = create_test_dataset(10);
367 let loader = DataLoader::new(dataset).batch_size(3).drop_last(true);
368
369 let batches: Vec<RecordBatch> = loader.into_iter().collect();
370 assert_eq!(batches.len(), 3); for batch in &batches {
373 assert_eq!(batch.num_rows(), 3);
374 }
375 }
376
377 #[test]
378 fn test_shuffle_deterministic() {
379 let dataset = create_test_dataset(100);
380
381 let loader1 = DataLoader::new(dataset.clone())
382 .batch_size(10)
383 .shuffle(true)
384 .seed(42);
385 let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
386
387 let loader2 = DataLoader::new(dataset)
388 .batch_size(10)
389 .shuffle(true)
390 .seed(42);
391 let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
392
393 assert_eq!(batches1.len(), batches2.len());
395 for (b1, b2) in batches1.iter().zip(batches2.iter()) {
396 assert_eq!(b1.num_rows(), b2.num_rows());
397 }
398 }
399
400 #[test]
401 fn test_shuffle_different_seeds() {
402 let dataset = create_test_dataset(100);
403
404 let loader1 = DataLoader::new(dataset.clone())
405 .batch_size(100)
406 .shuffle(true)
407 .seed(42);
408 let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
409
410 let loader2 = DataLoader::new(dataset)
411 .batch_size(100)
412 .shuffle(true)
413 .seed(123);
414 let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
415
416 assert_eq!(batches1.len(), batches2.len());
419 }
420
421 #[test]
422 fn test_all_rows_covered() {
423 let dataset = create_test_dataset(25);
424 let loader = DataLoader::new(dataset)
425 .batch_size(7)
426 .shuffle(true)
427 .seed(99);
428
429 let mut seen_ids = HashSet::new();
430 for batch in loader {
431 let id_col = batch
432 .column(0)
433 .as_any()
434 .downcast_ref::<Int32Array>()
435 .unwrap_or_else(|| panic!("Should be Int32Array"));
436 for i in 0..id_col.len() {
437 seen_ids.insert(id_col.value(i));
438 }
439 }
440
441 assert_eq!(seen_ids.len(), 25);
442 for i in 0..25i32 {
443 assert!(seen_ids.contains(&i));
444 }
445 }
446
447 #[test]
448 fn test_num_batches() {
449 let dataset = create_test_dataset(10);
450
451 let loader = DataLoader::new(dataset.clone()).batch_size(3);
452 assert_eq!(loader.num_batches(), 4);
453
454 let loader = DataLoader::new(dataset).batch_size(3).drop_last(true);
455 assert_eq!(loader.num_batches(), 3);
456 }
457
458 #[test]
459 fn test_builder() {
460 let dataset = create_test_dataset(10);
461 let loader = DataLoaderBuilder::new()
462 .batch_size(5)
463 .shuffle(true)
464 .seed(42)
465 .build(dataset)
466 .ok()
467 .unwrap_or_else(|| panic!("Should build loader"));
468
469 assert_eq!(loader.get_batch_size(), 5);
470 assert!(loader.is_shuffle());
471 }
472
473 #[test]
474 fn test_builder_zero_batch_size_error() {
475 let dataset = create_test_dataset(10);
476 let result = DataLoaderBuilder::new().batch_size(0).build(dataset);
477 assert!(result.is_err());
478 }
479
480 #[test]
481 fn test_size_hint() {
482 let dataset = create_test_dataset(10);
483 let loader = DataLoader::new(dataset).batch_size(3);
484
485 let mut iter = loader.into_iter();
486 assert_eq!(iter.size_hint(), (4, Some(4)));
487
488 let _ = iter.next();
489 assert_eq!(iter.size_hint(), (3, Some(3)));
490 }
491
492 #[test]
493 fn test_getters() {
494 let dataset = create_test_dataset(10);
495 let loader = DataLoader::new(dataset)
496 .batch_size(5)
497 .shuffle(true)
498 .drop_last(true);
499
500 assert_eq!(loader.get_batch_size(), 5);
501 assert!(loader.is_shuffle());
502 assert!(loader.is_drop_last());
503 assert_eq!(loader.len(), 10);
504 assert!(!loader.is_empty());
505 }
506
507 #[test]
508 fn test_batch_size_min_one() {
509 let dataset = create_test_dataset(10);
510 let loader = DataLoader::new(dataset).batch_size(0);
511 assert_eq!(loader.get_batch_size(), 1);
512 }
513
514 #[test]
515 fn test_empty_dataset() {
516 let dataset = create_test_dataset(0);
517 let loader = DataLoader::new(dataset).batch_size(3);
518 let batches: Vec<RecordBatch> = loader.into_iter().collect();
519 assert!(batches.is_empty());
520 }
521
522 #[test]
523 fn test_empty_dataset_drop_last() {
524 let dataset = create_test_dataset(0);
525 let loader = DataLoader::new(dataset).batch_size(3).drop_last(true);
526 let batches: Vec<RecordBatch> = loader.into_iter().collect();
527 assert!(batches.is_empty());
528 }
529
530 #[test]
531 fn test_is_empty() {
532 let empty_dataset = create_test_dataset(0);
533 let loader_empty = DataLoader::new(empty_dataset);
534 assert!(loader_empty.is_empty());
535
536 let dataset = create_test_dataset(5);
537 let loader = DataLoader::new(dataset);
538 assert!(!loader.is_empty());
539 }
540
541 #[test]
542 fn test_len() {
543 let dataset = create_test_dataset(42);
544 let loader = DataLoader::new(dataset);
545 assert_eq!(loader.len(), 42);
546 }
547
548 #[test]
549 fn test_single_row_dataset() {
550 let dataset = create_test_dataset(1);
551 let loader = DataLoader::new(dataset).batch_size(5);
552 let batches: Vec<RecordBatch> = loader.into_iter().collect();
553 assert_eq!(batches.len(), 1);
554 assert_eq!(batches[0].num_rows(), 1);
555 }
556
557 #[test]
558 fn test_single_row_drop_last() {
559 let dataset = create_test_dataset(1);
560 let loader = DataLoader::new(dataset).batch_size(5).drop_last(true);
561 let batches: Vec<RecordBatch> = loader.into_iter().collect();
562 assert!(batches.is_empty());
564 }
565
566 #[test]
567 fn test_batch_size_equals_dataset_size() {
568 let dataset = create_test_dataset(10);
569 let loader = DataLoader::new(dataset).batch_size(10);
570 let batches: Vec<RecordBatch> = loader.into_iter().collect();
571 assert_eq!(batches.len(), 1);
572 assert_eq!(batches[0].num_rows(), 10);
573 }
574
575 #[test]
576 fn test_batch_size_larger_than_dataset() {
577 let dataset = create_test_dataset(5);
578 let loader = DataLoader::new(dataset).batch_size(100);
579 let batches: Vec<RecordBatch> = loader.into_iter().collect();
580 assert_eq!(batches.len(), 1);
581 assert_eq!(batches[0].num_rows(), 5);
582 }
583
584 #[test]
585 fn test_batch_size_larger_than_dataset_drop_last() {
586 let dataset = create_test_dataset(5);
587 let loader = DataLoader::new(dataset).batch_size(100).drop_last(true);
588 let batches: Vec<RecordBatch> = loader.into_iter().collect();
589 assert!(batches.is_empty());
591 }
592
593 #[test]
594 fn test_num_batches_with_drop_last() {
595 let dataset = create_test_dataset(10);
596
597 let loader_without_drop = DataLoader::new(dataset.clone()).batch_size(3);
598 assert_eq!(loader_without_drop.num_batches(), 4); let loader_with_drop = DataLoader::new(dataset).batch_size(3).drop_last(true);
601 assert_eq!(loader_with_drop.num_batches(), 3); }
603
604 #[test]
605 fn test_builder_all_options() {
606 let dataset = create_test_dataset(10);
607 let result = DataLoaderBuilder::new()
608 .batch_size(4)
609 .shuffle(true)
610 .drop_last(true)
611 .seed(42)
612 .build(dataset);
613
614 assert!(result.is_ok());
615 let loader = result.ok().unwrap();
616 assert_eq!(loader.get_batch_size(), 4);
617 assert!(loader.is_shuffle());
618 assert!(loader.is_drop_last());
619 }
620
621 #[test]
622 fn test_size_hint_empty_dataset() {
623 let dataset = create_test_dataset(0);
624 let loader = DataLoader::new(dataset).batch_size(3);
625 let iter = loader.into_iter();
626 assert_eq!(iter.size_hint(), (0, Some(0)));
627 }
628
629 #[test]
630 fn test_iterator_exhaustion() {
631 let dataset = create_test_dataset(5);
632 let loader = DataLoader::new(dataset).batch_size(2);
633 let mut iter = loader.into_iter();
634
635 assert!(iter.next().is_some());
637 assert!(iter.next().is_some());
638 assert!(iter.next().is_some());
639 assert!(iter.next().is_none());
641 assert!(iter.next().is_none());
643 }
644
645 #[test]
646 fn test_size_hint_during_iteration() {
647 let dataset = create_test_dataset(10);
648 let loader = DataLoader::new(dataset).batch_size(3);
649 let mut iter = loader.into_iter();
650
651 assert_eq!(iter.size_hint(), (4, Some(4)));
653
654 iter.next();
655 assert_eq!(iter.size_hint(), (3, Some(3)));
656
657 iter.next();
658 assert_eq!(iter.size_hint(), (2, Some(2)));
659
660 iter.next();
661 assert_eq!(iter.size_hint(), (1, Some(1)));
662
663 iter.next();
664 assert_eq!(iter.size_hint(), (0, Some(0)));
665 }
666
667 #[test]
668 fn test_debug_impl() {
669 let dataset = create_test_dataset(5);
670 let loader = DataLoader::new(dataset).batch_size(2);
671 let debug_str = format!("{:?}", loader);
672 assert!(debug_str.contains("DataLoader"));
673 assert!(debug_str.contains("batch_size: 2"));
674 }
675
676 #[test]
677 fn test_builder_debug_impl() {
678 let builder = DataLoaderBuilder::new().batch_size(10).drop_last(true);
679 let debug_str = format!("{:?}", builder);
680 assert!(debug_str.contains("DataLoaderBuilder"));
681 }
682}