Skip to main content

alimentar/
parallel.rs

1//! Parallel data loading with multi-worker support.
2//!
3//! Provides a parallel data loader that uses multiple threads to load data
4//! in parallel, similar to PyTorch's `DataLoader` with `num_workers > 0`.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use alimentar::{parallel::ParallelDataLoader, ArrowDataset, Dataset};
10//!
11//! let dataset = ArrowDataset::from_parquet("data.parquet").unwrap();
12//! let loader = ParallelDataLoader::new(dataset)
13//!     .batch_size(32)
14//!     .num_workers(4)
15//!     .prefetch(2);
16//!
17//! for batch in loader {
18//!     println!("Batch has {} rows", batch.num_rows());
19//! }
20//! ```
21
22use std::{sync::Arc, thread};
23
24use arrow::record_batch::RecordBatch;
25
26use crate::{dataset::Dataset, error::Result};
27
28/// Parallel data loader with multi-worker support.
29///
30/// Uses a thread pool to load batches in parallel, with configurable
31/// number of workers and prefetch buffer size.
32#[derive(Debug)]
33pub struct ParallelDataLoader<D: Dataset> {
34    dataset: Arc<D>,
35    batch_size: usize,
36    num_workers: usize,
37    prefetch: usize,
38    #[cfg(feature = "shuffle")]
39    shuffle: bool,
40    #[cfg(feature = "shuffle")]
41    seed: Option<u64>,
42    drop_last: bool,
43}
44
45impl<D: Dataset + 'static> ParallelDataLoader<D> {
46    /// Creates a new parallel data loader.
47    pub fn new(dataset: D) -> Self {
48        Self {
49            dataset: Arc::new(dataset),
50            batch_size: 1,
51            num_workers: 0, // 0 = main thread only (no workers)
52            prefetch: 2,
53            #[cfg(feature = "shuffle")]
54            shuffle: false,
55            #[cfg(feature = "shuffle")]
56            seed: None,
57            drop_last: false,
58        }
59    }
60
61    /// Sets the batch size (minimum 1).
62    #[must_use]
63    pub fn batch_size(mut self, size: usize) -> Self {
64        self.batch_size = size.max(1);
65        self
66    }
67
68    /// Sets the number of worker threads (0 = main thread only).
69    ///
70    /// Note: On WASM targets, num_workers is always 0.
71    #[must_use]
72    pub fn num_workers(mut self, workers: usize) -> Self {
73        #[cfg(target_arch = "wasm32")]
74        {
75            let _ = workers;
76            self.num_workers = 0;
77        }
78        #[cfg(not(target_arch = "wasm32"))]
79        {
80            self.num_workers = workers;
81        }
82        self
83    }
84
85    /// Sets the prefetch buffer size.
86    #[must_use]
87    pub fn prefetch(mut self, size: usize) -> Self {
88        self.prefetch = size.max(1);
89        self
90    }
91
92    /// Enables or disables shuffling.
93    #[cfg(feature = "shuffle")]
94    #[must_use]
95    pub fn shuffle(mut self, enable: bool) -> Self {
96        self.shuffle = enable;
97        self
98    }
99
100    /// Sets the random seed for shuffling.
101    #[cfg(feature = "shuffle")]
102    #[must_use]
103    pub fn seed(mut self, seed: u64) -> Self {
104        self.seed = Some(seed);
105        self
106    }
107
108    /// Enables or disables dropping the last incomplete batch.
109    #[must_use]
110    pub fn drop_last(mut self, enable: bool) -> Self {
111        self.drop_last = enable;
112        self
113    }
114
115    /// Returns the batch size.
116    pub fn get_batch_size(&self) -> usize {
117        self.batch_size
118    }
119
120    /// Returns the number of workers.
121    pub fn get_num_workers(&self) -> usize {
122        self.num_workers
123    }
124
125    /// Returns the prefetch size.
126    pub fn get_prefetch(&self) -> usize {
127        self.prefetch
128    }
129
130    /// Returns the number of batches.
131    pub fn num_batches(&self) -> usize {
132        let total_rows = self.dataset.len();
133        if self.drop_last {
134            total_rows / self.batch_size
135        } else {
136            total_rows.div_ceil(self.batch_size)
137        }
138    }
139
140    /// Returns the total number of rows in the underlying dataset.
141    pub fn len(&self) -> usize {
142        self.dataset.len()
143    }
144
145    /// Returns true if the underlying dataset is empty.
146    pub fn is_empty(&self) -> bool {
147        self.dataset.is_empty()
148    }
149}
150
151impl<D: Dataset + 'static> IntoIterator for ParallelDataLoader<D> {
152    type Item = RecordBatch;
153    type IntoIter = ParallelDataLoaderIterator<D>;
154
155    fn into_iter(self) -> Self::IntoIter {
156        let total_rows = self.dataset.len();
157
158        // Generate indices
159        #[allow(unused_mut)]
160        let mut indices: Vec<usize> = (0..total_rows).collect();
161
162        #[cfg(feature = "shuffle")]
163        if self.shuffle {
164            use rand::{seq::SliceRandom, SeedableRng};
165
166            let mut rng = match self.seed {
167                Some(s) => rand::rngs::StdRng::seed_from_u64(s),
168                None => rand::rngs::StdRng::from_entropy(),
169            };
170            indices.shuffle(&mut rng);
171        }
172
173        if self.num_workers == 0 {
174            // Single-threaded path
175            ParallelDataLoaderIterator::SingleThreaded {
176                dataset: self.dataset,
177                indices,
178                batch_size: self.batch_size,
179                drop_last: self.drop_last,
180                position: 0,
181            }
182        } else {
183            // Multi-threaded path with channel
184            use std::sync::mpsc;
185
186            let (tx, rx) = mpsc::sync_channel(self.prefetch);
187            let dataset = self.dataset.clone();
188            let batch_size = self.batch_size;
189            let drop_last = self.drop_last;
190            let num_workers = self.num_workers;
191
192            // Spawn worker thread(s)
193            let handle = thread::spawn(move || {
194                // Simple round-robin distribution to workers
195                let chunks: Vec<Vec<usize>> = indices
196                    .chunks(batch_size)
197                    .filter(|chunk| !drop_last || chunk.len() == batch_size)
198                    .map(|chunk| chunk.to_vec())
199                    .collect();
200
201                // Use thread pool for parallel processing
202                let pool_size = num_workers.min(chunks.len());
203                if pool_size == 0 {
204                    return;
205                }
206
207                // Process chunks and send batches
208                for batch in chunks.iter().filter_map(|chunk_indices| {
209                    collect_batch_from_indices(&*dataset, chunk_indices)
210                }) {
211                    if tx.send(batch).is_err() {
212                        break;
213                    }
214                }
215            });
216
217            ParallelDataLoaderIterator::MultiThreaded {
218                receiver: rx,
219                _handle: handle,
220            }
221        }
222    }
223}
224
225/// Collects rows from dataset into a single batch.
226fn collect_batch_from_indices<D: Dataset>(dataset: &D, indices: &[usize]) -> Option<RecordBatch> {
227    use arrow::compute::concat_batches;
228
229    let rows: Vec<RecordBatch> = indices.iter().filter_map(|&idx| dataset.get(idx)).collect();
230
231    if rows.is_empty() {
232        return None;
233    }
234
235    let schema = dataset.schema();
236    concat_batches(&schema, &rows).ok()
237}
238
239/// Iterator for parallel data loader.
240#[allow(missing_docs)]
241pub enum ParallelDataLoaderIterator<D: Dataset> {
242    /// Single-threaded iteration (num_workers = 0)
243    SingleThreaded {
244        /// The dataset being iterated
245        dataset: Arc<D>,
246        /// Row indices to iterate
247        indices: Vec<usize>,
248        /// Batch size for iteration
249        batch_size: usize,
250        /// Whether to drop the last incomplete batch
251        drop_last: bool,
252        /// Current position in indices
253        position: usize,
254    },
255    /// Multi-threaded iteration with channel
256    MultiThreaded {
257        /// Receiver for batches from worker threads
258        receiver: std::sync::mpsc::Receiver<RecordBatch>,
259        /// Handle to the worker thread
260        _handle: thread::JoinHandle<()>,
261    },
262}
263
264impl<D: Dataset> std::fmt::Debug for ParallelDataLoaderIterator<D> {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        match self {
267            Self::SingleThreaded {
268                position,
269                batch_size,
270                ..
271            } => f
272                .debug_struct("ParallelDataLoaderIterator::SingleThreaded")
273                .field("position", position)
274                .field("batch_size", batch_size)
275                .finish(),
276            Self::MultiThreaded { .. } => f
277                .debug_struct("ParallelDataLoaderIterator::MultiThreaded")
278                .finish(),
279        }
280    }
281}
282
283impl<D: Dataset + 'static> Iterator for ParallelDataLoaderIterator<D> {
284    type Item = RecordBatch;
285
286    fn next(&mut self) -> Option<Self::Item> {
287        match self {
288            Self::SingleThreaded {
289                dataset,
290                indices,
291                batch_size,
292                drop_last,
293                position,
294            } => {
295                if *position >= indices.len() {
296                    return None;
297                }
298
299                let end = (*position + *batch_size).min(indices.len());
300                let chunk_indices = &indices[*position..end];
301
302                if *drop_last && chunk_indices.len() < *batch_size {
303                    return None;
304                }
305
306                *position = end;
307                collect_batch_from_indices(&**dataset, chunk_indices)
308            }
309            Self::MultiThreaded { receiver, .. } => receiver.recv().ok(),
310        }
311    }
312
313    fn size_hint(&self) -> (usize, Option<usize>) {
314        match self {
315            Self::SingleThreaded {
316                indices,
317                batch_size,
318                drop_last,
319                position,
320                ..
321            } => {
322                let remaining = indices.len().saturating_sub(*position);
323                let batches = if *drop_last {
324                    remaining / *batch_size
325                } else {
326                    remaining.div_ceil(*batch_size)
327                };
328                (batches, Some(batches))
329            }
330            Self::MultiThreaded { .. } => (0, None),
331        }
332    }
333}
334
335/// Builder for parallel data loader configuration.
336#[derive(Debug, Default)]
337pub struct ParallelDataLoaderBuilder {
338    batch_size: Option<usize>,
339    num_workers: Option<usize>,
340    prefetch: Option<usize>,
341    #[cfg(feature = "shuffle")]
342    shuffle: Option<bool>,
343    #[cfg(feature = "shuffle")]
344    seed: Option<u64>,
345    drop_last: Option<bool>,
346}
347
348impl ParallelDataLoaderBuilder {
349    /// Creates a new builder with default values.
350    pub fn new() -> Self {
351        Self::default()
352    }
353
354    /// Sets the batch size.
355    #[must_use]
356    pub fn batch_size(mut self, size: usize) -> Self {
357        self.batch_size = Some(size);
358        self
359    }
360
361    /// Sets the number of workers.
362    #[must_use]
363    pub fn num_workers(mut self, workers: usize) -> Self {
364        self.num_workers = Some(workers);
365        self
366    }
367
368    /// Sets the prefetch size.
369    #[must_use]
370    pub fn prefetch(mut self, size: usize) -> Self {
371        self.prefetch = Some(size);
372        self
373    }
374
375    /// Enables shuffling.
376    #[cfg(feature = "shuffle")]
377    #[must_use]
378    pub fn shuffle(mut self, enable: bool) -> Self {
379        self.shuffle = Some(enable);
380        self
381    }
382
383    /// Sets the random seed.
384    #[cfg(feature = "shuffle")]
385    #[must_use]
386    pub fn seed(mut self, seed: u64) -> Self {
387        self.seed = Some(seed);
388        self
389    }
390
391    /// Enables drop_last.
392    #[must_use]
393    pub fn drop_last(mut self, enable: bool) -> Self {
394        self.drop_last = Some(enable);
395        self
396    }
397
398    /// Builds the parallel data loader with the given dataset.
399    pub fn build<D: Dataset + 'static>(self, dataset: D) -> Result<ParallelDataLoader<D>> {
400        let mut loader = ParallelDataLoader::new(dataset);
401
402        if let Some(size) = self.batch_size {
403            loader = loader.batch_size(size);
404        }
405        if let Some(workers) = self.num_workers {
406            loader = loader.num_workers(workers);
407        }
408        if let Some(size) = self.prefetch {
409            loader = loader.prefetch(size);
410        }
411        #[cfg(feature = "shuffle")]
412        if let Some(enable) = self.shuffle {
413            loader = loader.shuffle(enable);
414        }
415        #[cfg(feature = "shuffle")]
416        if let Some(seed) = self.seed {
417            loader = loader.seed(seed);
418        }
419        if let Some(enable) = self.drop_last {
420            loader = loader.drop_last(enable);
421        }
422
423        Ok(loader)
424    }
425}
426
427#[cfg(test)]
428#[allow(
429    clippy::cast_possible_truncation,
430    clippy::cast_possible_wrap,
431    clippy::uninlined_format_args,
432    clippy::unwrap_used,
433    clippy::expect_used
434)]
435mod tests {
436    use std::collections::HashSet;
437
438    use arrow::{
439        array::{Int32Array, StringArray},
440        datatypes::{DataType, Field, Schema},
441    };
442
443    use super::*;
444    use crate::ArrowDataset;
445
446    fn create_test_dataset(rows: usize) -> ArrowDataset {
447        let schema = Arc::new(Schema::new(vec![
448            Field::new("id", DataType::Int32, false),
449            Field::new("value", DataType::Utf8, false),
450        ]));
451
452        let ids: Vec<i32> = (0..rows as i32).collect();
453        let values: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
454
455        let batch = RecordBatch::try_new(
456            schema,
457            vec![
458                Arc::new(Int32Array::from(ids)),
459                Arc::new(StringArray::from(values)),
460            ],
461        )
462        .ok()
463        .unwrap_or_else(|| panic!("Should create batch"));
464
465        ArrowDataset::from_batch(batch)
466            .ok()
467            .unwrap_or_else(|| panic!("Should create dataset"))
468    }
469
470    #[test]
471    fn test_parallel_loader_single_threaded() {
472        let dataset = create_test_dataset(100);
473        let loader = ParallelDataLoader::new(dataset)
474            .batch_size(10)
475            .num_workers(0);
476
477        assert_eq!(loader.get_batch_size(), 10);
478        assert_eq!(loader.get_num_workers(), 0);
479        assert_eq!(loader.num_batches(), 10);
480
481        let batches: Vec<RecordBatch> = loader.into_iter().collect();
482        assert_eq!(batches.len(), 10);
483
484        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
485        assert_eq!(total_rows, 100);
486    }
487
488    #[test]
489    fn test_parallel_loader_multi_threaded() {
490        let dataset = create_test_dataset(100);
491        let loader = ParallelDataLoader::new(dataset)
492            .batch_size(10)
493            .num_workers(2)
494            .prefetch(4);
495
496        assert_eq!(loader.get_num_workers(), 2);
497        assert_eq!(loader.get_prefetch(), 4);
498
499        let batches: Vec<RecordBatch> = loader.into_iter().collect();
500        assert_eq!(batches.len(), 10);
501
502        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
503        assert_eq!(total_rows, 100);
504    }
505
506    #[test]
507    fn test_parallel_loader_drop_last() {
508        let dataset = create_test_dataset(25);
509        let loader = ParallelDataLoader::new(dataset)
510            .batch_size(10)
511            .drop_last(true);
512
513        let batches: Vec<RecordBatch> = loader.into_iter().collect();
514        assert_eq!(batches.len(), 2);
515
516        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
517        assert_eq!(total_rows, 20);
518    }
519
520    #[test]
521    #[cfg(feature = "shuffle")]
522    fn test_parallel_loader_shuffle() {
523        let dataset = create_test_dataset(100);
524        let loader1 = ParallelDataLoader::new(dataset.clone())
525            .batch_size(10)
526            .shuffle(true)
527            .seed(42);
528
529        let loader2 = ParallelDataLoader::new(dataset)
530            .batch_size(10)
531            .shuffle(true)
532            .seed(42);
533
534        let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
535        let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
536
537        // Same seed should produce same order
538        for (b1, b2) in batches1.iter().zip(batches2.iter()) {
539            let ids1 = b1.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
540            let ids2 = b2.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
541
542            for i in 0..ids1.len() {
543                assert_eq!(ids1.value(i), ids2.value(i));
544            }
545        }
546    }
547
548    #[test]
549    fn test_parallel_loader_all_rows() {
550        let dataset = create_test_dataset(50);
551        let loader = ParallelDataLoader::new(dataset)
552            .batch_size(7)
553            .num_workers(2);
554
555        let mut seen_ids: HashSet<i32> = HashSet::new();
556        for batch in loader {
557            let ids = batch
558                .column(0)
559                .as_any()
560                .downcast_ref::<Int32Array>()
561                .unwrap();
562            for i in 0..ids.len() {
563                seen_ids.insert(ids.value(i));
564            }
565        }
566
567        // All 50 IDs should be present
568        assert_eq!(seen_ids.len(), 50);
569        for i in 0..50 {
570            assert!(seen_ids.contains(&i), "Missing id: {}", i);
571        }
572    }
573
574    #[test]
575    fn test_parallel_loader_getters() {
576        let dataset = create_test_dataset(100);
577        let loader = ParallelDataLoader::new(dataset)
578            .batch_size(20)
579            .num_workers(4)
580            .prefetch(8);
581
582        assert_eq!(loader.get_batch_size(), 20);
583        assert_eq!(loader.get_num_workers(), 4);
584        assert_eq!(loader.get_prefetch(), 8);
585        assert_eq!(loader.len(), 100);
586        assert!(!loader.is_empty());
587    }
588
589    #[test]
590    fn test_parallel_loader_builder() {
591        let dataset = create_test_dataset(100);
592        let loader = ParallelDataLoaderBuilder::new()
593            .batch_size(25)
594            .num_workers(2)
595            .prefetch(4)
596            .drop_last(true)
597            .build(dataset)
598            .ok()
599            .unwrap_or_else(|| panic!("Should build"));
600
601        assert_eq!(loader.get_batch_size(), 25);
602        assert_eq!(loader.get_num_workers(), 2);
603        assert_eq!(loader.num_batches(), 4);
604    }
605
606    #[test]
607    fn test_parallel_loader_empty_dataset() {
608        // Create dataset with at least 1 row for valid ArrowDataset
609        let dataset = create_test_dataset(1);
610        let loader = ParallelDataLoader::new(dataset)
611            .batch_size(10)
612            .num_workers(0);
613
614        let batches: Vec<RecordBatch> = loader.into_iter().collect();
615        assert_eq!(batches.len(), 1);
616    }
617
618    #[test]
619    fn test_parallel_loader_batch_size_min() {
620        let dataset = create_test_dataset(10);
621        let loader = ParallelDataLoader::new(dataset).batch_size(0);
622
623        assert_eq!(loader.get_batch_size(), 1);
624    }
625
626    #[test]
627    fn test_parallel_loader_debug() {
628        let dataset = create_test_dataset(10);
629        let loader = ParallelDataLoader::new(dataset)
630            .batch_size(5)
631            .num_workers(2);
632
633        let debug_str = format!("{:?}", loader);
634        assert!(debug_str.contains("ParallelDataLoader"));
635
636        let iter = loader.into_iter();
637        let iter_debug = format!("{:?}", iter);
638        assert!(iter_debug.contains("ParallelDataLoaderIterator"));
639    }
640
641    #[test]
642    fn test_parallel_loader_size_hint() {
643        let dataset = create_test_dataset(25);
644        let loader = ParallelDataLoader::new(dataset)
645            .batch_size(10)
646            .num_workers(0);
647
648        let mut iter = loader.into_iter();
649        assert_eq!(iter.size_hint(), (3, Some(3)));
650
651        let _ = iter.next();
652        assert_eq!(iter.size_hint(), (2, Some(2)));
653    }
654
655    #[test]
656    fn test_builder_debug() {
657        let builder = ParallelDataLoaderBuilder::new()
658            .batch_size(32)
659            .num_workers(4);
660
661        let debug_str = format!("{:?}", builder);
662        assert!(debug_str.contains("ParallelDataLoaderBuilder"));
663    }
664
665    #[test]
666    fn test_parallel_loader_single_row() {
667        let dataset = create_test_dataset(1);
668        let loader = ParallelDataLoader::new(dataset)
669            .batch_size(10)
670            .num_workers(2);
671
672        let batches: Vec<RecordBatch> = loader.into_iter().collect();
673        assert_eq!(batches.len(), 1);
674        assert_eq!(batches[0].num_rows(), 1);
675    }
676
677    #[test]
678    fn test_parallel_loader_batch_equals_dataset() {
679        let dataset = create_test_dataset(50);
680        let loader = ParallelDataLoader::new(dataset)
681            .batch_size(50)
682            .num_workers(0);
683
684        let batches: Vec<RecordBatch> = loader.into_iter().collect();
685        assert_eq!(batches.len(), 1);
686        assert_eq!(batches[0].num_rows(), 50);
687    }
688
689    #[test]
690    fn test_parallel_loader_batch_larger_than_dataset() {
691        let dataset = create_test_dataset(10);
692        let loader = ParallelDataLoader::new(dataset)
693            .batch_size(100)
694            .num_workers(0);
695
696        let batches: Vec<RecordBatch> = loader.into_iter().collect();
697        assert_eq!(batches.len(), 1);
698        assert_eq!(batches[0].num_rows(), 10);
699    }
700
701    #[test]
702    fn test_parallel_loader_drop_last_exact_fit() {
703        let dataset = create_test_dataset(100);
704        let loader = ParallelDataLoader::new(dataset)
705            .batch_size(25)
706            .drop_last(true)
707            .num_workers(0);
708
709        let batches: Vec<RecordBatch> = loader.into_iter().collect();
710        assert_eq!(batches.len(), 4); // 100 / 25 = 4, no remainder
711    }
712
713    #[test]
714    fn test_parallel_loader_drop_last_with_remainder() {
715        let dataset = create_test_dataset(100);
716        let loader = ParallelDataLoader::new(dataset)
717            .batch_size(30)
718            .drop_last(true)
719            .num_workers(0);
720
721        let batches: Vec<RecordBatch> = loader.into_iter().collect();
722        assert_eq!(batches.len(), 3); // 100 / 30 = 3, remainder dropped
723    }
724
725    #[test]
726    fn test_parallel_loader_num_batches_calculation() {
727        let dataset = create_test_dataset(100);
728
729        // Without drop_last: ceil(100/30) = 4
730        let loader1 = ParallelDataLoader::new(dataset.clone())
731            .batch_size(30)
732            .num_workers(0);
733        assert_eq!(loader1.num_batches(), 4);
734
735        // With drop_last: floor(100/30) = 3
736        let loader2 = ParallelDataLoader::new(dataset)
737            .batch_size(30)
738            .drop_last(true)
739            .num_workers(0);
740        assert_eq!(loader2.num_batches(), 3);
741    }
742
743    #[test]
744    fn test_parallel_loader_prefetch_setting() {
745        let dataset = create_test_dataset(100);
746        let loader = ParallelDataLoader::new(dataset).batch_size(10).prefetch(16);
747
748        assert_eq!(loader.get_prefetch(), 16);
749    }
750
751    #[test]
752    fn test_parallel_loader_iterator_exhaustion() {
753        let dataset = create_test_dataset(30);
754        let loader = ParallelDataLoader::new(dataset)
755            .batch_size(10)
756            .num_workers(0);
757
758        let mut iter = loader.into_iter();
759
760        // Should yield 3 batches
761        assert!(iter.next().is_some());
762        assert!(iter.next().is_some());
763        assert!(iter.next().is_some());
764        // Should be exhausted
765        assert!(iter.next().is_none());
766        // Should stay exhausted
767        assert!(iter.next().is_none());
768    }
769
770    #[test]
771    fn test_parallel_loader_total_rows_preserved() {
772        let dataset = create_test_dataset(97);
773        let loader = ParallelDataLoader::new(dataset)
774            .batch_size(10)
775            .num_workers(0);
776
777        let total: usize = loader.into_iter().map(|b| b.num_rows()).sum();
778        assert_eq!(total, 97);
779    }
780
781    #[test]
782    fn test_parallel_loader_builder_defaults() {
783        let dataset = create_test_dataset(50);
784        let loader = ParallelDataLoaderBuilder::new()
785            .build(dataset)
786            .ok()
787            .unwrap_or_else(|| panic!("build"));
788
789        // Defaults from ParallelDataLoader::new()
790        assert_eq!(loader.get_batch_size(), 1);
791        assert_eq!(loader.get_prefetch(), 2);
792    }
793
794    #[test]
795    fn test_parallel_loader_builder_with_shuffle() {
796        let dataset = create_test_dataset(50);
797        let loader = ParallelDataLoaderBuilder::new()
798            .batch_size(10)
799            .shuffle(true)
800            .seed(42)
801            .build(dataset)
802            .ok()
803            .unwrap_or_else(|| panic!("build"));
804
805        let batches: Vec<RecordBatch> = loader.into_iter().collect();
806        assert_eq!(batches.len(), 5);
807    }
808
809    #[test]
810    fn test_parallel_loader_zero_workers_single_threaded() {
811        let dataset = create_test_dataset(100);
812        let loader = ParallelDataLoader::new(dataset)
813            .batch_size(20)
814            .num_workers(0);
815
816        assert_eq!(loader.get_num_workers(), 0);
817
818        let batches: Vec<RecordBatch> = loader.into_iter().collect();
819        assert_eq!(batches.len(), 5);
820    }
821}