1use std::{sync::Arc, thread};
23
24use arrow::record_batch::RecordBatch;
25
26use crate::{dataset::Dataset, error::Result};
27
28#[derive(Debug)]
33pub struct ParallelDataLoader<D: Dataset> {
34 dataset: Arc<D>,
35 batch_size: usize,
36 num_workers: usize,
37 prefetch: usize,
38 #[cfg(feature = "shuffle")]
39 shuffle: bool,
40 #[cfg(feature = "shuffle")]
41 seed: Option<u64>,
42 drop_last: bool,
43}
44
45impl<D: Dataset + 'static> ParallelDataLoader<D> {
46 pub fn new(dataset: D) -> Self {
48 Self {
49 dataset: Arc::new(dataset),
50 batch_size: 1,
51 num_workers: 0, prefetch: 2,
53 #[cfg(feature = "shuffle")]
54 shuffle: false,
55 #[cfg(feature = "shuffle")]
56 seed: None,
57 drop_last: false,
58 }
59 }
60
61 #[must_use]
63 pub fn batch_size(mut self, size: usize) -> Self {
64 self.batch_size = size.max(1);
65 self
66 }
67
68 #[must_use]
72 pub fn num_workers(mut self, workers: usize) -> Self {
73 #[cfg(target_arch = "wasm32")]
74 {
75 let _ = workers;
76 self.num_workers = 0;
77 }
78 #[cfg(not(target_arch = "wasm32"))]
79 {
80 self.num_workers = workers;
81 }
82 self
83 }
84
85 #[must_use]
87 pub fn prefetch(mut self, size: usize) -> Self {
88 self.prefetch = size.max(1);
89 self
90 }
91
92 #[cfg(feature = "shuffle")]
94 #[must_use]
95 pub fn shuffle(mut self, enable: bool) -> Self {
96 self.shuffle = enable;
97 self
98 }
99
100 #[cfg(feature = "shuffle")]
102 #[must_use]
103 pub fn seed(mut self, seed: u64) -> Self {
104 self.seed = Some(seed);
105 self
106 }
107
108 #[must_use]
110 pub fn drop_last(mut self, enable: bool) -> Self {
111 self.drop_last = enable;
112 self
113 }
114
115 pub fn get_batch_size(&self) -> usize {
117 self.batch_size
118 }
119
120 pub fn get_num_workers(&self) -> usize {
122 self.num_workers
123 }
124
125 pub fn get_prefetch(&self) -> usize {
127 self.prefetch
128 }
129
130 pub fn num_batches(&self) -> usize {
132 let total_rows = self.dataset.len();
133 if self.drop_last {
134 total_rows / self.batch_size
135 } else {
136 total_rows.div_ceil(self.batch_size)
137 }
138 }
139
140 pub fn len(&self) -> usize {
142 self.dataset.len()
143 }
144
145 pub fn is_empty(&self) -> bool {
147 self.dataset.is_empty()
148 }
149}
150
151impl<D: Dataset + 'static> IntoIterator for ParallelDataLoader<D> {
152 type Item = RecordBatch;
153 type IntoIter = ParallelDataLoaderIterator<D>;
154
155 fn into_iter(self) -> Self::IntoIter {
156 let total_rows = self.dataset.len();
157
158 #[allow(unused_mut)]
160 let mut indices: Vec<usize> = (0..total_rows).collect();
161
162 #[cfg(feature = "shuffle")]
163 if self.shuffle {
164 use rand::{seq::SliceRandom, SeedableRng};
165
166 let mut rng = match self.seed {
167 Some(s) => rand::rngs::StdRng::seed_from_u64(s),
168 None => rand::rngs::StdRng::from_entropy(),
169 };
170 indices.shuffle(&mut rng);
171 }
172
173 if self.num_workers == 0 {
174 ParallelDataLoaderIterator::SingleThreaded {
176 dataset: self.dataset,
177 indices,
178 batch_size: self.batch_size,
179 drop_last: self.drop_last,
180 position: 0,
181 }
182 } else {
183 use std::sync::mpsc;
185
186 let (tx, rx) = mpsc::sync_channel(self.prefetch);
187 let dataset = self.dataset.clone();
188 let batch_size = self.batch_size;
189 let drop_last = self.drop_last;
190 let num_workers = self.num_workers;
191
192 let handle = thread::spawn(move || {
194 let chunks: Vec<Vec<usize>> = indices
196 .chunks(batch_size)
197 .filter(|chunk| !drop_last || chunk.len() == batch_size)
198 .map(|chunk| chunk.to_vec())
199 .collect();
200
201 let pool_size = num_workers.min(chunks.len());
203 if pool_size == 0 {
204 return;
205 }
206
207 for batch in chunks.iter().filter_map(|chunk_indices| {
209 collect_batch_from_indices(&*dataset, chunk_indices)
210 }) {
211 if tx.send(batch).is_err() {
212 break;
213 }
214 }
215 });
216
217 ParallelDataLoaderIterator::MultiThreaded {
218 receiver: rx,
219 _handle: handle,
220 }
221 }
222 }
223}
224
225fn collect_batch_from_indices<D: Dataset>(dataset: &D, indices: &[usize]) -> Option<RecordBatch> {
227 use arrow::compute::concat_batches;
228
229 let rows: Vec<RecordBatch> = indices.iter().filter_map(|&idx| dataset.get(idx)).collect();
230
231 if rows.is_empty() {
232 return None;
233 }
234
235 let schema = dataset.schema();
236 concat_batches(&schema, &rows).ok()
237}
238
239#[allow(missing_docs)]
241pub enum ParallelDataLoaderIterator<D: Dataset> {
242 SingleThreaded {
244 dataset: Arc<D>,
246 indices: Vec<usize>,
248 batch_size: usize,
250 drop_last: bool,
252 position: usize,
254 },
255 MultiThreaded {
257 receiver: std::sync::mpsc::Receiver<RecordBatch>,
259 _handle: thread::JoinHandle<()>,
261 },
262}
263
264impl<D: Dataset> std::fmt::Debug for ParallelDataLoaderIterator<D> {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 match self {
267 Self::SingleThreaded {
268 position,
269 batch_size,
270 ..
271 } => f
272 .debug_struct("ParallelDataLoaderIterator::SingleThreaded")
273 .field("position", position)
274 .field("batch_size", batch_size)
275 .finish(),
276 Self::MultiThreaded { .. } => f
277 .debug_struct("ParallelDataLoaderIterator::MultiThreaded")
278 .finish(),
279 }
280 }
281}
282
283impl<D: Dataset + 'static> Iterator for ParallelDataLoaderIterator<D> {
284 type Item = RecordBatch;
285
286 fn next(&mut self) -> Option<Self::Item> {
287 match self {
288 Self::SingleThreaded {
289 dataset,
290 indices,
291 batch_size,
292 drop_last,
293 position,
294 } => {
295 if *position >= indices.len() {
296 return None;
297 }
298
299 let end = (*position + *batch_size).min(indices.len());
300 let chunk_indices = &indices[*position..end];
301
302 if *drop_last && chunk_indices.len() < *batch_size {
303 return None;
304 }
305
306 *position = end;
307 collect_batch_from_indices(&**dataset, chunk_indices)
308 }
309 Self::MultiThreaded { receiver, .. } => receiver.recv().ok(),
310 }
311 }
312
313 fn size_hint(&self) -> (usize, Option<usize>) {
314 match self {
315 Self::SingleThreaded {
316 indices,
317 batch_size,
318 drop_last,
319 position,
320 ..
321 } => {
322 let remaining = indices.len().saturating_sub(*position);
323 let batches = if *drop_last {
324 remaining / *batch_size
325 } else {
326 remaining.div_ceil(*batch_size)
327 };
328 (batches, Some(batches))
329 }
330 Self::MultiThreaded { .. } => (0, None),
331 }
332 }
333}
334
335#[derive(Debug, Default)]
337pub struct ParallelDataLoaderBuilder {
338 batch_size: Option<usize>,
339 num_workers: Option<usize>,
340 prefetch: Option<usize>,
341 #[cfg(feature = "shuffle")]
342 shuffle: Option<bool>,
343 #[cfg(feature = "shuffle")]
344 seed: Option<u64>,
345 drop_last: Option<bool>,
346}
347
348impl ParallelDataLoaderBuilder {
349 pub fn new() -> Self {
351 Self::default()
352 }
353
354 #[must_use]
356 pub fn batch_size(mut self, size: usize) -> Self {
357 self.batch_size = Some(size);
358 self
359 }
360
361 #[must_use]
363 pub fn num_workers(mut self, workers: usize) -> Self {
364 self.num_workers = Some(workers);
365 self
366 }
367
368 #[must_use]
370 pub fn prefetch(mut self, size: usize) -> Self {
371 self.prefetch = Some(size);
372 self
373 }
374
375 #[cfg(feature = "shuffle")]
377 #[must_use]
378 pub fn shuffle(mut self, enable: bool) -> Self {
379 self.shuffle = Some(enable);
380 self
381 }
382
383 #[cfg(feature = "shuffle")]
385 #[must_use]
386 pub fn seed(mut self, seed: u64) -> Self {
387 self.seed = Some(seed);
388 self
389 }
390
391 #[must_use]
393 pub fn drop_last(mut self, enable: bool) -> Self {
394 self.drop_last = Some(enable);
395 self
396 }
397
398 pub fn build<D: Dataset + 'static>(self, dataset: D) -> Result<ParallelDataLoader<D>> {
400 let mut loader = ParallelDataLoader::new(dataset);
401
402 if let Some(size) = self.batch_size {
403 loader = loader.batch_size(size);
404 }
405 if let Some(workers) = self.num_workers {
406 loader = loader.num_workers(workers);
407 }
408 if let Some(size) = self.prefetch {
409 loader = loader.prefetch(size);
410 }
411 #[cfg(feature = "shuffle")]
412 if let Some(enable) = self.shuffle {
413 loader = loader.shuffle(enable);
414 }
415 #[cfg(feature = "shuffle")]
416 if let Some(seed) = self.seed {
417 loader = loader.seed(seed);
418 }
419 if let Some(enable) = self.drop_last {
420 loader = loader.drop_last(enable);
421 }
422
423 Ok(loader)
424 }
425}
426
427#[cfg(test)]
428#[allow(
429 clippy::cast_possible_truncation,
430 clippy::cast_possible_wrap,
431 clippy::uninlined_format_args,
432 clippy::unwrap_used,
433 clippy::expect_used
434)]
435mod tests {
436 use std::collections::HashSet;
437
438 use arrow::{
439 array::{Int32Array, StringArray},
440 datatypes::{DataType, Field, Schema},
441 };
442
443 use super::*;
444 use crate::ArrowDataset;
445
446 fn create_test_dataset(rows: usize) -> ArrowDataset {
447 let schema = Arc::new(Schema::new(vec![
448 Field::new("id", DataType::Int32, false),
449 Field::new("value", DataType::Utf8, false),
450 ]));
451
452 let ids: Vec<i32> = (0..rows as i32).collect();
453 let values: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
454
455 let batch = RecordBatch::try_new(
456 schema,
457 vec![
458 Arc::new(Int32Array::from(ids)),
459 Arc::new(StringArray::from(values)),
460 ],
461 )
462 .ok()
463 .unwrap_or_else(|| panic!("Should create batch"));
464
465 ArrowDataset::from_batch(batch)
466 .ok()
467 .unwrap_or_else(|| panic!("Should create dataset"))
468 }
469
470 #[test]
471 fn test_parallel_loader_single_threaded() {
472 let dataset = create_test_dataset(100);
473 let loader = ParallelDataLoader::new(dataset)
474 .batch_size(10)
475 .num_workers(0);
476
477 assert_eq!(loader.get_batch_size(), 10);
478 assert_eq!(loader.get_num_workers(), 0);
479 assert_eq!(loader.num_batches(), 10);
480
481 let batches: Vec<RecordBatch> = loader.into_iter().collect();
482 assert_eq!(batches.len(), 10);
483
484 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
485 assert_eq!(total_rows, 100);
486 }
487
488 #[test]
489 fn test_parallel_loader_multi_threaded() {
490 let dataset = create_test_dataset(100);
491 let loader = ParallelDataLoader::new(dataset)
492 .batch_size(10)
493 .num_workers(2)
494 .prefetch(4);
495
496 assert_eq!(loader.get_num_workers(), 2);
497 assert_eq!(loader.get_prefetch(), 4);
498
499 let batches: Vec<RecordBatch> = loader.into_iter().collect();
500 assert_eq!(batches.len(), 10);
501
502 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
503 assert_eq!(total_rows, 100);
504 }
505
506 #[test]
507 fn test_parallel_loader_drop_last() {
508 let dataset = create_test_dataset(25);
509 let loader = ParallelDataLoader::new(dataset)
510 .batch_size(10)
511 .drop_last(true);
512
513 let batches: Vec<RecordBatch> = loader.into_iter().collect();
514 assert_eq!(batches.len(), 2);
515
516 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
517 assert_eq!(total_rows, 20);
518 }
519
520 #[test]
521 #[cfg(feature = "shuffle")]
522 fn test_parallel_loader_shuffle() {
523 let dataset = create_test_dataset(100);
524 let loader1 = ParallelDataLoader::new(dataset.clone())
525 .batch_size(10)
526 .shuffle(true)
527 .seed(42);
528
529 let loader2 = ParallelDataLoader::new(dataset)
530 .batch_size(10)
531 .shuffle(true)
532 .seed(42);
533
534 let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
535 let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
536
537 for (b1, b2) in batches1.iter().zip(batches2.iter()) {
539 let ids1 = b1.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
540 let ids2 = b2.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
541
542 for i in 0..ids1.len() {
543 assert_eq!(ids1.value(i), ids2.value(i));
544 }
545 }
546 }
547
548 #[test]
549 fn test_parallel_loader_all_rows() {
550 let dataset = create_test_dataset(50);
551 let loader = ParallelDataLoader::new(dataset)
552 .batch_size(7)
553 .num_workers(2);
554
555 let mut seen_ids: HashSet<i32> = HashSet::new();
556 for batch in loader {
557 let ids = batch
558 .column(0)
559 .as_any()
560 .downcast_ref::<Int32Array>()
561 .unwrap();
562 for i in 0..ids.len() {
563 seen_ids.insert(ids.value(i));
564 }
565 }
566
567 assert_eq!(seen_ids.len(), 50);
569 for i in 0..50 {
570 assert!(seen_ids.contains(&i), "Missing id: {}", i);
571 }
572 }
573
574 #[test]
575 fn test_parallel_loader_getters() {
576 let dataset = create_test_dataset(100);
577 let loader = ParallelDataLoader::new(dataset)
578 .batch_size(20)
579 .num_workers(4)
580 .prefetch(8);
581
582 assert_eq!(loader.get_batch_size(), 20);
583 assert_eq!(loader.get_num_workers(), 4);
584 assert_eq!(loader.get_prefetch(), 8);
585 assert_eq!(loader.len(), 100);
586 assert!(!loader.is_empty());
587 }
588
589 #[test]
590 fn test_parallel_loader_builder() {
591 let dataset = create_test_dataset(100);
592 let loader = ParallelDataLoaderBuilder::new()
593 .batch_size(25)
594 .num_workers(2)
595 .prefetch(4)
596 .drop_last(true)
597 .build(dataset)
598 .ok()
599 .unwrap_or_else(|| panic!("Should build"));
600
601 assert_eq!(loader.get_batch_size(), 25);
602 assert_eq!(loader.get_num_workers(), 2);
603 assert_eq!(loader.num_batches(), 4);
604 }
605
606 #[test]
607 fn test_parallel_loader_empty_dataset() {
608 let dataset = create_test_dataset(1);
610 let loader = ParallelDataLoader::new(dataset)
611 .batch_size(10)
612 .num_workers(0);
613
614 let batches: Vec<RecordBatch> = loader.into_iter().collect();
615 assert_eq!(batches.len(), 1);
616 }
617
618 #[test]
619 fn test_parallel_loader_batch_size_min() {
620 let dataset = create_test_dataset(10);
621 let loader = ParallelDataLoader::new(dataset).batch_size(0);
622
623 assert_eq!(loader.get_batch_size(), 1);
624 }
625
626 #[test]
627 fn test_parallel_loader_debug() {
628 let dataset = create_test_dataset(10);
629 let loader = ParallelDataLoader::new(dataset)
630 .batch_size(5)
631 .num_workers(2);
632
633 let debug_str = format!("{:?}", loader);
634 assert!(debug_str.contains("ParallelDataLoader"));
635
636 let iter = loader.into_iter();
637 let iter_debug = format!("{:?}", iter);
638 assert!(iter_debug.contains("ParallelDataLoaderIterator"));
639 }
640
641 #[test]
642 fn test_parallel_loader_size_hint() {
643 let dataset = create_test_dataset(25);
644 let loader = ParallelDataLoader::new(dataset)
645 .batch_size(10)
646 .num_workers(0);
647
648 let mut iter = loader.into_iter();
649 assert_eq!(iter.size_hint(), (3, Some(3)));
650
651 let _ = iter.next();
652 assert_eq!(iter.size_hint(), (2, Some(2)));
653 }
654
655 #[test]
656 fn test_builder_debug() {
657 let builder = ParallelDataLoaderBuilder::new()
658 .batch_size(32)
659 .num_workers(4);
660
661 let debug_str = format!("{:?}", builder);
662 assert!(debug_str.contains("ParallelDataLoaderBuilder"));
663 }
664
665 #[test]
666 fn test_parallel_loader_single_row() {
667 let dataset = create_test_dataset(1);
668 let loader = ParallelDataLoader::new(dataset)
669 .batch_size(10)
670 .num_workers(2);
671
672 let batches: Vec<RecordBatch> = loader.into_iter().collect();
673 assert_eq!(batches.len(), 1);
674 assert_eq!(batches[0].num_rows(), 1);
675 }
676
677 #[test]
678 fn test_parallel_loader_batch_equals_dataset() {
679 let dataset = create_test_dataset(50);
680 let loader = ParallelDataLoader::new(dataset)
681 .batch_size(50)
682 .num_workers(0);
683
684 let batches: Vec<RecordBatch> = loader.into_iter().collect();
685 assert_eq!(batches.len(), 1);
686 assert_eq!(batches[0].num_rows(), 50);
687 }
688
689 #[test]
690 fn test_parallel_loader_batch_larger_than_dataset() {
691 let dataset = create_test_dataset(10);
692 let loader = ParallelDataLoader::new(dataset)
693 .batch_size(100)
694 .num_workers(0);
695
696 let batches: Vec<RecordBatch> = loader.into_iter().collect();
697 assert_eq!(batches.len(), 1);
698 assert_eq!(batches[0].num_rows(), 10);
699 }
700
701 #[test]
702 fn test_parallel_loader_drop_last_exact_fit() {
703 let dataset = create_test_dataset(100);
704 let loader = ParallelDataLoader::new(dataset)
705 .batch_size(25)
706 .drop_last(true)
707 .num_workers(0);
708
709 let batches: Vec<RecordBatch> = loader.into_iter().collect();
710 assert_eq!(batches.len(), 4); }
712
713 #[test]
714 fn test_parallel_loader_drop_last_with_remainder() {
715 let dataset = create_test_dataset(100);
716 let loader = ParallelDataLoader::new(dataset)
717 .batch_size(30)
718 .drop_last(true)
719 .num_workers(0);
720
721 let batches: Vec<RecordBatch> = loader.into_iter().collect();
722 assert_eq!(batches.len(), 3); }
724
725 #[test]
726 fn test_parallel_loader_num_batches_calculation() {
727 let dataset = create_test_dataset(100);
728
729 let loader1 = ParallelDataLoader::new(dataset.clone())
731 .batch_size(30)
732 .num_workers(0);
733 assert_eq!(loader1.num_batches(), 4);
734
735 let loader2 = ParallelDataLoader::new(dataset)
737 .batch_size(30)
738 .drop_last(true)
739 .num_workers(0);
740 assert_eq!(loader2.num_batches(), 3);
741 }
742
743 #[test]
744 fn test_parallel_loader_prefetch_setting() {
745 let dataset = create_test_dataset(100);
746 let loader = ParallelDataLoader::new(dataset).batch_size(10).prefetch(16);
747
748 assert_eq!(loader.get_prefetch(), 16);
749 }
750
751 #[test]
752 fn test_parallel_loader_iterator_exhaustion() {
753 let dataset = create_test_dataset(30);
754 let loader = ParallelDataLoader::new(dataset)
755 .batch_size(10)
756 .num_workers(0);
757
758 let mut iter = loader.into_iter();
759
760 assert!(iter.next().is_some());
762 assert!(iter.next().is_some());
763 assert!(iter.next().is_some());
764 assert!(iter.next().is_none());
766 assert!(iter.next().is_none());
768 }
769
770 #[test]
771 fn test_parallel_loader_total_rows_preserved() {
772 let dataset = create_test_dataset(97);
773 let loader = ParallelDataLoader::new(dataset)
774 .batch_size(10)
775 .num_workers(0);
776
777 let total: usize = loader.into_iter().map(|b| b.num_rows()).sum();
778 assert_eq!(total, 97);
779 }
780
781 #[test]
782 fn test_parallel_loader_builder_defaults() {
783 let dataset = create_test_dataset(50);
784 let loader = ParallelDataLoaderBuilder::new()
785 .build(dataset)
786 .ok()
787 .unwrap_or_else(|| panic!("build"));
788
789 assert_eq!(loader.get_batch_size(), 1);
791 assert_eq!(loader.get_prefetch(), 2);
792 }
793
794 #[test]
795 fn test_parallel_loader_builder_with_shuffle() {
796 let dataset = create_test_dataset(50);
797 let loader = ParallelDataLoaderBuilder::new()
798 .batch_size(10)
799 .shuffle(true)
800 .seed(42)
801 .build(dataset)
802 .ok()
803 .unwrap_or_else(|| panic!("build"));
804
805 let batches: Vec<RecordBatch> = loader.into_iter().collect();
806 assert_eq!(batches.len(), 5);
807 }
808
809 #[test]
810 fn test_parallel_loader_zero_workers_single_threaded() {
811 let dataset = create_test_dataset(100);
812 let loader = ParallelDataLoader::new(dataset)
813 .batch_size(20)
814 .num_workers(0);
815
816 assert_eq!(loader.get_num_workers(), 0);
817
818 let batches: Vec<RecordBatch> = loader.into_iter().collect();
819 assert_eq!(batches.len(), 5);
820 }
821}