Skip to main content

alimentar/
weighted.rs

1//! Weighted DataLoader for importance sampling.
2//!
3//! Provides [`WeightedDataLoader`] for sampling with per-sample weights,
4//! enabling importance sampling for imbalanced datasets or CITL reweighting.
5
6use std::sync::Arc;
7
8use arrow::{array::RecordBatch, compute::concat_batches};
9#[cfg(feature = "shuffle")]
10use rand::{distributions::WeightedIndex, prelude::Distribution, SeedableRng};
11
12use crate::{dataset::Dataset, error::Result, Error};
13
14/// A data loader that samples with per-sample weights.
15///
16/// Unlike [`DataLoader`](crate::DataLoader) which samples uniformly,
17/// `WeightedDataLoader` samples proportional to the provided weights.
18/// This is useful for:
19/// - Importance sampling in imbalanced datasets
20/// - CITL reweighting (`--reweight 1.5` for compiler-verified labels)
21/// - Curriculum learning with difficulty-based sampling
22///
23/// # Example
24///
25/// ```no_run
26/// use alimentar::{ArrowDataset, Dataset, WeightedDataLoader};
27///
28/// let dataset = ArrowDataset::from_parquet("data.parquet").unwrap();
29/// let weights = vec![1.0; dataset.len()]; // Uniform weights
30///
31/// let loader = WeightedDataLoader::new(dataset, weights)
32///     .unwrap()
33///     .batch_size(32)
34///     .seed(42);
35///
36/// for batch in loader {
37///     println!("Batch with {} rows", batch.num_rows());
38/// }
39/// ```
40#[derive(Debug)]
41pub struct WeightedDataLoader<D: Dataset> {
42    dataset: Arc<D>,
43    weights: Vec<f32>,
44    batch_size: usize,
45    num_samples: usize,
46    drop_last: bool,
47    #[allow(dead_code)] // Used only with shuffle feature
48    seed: Option<u64>,
49}
50
51impl<D: Dataset> WeightedDataLoader<D> {
52    /// Creates a new weighted data loader.
53    ///
54    /// # Arguments
55    ///
56    /// * `dataset` - The dataset to sample from
57    /// * `weights` - Per-sample weights (must match dataset length)
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if weights length doesn't match dataset length,
62    /// or if any weight is negative.
63    pub fn new(dataset: D, weights: Vec<f32>) -> Result<Self> {
64        let len = dataset.len();
65        if weights.len() != len {
66            return Err(Error::invalid_config(format!(
67                "weights length {} doesn't match dataset length {}",
68                weights.len(),
69                len
70            )));
71        }
72
73        if weights.iter().any(|&w| w < 0.0) {
74            return Err(Error::invalid_config("weights must be non-negative"));
75        }
76
77        Ok(Self {
78            dataset: Arc::new(dataset),
79            weights,
80            batch_size: 1,
81            num_samples: len,
82            drop_last: false,
83            seed: None,
84        })
85    }
86
87    /// Creates a weighted loader with a uniform reweight factor.
88    ///
89    /// Multiplies all weights by the given factor. Useful for CITL's
90    /// `--reweight 1.5` which boosts compiler-verified samples.
91    ///
92    /// # Arguments
93    ///
94    /// * `dataset` - The dataset to sample from
95    /// * `reweight` - Factor to multiply all weights by
96    pub fn with_reweight(dataset: D, reweight: f32) -> Result<Self> {
97        let len = dataset.len();
98        let weights = vec![reweight; len];
99        Self::new(dataset, weights)
100    }
101
102    /// Sets the batch size.
103    #[must_use]
104    pub fn batch_size(mut self, size: usize) -> Self {
105        self.batch_size = size.max(1);
106        self
107    }
108
109    /// Sets the total number of samples per epoch.
110    ///
111    /// By default, samples `len()` items per epoch. Set this to oversample
112    /// or undersample the dataset.
113    #[must_use]
114    pub fn num_samples(mut self, n: usize) -> Self {
115        self.num_samples = n;
116        self
117    }
118
119    /// Sets whether to drop the last incomplete batch.
120    #[must_use]
121    pub fn drop_last(mut self, drop_last: bool) -> Self {
122        self.drop_last = drop_last;
123        self
124    }
125
126    /// Sets the random seed for reproducibility.
127    #[cfg(feature = "shuffle")]
128    #[must_use]
129    pub fn seed(mut self, seed: u64) -> Self {
130        self.seed = Some(seed);
131        self
132    }
133
134    /// Returns the configured batch size.
135    pub fn get_batch_size(&self) -> usize {
136        self.batch_size
137    }
138
139    /// Returns the number of samples per epoch.
140    pub fn get_num_samples(&self) -> usize {
141        self.num_samples
142    }
143
144    /// Returns the weights.
145    pub fn weights(&self) -> &[f32] {
146        &self.weights
147    }
148
149    /// Returns the number of batches that will be yielded.
150    pub fn num_batches(&self) -> usize {
151        if self.drop_last {
152            self.num_samples / self.batch_size
153        } else {
154            self.num_samples.div_ceil(self.batch_size)
155        }
156    }
157
158    /// Returns the dataset length.
159    pub fn len(&self) -> usize {
160        self.dataset.len()
161    }
162
163    /// Returns true if the dataset is empty.
164    pub fn is_empty(&self) -> bool {
165        self.dataset.is_empty()
166    }
167}
168
169#[cfg(feature = "shuffle")]
170impl<D: Dataset> IntoIterator for WeightedDataLoader<D> {
171    type Item = RecordBatch;
172    type IntoIter = WeightedDataLoaderIterator<D>;
173
174    fn into_iter(self) -> Self::IntoIter {
175        // Create weighted index for sampling
176        let dist = WeightedIndex::new(&self.weights).ok();
177
178        WeightedDataLoaderIterator {
179            dataset: self.dataset,
180            dist,
181            batch_size: self.batch_size,
182            num_samples: self.num_samples,
183            drop_last: self.drop_last,
184            rng: match self.seed {
185                Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
186                None => rand::rngs::StdRng::from_entropy(),
187            },
188            samples_yielded: 0,
189        }
190    }
191}
192
193/// Iterator over weighted sampled batches.
194#[cfg(feature = "shuffle")]
195pub struct WeightedDataLoaderIterator<D: Dataset> {
196    dataset: Arc<D>,
197    dist: Option<WeightedIndex<f32>>,
198    batch_size: usize,
199    num_samples: usize,
200    drop_last: bool,
201    rng: rand::rngs::StdRng,
202    samples_yielded: usize,
203}
204
205#[cfg(feature = "shuffle")]
206impl<D: Dataset> Iterator for WeightedDataLoaderIterator<D> {
207    type Item = RecordBatch;
208
209    fn next(&mut self) -> Option<Self::Item> {
210        if self.samples_yielded >= self.num_samples {
211            return None;
212        }
213
214        let remaining = self.num_samples - self.samples_yielded;
215        let batch_size = remaining.min(self.batch_size);
216
217        // Skip incomplete batch if drop_last is set
218        if self.drop_last && batch_size < self.batch_size {
219            return None;
220        }
221
222        // Sample indices according to weights
223        let indices: Vec<usize> = if let Some(dist) = &self.dist {
224            (0..batch_size)
225                .map(|_| dist.sample(&mut self.rng))
226                .collect()
227        } else {
228            // Fallback: uniform sampling if weights are all zero
229            let len = self.dataset.len();
230            if len == 0 {
231                return None;
232            }
233            (0..batch_size)
234                .map(|i| (self.samples_yielded + i) % len)
235                .collect()
236        };
237
238        self.samples_yielded += batch_size;
239
240        // Get rows and concatenate
241        let rows: Vec<RecordBatch> = indices
242            .iter()
243            .filter_map(|&idx| self.dataset.get(idx))
244            .collect();
245
246        if rows.is_empty() {
247            return None;
248        }
249
250        concat_batches(&self.dataset.schema(), &rows).ok()
251    }
252
253    fn size_hint(&self) -> (usize, Option<usize>) {
254        let remaining = self.num_samples.saturating_sub(self.samples_yielded);
255        let batches = if self.drop_last {
256            remaining / self.batch_size
257        } else if remaining > 0 {
258            remaining.div_ceil(self.batch_size)
259        } else {
260            0
261        };
262        (batches, Some(batches))
263    }
264}
265
266#[cfg(feature = "shuffle")]
267impl<D: Dataset> std::fmt::Debug for WeightedDataLoaderIterator<D> {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        f.debug_struct("WeightedDataLoaderIterator")
270            .field("batch_size", &self.batch_size)
271            .field("num_samples", &self.num_samples)
272            .field("samples_yielded", &self.samples_yielded)
273            .finish_non_exhaustive()
274    }
275}
276
277#[cfg(test)]
278#[cfg(feature = "shuffle")]
279#[allow(
280    clippy::cast_possible_truncation,
281    clippy::cast_possible_wrap,
282    clippy::float_cmp
283)]
284mod tests {
285    use std::collections::HashMap;
286
287    use arrow::{
288        array::{Int32Array, StringArray},
289        datatypes::{DataType, Field, Schema},
290    };
291
292    use super::*;
293    use crate::ArrowDataset;
294
295    fn create_test_dataset(rows: usize) -> ArrowDataset {
296        let schema = Arc::new(Schema::new(vec![
297            Field::new("id", DataType::Int32, false),
298            Field::new("value", DataType::Utf8, false),
299        ]));
300
301        let ids: Vec<i32> = (0..rows as i32).collect();
302        let values: Vec<String> = ids.iter().map(|i| format!("val_{}", i)).collect();
303
304        let batch = RecordBatch::try_new(
305            schema,
306            vec![
307                Arc::new(Int32Array::from(ids)),
308                Arc::new(StringArray::from(values)),
309            ],
310        )
311        .ok()
312        .unwrap_or_else(|| panic!("Should create batch"));
313
314        ArrowDataset::from_batch(batch)
315            .ok()
316            .unwrap_or_else(|| panic!("Should create dataset"))
317    }
318
319    #[test]
320    fn test_weighted_loader_creation() {
321        let dataset = create_test_dataset(10);
322        let weights = vec![1.0; 10];
323
324        let loader = WeightedDataLoader::new(dataset, weights);
325        assert!(loader.is_ok());
326
327        let loader = loader
328            .ok()
329            .unwrap_or_else(|| panic!("Should create loader"));
330        assert_eq!(loader.len(), 10);
331        assert_eq!(loader.get_num_samples(), 10);
332    }
333
334    #[test]
335    fn test_weighted_loader_wrong_length() {
336        let dataset = create_test_dataset(10);
337        let weights = vec![1.0; 5]; // Wrong length
338
339        let result = WeightedDataLoader::new(dataset, weights);
340        assert!(result.is_err());
341    }
342
343    #[test]
344    fn test_weighted_loader_negative_weight() {
345        let dataset = create_test_dataset(10);
346        let mut weights = vec![1.0; 10];
347        weights[5] = -1.0; // Negative weight
348
349        let result = WeightedDataLoader::new(dataset, weights);
350        assert!(result.is_err());
351    }
352
353    #[test]
354    fn test_weighted_loader_with_reweight() {
355        let dataset = create_test_dataset(10);
356
357        let loader = WeightedDataLoader::with_reweight(dataset, 1.5)
358            .ok()
359            .unwrap_or_else(|| panic!("Should create loader"));
360
361        assert!(loader.weights().iter().all(|&w| w == 1.5));
362    }
363
364    #[test]
365    fn test_weighted_loader_basic_iteration() {
366        let dataset = create_test_dataset(10);
367        let weights = vec![1.0; 10];
368
369        let loader = WeightedDataLoader::new(dataset, weights)
370            .ok()
371            .unwrap_or_else(|| panic!("Should create loader"))
372            .batch_size(3)
373            .seed(42);
374
375        let batches: Vec<RecordBatch> = loader.into_iter().collect();
376        assert_eq!(batches.len(), 4); // ceil(10/3) = 4
377
378        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
379        assert_eq!(total_rows, 10);
380    }
381
382    #[test]
383    fn test_weighted_loader_drop_last() {
384        let dataset = create_test_dataset(10);
385        let weights = vec![1.0; 10];
386
387        let loader = WeightedDataLoader::new(dataset, weights)
388            .ok()
389            .unwrap_or_else(|| panic!("Should create loader"))
390            .batch_size(3)
391            .drop_last(true)
392            .seed(42);
393
394        let batches: Vec<RecordBatch> = loader.into_iter().collect();
395        assert_eq!(batches.len(), 3); // 10/3 = 3 full batches
396
397        for batch in &batches {
398            assert_eq!(batch.num_rows(), 3);
399        }
400    }
401
402    #[test]
403    fn test_weighted_loader_deterministic() {
404        let dataset = create_test_dataset(100);
405        let weights = vec![1.0; 100];
406
407        let loader1 = WeightedDataLoader::new(dataset.clone(), weights.clone())
408            .ok()
409            .unwrap_or_else(|| panic!("Should create loader"))
410            .batch_size(10)
411            .seed(42);
412        let batches1: Vec<RecordBatch> = loader1.into_iter().collect();
413
414        let loader2 = WeightedDataLoader::new(dataset, weights)
415            .ok()
416            .unwrap_or_else(|| panic!("Should create loader"))
417            .batch_size(10)
418            .seed(42);
419        let batches2: Vec<RecordBatch> = loader2.into_iter().collect();
420
421        assert_eq!(batches1.len(), batches2.len());
422        for (b1, b2) in batches1.iter().zip(batches2.iter()) {
423            let ids1 = b1
424                .column(0)
425                .as_any()
426                .downcast_ref::<Int32Array>()
427                .unwrap_or_else(|| panic!("Should be Int32Array"));
428            let ids2 = b2
429                .column(0)
430                .as_any()
431                .downcast_ref::<Int32Array>()
432                .unwrap_or_else(|| panic!("Should be Int32Array"));
433
434            for i in 0..ids1.len() {
435                assert_eq!(ids1.value(i), ids2.value(i));
436            }
437        }
438    }
439
440    #[test]
441    fn test_weighted_loader_biased_sampling() {
442        // Create dataset with 10 items, heavily weight item 0
443        let dataset = create_test_dataset(10);
444        let mut weights = vec![0.1; 10];
445        weights[0] = 10.0; // Item 0 should appear much more often
446
447        let loader = WeightedDataLoader::new(dataset, weights)
448            .ok()
449            .unwrap_or_else(|| panic!("Should create loader"))
450            .batch_size(1)
451            .num_samples(1000) // Large sample to see distribution
452            .seed(42);
453
454        let mut counts: HashMap<i32, usize> = HashMap::new();
455        for batch in loader {
456            let ids = batch
457                .column(0)
458                .as_any()
459                .downcast_ref::<Int32Array>()
460                .unwrap_or_else(|| panic!("Should be Int32Array"));
461            for i in 0..ids.len() {
462                *counts.entry(ids.value(i)).or_insert(0) += 1;
463            }
464        }
465
466        // Item 0 should appear significantly more than others
467        let count_0 = *counts.get(&0).unwrap_or(&0);
468        let count_1 = *counts.get(&1).unwrap_or(&0);
469
470        // With weights 10.0 vs 0.1, item 0 should appear ~100x more often
471        assert!(
472            count_0 > count_1 * 10,
473            "Item 0 ({}) should appear much more than item 1 ({})",
474            count_0,
475            count_1
476        );
477    }
478
479    #[test]
480    fn test_weighted_loader_num_samples() {
481        let dataset = create_test_dataset(10);
482        let weights = vec![1.0; 10];
483
484        let loader = WeightedDataLoader::new(dataset, weights)
485            .ok()
486            .unwrap_or_else(|| panic!("Should create loader"))
487            .batch_size(5)
488            .num_samples(25) // More than dataset size
489            .seed(42);
490
491        let batches: Vec<RecordBatch> = loader.into_iter().collect();
492        assert_eq!(batches.len(), 5); // ceil(25/5) = 5
493
494        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
495        assert_eq!(total_rows, 25);
496    }
497
498    #[test]
499    fn test_weighted_loader_num_batches() {
500        let dataset = create_test_dataset(10);
501        let weights = vec![1.0; 10];
502
503        let loader = WeightedDataLoader::new(dataset.clone(), weights.clone())
504            .ok()
505            .unwrap_or_else(|| panic!("Should create loader"))
506            .batch_size(3);
507        assert_eq!(loader.num_batches(), 4);
508
509        let loader = WeightedDataLoader::new(dataset, weights)
510            .ok()
511            .unwrap_or_else(|| panic!("Should create loader"))
512            .batch_size(3)
513            .drop_last(true);
514        assert_eq!(loader.num_batches(), 3);
515    }
516
517    #[test]
518    fn test_weighted_loader_size_hint() {
519        let dataset = create_test_dataset(10);
520        let weights = vec![1.0; 10];
521
522        let loader = WeightedDataLoader::new(dataset, weights)
523            .ok()
524            .unwrap_or_else(|| panic!("Should create loader"))
525            .batch_size(3)
526            .seed(42);
527
528        let mut iter = loader.into_iter();
529        assert_eq!(iter.size_hint(), (4, Some(4)));
530
531        let _ = iter.next();
532        assert_eq!(iter.size_hint(), (3, Some(3)));
533    }
534
535    #[test]
536    fn test_weighted_loader_getters() {
537        let dataset = create_test_dataset(10);
538        let weights = vec![1.5; 10];
539
540        let loader = WeightedDataLoader::new(dataset, weights)
541            .ok()
542            .unwrap_or_else(|| panic!("Should create loader"))
543            .batch_size(5)
544            .num_samples(20);
545
546        assert_eq!(loader.get_batch_size(), 5);
547        assert_eq!(loader.get_num_samples(), 20);
548        assert_eq!(loader.len(), 10);
549        assert!(!loader.is_empty());
550        assert!(loader.weights().iter().all(|&w| w == 1.5));
551    }
552
553    #[test]
554    fn test_weighted_loader_batch_size_min_one() {
555        let dataset = create_test_dataset(10);
556        let weights = vec![1.0; 10];
557
558        let loader = WeightedDataLoader::new(dataset, weights)
559            .ok()
560            .unwrap_or_else(|| panic!("Should create loader"))
561            .batch_size(0);
562
563        assert_eq!(loader.get_batch_size(), 1);
564    }
565
566    #[test]
567    fn test_weighted_loader_debug() {
568        let dataset = create_test_dataset(10);
569        let weights = vec![1.0; 10];
570
571        let loader = WeightedDataLoader::new(dataset, weights)
572            .ok()
573            .unwrap_or_else(|| panic!("Should create loader"))
574            .batch_size(5)
575            .seed(42);
576
577        let debug_str = format!("{:?}", loader);
578        assert!(debug_str.contains("WeightedDataLoader"));
579
580        let iter = loader.into_iter();
581        let iter_debug = format!("{:?}", iter);
582        assert!(iter_debug.contains("WeightedDataLoaderIterator"));
583    }
584
585    #[test]
586    fn test_weighted_loader_all_zero_weights() {
587        // All zero weights should fall back to uniform sampling
588        let dataset = create_test_dataset(10);
589        let weights = vec![0.0; 10];
590
591        let loader = WeightedDataLoader::new(dataset, weights)
592            .ok()
593            .unwrap_or_else(|| panic!("Should create loader"))
594            .batch_size(5)
595            .num_samples(20)
596            .seed(42);
597
598        // Should still be able to iterate (falls back to uniform)
599        let batches: Vec<RecordBatch> = loader.into_iter().collect();
600        assert_eq!(batches.len(), 4); // 20 samples / 5 batch_size = 4 batches
601    }
602
603    #[test]
604    fn test_weighted_loader_single_nonzero_weight() {
605        // Only one item has weight, should sample only that item
606        let dataset = create_test_dataset(10);
607        let mut weights = vec![0.0; 10];
608        weights[5] = 1.0; // Only item 5 has weight
609
610        let loader = WeightedDataLoader::new(dataset, weights)
611            .ok()
612            .unwrap_or_else(|| panic!("Should create loader"))
613            .batch_size(1)
614            .num_samples(10)
615            .seed(42);
616
617        let mut all_are_item_5 = true;
618        for batch in loader {
619            let ids = batch
620                .column(0)
621                .as_any()
622                .downcast_ref::<Int32Array>()
623                .unwrap_or_else(|| panic!("Should be Int32Array"));
624            for i in 0..ids.len() {
625                if ids.value(i) != 5 {
626                    all_are_item_5 = false;
627                }
628            }
629        }
630        assert!(all_are_item_5, "All samples should be item 5");
631    }
632
633    #[test]
634    fn test_weighted_loader_large_dataset() {
635        // Test with larger dataset to verify performance
636        let schema = Arc::new(Schema::new(vec![
637            Field::new("id", DataType::Int32, false),
638            Field::new("value", DataType::Utf8, false),
639        ]));
640
641        let ids: Vec<i32> = (0..10000).collect();
642        let values: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
643
644        let batch = RecordBatch::try_new(
645            schema,
646            vec![
647                Arc::new(Int32Array::from(ids)),
648                Arc::new(StringArray::from(values)),
649            ],
650        )
651        .ok()
652        .unwrap_or_else(|| panic!("Should create batch"));
653
654        let dataset = ArrowDataset::from_batch(batch)
655            .ok()
656            .unwrap_or_else(|| panic!("Should create dataset"));
657
658        let weights: Vec<f32> = (0..10000).map(|i| (i % 10 + 1) as f32).collect();
659
660        let loader = WeightedDataLoader::new(dataset, weights)
661            .ok()
662            .unwrap_or_else(|| panic!("Should create loader"))
663            .batch_size(100)
664            .num_samples(5000)
665            .seed(42);
666
667        let batches: Vec<RecordBatch> = loader.into_iter().collect();
668        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
669        assert_eq!(total_rows, 5000);
670    }
671
672    #[test]
673    fn test_weighted_loader_very_small_weights() {
674        // Test with very small but nonzero weights
675        let dataset = create_test_dataset(10);
676        let weights: Vec<f32> = (0..10).map(|i| (i + 1) as f32 * 1e-10).collect();
677
678        let loader = WeightedDataLoader::new(dataset, weights)
679            .ok()
680            .unwrap_or_else(|| panic!("Should create loader"))
681            .batch_size(5)
682            .num_samples(20)
683            .seed(42);
684
685        let batches: Vec<RecordBatch> = loader.into_iter().collect();
686        assert_eq!(batches.len(), 4);
687    }
688
689    #[test]
690    fn test_weighted_loader_mixed_zero_nonzero() {
691        // Half zero, half nonzero weights
692        let dataset = create_test_dataset(10);
693        let weights: Vec<f32> = (0..10).map(|i| if i < 5 { 0.0 } else { 1.0 }).collect();
694
695        let loader = WeightedDataLoader::new(dataset, weights)
696            .ok()
697            .unwrap_or_else(|| panic!("Should create loader"))
698            .batch_size(1)
699            .num_samples(100)
700            .seed(42);
701
702        let mut counts: HashMap<i32, usize> = HashMap::new();
703        for batch in loader {
704            let ids = batch
705                .column(0)
706                .as_any()
707                .downcast_ref::<Int32Array>()
708                .unwrap_or_else(|| panic!("Should be Int32Array"));
709            for i in 0..ids.len() {
710                *counts.entry(ids.value(i)).or_insert(0) += 1;
711            }
712        }
713
714        // Items 0-4 should have 0 counts, items 5-9 should have counts
715        for i in 0..5 {
716            assert_eq!(
717                *counts.get(&i).unwrap_or(&0),
718                0,
719                "Item {} should not be sampled",
720                i
721            );
722        }
723        for i in 5..10 {
724            assert!(
725                *counts.get(&i).unwrap_or(&0) > 0,
726                "Item {} should be sampled",
727                i
728            );
729        }
730    }
731
732    #[test]
733    fn test_weighted_loader_undersample() {
734        // num_samples less than dataset size
735        let dataset = create_test_dataset(100);
736        let weights = vec![1.0; 100];
737
738        let loader = WeightedDataLoader::new(dataset, weights)
739            .ok()
740            .unwrap_or_else(|| panic!("Should create loader"))
741            .batch_size(5)
742            .num_samples(20)
743            .seed(42);
744
745        let batches: Vec<RecordBatch> = loader.into_iter().collect();
746        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
747        assert_eq!(total_rows, 20);
748    }
749
750    #[test]
751    fn test_weighted_loader_exact_batch_multiple() {
752        // num_samples exactly divisible by batch_size
753        let dataset = create_test_dataset(100);
754        let weights = vec![1.0; 100];
755
756        let loader = WeightedDataLoader::new(dataset, weights)
757            .ok()
758            .unwrap_or_else(|| panic!("Should create loader"))
759            .batch_size(10)
760            .num_samples(50);
761
762        let batches: Vec<RecordBatch> = loader.into_iter().collect();
763        assert_eq!(batches.len(), 5);
764        for batch in &batches {
765            assert_eq!(batch.num_rows(), 10);
766        }
767    }
768
769    #[test]
770    fn test_weighted_loader_negative_weight_error() {
771        let dataset = create_test_dataset(10);
772        let weights = vec![1.0, 2.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
773
774        let result = WeightedDataLoader::new(dataset, weights);
775        assert!(result.is_err());
776    }
777
778    #[test]
779    fn test_weighted_loader_single_item() {
780        let dataset = create_test_dataset(1);
781        let weights = vec![1.0];
782
783        let loader = WeightedDataLoader::new(dataset, weights)
784            .ok()
785            .unwrap_or_else(|| panic!("Should create loader"))
786            .batch_size(1)
787            .num_samples(10);
788
789        let batches: Vec<RecordBatch> = loader.into_iter().collect();
790        assert_eq!(batches.len(), 10);
791
792        // All batches should have the same single row
793        for batch in batches {
794            assert_eq!(batch.num_rows(), 1);
795        }
796    }
797
798    #[test]
799    fn test_weighted_loader_oversample() {
800        // num_samples much larger than dataset size
801        let dataset = create_test_dataset(5);
802        let weights = vec![1.0; 5];
803
804        let loader = WeightedDataLoader::new(dataset, weights)
805            .ok()
806            .unwrap_or_else(|| panic!("Should create loader"))
807            .batch_size(10)
808            .num_samples(100);
809
810        let batches: Vec<RecordBatch> = loader.into_iter().collect();
811        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
812        assert_eq!(total_rows, 100);
813    }
814
815    #[test]
816    fn test_weighted_loader_is_empty() {
817        // is_empty() returns dataset.is_empty(), not based on num_samples
818        let dataset = create_test_dataset(10);
819        let weights = vec![1.0; 10];
820
821        let loader = WeightedDataLoader::new(dataset, weights)
822            .ok()
823            .unwrap_or_else(|| panic!("Should create loader"));
824
825        // Dataset has 10 items, so not empty
826        assert!(!loader.is_empty());
827        assert_eq!(loader.len(), 10);
828    }
829
830    #[test]
831    fn test_weighted_loader_len() {
832        // len() returns dataset.len(), not num_samples
833        let dataset = create_test_dataset(100);
834        let weights = vec![1.0; 100];
835
836        let loader = WeightedDataLoader::new(dataset, weights)
837            .ok()
838            .unwrap_or_else(|| panic!("Should create loader"))
839            .num_samples(42);
840
841        // len() returns dataset length
842        assert_eq!(loader.len(), 100);
843        // get_num_samples() returns configured num_samples
844        assert_eq!(loader.get_num_samples(), 42);
845    }
846
847    #[test]
848    fn test_weighted_loader_weight_length_mismatch() {
849        let dataset = create_test_dataset(10);
850        let weights = vec![1.0; 5]; // Wrong length
851
852        let result = WeightedDataLoader::new(dataset, weights);
853        assert!(result.is_err());
854    }
855
856    #[test]
857    fn test_weighted_loader_very_large_weight() {
858        let dataset = create_test_dataset(3);
859        let weights = vec![1e10, 1.0, 1.0]; // First item has huge weight
860
861        let loader = WeightedDataLoader::new(dataset, weights)
862            .ok()
863            .unwrap_or_else(|| panic!("Should create loader"))
864            .batch_size(1)
865            .num_samples(100)
866            .seed(42);
867
868        let mut counts: HashMap<i32, usize> = HashMap::new();
869        for batch in loader {
870            let ids = batch
871                .column(0)
872                .as_any()
873                .downcast_ref::<Int32Array>()
874                .unwrap_or_else(|| panic!("Should be Int32Array"));
875            for i in 0..ids.len() {
876                *counts.entry(ids.value(i)).or_insert(0) += 1;
877            }
878        }
879
880        // First item should be sampled almost exclusively
881        let first_count = *counts.get(&0).unwrap_or(&0);
882        assert!(
883            first_count > 95,
884            "First item should dominate: {}",
885            first_count
886        );
887    }
888
889    #[test]
890    fn test_weighted_loader_extreme_weight_ratio() {
891        let dataset = create_test_dataset(2);
892        // 1000:1 weight ratio
893        let weights = vec![1000.0, 1.0];
894
895        let loader = WeightedDataLoader::new(dataset, weights)
896            .ok()
897            .unwrap_or_else(|| panic!("Should create loader"))
898            .batch_size(1)
899            .num_samples(1000)
900            .seed(42);
901
902        let mut counts: HashMap<i32, usize> = HashMap::new();
903        for batch in loader {
904            let ids = batch
905                .column(0)
906                .as_any()
907                .downcast_ref::<Int32Array>()
908                .unwrap_or_else(|| panic!("Should be Int32Array"));
909            for i in 0..ids.len() {
910                *counts.entry(ids.value(i)).or_insert(0) += 1;
911            }
912        }
913
914        let first = *counts.get(&0).unwrap_or(&0);
915        let second = *counts.get(&1).unwrap_or(&0);
916
917        // First should be ~1000x more frequent than second
918        assert!(
919            first > 900,
920            "First should dominate: {} vs {}",
921            first,
922            second
923        );
924    }
925
926    #[test]
927    fn test_weighted_loader_reweight_zero() {
928        let dataset = create_test_dataset(5);
929        // Zero reweight factor creates all-zero weights
930        let loader = WeightedDataLoader::with_reweight(dataset, 0.0);
931        assert!(loader.is_ok());
932        let loader = loader.ok().unwrap();
933        // All weights should be 0.0
934        assert!(loader.weights().iter().all(|&w| w == 0.0));
935    }
936
937    #[test]
938    fn test_weighted_loader_size_hint_drop_last_edge() {
939        let dataset = create_test_dataset(10);
940        let weights = vec![1.0; 10];
941
942        // 10 samples, batch_size 3, drop_last=true -> 3 full batches
943        let loader = WeightedDataLoader::new(dataset, weights)
944            .ok()
945            .unwrap()
946            .batch_size(3)
947            .num_samples(10)
948            .drop_last(true);
949
950        assert_eq!(loader.num_batches(), 3);
951    }
952
953    #[test]
954    fn test_weighted_loader_size_hint_no_drop_last() {
955        let dataset = create_test_dataset(10);
956        let weights = vec![1.0; 10];
957
958        // 10 samples, batch_size 3, drop_last=false -> 4 batches
959        let loader = WeightedDataLoader::new(dataset, weights)
960            .ok()
961            .unwrap()
962            .batch_size(3)
963            .num_samples(10)
964            .drop_last(false);
965
966        assert_eq!(loader.num_batches(), 4);
967    }
968
969    #[test]
970    fn test_weighted_loader_iteration_with_drop_last() {
971        let dataset = create_test_dataset(10);
972        let weights = vec![1.0; 10];
973
974        let loader = WeightedDataLoader::new(dataset, weights)
975            .ok()
976            .unwrap()
977            .batch_size(4)
978            .num_samples(10)
979            .drop_last(true)
980            .seed(42);
981
982        let batches: Vec<RecordBatch> = loader.into_iter().collect();
983        // 10 samples / 4 batch_size with drop_last = 2 full batches
984        assert_eq!(batches.len(), 2);
985        for batch in batches {
986            assert_eq!(batch.num_rows(), 4);
987        }
988    }
989}