Skip to main content

alimentar/
dataloader.rs

1//! DataLoader for batched iteration over datasets.
2//!
3//! The [`DataLoader`] provides configurable batch iteration with support
4//! for shuffling and dropping incomplete last batches.
5
6use std::sync::Arc;
7
8use arrow::{array::RecordBatch, compute::concat_batches};
9#[cfg(feature = "shuffle")]
10use rand::{seq::SliceRandom, SeedableRng};
11
12use crate::{dataset::Dataset, error::Result};
13
14/// A data loader that provides batched iteration over a dataset.
15///
16/// The DataLoader wraps a dataset and provides:
17/// - Configurable batch sizes
18/// - Optional shuffling with reproducible seeds
19/// - Option to drop incomplete final batches
20///
21/// # Example
22///
23/// ```no_run
24/// use alimentar::{ArrowDataset, DataLoader};
25///
26/// let dataset = ArrowDataset::from_parquet("data.parquet").unwrap();
27/// let loader = DataLoader::new(dataset)
28///     .batch_size(32)
29///     .shuffle(true)
30///     .seed(42);
31///
32/// for batch in loader {
33///     println!("Processing batch with {} rows", batch.num_rows());
34/// }
35/// ```
36#[derive(Debug)]
37pub struct DataLoader<D: Dataset> {
38    dataset: Arc<D>,
39    batch_size: usize,
40    #[allow(dead_code)] // Used only with shuffle feature
41    shuffle: bool,
42    drop_last: bool,
43    #[allow(dead_code)] // Used only with shuffle feature
44    seed: Option<u64>,
45}
46
47impl<D: Dataset> DataLoader<D> {
48    /// Creates a new DataLoader wrapping the given dataset.
49    ///
50    /// Default configuration:
51    /// - batch_size: 1
52    /// - shuffle: false
53    /// - drop_last: false
54    /// - seed: None (random)
55    pub fn new(dataset: D) -> Self {
56        Self {
57            dataset: Arc::new(dataset),
58            batch_size: 1,
59            shuffle: false,
60            drop_last: false,
61            seed: None,
62        }
63    }
64
65    /// Sets the batch size.
66    ///
67    /// Each iteration will yield a RecordBatch with at most this many rows.
68    /// The final batch may have fewer rows unless `drop_last` is set.
69    ///
70    /// #[requires(true)]
71    /// #[ensures(result.batch_size >= 1)]
72    /// #[ensures(size >= 1 ==> result.batch_size == size)]
73    /// #[ensures(size == 0 ==> result.batch_size == 1)]
74    #[must_use]
75    pub fn batch_size(mut self, size: usize) -> Self {
76        self.batch_size = size.max(1);
77        self
78    }
79
80    /// Enables or disables shuffling.
81    ///
82    /// When enabled, the row order is randomized before each epoch.
83    /// Requires the `shuffle` feature.
84    #[cfg(feature = "shuffle")]
85    #[must_use]
86    pub fn shuffle(mut self, shuffle: bool) -> Self {
87        self.shuffle = shuffle;
88        self
89    }
90
91    /// Sets whether to drop the last incomplete batch.
92    ///
93    /// When true, if the dataset size is not evenly divisible by the batch
94    /// size, the final partial batch is skipped.
95    #[must_use]
96    pub fn drop_last(mut self, drop_last: bool) -> Self {
97        self.drop_last = drop_last;
98        self
99    }
100
101    /// Sets the random seed for shuffling.
102    ///
103    /// Setting a seed makes shuffling deterministic and reproducible.
104    /// Requires the `shuffle` feature.
105    #[cfg(feature = "shuffle")]
106    #[must_use]
107    pub fn seed(mut self, seed: u64) -> Self {
108        self.seed = Some(seed);
109        self
110    }
111
112    /// Returns the configured batch size.
113    pub fn get_batch_size(&self) -> usize {
114        self.batch_size
115    }
116
117    /// Returns whether shuffling is enabled.
118    pub fn is_shuffle(&self) -> bool {
119        self.shuffle
120    }
121
122    /// Returns whether drop_last is enabled.
123    pub fn is_drop_last(&self) -> bool {
124        self.drop_last
125    }
126
127    /// Returns the number of batches that will be yielded.
128    pub fn num_batches(&self) -> usize {
129        let len = self.dataset.len();
130        if self.drop_last {
131            len / self.batch_size
132        } else {
133            len.div_ceil(self.batch_size)
134        }
135    }
136
137    /// Returns the total number of rows in the underlying dataset.
138    pub fn len(&self) -> usize {
139        self.dataset.len()
140    }
141
142    /// Returns true if the dataset is empty.
143    pub fn is_empty(&self) -> bool {
144        self.dataset.is_empty()
145    }
146}
147
148impl<D: Dataset> IntoIterator for DataLoader<D> {
149    type Item = RecordBatch;
150    type IntoIter = DataLoaderIterator<D>;
151
152    fn into_iter(self) -> Self::IntoIter {
153        let indices: Vec<usize> = (0..self.dataset.len()).collect();
154
155        #[cfg(feature = "shuffle")]
156        let shuffled_indices = if self.shuffle {
157            let mut indices = indices;
158            let mut rng = match self.seed {
159                Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
160                None => rand::rngs::StdRng::from_entropy(),
161            };
162            indices.shuffle(&mut rng);
163            indices
164        } else {
165            indices
166        };
167
168        #[cfg(not(feature = "shuffle"))]
169        let shuffled_indices = indices;
170
171        DataLoaderIterator {
172            dataset: self.dataset,
173            batch_size: self.batch_size,
174            drop_last: self.drop_last,
175            indices: shuffled_indices,
176            position: 0,
177        }
178    }
179}
180
181/// Iterator over batched data from a DataLoader.
182pub struct DataLoaderIterator<D: Dataset> {
183    dataset: Arc<D>,
184    batch_size: usize,
185    drop_last: bool,
186    indices: Vec<usize>,
187    position: usize,
188}
189
190impl<D: Dataset> Iterator for DataLoaderIterator<D> {
191    type Item = RecordBatch;
192
193    fn next(&mut self) -> Option<Self::Item> {
194        if self.position >= self.indices.len() {
195            return None;
196        }
197
198        let remaining = self.indices.len() - self.position;
199        let batch_size = remaining.min(self.batch_size);
200
201        // Skip incomplete batch if drop_last is set
202        if self.drop_last && batch_size < self.batch_size {
203            return None;
204        }
205
206        // Collect rows for this batch
207        let batch_indices = &self.indices[self.position..self.position + batch_size];
208        self.position += batch_size;
209
210        // Get individual rows and concatenate
211        let rows: Vec<RecordBatch> = batch_indices
212            .iter()
213            .filter_map(|&idx| self.dataset.get(idx))
214            .collect();
215
216        if rows.is_empty() {
217            return None;
218        }
219
220        // Concatenate all rows into a single batch
221        concat_batches(&self.dataset.schema(), &rows).ok()
222    }
223
224    fn size_hint(&self) -> (usize, Option<usize>) {
225        let remaining = self.indices.len().saturating_sub(self.position);
226        let batches = if self.drop_last {
227            remaining / self.batch_size
228        } else if remaining > 0 {
229            remaining.div_ceil(self.batch_size)
230        } else {
231            0
232        };
233        (batches, Some(batches))
234    }
235}
236
237/// Builder for creating DataLoaders with more complex configurations.
238#[derive(Debug, Default)]
239pub struct DataLoaderBuilder {
240    batch_size: Option<usize>,
241    shuffle: Option<bool>,
242    drop_last: Option<bool>,
243    seed: Option<u64>,
244}
245
246impl DataLoaderBuilder {
247    /// Creates a new builder with default values.
248    pub fn new() -> Self {
249        Self::default()
250    }
251
252    /// Sets the batch size.
253    #[must_use]
254    pub fn batch_size(mut self, size: usize) -> Self {
255        self.batch_size = Some(size);
256        self
257    }
258
259    /// Sets whether to shuffle.
260    #[must_use]
261    pub fn shuffle(mut self, shuffle: bool) -> Self {
262        self.shuffle = Some(shuffle);
263        self
264    }
265
266    /// Sets whether to drop the last incomplete batch.
267    #[must_use]
268    pub fn drop_last(mut self, drop_last: bool) -> Self {
269        self.drop_last = Some(drop_last);
270        self
271    }
272
273    /// Sets the random seed.
274    #[must_use]
275    pub fn seed(mut self, seed: u64) -> Self {
276        self.seed = Some(seed);
277        self
278    }
279
280    /// Builds a DataLoader with the given dataset.
281    ///
282    /// # Errors
283    ///
284    /// Returns an error if the batch size is zero.
285    pub fn build<D: Dataset>(self, dataset: D) -> Result<DataLoader<D>> {
286        let batch_size = self.batch_size.unwrap_or(1);
287        if batch_size == 0 {
288            return Err(crate::error::Error::invalid_config(
289                "batch_size must be greater than 0",
290            ));
291        }
292
293        let mut loader = DataLoader::new(dataset).batch_size(batch_size);
294
295        #[cfg(feature = "shuffle")]
296        if let Some(shuffle) = self.shuffle {
297            loader = loader.shuffle(shuffle);
298        }
299        if let Some(drop_last) = self.drop_last {
300            loader = loader.drop_last(drop_last);
301        }
302        #[cfg(feature = "shuffle")]
303        if let Some(seed) = self.seed {
304            loader = loader.seed(seed);
305        }
306
307        Ok(loader)
308    }
309}
310
311#[cfg(test)]
312#[allow(
313    clippy::cast_possible_truncation,
314    clippy::cast_possible_wrap,
315    clippy::uninlined_format_args
316)]
317mod tests {
318    use std::collections::HashSet;
319
320    use arrow::{
321        array::{Int32Array, StringArray},
322        datatypes::{DataType, Field, Schema},
323    };
324
325    use super::*;
326    use crate::ArrowDataset;
327
328    fn create_test_dataset(rows: usize) -> ArrowDataset {
329        let schema = Arc::new(Schema::new(vec![
330            Field::new("id", DataType::Int32, false),
331            Field::new("value", DataType::Utf8, false),
332        ]));
333
334        let ids: Vec<i32> = (0..rows as i32).collect();
335        let values: Vec<String> = ids.iter().map(|i| format!("val_{}", i)).collect();
336
337        let batch = RecordBatch::try_new(
338            schema,
339            vec![
340                Arc::new(Int32Array::from(ids)),
341                Arc::new(StringArray::from(values)),
342            ],
343        )
344        .ok()
345        .unwrap_or_else(|| panic!("Should create batch"));
346
347        ArrowDataset::from_batch(batch)
348            .ok()
349            .unwrap_or_else(|| panic!("Should create dataset"))
350    }
351
352    #[test]
353    fn test_basic_iteration() {
354        let dataset = create_test_dataset(10);
355        let loader = DataLoader::new(dataset).batch_size(3);
356
357        let batches: Vec<RecordBatch> = loader.into_iter().collect();
358        assert_eq!(batches.len(), 4); // 3 + 3 + 3 + 1
359
360        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
361        assert_eq!(total_rows, 10);
362    }
363
364    #[test]
365    fn test_drop_last() {
366        let dataset = create_test_dataset(10);
367        let loader = DataLoader::new(dataset).batch_size(3).drop_last(true);
368
369        let batches: Vec<RecordBatch> = loader.into_iter().collect();
370        assert_eq!(batches.len(), 3); // Only full batches
371
372        for batch in &batches {
373            assert_eq!(batch.num_rows(), 3);
374        }
375    }
376
377    #[test]
378    fn test_shuffle_deterministic() {
379        let dataset = create_test_dataset(100);
380
381        let loader1 = DataLoader::new(dataset.clone())
382            .batch_size(10)
383            .shuffle(true)
384            .seed(42);
385        let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
386
387        let loader2 = DataLoader::new(dataset)
388            .batch_size(10)
389            .shuffle(true)
390            .seed(42);
391        let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
392
393        // Same seed should produce same order
394        assert_eq!(batches1.len(), batches2.len());
395        for (b1, b2) in batches1.iter().zip(batches2.iter()) {
396            assert_eq!(b1.num_rows(), b2.num_rows());
397        }
398    }
399
400    #[test]
401    fn test_shuffle_different_seeds() {
402        let dataset = create_test_dataset(100);
403
404        let loader1 = DataLoader::new(dataset.clone())
405            .batch_size(100)
406            .shuffle(true)
407            .seed(42);
408        let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
409
410        let loader2 = DataLoader::new(dataset)
411            .batch_size(100)
412            .shuffle(true)
413            .seed(123);
414        let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
415
416        // Different seeds should likely produce different order
417        // (we check that we got all the data, order may differ)
418        assert_eq!(batches1.len(), batches2.len());
419    }
420
421    #[test]
422    fn test_all_rows_covered() {
423        let dataset = create_test_dataset(25);
424        let loader = DataLoader::new(dataset)
425            .batch_size(7)
426            .shuffle(true)
427            .seed(99);
428
429        let mut seen_ids = HashSet::new();
430        for batch in loader {
431            let id_col = batch
432                .column(0)
433                .as_any()
434                .downcast_ref::<Int32Array>()
435                .unwrap_or_else(|| panic!("Should be Int32Array"));
436            for i in 0..id_col.len() {
437                seen_ids.insert(id_col.value(i));
438            }
439        }
440
441        assert_eq!(seen_ids.len(), 25);
442        for i in 0..25i32 {
443            assert!(seen_ids.contains(&i));
444        }
445    }
446
447    #[test]
448    fn test_num_batches() {
449        let dataset = create_test_dataset(10);
450
451        let loader = DataLoader::new(dataset.clone()).batch_size(3);
452        assert_eq!(loader.num_batches(), 4);
453
454        let loader = DataLoader::new(dataset).batch_size(3).drop_last(true);
455        assert_eq!(loader.num_batches(), 3);
456    }
457
458    #[test]
459    fn test_builder() {
460        let dataset = create_test_dataset(10);
461        let loader = DataLoaderBuilder::new()
462            .batch_size(5)
463            .shuffle(true)
464            .seed(42)
465            .build(dataset)
466            .ok()
467            .unwrap_or_else(|| panic!("Should build loader"));
468
469        assert_eq!(loader.get_batch_size(), 5);
470        assert!(loader.is_shuffle());
471    }
472
473    #[test]
474    fn test_builder_zero_batch_size_error() {
475        let dataset = create_test_dataset(10);
476        let result = DataLoaderBuilder::new().batch_size(0).build(dataset);
477        assert!(result.is_err());
478    }
479
480    #[test]
481    fn test_size_hint() {
482        let dataset = create_test_dataset(10);
483        let loader = DataLoader::new(dataset).batch_size(3);
484
485        let mut iter = loader.into_iter();
486        assert_eq!(iter.size_hint(), (4, Some(4)));
487
488        let _ = iter.next();
489        assert_eq!(iter.size_hint(), (3, Some(3)));
490    }
491
492    #[test]
493    fn test_getters() {
494        let dataset = create_test_dataset(10);
495        let loader = DataLoader::new(dataset)
496            .batch_size(5)
497            .shuffle(true)
498            .drop_last(true);
499
500        assert_eq!(loader.get_batch_size(), 5);
501        assert!(loader.is_shuffle());
502        assert!(loader.is_drop_last());
503        assert_eq!(loader.len(), 10);
504        assert!(!loader.is_empty());
505    }
506
507    #[test]
508    fn test_batch_size_min_one() {
509        let dataset = create_test_dataset(10);
510        let loader = DataLoader::new(dataset).batch_size(0);
511        assert_eq!(loader.get_batch_size(), 1);
512    }
513
514    #[test]
515    fn test_empty_dataset() {
516        let dataset = create_test_dataset(0);
517        let loader = DataLoader::new(dataset).batch_size(3);
518        let batches: Vec<RecordBatch> = loader.into_iter().collect();
519        assert!(batches.is_empty());
520    }
521
522    #[test]
523    fn test_empty_dataset_drop_last() {
524        let dataset = create_test_dataset(0);
525        let loader = DataLoader::new(dataset).batch_size(3).drop_last(true);
526        let batches: Vec<RecordBatch> = loader.into_iter().collect();
527        assert!(batches.is_empty());
528    }
529
530    #[test]
531    fn test_is_empty() {
532        let empty_dataset = create_test_dataset(0);
533        let loader_empty = DataLoader::new(empty_dataset);
534        assert!(loader_empty.is_empty());
535
536        let dataset = create_test_dataset(5);
537        let loader = DataLoader::new(dataset);
538        assert!(!loader.is_empty());
539    }
540
541    #[test]
542    fn test_len() {
543        let dataset = create_test_dataset(42);
544        let loader = DataLoader::new(dataset);
545        assert_eq!(loader.len(), 42);
546    }
547
548    #[test]
549    fn test_single_row_dataset() {
550        let dataset = create_test_dataset(1);
551        let loader = DataLoader::new(dataset).batch_size(5);
552        let batches: Vec<RecordBatch> = loader.into_iter().collect();
553        assert_eq!(batches.len(), 1);
554        assert_eq!(batches[0].num_rows(), 1);
555    }
556
557    #[test]
558    fn test_single_row_drop_last() {
559        let dataset = create_test_dataset(1);
560        let loader = DataLoader::new(dataset).batch_size(5).drop_last(true);
561        let batches: Vec<RecordBatch> = loader.into_iter().collect();
562        // Single row is smaller than batch size, so it should be dropped
563        assert!(batches.is_empty());
564    }
565
566    #[test]
567    fn test_batch_size_equals_dataset_size() {
568        let dataset = create_test_dataset(10);
569        let loader = DataLoader::new(dataset).batch_size(10);
570        let batches: Vec<RecordBatch> = loader.into_iter().collect();
571        assert_eq!(batches.len(), 1);
572        assert_eq!(batches[0].num_rows(), 10);
573    }
574
575    #[test]
576    fn test_batch_size_larger_than_dataset() {
577        let dataset = create_test_dataset(5);
578        let loader = DataLoader::new(dataset).batch_size(100);
579        let batches: Vec<RecordBatch> = loader.into_iter().collect();
580        assert_eq!(batches.len(), 1);
581        assert_eq!(batches[0].num_rows(), 5);
582    }
583
584    #[test]
585    fn test_batch_size_larger_than_dataset_drop_last() {
586        let dataset = create_test_dataset(5);
587        let loader = DataLoader::new(dataset).batch_size(100).drop_last(true);
588        let batches: Vec<RecordBatch> = loader.into_iter().collect();
589        // Batch is incomplete, should be dropped
590        assert!(batches.is_empty());
591    }
592
593    #[test]
594    fn test_num_batches_with_drop_last() {
595        let dataset = create_test_dataset(10);
596
597        let loader_without_drop = DataLoader::new(dataset.clone()).batch_size(3);
598        assert_eq!(loader_without_drop.num_batches(), 4); // 3 + 3 + 3 + 1
599
600        let loader_with_drop = DataLoader::new(dataset).batch_size(3).drop_last(true);
601        assert_eq!(loader_with_drop.num_batches(), 3); // 3 + 3 + 3
602    }
603
604    #[test]
605    fn test_builder_all_options() {
606        let dataset = create_test_dataset(10);
607        let result = DataLoaderBuilder::new()
608            .batch_size(4)
609            .shuffle(true)
610            .drop_last(true)
611            .seed(42)
612            .build(dataset);
613
614        assert!(result.is_ok());
615        let loader = result.ok().unwrap();
616        assert_eq!(loader.get_batch_size(), 4);
617        assert!(loader.is_shuffle());
618        assert!(loader.is_drop_last());
619    }
620
621    #[test]
622    fn test_size_hint_empty_dataset() {
623        let dataset = create_test_dataset(0);
624        let loader = DataLoader::new(dataset).batch_size(3);
625        let iter = loader.into_iter();
626        assert_eq!(iter.size_hint(), (0, Some(0)));
627    }
628
629    #[test]
630    fn test_iterator_exhaustion() {
631        let dataset = create_test_dataset(5);
632        let loader = DataLoader::new(dataset).batch_size(2);
633        let mut iter = loader.into_iter();
634
635        // Should yield 3 batches: 2, 2, 1
636        assert!(iter.next().is_some());
637        assert!(iter.next().is_some());
638        assert!(iter.next().is_some());
639        // Should be exhausted
640        assert!(iter.next().is_none());
641        // Should remain exhausted
642        assert!(iter.next().is_none());
643    }
644
645    #[test]
646    fn test_size_hint_during_iteration() {
647        let dataset = create_test_dataset(10);
648        let loader = DataLoader::new(dataset).batch_size(3);
649        let mut iter = loader.into_iter();
650
651        // Initially 4 batches remaining
652        assert_eq!(iter.size_hint(), (4, Some(4)));
653
654        iter.next();
655        assert_eq!(iter.size_hint(), (3, Some(3)));
656
657        iter.next();
658        assert_eq!(iter.size_hint(), (2, Some(2)));
659
660        iter.next();
661        assert_eq!(iter.size_hint(), (1, Some(1)));
662
663        iter.next();
664        assert_eq!(iter.size_hint(), (0, Some(0)));
665    }
666
667    #[test]
668    fn test_debug_impl() {
669        let dataset = create_test_dataset(5);
670        let loader = DataLoader::new(dataset).batch_size(2);
671        let debug_str = format!("{:?}", loader);
672        assert!(debug_str.contains("DataLoader"));
673        assert!(debug_str.contains("batch_size: 2"));
674    }
675
676    #[test]
677    fn test_builder_debug_impl() {
678        let builder = DataLoaderBuilder::new().batch_size(10).drop_last(true);
679        let debug_str = format!("{:?}", builder);
680        assert!(debug_str.contains("DataLoaderBuilder"));
681    }
682}