1#![allow(clippy::cast_precision_loss)]
24
25use std::collections::HashMap;
26
27use crate::{
28 dataset::{ArrowDataset, Dataset},
29 error::{Error, Result},
30};
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
34pub enum ImbalanceSeverity {
35 None,
37 Low,
39 Moderate,
41 Severe,
43 Extreme,
45}
46
47impl ImbalanceSeverity {
48 pub fn from_ratio(ratio: f64) -> Self {
50 if ratio < 1.5 {
51 Self::None
52 } else if ratio < 3.0 {
53 Self::Low
54 } else if ratio < 10.0 {
55 Self::Moderate
56 } else if ratio < 100.0 {
57 Self::Severe
58 } else {
59 Self::Extreme
60 }
61 }
62
63 pub fn is_imbalanced(&self) -> bool {
65 *self != Self::None
66 }
67
68 pub fn description(&self) -> &'static str {
70 match self {
71 Self::None => "Balanced",
72 Self::Low => "Slightly imbalanced",
73 Self::Moderate => "Moderately imbalanced",
74 Self::Severe => "Severely imbalanced",
75 Self::Extreme => "Extremely imbalanced",
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct ImbalanceMetrics {
83 pub imbalance_ratio: f64,
85 pub entropy: f64,
88 pub normalized_entropy: f64,
90 pub gini: f64,
92 pub severity: ImbalanceSeverity,
94}
95
96impl ImbalanceMetrics {
97 pub fn from_counts(counts: &HashMap<String, usize>) -> Self {
99 if counts.is_empty() {
100 return Self {
101 imbalance_ratio: 1.0,
102 entropy: 0.0,
103 normalized_entropy: 1.0,
104 gini: 0.0,
105 severity: ImbalanceSeverity::None,
106 };
107 }
108
109 let total: usize = counts.values().sum();
110 if total == 0 {
111 return Self {
112 imbalance_ratio: 1.0,
113 entropy: 0.0,
114 normalized_entropy: 1.0,
115 gini: 0.0,
116 severity: ImbalanceSeverity::None,
117 };
118 }
119
120 let total_f = total as f64;
121 let n_classes = counts.len();
122
123 let max_count = counts.values().copied().max().unwrap_or(0);
125 let min_count = counts.values().copied().min().unwrap_or(0);
126 let imbalance_ratio = if min_count > 0 {
127 max_count as f64 / min_count as f64
128 } else {
129 f64::INFINITY
130 };
131
132 let entropy: f64 = counts
134 .values()
135 .map(|&c| {
136 if c > 0 {
137 let p = c as f64 / total_f;
138 -p * p.ln()
139 } else {
140 0.0
141 }
142 })
143 .sum();
144
145 let max_entropy = (n_classes as f64).ln();
147 let normalized_entropy = if max_entropy > 0.0 {
148 entropy / max_entropy
149 } else {
150 1.0
151 };
152
153 let gini: f64 = 1.0
155 - counts
156 .values()
157 .map(|&c| {
158 let p = c as f64 / total_f;
159 p * p
160 })
161 .sum::<f64>();
162
163 let severity = ImbalanceSeverity::from_ratio(imbalance_ratio);
164
165 Self {
166 imbalance_ratio,
167 entropy,
168 normalized_entropy,
169 gini,
170 severity,
171 }
172 }
173
174 pub fn is_imbalanced(&self) -> bool {
176 self.severity.is_imbalanced()
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct ClassDistribution {
183 pub counts: HashMap<String, usize>,
185 pub proportions: HashMap<String, f64>,
187 pub total: usize,
189 pub num_classes: usize,
191 pub majority_class: Option<String>,
193 pub minority_class: Option<String>,
195}
196
197impl ClassDistribution {
198 pub fn from_counts(counts: HashMap<String, usize>) -> Self {
200 let total: usize = counts.values().sum();
201 let num_classes = counts.len();
202
203 let proportions: HashMap<String, f64> = counts
204 .iter()
205 .map(|(k, &v)| {
206 let prop = if total > 0 {
207 v as f64 / total as f64
208 } else {
209 0.0
210 };
211 (k.clone(), prop)
212 })
213 .collect();
214
215 let majority_class = counts
216 .iter()
217 .max_by_key(|(_, &v)| v)
218 .map(|(k, _)| k.clone());
219
220 let minority_class = counts
221 .iter()
222 .filter(|(_, &v)| v > 0)
223 .min_by_key(|(_, &v)| v)
224 .map(|(k, _)| k.clone());
225
226 Self {
227 counts,
228 proportions,
229 total,
230 num_classes,
231 majority_class,
232 minority_class,
233 }
234 }
235
236 pub fn get_count(&self, class: &str) -> usize {
238 self.counts.get(class).copied().unwrap_or(0)
239 }
240
241 pub fn get_proportion(&self, class: &str) -> f64 {
243 self.proportions.get(class).copied().unwrap_or(0.0)
244 }
245}
246
247#[derive(Debug, Clone, PartialEq, Eq)]
249pub enum ImbalanceRecommendation {
250 NoAction,
252 UseStratifiedSplit,
254 UseClassWeights,
256 ConsiderOversampling,
258 ConsiderUndersampling,
260 ConsiderSMOTE,
262 CollectMoreData,
264 UseAppropriateMetrics,
266 ConsiderAnomalyDetection,
268}
269
270impl ImbalanceRecommendation {
271 pub fn description(&self) -> &'static str {
273 match self {
274 Self::NoAction => "No action needed - dataset is balanced",
275 Self::UseStratifiedSplit => "Use stratified sampling for train/test splits",
276 Self::UseClassWeights => "Apply class weights during model training",
277 Self::ConsiderOversampling => "Consider oversampling the minority class",
278 Self::ConsiderUndersampling => "Consider undersampling the majority class",
279 Self::ConsiderSMOTE => "Consider SMOTE or synthetic minority oversampling",
280 Self::CollectMoreData => "Collect more samples for minority classes",
281 Self::UseAppropriateMetrics => {
282 "Use F1-score, AUC-ROC, or precision-recall instead of accuracy"
283 }
284 Self::ConsiderAnomalyDetection => "Consider framing as anomaly detection problem",
285 }
286 }
287}
288
289impl std::fmt::Display for ImbalanceRecommendation {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 write!(f, "{}", self.description())
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct ImbalanceReport {
298 pub column: String,
300 pub distribution: ClassDistribution,
302 pub metrics: ImbalanceMetrics,
304 pub recommendations: Vec<ImbalanceRecommendation>,
306}
307
308impl ImbalanceReport {
309 pub fn from_distribution(column: impl Into<String>, distribution: ClassDistribution) -> Self {
311 let metrics = ImbalanceMetrics::from_counts(&distribution.counts);
312 let recommendations = generate_recommendations(&metrics, &distribution);
313
314 Self {
315 column: column.into(),
316 distribution,
317 metrics,
318 recommendations,
319 }
320 }
321
322 pub fn is_imbalanced(&self) -> bool {
324 self.metrics.is_imbalanced()
325 }
326
327 pub fn severity(&self) -> ImbalanceSeverity {
329 self.metrics.severity
330 }
331}
332
333fn generate_recommendations(
335 metrics: &ImbalanceMetrics,
336 distribution: &ClassDistribution,
337) -> Vec<ImbalanceRecommendation> {
338 let mut recs = Vec::new();
339
340 match metrics.severity {
341 ImbalanceSeverity::None => {
342 recs.push(ImbalanceRecommendation::NoAction);
343 }
344 ImbalanceSeverity::Low => {
345 recs.push(ImbalanceRecommendation::UseStratifiedSplit);
346 recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
347 }
348 ImbalanceSeverity::Moderate => {
349 recs.push(ImbalanceRecommendation::UseStratifiedSplit);
350 recs.push(ImbalanceRecommendation::UseClassWeights);
351 recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
352 if distribution.total < 10000 {
353 recs.push(ImbalanceRecommendation::ConsiderOversampling);
354 } else {
355 recs.push(ImbalanceRecommendation::ConsiderUndersampling);
356 }
357 }
358 ImbalanceSeverity::Severe => {
359 recs.push(ImbalanceRecommendation::UseStratifiedSplit);
360 recs.push(ImbalanceRecommendation::UseClassWeights);
361 recs.push(ImbalanceRecommendation::ConsiderSMOTE);
362 recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
363 recs.push(ImbalanceRecommendation::CollectMoreData);
364 }
365 ImbalanceSeverity::Extreme => {
366 recs.push(ImbalanceRecommendation::ConsiderAnomalyDetection);
367 recs.push(ImbalanceRecommendation::UseStratifiedSplit);
368 recs.push(ImbalanceRecommendation::ConsiderSMOTE);
369 recs.push(ImbalanceRecommendation::CollectMoreData);
370 recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
371 }
372 }
373
374 recs
375}
376
377pub struct ImbalanceDetector {
379 label_column: String,
381}
382
383impl ImbalanceDetector {
384 pub fn new(label_column: impl Into<String>) -> Self {
386 Self {
387 label_column: label_column.into(),
388 }
389 }
390
391 pub fn label_column(&self) -> &str {
393 &self.label_column
394 }
395
396 pub fn analyze(&self, dataset: &ArrowDataset) -> Result<ImbalanceReport> {
398 let counts = self.count_classes(dataset)?;
399 let distribution = ClassDistribution::from_counts(counts);
400 Ok(ImbalanceReport::from_distribution(
401 &self.label_column,
402 distribution,
403 ))
404 }
405
406 fn count_classes(&self, dataset: &ArrowDataset) -> Result<HashMap<String, usize>> {
408 use arrow::array::{Array, Int32Array, Int64Array, StringArray};
409
410 let schema = dataset.schema();
411 let col_idx = schema
412 .fields()
413 .iter()
414 .position(|f| f.name() == &self.label_column)
415 .ok_or_else(|| {
416 Error::invalid_config(format!(
417 "Column '{}' not found in schema",
418 self.label_column
419 ))
420 })?;
421
422 let mut counts: HashMap<String, usize> = HashMap::new();
423
424 for batch in dataset.iter() {
425 let array = batch.column(col_idx);
426
427 if let Some(arr) = array.as_any().downcast_ref::<StringArray>() {
429 for i in 0..arr.len() {
430 if !arr.is_null(i) {
431 let key = arr.value(i).to_string();
432 *counts.entry(key).or_insert(0) += 1;
433 }
434 }
435 } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
436 for i in 0..arr.len() {
437 if !arr.is_null(i) {
438 let key = arr.value(i).to_string();
439 *counts.entry(key).or_insert(0) += 1;
440 }
441 }
442 } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
443 for i in 0..arr.len() {
444 if !arr.is_null(i) {
445 let key = arr.value(i).to_string();
446 *counts.entry(key).or_insert(0) += 1;
447 }
448 }
449 } else {
450 return Err(Error::invalid_config(format!(
451 "Unsupported column type for '{}'. Expected string or integer.",
452 self.label_column
453 )));
454 }
455 }
456
457 if counts.is_empty() {
458 return Err(Error::invalid_config(format!(
459 "No valid values found in column '{}'",
460 self.label_column
461 )));
462 }
463
464 Ok(counts)
465 }
466}
467
468#[derive(Debug, Clone, Copy, PartialEq, Eq)]
470pub enum ResampleStrategy {
471 Oversample,
473 Undersample,
475}
476
477#[cfg(feature = "shuffle")]
492#[allow(clippy::cast_possible_truncation)]
493fn group_rows_by_label(
497 label_array: &dyn arrow::array::Array,
498 n: usize,
499 label_column: &str,
500) -> Result<std::collections::HashMap<String, Vec<u32>>> {
501 use arrow::array::{Array, Int32Array, Int64Array, StringArray};
502
503 let mut groups: std::collections::HashMap<String, Vec<u32>> = std::collections::HashMap::new();
504
505 if let Some(arr) = label_array.as_any().downcast_ref::<Int64Array>() {
506 for i in 0..n {
507 if !arr.is_null(i) {
508 groups
509 .entry(arr.value(i).to_string())
510 .or_default()
511 .push(i as u32);
512 }
513 }
514 } else if let Some(arr) = label_array.as_any().downcast_ref::<Int32Array>() {
515 for i in 0..n {
516 if !arr.is_null(i) {
517 groups
518 .entry(arr.value(i).to_string())
519 .or_default()
520 .push(i as u32);
521 }
522 }
523 } else if let Some(arr) = label_array.as_any().downcast_ref::<StringArray>() {
524 for i in 0..n {
525 if !arr.is_null(i) {
526 groups
527 .entry(arr.value(i).to_string())
528 .or_default()
529 .push(i as u32);
530 }
531 }
532 } else {
533 return Err(Error::invalid_config(format!(
534 "Unsupported column type for '{label_column}'"
535 )));
536 }
537
538 Ok(groups)
539}
540
541#[cfg(feature = "shuffle")]
546fn balance_group_indices(
547 groups: &std::collections::HashMap<String, Vec<u32>>,
548 target: usize,
549 rng: &mut rand::rngs::StdRng,
550) -> Vec<u32> {
551 use rand::seq::SliceRandom;
552
553 let mut all_indices: Vec<u32> = Vec::new();
554 for indices in groups.values() {
555 match indices.len().cmp(&target) {
556 std::cmp::Ordering::Equal => {
557 all_indices.extend_from_slice(indices);
558 }
559 std::cmp::Ordering::Less => {
560 all_indices.extend_from_slice(indices);
561 let mut extra: Vec<u32> = Vec::with_capacity(target - indices.len());
562 while extra.len() + indices.len() < target {
563 extra.extend_from_slice(indices);
564 }
565 extra.truncate(target - indices.len());
566 extra.shuffle(rng);
567 all_indices.extend(extra);
568 }
569 std::cmp::Ordering::Greater => {
570 let mut sampled = indices.clone();
571 sampled.shuffle(rng);
572 sampled.truncate(target);
573 all_indices.extend(sampled);
574 }
575 }
576 }
577 all_indices.shuffle(rng);
578 all_indices
579}
580
581#[cfg(feature = "shuffle")]
586pub fn resample(
587 dataset: &ArrowDataset,
588 label_column: &str,
589 strategy: ResampleStrategy,
590 seed: u64,
591) -> Result<ArrowDataset> {
592 use arrow::compute;
593 use rand::SeedableRng;
594
595 let batches: Vec<arrow::array::RecordBatch> = dataset.iter().collect();
596 let batch = if batches.len() == 1 {
597 batches
598 .into_iter()
599 .next()
600 .ok_or_else(|| Error::empty_dataset("empty"))?
601 } else {
602 arrow::compute::concat_batches(&dataset.schema(), &batches).map_err(Error::Arrow)?
603 };
604 let schema = batch.schema();
605 let col_idx = schema
606 .fields()
607 .iter()
608 .position(|f| f.name() == label_column)
609 .ok_or_else(|| {
610 Error::invalid_config(format!("Column '{label_column}' not found in schema"))
611 })?;
612
613 let groups = group_rows_by_label(
614 batch.column(col_idx).as_ref(),
615 batch.num_rows(),
616 label_column,
617 )?;
618
619 if groups.is_empty() {
620 return Err(Error::empty_dataset("No valid labels found for resampling"));
621 }
622
623 let target = match strategy {
624 ResampleStrategy::Oversample => groups.values().map(|v| v.len()).max().unwrap_or(0),
625 ResampleStrategy::Undersample => groups.values().map(|v| v.len()).min().unwrap_or(0),
626 };
627
628 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
629 let all_indices = balance_group_indices(&groups, target, &mut rng);
630
631 let indices_array = arrow::array::UInt32Array::from(all_indices);
632 let columns: Vec<arrow::array::ArrayRef> = (0..batch.num_columns())
633 .map(|i| compute::take(batch.column(i), &indices_array, None))
634 .collect::<std::result::Result<Vec<_>, _>>()
635 .map_err(|e| Error::invalid_config(format!("Arrow take failed: {e}")))?;
636
637 let result_batch = arrow::array::RecordBatch::try_new(schema, columns)
638 .map_err(|e| Error::invalid_config(format!("Failed to create resampled batch: {e}")))?;
639
640 ArrowDataset::from_batch(result_batch)
641}
642
643pub fn sqrt_inverse_weights(class_counts: &[usize]) -> Vec<f32> {
655 let k = class_counts.len() as f64;
656 let n: f64 = class_counts.iter().sum::<usize>() as f64;
657
658 if k == 0.0 || n == 0.0 {
659 return Vec::new();
660 }
661
662 let raw: Vec<f64> = class_counts
663 .iter()
664 .map(|&c| {
665 if c == 0 {
666 0.0
667 } else {
668 (n / (k * c as f64)).sqrt()
669 }
670 })
671 .collect();
672
673 let sum: f64 = raw.iter().sum();
675 if sum == 0.0 {
676 return vec![1.0; class_counts.len()];
677 }
678
679 raw.iter().map(|&w| (w * k / sum) as f32).collect()
680}
681
682#[cfg(test)]
683mod tests {
684 use std::sync::Arc;
685
686 use arrow::{
687 array::{Int32Array, StringArray},
688 datatypes::{DataType, Field, Schema},
689 record_batch::RecordBatch,
690 };
691
692 use super::*;
693
694 #[test]
697 fn test_severity_from_ratio() {
698 assert_eq!(ImbalanceSeverity::from_ratio(1.0), ImbalanceSeverity::None);
699 assert_eq!(ImbalanceSeverity::from_ratio(1.4), ImbalanceSeverity::None);
700 assert_eq!(ImbalanceSeverity::from_ratio(1.5), ImbalanceSeverity::Low);
701 assert_eq!(ImbalanceSeverity::from_ratio(2.9), ImbalanceSeverity::Low);
702 assert_eq!(
703 ImbalanceSeverity::from_ratio(3.0),
704 ImbalanceSeverity::Moderate
705 );
706 assert_eq!(
707 ImbalanceSeverity::from_ratio(9.9),
708 ImbalanceSeverity::Moderate
709 );
710 assert_eq!(
711 ImbalanceSeverity::from_ratio(10.0),
712 ImbalanceSeverity::Severe
713 );
714 assert_eq!(
715 ImbalanceSeverity::from_ratio(99.0),
716 ImbalanceSeverity::Severe
717 );
718 assert_eq!(
719 ImbalanceSeverity::from_ratio(100.0),
720 ImbalanceSeverity::Extreme
721 );
722 assert_eq!(
723 ImbalanceSeverity::from_ratio(1000.0),
724 ImbalanceSeverity::Extreme
725 );
726 }
727
728 #[test]
729 fn test_severity_is_imbalanced() {
730 assert!(!ImbalanceSeverity::None.is_imbalanced());
731 assert!(ImbalanceSeverity::Low.is_imbalanced());
732 assert!(ImbalanceSeverity::Moderate.is_imbalanced());
733 assert!(ImbalanceSeverity::Severe.is_imbalanced());
734 assert!(ImbalanceSeverity::Extreme.is_imbalanced());
735 }
736
737 #[test]
738 fn test_severity_ordering() {
739 assert!(ImbalanceSeverity::None < ImbalanceSeverity::Low);
740 assert!(ImbalanceSeverity::Low < ImbalanceSeverity::Moderate);
741 assert!(ImbalanceSeverity::Moderate < ImbalanceSeverity::Severe);
742 assert!(ImbalanceSeverity::Severe < ImbalanceSeverity::Extreme);
743 }
744
745 #[test]
746 fn test_severity_description() {
747 assert_eq!(ImbalanceSeverity::None.description(), "Balanced");
748 assert_eq!(
749 ImbalanceSeverity::Extreme.description(),
750 "Extremely imbalanced"
751 );
752 }
753
754 #[test]
757 fn test_metrics_balanced() {
758 let counts: HashMap<String, usize> = [("A".to_string(), 100), ("B".to_string(), 100)]
759 .into_iter()
760 .collect();
761
762 let metrics = ImbalanceMetrics::from_counts(&counts);
763
764 assert!((metrics.imbalance_ratio - 1.0).abs() < 0.01);
765 assert!((metrics.normalized_entropy - 1.0).abs() < 0.01);
766 assert!((metrics.gini - 0.5).abs() < 0.01);
767 assert_eq!(metrics.severity, ImbalanceSeverity::None);
768 assert!(!metrics.is_imbalanced());
769 }
770
771 #[test]
772 fn test_metrics_imbalanced() {
773 let counts: HashMap<String, usize> = [("A".to_string(), 900), ("B".to_string(), 100)]
774 .into_iter()
775 .collect();
776
777 let metrics = ImbalanceMetrics::from_counts(&counts);
778
779 assert!((metrics.imbalance_ratio - 9.0).abs() < 0.01);
780 assert!(metrics.normalized_entropy < 0.8);
781 assert_eq!(metrics.severity, ImbalanceSeverity::Moderate);
782 assert!(metrics.is_imbalanced());
783 }
784
785 #[test]
786 fn test_metrics_severely_imbalanced() {
787 let counts: HashMap<String, usize> = [("A".to_string(), 990), ("B".to_string(), 10)]
788 .into_iter()
789 .collect();
790
791 let metrics = ImbalanceMetrics::from_counts(&counts);
792
793 assert!((metrics.imbalance_ratio - 99.0).abs() < 0.01);
794 assert_eq!(metrics.severity, ImbalanceSeverity::Severe);
795 }
796
797 #[test]
798 fn test_metrics_empty() {
799 let counts: HashMap<String, usize> = HashMap::new();
800 let metrics = ImbalanceMetrics::from_counts(&counts);
801
802 assert!((metrics.imbalance_ratio - 1.0).abs() < 0.01);
803 assert_eq!(metrics.severity, ImbalanceSeverity::None);
804 }
805
806 #[test]
807 fn test_metrics_single_class() {
808 let counts: HashMap<String, usize> = [("A".to_string(), 100)].into_iter().collect();
809
810 let metrics = ImbalanceMetrics::from_counts(&counts);
811
812 assert!(metrics.imbalance_ratio.is_infinite() || metrics.imbalance_ratio >= 1.0);
814 assert!((metrics.entropy - 0.0).abs() < 0.01);
815 assert!((metrics.gini - 0.0).abs() < 0.01);
816 }
817
818 #[test]
819 fn test_metrics_multiclass_balanced() {
820 let counts: HashMap<String, usize> = [
821 ("A".to_string(), 100),
822 ("B".to_string(), 100),
823 ("C".to_string(), 100),
824 ]
825 .into_iter()
826 .collect();
827
828 let metrics = ImbalanceMetrics::from_counts(&counts);
829
830 assert!((metrics.imbalance_ratio - 1.0).abs() < 0.01);
831 assert!((metrics.normalized_entropy - 1.0).abs() < 0.01);
832 }
833
834 #[test]
837 fn test_distribution_from_counts() {
838 let counts: HashMap<String, usize> = [("A".to_string(), 75), ("B".to_string(), 25)]
839 .into_iter()
840 .collect();
841
842 let dist = ClassDistribution::from_counts(counts);
843
844 assert_eq!(dist.total, 100);
845 assert_eq!(dist.num_classes, 2);
846 assert_eq!(dist.get_count("A"), 75);
847 assert_eq!(dist.get_count("B"), 25);
848 assert!((dist.get_proportion("A") - 0.75).abs() < 0.01);
849 assert!((dist.get_proportion("B") - 0.25).abs() < 0.01);
850 assert_eq!(dist.majority_class, Some("A".to_string()));
851 assert_eq!(dist.minority_class, Some("B".to_string()));
852 }
853
854 #[test]
855 fn test_distribution_missing_class() {
856 let counts: HashMap<String, usize> = [("A".to_string(), 100)].into_iter().collect();
857 let dist = ClassDistribution::from_counts(counts);
858
859 assert_eq!(dist.get_count("B"), 0);
860 assert!((dist.get_proportion("B") - 0.0).abs() < 0.01);
861 }
862
863 #[test]
866 fn test_recommendation_display() {
867 let rec = ImbalanceRecommendation::UseStratifiedSplit;
868 assert!(rec.to_string().contains("stratified"));
869 }
870
871 #[test]
874 fn test_report_balanced() {
875 let counts: HashMap<String, usize> = [("A".to_string(), 100), ("B".to_string(), 100)]
876 .into_iter()
877 .collect();
878 let dist = ClassDistribution::from_counts(counts);
879 let report = ImbalanceReport::from_distribution("label", dist);
880
881 assert!(!report.is_imbalanced());
882 assert_eq!(report.severity(), ImbalanceSeverity::None);
883 assert!(report
884 .recommendations
885 .contains(&ImbalanceRecommendation::NoAction));
886 }
887
888 #[test]
889 fn test_report_imbalanced() {
890 let counts: HashMap<String, usize> = [("A".to_string(), 900), ("B".to_string(), 100)]
891 .into_iter()
892 .collect();
893 let dist = ClassDistribution::from_counts(counts);
894 let report = ImbalanceReport::from_distribution("label", dist);
895
896 assert!(report.is_imbalanced());
897 assert!(report
898 .recommendations
899 .contains(&ImbalanceRecommendation::UseStratifiedSplit));
900 assert!(report
901 .recommendations
902 .contains(&ImbalanceRecommendation::UseAppropriateMetrics));
903 }
904
905 #[test]
906 fn test_report_severely_imbalanced() {
907 let counts: HashMap<String, usize> = [("A".to_string(), 9900), ("B".to_string(), 100)]
908 .into_iter()
909 .collect();
910 let dist = ClassDistribution::from_counts(counts);
911 let report = ImbalanceReport::from_distribution("label", dist);
912
913 assert_eq!(report.severity(), ImbalanceSeverity::Severe);
914 assert!(report
915 .recommendations
916 .contains(&ImbalanceRecommendation::ConsiderSMOTE));
917 assert!(report
918 .recommendations
919 .contains(&ImbalanceRecommendation::CollectMoreData));
920 }
921
922 #[test]
923 fn test_report_extremely_imbalanced() {
924 let counts: HashMap<String, usize> = [("A".to_string(), 10000), ("B".to_string(), 10)]
925 .into_iter()
926 .collect();
927 let dist = ClassDistribution::from_counts(counts);
928 let report = ImbalanceReport::from_distribution("label", dist);
929
930 assert_eq!(report.severity(), ImbalanceSeverity::Extreme);
931 assert!(report
932 .recommendations
933 .contains(&ImbalanceRecommendation::ConsiderAnomalyDetection));
934 }
935
936 fn make_string_dataset(labels: Vec<&str>) -> ArrowDataset {
939 let schema = Arc::new(Schema::new(vec![Field::new(
940 "label",
941 DataType::Utf8,
942 false,
943 )]));
944
945 let batch = RecordBatch::try_new(
946 Arc::clone(&schema),
947 vec![Arc::new(StringArray::from(labels))],
948 )
949 .expect("batch");
950
951 ArrowDataset::from_batch(batch).expect("dataset")
952 }
953
954 fn make_int_dataset(labels: Vec<i32>) -> ArrowDataset {
955 let schema = Arc::new(Schema::new(vec![Field::new(
956 "label",
957 DataType::Int32,
958 false,
959 )]));
960
961 let batch = RecordBatch::try_new(
962 Arc::clone(&schema),
963 vec![Arc::new(Int32Array::from(labels))],
964 )
965 .expect("batch");
966
967 ArrowDataset::from_batch(batch).expect("dataset")
968 }
969
970 #[test]
971 fn test_detector_new() {
972 let detector = ImbalanceDetector::new("label");
973 assert_eq!(detector.label_column(), "label");
974 }
975
976 #[test]
977 fn test_detector_analyze_balanced() {
978 let labels: Vec<&str> = (0..100).map(|i| if i < 50 { "A" } else { "B" }).collect();
979 let dataset = make_string_dataset(labels);
980
981 let detector = ImbalanceDetector::new("label");
982 let report = detector.analyze(&dataset).expect("analyze");
983
984 assert!(!report.is_imbalanced());
985 assert_eq!(report.distribution.total, 100);
986 assert_eq!(report.distribution.num_classes, 2);
987 }
988
989 #[test]
990 fn test_detector_analyze_imbalanced() {
991 let mut labels: Vec<&str> = vec!["A"; 90];
992 labels.extend(vec!["B"; 10]);
993 let dataset = make_string_dataset(labels);
994
995 let detector = ImbalanceDetector::new("label");
996 let report = detector.analyze(&dataset).expect("analyze");
997
998 assert!(report.is_imbalanced());
999 assert_eq!(report.distribution.majority_class, Some("A".to_string()));
1000 assert_eq!(report.distribution.minority_class, Some("B".to_string()));
1001 }
1002
1003 #[test]
1004 fn test_detector_analyze_int_labels() {
1005 let labels: Vec<i32> = (0..100).map(|i| if i < 80 { 0 } else { 1 }).collect();
1006 let dataset = make_int_dataset(labels);
1007
1008 let detector = ImbalanceDetector::new("label");
1009 let report = detector.analyze(&dataset).expect("analyze");
1010
1011 assert!(report.is_imbalanced());
1012 assert_eq!(report.distribution.get_count("0"), 80);
1013 assert_eq!(report.distribution.get_count("1"), 20);
1014 }
1015
1016 #[test]
1017 fn test_detector_missing_column() {
1018 let dataset = make_string_dataset(vec!["A", "B", "A"]);
1019
1020 let detector = ImbalanceDetector::new("nonexistent");
1021 let result = detector.analyze(&dataset);
1022
1023 assert!(result.is_err());
1024 }
1025
1026 #[test]
1027 fn test_detector_multiclass() {
1028 let mut labels = vec!["A"; 50];
1029 labels.extend(vec!["B"; 30]);
1030 labels.extend(vec!["C"; 20]);
1031 let dataset = make_string_dataset(labels);
1032
1033 let detector = ImbalanceDetector::new("label");
1034 let report = detector.analyze(&dataset).expect("analyze");
1035
1036 assert_eq!(report.distribution.num_classes, 3);
1037 assert_eq!(report.distribution.majority_class, Some("A".to_string()));
1038 assert_eq!(report.distribution.minority_class, Some("C".to_string()));
1039 }
1040
1041 #[cfg(feature = "shuffle")]
1042 #[test]
1043 fn test_resample_oversample() {
1044 let schema = Arc::new(Schema::new(vec![
1046 Field::new("text", DataType::Utf8, false),
1047 Field::new("label", DataType::Int32, false),
1048 ]));
1049 let mut texts = Vec::new();
1050 let mut labels = Vec::new();
1051 for i in 0..10 {
1052 texts.push(format!("text_a_{i}"));
1053 labels.push(0i32);
1054 }
1055 for i in 0..2 {
1056 texts.push(format!("text_b_{i}"));
1057 labels.push(1i32);
1058 }
1059 let batch = RecordBatch::try_new(
1060 schema,
1061 vec![
1062 Arc::new(StringArray::from(texts)),
1063 Arc::new(Int32Array::from(labels)),
1064 ],
1065 )
1066 .unwrap();
1067 let ds = ArrowDataset::from_batch(batch).unwrap();
1068
1069 let result = resample(&ds, "label", ResampleStrategy::Oversample, 42).unwrap();
1070 assert_eq!(result.len(), 20);
1072 }
1073
1074 #[cfg(feature = "shuffle")]
1075 #[test]
1076 fn test_resample_undersample() {
1077 let schema = Arc::new(Schema::new(vec![
1078 Field::new("text", DataType::Utf8, false),
1079 Field::new("label", DataType::Int32, false),
1080 ]));
1081 let mut texts = Vec::new();
1082 let mut labels = Vec::new();
1083 for i in 0..10 {
1084 texts.push(format!("text_a_{i}"));
1085 labels.push(0i32);
1086 }
1087 for i in 0..2 {
1088 texts.push(format!("text_b_{i}"));
1089 labels.push(1i32);
1090 }
1091 let batch = RecordBatch::try_new(
1092 schema,
1093 vec![
1094 Arc::new(StringArray::from(texts)),
1095 Arc::new(Int32Array::from(labels)),
1096 ],
1097 )
1098 .unwrap();
1099 let ds = ArrowDataset::from_batch(batch).unwrap();
1100
1101 let result = resample(&ds, "label", ResampleStrategy::Undersample, 42).unwrap();
1102 assert_eq!(result.len(), 4);
1104 }
1105
1106 #[test]
1107 fn test_sqrt_inverse_weights() {
1108 let counts = vec![17252, 2402, 2858, 2875, 3920];
1110 let weights = sqrt_inverse_weights(&counts);
1111 assert_eq!(weights.len(), 5);
1112 assert!(weights[0] < weights[1]);
1114 let sum: f32 = weights.iter().sum();
1116 assert!((sum - 5.0).abs() < 0.01);
1117 }
1118}