1use std::sync::Arc;
7
8use arrow::{array::RecordBatch, compute::concat_batches};
9#[cfg(feature = "shuffle")]
10use rand::{distributions::WeightedIndex, prelude::Distribution, SeedableRng};
11
12use crate::{dataset::Dataset, error::Result, Error};
13
14#[derive(Debug)]
41pub struct WeightedDataLoader<D: Dataset> {
42 dataset: Arc<D>,
43 weights: Vec<f32>,
44 batch_size: usize,
45 num_samples: usize,
46 drop_last: bool,
47 #[allow(dead_code)] seed: Option<u64>,
49}
50
51impl<D: Dataset> WeightedDataLoader<D> {
52 pub fn new(dataset: D, weights: Vec<f32>) -> Result<Self> {
64 let len = dataset.len();
65 if weights.len() != len {
66 return Err(Error::invalid_config(format!(
67 "weights length {} doesn't match dataset length {}",
68 weights.len(),
69 len
70 )));
71 }
72
73 if weights.iter().any(|&w| w < 0.0) {
74 return Err(Error::invalid_config("weights must be non-negative"));
75 }
76
77 Ok(Self {
78 dataset: Arc::new(dataset),
79 weights,
80 batch_size: 1,
81 num_samples: len,
82 drop_last: false,
83 seed: None,
84 })
85 }
86
87 pub fn with_reweight(dataset: D, reweight: f32) -> Result<Self> {
97 let len = dataset.len();
98 let weights = vec![reweight; len];
99 Self::new(dataset, weights)
100 }
101
102 #[must_use]
104 pub fn batch_size(mut self, size: usize) -> Self {
105 self.batch_size = size.max(1);
106 self
107 }
108
109 #[must_use]
114 pub fn num_samples(mut self, n: usize) -> Self {
115 self.num_samples = n;
116 self
117 }
118
119 #[must_use]
121 pub fn drop_last(mut self, drop_last: bool) -> Self {
122 self.drop_last = drop_last;
123 self
124 }
125
126 #[cfg(feature = "shuffle")]
128 #[must_use]
129 pub fn seed(mut self, seed: u64) -> Self {
130 self.seed = Some(seed);
131 self
132 }
133
134 pub fn get_batch_size(&self) -> usize {
136 self.batch_size
137 }
138
139 pub fn get_num_samples(&self) -> usize {
141 self.num_samples
142 }
143
144 pub fn weights(&self) -> &[f32] {
146 &self.weights
147 }
148
149 pub fn num_batches(&self) -> usize {
151 if self.drop_last {
152 self.num_samples / self.batch_size
153 } else {
154 self.num_samples.div_ceil(self.batch_size)
155 }
156 }
157
158 pub fn len(&self) -> usize {
160 self.dataset.len()
161 }
162
163 pub fn is_empty(&self) -> bool {
165 self.dataset.is_empty()
166 }
167}
168
169#[cfg(feature = "shuffle")]
170impl<D: Dataset> IntoIterator for WeightedDataLoader<D> {
171 type Item = RecordBatch;
172 type IntoIter = WeightedDataLoaderIterator<D>;
173
174 fn into_iter(self) -> Self::IntoIter {
175 let dist = WeightedIndex::new(&self.weights).ok();
177
178 WeightedDataLoaderIterator {
179 dataset: self.dataset,
180 dist,
181 batch_size: self.batch_size,
182 num_samples: self.num_samples,
183 drop_last: self.drop_last,
184 rng: match self.seed {
185 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
186 None => rand::rngs::StdRng::from_entropy(),
187 },
188 samples_yielded: 0,
189 }
190 }
191}
192
193#[cfg(feature = "shuffle")]
195pub struct WeightedDataLoaderIterator<D: Dataset> {
196 dataset: Arc<D>,
197 dist: Option<WeightedIndex<f32>>,
198 batch_size: usize,
199 num_samples: usize,
200 drop_last: bool,
201 rng: rand::rngs::StdRng,
202 samples_yielded: usize,
203}
204
205#[cfg(feature = "shuffle")]
206impl<D: Dataset> Iterator for WeightedDataLoaderIterator<D> {
207 type Item = RecordBatch;
208
209 fn next(&mut self) -> Option<Self::Item> {
210 if self.samples_yielded >= self.num_samples {
211 return None;
212 }
213
214 let remaining = self.num_samples - self.samples_yielded;
215 let batch_size = remaining.min(self.batch_size);
216
217 if self.drop_last && batch_size < self.batch_size {
219 return None;
220 }
221
222 let indices: Vec<usize> = if let Some(dist) = &self.dist {
224 (0..batch_size)
225 .map(|_| dist.sample(&mut self.rng))
226 .collect()
227 } else {
228 let len = self.dataset.len();
230 if len == 0 {
231 return None;
232 }
233 (0..batch_size)
234 .map(|i| (self.samples_yielded + i) % len)
235 .collect()
236 };
237
238 self.samples_yielded += batch_size;
239
240 let rows: Vec<RecordBatch> = indices
242 .iter()
243 .filter_map(|&idx| self.dataset.get(idx))
244 .collect();
245
246 if rows.is_empty() {
247 return None;
248 }
249
250 concat_batches(&self.dataset.schema(), &rows).ok()
251 }
252
253 fn size_hint(&self) -> (usize, Option<usize>) {
254 let remaining = self.num_samples.saturating_sub(self.samples_yielded);
255 let batches = if self.drop_last {
256 remaining / self.batch_size
257 } else if remaining > 0 {
258 remaining.div_ceil(self.batch_size)
259 } else {
260 0
261 };
262 (batches, Some(batches))
263 }
264}
265
266#[cfg(feature = "shuffle")]
267impl<D: Dataset> std::fmt::Debug for WeightedDataLoaderIterator<D> {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 f.debug_struct("WeightedDataLoaderIterator")
270 .field("batch_size", &self.batch_size)
271 .field("num_samples", &self.num_samples)
272 .field("samples_yielded", &self.samples_yielded)
273 .finish_non_exhaustive()
274 }
275}
276
277#[cfg(test)]
278#[cfg(feature = "shuffle")]
279#[allow(
280 clippy::cast_possible_truncation,
281 clippy::cast_possible_wrap,
282 clippy::float_cmp
283)]
284mod tests {
285 use std::collections::HashMap;
286
287 use arrow::{
288 array::{Int32Array, StringArray},
289 datatypes::{DataType, Field, Schema},
290 };
291
292 use super::*;
293 use crate::ArrowDataset;
294
295 fn create_test_dataset(rows: usize) -> ArrowDataset {
296 let schema = Arc::new(Schema::new(vec![
297 Field::new("id", DataType::Int32, false),
298 Field::new("value", DataType::Utf8, false),
299 ]));
300
301 let ids: Vec<i32> = (0..rows as i32).collect();
302 let values: Vec<String> = ids.iter().map(|i| format!("val_{}", i)).collect();
303
304 let batch = RecordBatch::try_new(
305 schema,
306 vec![
307 Arc::new(Int32Array::from(ids)),
308 Arc::new(StringArray::from(values)),
309 ],
310 )
311 .ok()
312 .unwrap_or_else(|| panic!("Should create batch"));
313
314 ArrowDataset::from_batch(batch)
315 .ok()
316 .unwrap_or_else(|| panic!("Should create dataset"))
317 }
318
319 #[test]
320 fn test_weighted_loader_creation() {
321 let dataset = create_test_dataset(10);
322 let weights = vec![1.0; 10];
323
324 let loader = WeightedDataLoader::new(dataset, weights);
325 assert!(loader.is_ok());
326
327 let loader = loader
328 .ok()
329 .unwrap_or_else(|| panic!("Should create loader"));
330 assert_eq!(loader.len(), 10);
331 assert_eq!(loader.get_num_samples(), 10);
332 }
333
334 #[test]
335 fn test_weighted_loader_wrong_length() {
336 let dataset = create_test_dataset(10);
337 let weights = vec![1.0; 5]; let result = WeightedDataLoader::new(dataset, weights);
340 assert!(result.is_err());
341 }
342
343 #[test]
344 fn test_weighted_loader_negative_weight() {
345 let dataset = create_test_dataset(10);
346 let mut weights = vec![1.0; 10];
347 weights[5] = -1.0; let result = WeightedDataLoader::new(dataset, weights);
350 assert!(result.is_err());
351 }
352
353 #[test]
354 fn test_weighted_loader_with_reweight() {
355 let dataset = create_test_dataset(10);
356
357 let loader = WeightedDataLoader::with_reweight(dataset, 1.5)
358 .ok()
359 .unwrap_or_else(|| panic!("Should create loader"));
360
361 assert!(loader.weights().iter().all(|&w| w == 1.5));
362 }
363
364 #[test]
365 fn test_weighted_loader_basic_iteration() {
366 let dataset = create_test_dataset(10);
367 let weights = vec![1.0; 10];
368
369 let loader = WeightedDataLoader::new(dataset, weights)
370 .ok()
371 .unwrap_or_else(|| panic!("Should create loader"))
372 .batch_size(3)
373 .seed(42);
374
375 let batches: Vec<RecordBatch> = loader.into_iter().collect();
376 assert_eq!(batches.len(), 4); let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
379 assert_eq!(total_rows, 10);
380 }
381
382 #[test]
383 fn test_weighted_loader_drop_last() {
384 let dataset = create_test_dataset(10);
385 let weights = vec![1.0; 10];
386
387 let loader = WeightedDataLoader::new(dataset, weights)
388 .ok()
389 .unwrap_or_else(|| panic!("Should create loader"))
390 .batch_size(3)
391 .drop_last(true)
392 .seed(42);
393
394 let batches: Vec<RecordBatch> = loader.into_iter().collect();
395 assert_eq!(batches.len(), 3); for batch in &batches {
398 assert_eq!(batch.num_rows(), 3);
399 }
400 }
401
402 #[test]
403 fn test_weighted_loader_deterministic() {
404 let dataset = create_test_dataset(100);
405 let weights = vec![1.0; 100];
406
407 let loader1 = WeightedDataLoader::new(dataset.clone(), weights.clone())
408 .ok()
409 .unwrap_or_else(|| panic!("Should create loader"))
410 .batch_size(10)
411 .seed(42);
412 let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
413
414 let loader2 = WeightedDataLoader::new(dataset, weights)
415 .ok()
416 .unwrap_or_else(|| panic!("Should create loader"))
417 .batch_size(10)
418 .seed(42);
419 let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
420
421 assert_eq!(batches1.len(), batches2.len());
422 for (b1, b2) in batches1.iter().zip(batches2.iter()) {
423 let ids1 = b1
424 .column(0)
425 .as_any()
426 .downcast_ref::<Int32Array>()
427 .unwrap_or_else(|| panic!("Should be Int32Array"));
428 let ids2 = b2
429 .column(0)
430 .as_any()
431 .downcast_ref::<Int32Array>()
432 .unwrap_or_else(|| panic!("Should be Int32Array"));
433
434 for i in 0..ids1.len() {
435 assert_eq!(ids1.value(i), ids2.value(i));
436 }
437 }
438 }
439
440 #[test]
441 fn test_weighted_loader_biased_sampling() {
442 let dataset = create_test_dataset(10);
444 let mut weights = vec![0.1; 10];
445 weights[0] = 10.0; let loader = WeightedDataLoader::new(dataset, weights)
448 .ok()
449 .unwrap_or_else(|| panic!("Should create loader"))
450 .batch_size(1)
451 .num_samples(1000) .seed(42);
453
454 let mut counts: HashMap<i32, usize> = HashMap::new();
455 for batch in loader {
456 let ids = batch
457 .column(0)
458 .as_any()
459 .downcast_ref::<Int32Array>()
460 .unwrap_or_else(|| panic!("Should be Int32Array"));
461 for i in 0..ids.len() {
462 *counts.entry(ids.value(i)).or_insert(0) += 1;
463 }
464 }
465
466 let count_0 = *counts.get(&0).unwrap_or(&0);
468 let count_1 = *counts.get(&1).unwrap_or(&0);
469
470 assert!(
472 count_0 > count_1 * 10,
473 "Item 0 ({}) should appear much more than item 1 ({})",
474 count_0,
475 count_1
476 );
477 }
478
479 #[test]
480 fn test_weighted_loader_num_samples() {
481 let dataset = create_test_dataset(10);
482 let weights = vec![1.0; 10];
483
484 let loader = WeightedDataLoader::new(dataset, weights)
485 .ok()
486 .unwrap_or_else(|| panic!("Should create loader"))
487 .batch_size(5)
488 .num_samples(25) .seed(42);
490
491 let batches: Vec<RecordBatch> = loader.into_iter().collect();
492 assert_eq!(batches.len(), 5); let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
495 assert_eq!(total_rows, 25);
496 }
497
498 #[test]
499 fn test_weighted_loader_num_batches() {
500 let dataset = create_test_dataset(10);
501 let weights = vec![1.0; 10];
502
503 let loader = WeightedDataLoader::new(dataset.clone(), weights.clone())
504 .ok()
505 .unwrap_or_else(|| panic!("Should create loader"))
506 .batch_size(3);
507 assert_eq!(loader.num_batches(), 4);
508
509 let loader = WeightedDataLoader::new(dataset, weights)
510 .ok()
511 .unwrap_or_else(|| panic!("Should create loader"))
512 .batch_size(3)
513 .drop_last(true);
514 assert_eq!(loader.num_batches(), 3);
515 }
516
517 #[test]
518 fn test_weighted_loader_size_hint() {
519 let dataset = create_test_dataset(10);
520 let weights = vec![1.0; 10];
521
522 let loader = WeightedDataLoader::new(dataset, weights)
523 .ok()
524 .unwrap_or_else(|| panic!("Should create loader"))
525 .batch_size(3)
526 .seed(42);
527
528 let mut iter = loader.into_iter();
529 assert_eq!(iter.size_hint(), (4, Some(4)));
530
531 let _ = iter.next();
532 assert_eq!(iter.size_hint(), (3, Some(3)));
533 }
534
535 #[test]
536 fn test_weighted_loader_getters() {
537 let dataset = create_test_dataset(10);
538 let weights = vec![1.5; 10];
539
540 let loader = WeightedDataLoader::new(dataset, weights)
541 .ok()
542 .unwrap_or_else(|| panic!("Should create loader"))
543 .batch_size(5)
544 .num_samples(20);
545
546 assert_eq!(loader.get_batch_size(), 5);
547 assert_eq!(loader.get_num_samples(), 20);
548 assert_eq!(loader.len(), 10);
549 assert!(!loader.is_empty());
550 assert!(loader.weights().iter().all(|&w| w == 1.5));
551 }
552
553 #[test]
554 fn test_weighted_loader_batch_size_min_one() {
555 let dataset = create_test_dataset(10);
556 let weights = vec![1.0; 10];
557
558 let loader = WeightedDataLoader::new(dataset, weights)
559 .ok()
560 .unwrap_or_else(|| panic!("Should create loader"))
561 .batch_size(0);
562
563 assert_eq!(loader.get_batch_size(), 1);
564 }
565
566 #[test]
567 fn test_weighted_loader_debug() {
568 let dataset = create_test_dataset(10);
569 let weights = vec![1.0; 10];
570
571 let loader = WeightedDataLoader::new(dataset, weights)
572 .ok()
573 .unwrap_or_else(|| panic!("Should create loader"))
574 .batch_size(5)
575 .seed(42);
576
577 let debug_str = format!("{:?}", loader);
578 assert!(debug_str.contains("WeightedDataLoader"));
579
580 let iter = loader.into_iter();
581 let iter_debug = format!("{:?}", iter);
582 assert!(iter_debug.contains("WeightedDataLoaderIterator"));
583 }
584
585 #[test]
586 fn test_weighted_loader_all_zero_weights() {
587 let dataset = create_test_dataset(10);
589 let weights = vec![0.0; 10];
590
591 let loader = WeightedDataLoader::new(dataset, weights)
592 .ok()
593 .unwrap_or_else(|| panic!("Should create loader"))
594 .batch_size(5)
595 .num_samples(20)
596 .seed(42);
597
598 let batches: Vec<RecordBatch> = loader.into_iter().collect();
600 assert_eq!(batches.len(), 4); }
602
603 #[test]
604 fn test_weighted_loader_single_nonzero_weight() {
605 let dataset = create_test_dataset(10);
607 let mut weights = vec![0.0; 10];
608 weights[5] = 1.0; let loader = WeightedDataLoader::new(dataset, weights)
611 .ok()
612 .unwrap_or_else(|| panic!("Should create loader"))
613 .batch_size(1)
614 .num_samples(10)
615 .seed(42);
616
617 let mut all_are_item_5 = true;
618 for batch in loader {
619 let ids = batch
620 .column(0)
621 .as_any()
622 .downcast_ref::<Int32Array>()
623 .unwrap_or_else(|| panic!("Should be Int32Array"));
624 for i in 0..ids.len() {
625 if ids.value(i) != 5 {
626 all_are_item_5 = false;
627 }
628 }
629 }
630 assert!(all_are_item_5, "All samples should be item 5");
631 }
632
633 #[test]
634 fn test_weighted_loader_large_dataset() {
635 let schema = Arc::new(Schema::new(vec![
637 Field::new("id", DataType::Int32, false),
638 Field::new("value", DataType::Utf8, false),
639 ]));
640
641 let ids: Vec<i32> = (0..10000).collect();
642 let values: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
643
644 let batch = RecordBatch::try_new(
645 schema,
646 vec![
647 Arc::new(Int32Array::from(ids)),
648 Arc::new(StringArray::from(values)),
649 ],
650 )
651 .ok()
652 .unwrap_or_else(|| panic!("Should create batch"));
653
654 let dataset = ArrowDataset::from_batch(batch)
655 .ok()
656 .unwrap_or_else(|| panic!("Should create dataset"));
657
658 let weights: Vec<f32> = (0..10000).map(|i| (i % 10 + 1) as f32).collect();
659
660 let loader = WeightedDataLoader::new(dataset, weights)
661 .ok()
662 .unwrap_or_else(|| panic!("Should create loader"))
663 .batch_size(100)
664 .num_samples(5000)
665 .seed(42);
666
667 let batches: Vec<RecordBatch> = loader.into_iter().collect();
668 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
669 assert_eq!(total_rows, 5000);
670 }
671
672 #[test]
673 fn test_weighted_loader_very_small_weights() {
674 let dataset = create_test_dataset(10);
676 let weights: Vec<f32> = (0..10).map(|i| (i + 1) as f32 * 1e-10).collect();
677
678 let loader = WeightedDataLoader::new(dataset, weights)
679 .ok()
680 .unwrap_or_else(|| panic!("Should create loader"))
681 .batch_size(5)
682 .num_samples(20)
683 .seed(42);
684
685 let batches: Vec<RecordBatch> = loader.into_iter().collect();
686 assert_eq!(batches.len(), 4);
687 }
688
689 #[test]
690 fn test_weighted_loader_mixed_zero_nonzero() {
691 let dataset = create_test_dataset(10);
693 let weights: Vec<f32> = (0..10).map(|i| if i < 5 { 0.0 } else { 1.0 }).collect();
694
695 let loader = WeightedDataLoader::new(dataset, weights)
696 .ok()
697 .unwrap_or_else(|| panic!("Should create loader"))
698 .batch_size(1)
699 .num_samples(100)
700 .seed(42);
701
702 let mut counts: HashMap<i32, usize> = HashMap::new();
703 for batch in loader {
704 let ids = batch
705 .column(0)
706 .as_any()
707 .downcast_ref::<Int32Array>()
708 .unwrap_or_else(|| panic!("Should be Int32Array"));
709 for i in 0..ids.len() {
710 *counts.entry(ids.value(i)).or_insert(0) += 1;
711 }
712 }
713
714 for i in 0..5 {
716 assert_eq!(
717 *counts.get(&i).unwrap_or(&0),
718 0,
719 "Item {} should not be sampled",
720 i
721 );
722 }
723 for i in 5..10 {
724 assert!(
725 *counts.get(&i).unwrap_or(&0) > 0,
726 "Item {} should be sampled",
727 i
728 );
729 }
730 }
731
732 #[test]
733 fn test_weighted_loader_undersample() {
734 let dataset = create_test_dataset(100);
736 let weights = vec![1.0; 100];
737
738 let loader = WeightedDataLoader::new(dataset, weights)
739 .ok()
740 .unwrap_or_else(|| panic!("Should create loader"))
741 .batch_size(5)
742 .num_samples(20)
743 .seed(42);
744
745 let batches: Vec<RecordBatch> = loader.into_iter().collect();
746 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
747 assert_eq!(total_rows, 20);
748 }
749
750 #[test]
751 fn test_weighted_loader_exact_batch_multiple() {
752 let dataset = create_test_dataset(100);
754 let weights = vec![1.0; 100];
755
756 let loader = WeightedDataLoader::new(dataset, weights)
757 .ok()
758 .unwrap_or_else(|| panic!("Should create loader"))
759 .batch_size(10)
760 .num_samples(50);
761
762 let batches: Vec<RecordBatch> = loader.into_iter().collect();
763 assert_eq!(batches.len(), 5);
764 for batch in &batches {
765 assert_eq!(batch.num_rows(), 10);
766 }
767 }
768
769 #[test]
770 fn test_weighted_loader_negative_weight_error() {
771 let dataset = create_test_dataset(10);
772 let weights = vec![1.0, 2.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
773
774 let result = WeightedDataLoader::new(dataset, weights);
775 assert!(result.is_err());
776 }
777
778 #[test]
779 fn test_weighted_loader_single_item() {
780 let dataset = create_test_dataset(1);
781 let weights = vec![1.0];
782
783 let loader = WeightedDataLoader::new(dataset, weights)
784 .ok()
785 .unwrap_or_else(|| panic!("Should create loader"))
786 .batch_size(1)
787 .num_samples(10);
788
789 let batches: Vec<RecordBatch> = loader.into_iter().collect();
790 assert_eq!(batches.len(), 10);
791
792 for batch in batches {
794 assert_eq!(batch.num_rows(), 1);
795 }
796 }
797
798 #[test]
799 fn test_weighted_loader_oversample() {
800 let dataset = create_test_dataset(5);
802 let weights = vec![1.0; 5];
803
804 let loader = WeightedDataLoader::new(dataset, weights)
805 .ok()
806 .unwrap_or_else(|| panic!("Should create loader"))
807 .batch_size(10)
808 .num_samples(100);
809
810 let batches: Vec<RecordBatch> = loader.into_iter().collect();
811 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
812 assert_eq!(total_rows, 100);
813 }
814
815 #[test]
816 fn test_weighted_loader_is_empty() {
817 let dataset = create_test_dataset(10);
819 let weights = vec![1.0; 10];
820
821 let loader = WeightedDataLoader::new(dataset, weights)
822 .ok()
823 .unwrap_or_else(|| panic!("Should create loader"));
824
825 assert!(!loader.is_empty());
827 assert_eq!(loader.len(), 10);
828 }
829
830 #[test]
831 fn test_weighted_loader_len() {
832 let dataset = create_test_dataset(100);
834 let weights = vec![1.0; 100];
835
836 let loader = WeightedDataLoader::new(dataset, weights)
837 .ok()
838 .unwrap_or_else(|| panic!("Should create loader"))
839 .num_samples(42);
840
841 assert_eq!(loader.len(), 100);
843 assert_eq!(loader.get_num_samples(), 42);
845 }
846
847 #[test]
848 fn test_weighted_loader_weight_length_mismatch() {
849 let dataset = create_test_dataset(10);
850 let weights = vec![1.0; 5]; let result = WeightedDataLoader::new(dataset, weights);
853 assert!(result.is_err());
854 }
855
856 #[test]
857 fn test_weighted_loader_very_large_weight() {
858 let dataset = create_test_dataset(3);
859 let weights = vec![1e10, 1.0, 1.0]; let loader = WeightedDataLoader::new(dataset, weights)
862 .ok()
863 .unwrap_or_else(|| panic!("Should create loader"))
864 .batch_size(1)
865 .num_samples(100)
866 .seed(42);
867
868 let mut counts: HashMap<i32, usize> = HashMap::new();
869 for batch in loader {
870 let ids = batch
871 .column(0)
872 .as_any()
873 .downcast_ref::<Int32Array>()
874 .unwrap_or_else(|| panic!("Should be Int32Array"));
875 for i in 0..ids.len() {
876 *counts.entry(ids.value(i)).or_insert(0) += 1;
877 }
878 }
879
880 let first_count = *counts.get(&0).unwrap_or(&0);
882 assert!(
883 first_count > 95,
884 "First item should dominate: {}",
885 first_count
886 );
887 }
888
889 #[test]
890 fn test_weighted_loader_extreme_weight_ratio() {
891 let dataset = create_test_dataset(2);
892 let weights = vec![1000.0, 1.0];
894
895 let loader = WeightedDataLoader::new(dataset, weights)
896 .ok()
897 .unwrap_or_else(|| panic!("Should create loader"))
898 .batch_size(1)
899 .num_samples(1000)
900 .seed(42);
901
902 let mut counts: HashMap<i32, usize> = HashMap::new();
903 for batch in loader {
904 let ids = batch
905 .column(0)
906 .as_any()
907 .downcast_ref::<Int32Array>()
908 .unwrap_or_else(|| panic!("Should be Int32Array"));
909 for i in 0..ids.len() {
910 *counts.entry(ids.value(i)).or_insert(0) += 1;
911 }
912 }
913
914 let first = *counts.get(&0).unwrap_or(&0);
915 let second = *counts.get(&1).unwrap_or(&0);
916
917 assert!(
919 first > 900,
920 "First should dominate: {} vs {}",
921 first,
922 second
923 );
924 }
925
926 #[test]
927 fn test_weighted_loader_reweight_zero() {
928 let dataset = create_test_dataset(5);
929 let loader = WeightedDataLoader::with_reweight(dataset, 0.0);
931 assert!(loader.is_ok());
932 let loader = loader.ok().unwrap();
933 assert!(loader.weights().iter().all(|&w| w == 0.0));
935 }
936
937 #[test]
938 fn test_weighted_loader_size_hint_drop_last_edge() {
939 let dataset = create_test_dataset(10);
940 let weights = vec![1.0; 10];
941
942 let loader = WeightedDataLoader::new(dataset, weights)
944 .ok()
945 .unwrap()
946 .batch_size(3)
947 .num_samples(10)
948 .drop_last(true);
949
950 assert_eq!(loader.num_batches(), 3);
951 }
952
953 #[test]
954 fn test_weighted_loader_size_hint_no_drop_last() {
955 let dataset = create_test_dataset(10);
956 let weights = vec![1.0; 10];
957
958 let loader = WeightedDataLoader::new(dataset, weights)
960 .ok()
961 .unwrap()
962 .batch_size(3)
963 .num_samples(10)
964 .drop_last(false);
965
966 assert_eq!(loader.num_batches(), 4);
967 }
968
969 #[test]
970 fn test_weighted_loader_iteration_with_drop_last() {
971 let dataset = create_test_dataset(10);
972 let weights = vec![1.0; 10];
973
974 let loader = WeightedDataLoader::new(dataset, weights)
975 .ok()
976 .unwrap()
977 .batch_size(4)
978 .num_samples(10)
979 .drop_last(true)
980 .seed(42);
981
982 let batches: Vec<RecordBatch> = loader.into_iter().collect();
983 assert_eq!(batches.len(), 2);
985 for batch in batches {
986 assert_eq!(batch.num_rows(), 4);
987 }
988 }
989}