Skip to main content

alimentar/
imbalance.rs

1//! Imbalanced dataset detection for ML pipelines
2//!
3//! Detects class imbalance in classification datasets and provides
4//! recommendations for handling strategies.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use alimentar::imbalance::ImbalanceDetector;
10//!
11//! let detector = ImbalanceDetector::new("label");
12//! let report = detector.analyze(&dataset)?;
13//!
14//! if report.is_imbalanced() {
15//!     println!("Imbalance ratio: {:.2}", report.metrics.imbalance_ratio);
16//!     for rec in &report.recommendations {
17//!         println!("Recommendation: {}", rec);
18//!     }
19//! }
20//! ```
21
22// Statistical computation requires usize->f64 casts
23#![allow(clippy::cast_precision_loss)]
24
25use std::collections::HashMap;
26
27use crate::{
28    dataset::{ArrowDataset, Dataset},
29    error::{Error, Result},
30};
31
32/// Severity of class imbalance
33#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
34pub enum ImbalanceSeverity {
35    /// Balanced dataset (ratio < 1.5)
36    None,
37    /// Slight imbalance (1.5 <= ratio < 3)
38    Low,
39    /// Moderate imbalance (3 <= ratio < 10)
40    Moderate,
41    /// Severe imbalance (10 <= ratio < 100)
42    Severe,
43    /// Extreme imbalance (ratio >= 100)
44    Extreme,
45}
46
47impl ImbalanceSeverity {
48    /// Create severity from imbalance ratio (majority/minority)
49    pub fn from_ratio(ratio: f64) -> Self {
50        if ratio < 1.5 {
51            Self::None
52        } else if ratio < 3.0 {
53            Self::Low
54        } else if ratio < 10.0 {
55            Self::Moderate
56        } else if ratio < 100.0 {
57            Self::Severe
58        } else {
59            Self::Extreme
60        }
61    }
62
63    /// Check if this represents actual imbalance
64    pub fn is_imbalanced(&self) -> bool {
65        *self != Self::None
66    }
67
68    /// Get human-readable description
69    pub fn description(&self) -> &'static str {
70        match self {
71            Self::None => "Balanced",
72            Self::Low => "Slightly imbalanced",
73            Self::Moderate => "Moderately imbalanced",
74            Self::Severe => "Severely imbalanced",
75            Self::Extreme => "Extremely imbalanced",
76        }
77    }
78}
79
80/// Metrics for measuring class imbalance
81#[derive(Debug, Clone)]
82pub struct ImbalanceMetrics {
83    /// Ratio of majority to minority class (>= 1.0)
84    pub imbalance_ratio: f64,
85    /// Shannon entropy of class distribution (0 = single class, log(n) =
86    /// uniform)
87    pub entropy: f64,
88    /// Normalized entropy (0-1, 1 = perfectly balanced)
89    pub normalized_entropy: f64,
90    /// Gini impurity (0 = single class, 1-1/n = uniform)
91    pub gini: f64,
92    /// Severity classification
93    pub severity: ImbalanceSeverity,
94}
95
96impl ImbalanceMetrics {
97    /// Create metrics from class counts
98    pub fn from_counts(counts: &HashMap<String, usize>) -> Self {
99        if counts.is_empty() {
100            return Self {
101                imbalance_ratio: 1.0,
102                entropy: 0.0,
103                normalized_entropy: 1.0,
104                gini: 0.0,
105                severity: ImbalanceSeverity::None,
106            };
107        }
108
109        let total: usize = counts.values().sum();
110        if total == 0 {
111            return Self {
112                imbalance_ratio: 1.0,
113                entropy: 0.0,
114                normalized_entropy: 1.0,
115                gini: 0.0,
116                severity: ImbalanceSeverity::None,
117            };
118        }
119
120        let total_f = total as f64;
121        let n_classes = counts.len();
122
123        // Imbalance ratio
124        let max_count = counts.values().copied().max().unwrap_or(0);
125        let min_count = counts.values().copied().min().unwrap_or(0);
126        let imbalance_ratio = if min_count > 0 {
127            max_count as f64 / min_count as f64
128        } else {
129            f64::INFINITY
130        };
131
132        // Shannon entropy: -sum(p * log(p))
133        let entropy: f64 = counts
134            .values()
135            .map(|&c| {
136                if c > 0 {
137                    let p = c as f64 / total_f;
138                    -p * p.ln()
139                } else {
140                    0.0
141                }
142            })
143            .sum();
144
145        // Normalized entropy (relative to maximum possible)
146        let max_entropy = (n_classes as f64).ln();
147        let normalized_entropy = if max_entropy > 0.0 {
148            entropy / max_entropy
149        } else {
150            1.0
151        };
152
153        // Gini impurity: 1 - sum(p^2)
154        let gini: f64 = 1.0
155            - counts
156                .values()
157                .map(|&c| {
158                    let p = c as f64 / total_f;
159                    p * p
160                })
161                .sum::<f64>();
162
163        let severity = ImbalanceSeverity::from_ratio(imbalance_ratio);
164
165        Self {
166            imbalance_ratio,
167            entropy,
168            normalized_entropy,
169            gini,
170            severity,
171        }
172    }
173
174    /// Check if the dataset is imbalanced
175    pub fn is_imbalanced(&self) -> bool {
176        self.severity.is_imbalanced()
177    }
178}
179
180/// Distribution of classes in a dataset
181#[derive(Debug, Clone)]
182pub struct ClassDistribution {
183    /// Count per class
184    pub counts: HashMap<String, usize>,
185    /// Proportion per class (0-1)
186    pub proportions: HashMap<String, f64>,
187    /// Total number of samples
188    pub total: usize,
189    /// Number of unique classes
190    pub num_classes: usize,
191    /// Majority class name
192    pub majority_class: Option<String>,
193    /// Minority class name
194    pub minority_class: Option<String>,
195}
196
197impl ClassDistribution {
198    /// Create distribution from class counts
199    pub fn from_counts(counts: HashMap<String, usize>) -> Self {
200        let total: usize = counts.values().sum();
201        let num_classes = counts.len();
202
203        let proportions: HashMap<String, f64> = counts
204            .iter()
205            .map(|(k, &v)| {
206                let prop = if total > 0 {
207                    v as f64 / total as f64
208                } else {
209                    0.0
210                };
211                (k.clone(), prop)
212            })
213            .collect();
214
215        let majority_class = counts
216            .iter()
217            .max_by_key(|(_, &v)| v)
218            .map(|(k, _)| k.clone());
219
220        let minority_class = counts
221            .iter()
222            .filter(|(_, &v)| v > 0)
223            .min_by_key(|(_, &v)| v)
224            .map(|(k, _)| k.clone());
225
226        Self {
227            counts,
228            proportions,
229            total,
230            num_classes,
231            majority_class,
232            minority_class,
233        }
234    }
235
236    /// Get count for a specific class
237    pub fn get_count(&self, class: &str) -> usize {
238        self.counts.get(class).copied().unwrap_or(0)
239    }
240
241    /// Get proportion for a specific class
242    pub fn get_proportion(&self, class: &str) -> f64 {
243        self.proportions.get(class).copied().unwrap_or(0.0)
244    }
245}
246
247/// Recommendation for handling imbalanced data
248#[derive(Debug, Clone, PartialEq, Eq)]
249pub enum ImbalanceRecommendation {
250    /// No action needed
251    NoAction,
252    /// Use stratified sampling for train/test splits
253    UseStratifiedSplit,
254    /// Consider class weights in model training
255    UseClassWeights,
256    /// Consider oversampling minority class
257    ConsiderOversampling,
258    /// Consider undersampling majority class
259    ConsiderUndersampling,
260    /// Consider SMOTE or similar synthetic generation
261    ConsiderSMOTE,
262    /// Collect more data for minority classes
263    CollectMoreData,
264    /// Use appropriate metrics (F1, AUC-ROC, not accuracy)
265    UseAppropriateMetrics,
266    /// Consider anomaly detection approach
267    ConsiderAnomalyDetection,
268}
269
270impl ImbalanceRecommendation {
271    /// Get human-readable description
272    pub fn description(&self) -> &'static str {
273        match self {
274            Self::NoAction => "No action needed - dataset is balanced",
275            Self::UseStratifiedSplit => "Use stratified sampling for train/test splits",
276            Self::UseClassWeights => "Apply class weights during model training",
277            Self::ConsiderOversampling => "Consider oversampling the minority class",
278            Self::ConsiderUndersampling => "Consider undersampling the majority class",
279            Self::ConsiderSMOTE => "Consider SMOTE or synthetic minority oversampling",
280            Self::CollectMoreData => "Collect more samples for minority classes",
281            Self::UseAppropriateMetrics => {
282                "Use F1-score, AUC-ROC, or precision-recall instead of accuracy"
283            }
284            Self::ConsiderAnomalyDetection => "Consider framing as anomaly detection problem",
285        }
286    }
287}
288
289impl std::fmt::Display for ImbalanceRecommendation {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        write!(f, "{}", self.description())
292    }
293}
294
295/// Report from imbalance analysis
296#[derive(Debug, Clone)]
297pub struct ImbalanceReport {
298    /// Column analyzed
299    pub column: String,
300    /// Class distribution
301    pub distribution: ClassDistribution,
302    /// Imbalance metrics
303    pub metrics: ImbalanceMetrics,
304    /// Recommendations
305    pub recommendations: Vec<ImbalanceRecommendation>,
306}
307
308impl ImbalanceReport {
309    /// Create report from distribution
310    pub fn from_distribution(column: impl Into<String>, distribution: ClassDistribution) -> Self {
311        let metrics = ImbalanceMetrics::from_counts(&distribution.counts);
312        let recommendations = generate_recommendations(&metrics, &distribution);
313
314        Self {
315            column: column.into(),
316            distribution,
317            metrics,
318            recommendations,
319        }
320    }
321
322    /// Check if the dataset is imbalanced
323    pub fn is_imbalanced(&self) -> bool {
324        self.metrics.is_imbalanced()
325    }
326
327    /// Get severity
328    pub fn severity(&self) -> ImbalanceSeverity {
329        self.metrics.severity
330    }
331}
332
333/// Generate recommendations based on metrics
334fn generate_recommendations(
335    metrics: &ImbalanceMetrics,
336    distribution: &ClassDistribution,
337) -> Vec<ImbalanceRecommendation> {
338    let mut recs = Vec::new();
339
340    match metrics.severity {
341        ImbalanceSeverity::None => {
342            recs.push(ImbalanceRecommendation::NoAction);
343        }
344        ImbalanceSeverity::Low => {
345            recs.push(ImbalanceRecommendation::UseStratifiedSplit);
346            recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
347        }
348        ImbalanceSeverity::Moderate => {
349            recs.push(ImbalanceRecommendation::UseStratifiedSplit);
350            recs.push(ImbalanceRecommendation::UseClassWeights);
351            recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
352            if distribution.total < 10000 {
353                recs.push(ImbalanceRecommendation::ConsiderOversampling);
354            } else {
355                recs.push(ImbalanceRecommendation::ConsiderUndersampling);
356            }
357        }
358        ImbalanceSeverity::Severe => {
359            recs.push(ImbalanceRecommendation::UseStratifiedSplit);
360            recs.push(ImbalanceRecommendation::UseClassWeights);
361            recs.push(ImbalanceRecommendation::ConsiderSMOTE);
362            recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
363            recs.push(ImbalanceRecommendation::CollectMoreData);
364        }
365        ImbalanceSeverity::Extreme => {
366            recs.push(ImbalanceRecommendation::ConsiderAnomalyDetection);
367            recs.push(ImbalanceRecommendation::UseStratifiedSplit);
368            recs.push(ImbalanceRecommendation::ConsiderSMOTE);
369            recs.push(ImbalanceRecommendation::CollectMoreData);
370            recs.push(ImbalanceRecommendation::UseAppropriateMetrics);
371        }
372    }
373
374    recs
375}
376
377/// Detector for class imbalance in datasets
378pub struct ImbalanceDetector {
379    /// Label column name
380    label_column: String,
381}
382
383impl ImbalanceDetector {
384    /// Create a new imbalance detector
385    pub fn new(label_column: impl Into<String>) -> Self {
386        Self {
387            label_column: label_column.into(),
388        }
389    }
390
391    /// Get the label column name
392    pub fn label_column(&self) -> &str {
393        &self.label_column
394    }
395
396    /// Analyze a dataset for class imbalance
397    pub fn analyze(&self, dataset: &ArrowDataset) -> Result<ImbalanceReport> {
398        let counts = self.count_classes(dataset)?;
399        let distribution = ClassDistribution::from_counts(counts);
400        Ok(ImbalanceReport::from_distribution(
401            &self.label_column,
402            distribution,
403        ))
404    }
405
406    /// Count class occurrences
407    fn count_classes(&self, dataset: &ArrowDataset) -> Result<HashMap<String, usize>> {
408        use arrow::array::{Array, Int32Array, Int64Array, StringArray};
409
410        let schema = dataset.schema();
411        let col_idx = schema
412            .fields()
413            .iter()
414            .position(|f| f.name() == &self.label_column)
415            .ok_or_else(|| {
416                Error::invalid_config(format!(
417                    "Column '{}' not found in schema",
418                    self.label_column
419                ))
420            })?;
421
422        let mut counts: HashMap<String, usize> = HashMap::new();
423
424        for batch in dataset.iter() {
425            let array = batch.column(col_idx);
426
427            // Handle different array types
428            if let Some(arr) = array.as_any().downcast_ref::<StringArray>() {
429                for i in 0..arr.len() {
430                    if !arr.is_null(i) {
431                        let key = arr.value(i).to_string();
432                        *counts.entry(key).or_insert(0) += 1;
433                    }
434                }
435            } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
436                for i in 0..arr.len() {
437                    if !arr.is_null(i) {
438                        let key = arr.value(i).to_string();
439                        *counts.entry(key).or_insert(0) += 1;
440                    }
441                }
442            } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
443                for i in 0..arr.len() {
444                    if !arr.is_null(i) {
445                        let key = arr.value(i).to_string();
446                        *counts.entry(key).or_insert(0) += 1;
447                    }
448                }
449            } else {
450                return Err(Error::invalid_config(format!(
451                    "Unsupported column type for '{}'. Expected string or integer.",
452                    self.label_column
453                )));
454            }
455        }
456
457        if counts.is_empty() {
458            return Err(Error::invalid_config(format!(
459                "No valid values found in column '{}'",
460                self.label_column
461            )));
462        }
463
464        Ok(counts)
465    }
466}
467
468/// Strategy for resampling an imbalanced dataset.
469#[derive(Debug, Clone, Copy, PartialEq, Eq)]
470pub enum ResampleStrategy {
471    /// Duplicate minority class samples to match majority count.
472    Oversample,
473    /// Reduce majority class samples to match minority count.
474    Undersample,
475}
476
477/// Resample a classification dataset to address class imbalance.
478///
479/// Given a dataset and a label column, this function either oversamples
480/// minority classes (duplicating rows) or undersamples the majority class
481/// (removing rows) to produce a more balanced dataset.
482///
483/// # Arguments
484/// * `dataset` - Source dataset
485/// * `label_column` - Name of the integer or string label column
486/// * `strategy` - Oversample or Undersample
487/// * `seed` - Random seed for deterministic undersampling
488///
489/// # Errors
490/// Returns error if label column not found or dataset is empty.
491#[cfg(feature = "shuffle")]
492#[allow(clippy::cast_possible_truncation)]
493/// Group row indices by label value from an Arrow array column.
494///
495/// Supports `Int64Array`, `Int32Array`, and `StringArray` label types.
496fn group_rows_by_label(
497    label_array: &dyn arrow::array::Array,
498    n: usize,
499    label_column: &str,
500) -> Result<std::collections::HashMap<String, Vec<u32>>> {
501    use arrow::array::{Array, Int32Array, Int64Array, StringArray};
502
503    let mut groups: std::collections::HashMap<String, Vec<u32>> = std::collections::HashMap::new();
504
505    if let Some(arr) = label_array.as_any().downcast_ref::<Int64Array>() {
506        for i in 0..n {
507            if !arr.is_null(i) {
508                groups
509                    .entry(arr.value(i).to_string())
510                    .or_default()
511                    .push(i as u32);
512            }
513        }
514    } else if let Some(arr) = label_array.as_any().downcast_ref::<Int32Array>() {
515        for i in 0..n {
516            if !arr.is_null(i) {
517                groups
518                    .entry(arr.value(i).to_string())
519                    .or_default()
520                    .push(i as u32);
521            }
522        }
523    } else if let Some(arr) = label_array.as_any().downcast_ref::<StringArray>() {
524        for i in 0..n {
525            if !arr.is_null(i) {
526                groups
527                    .entry(arr.value(i).to_string())
528                    .or_default()
529                    .push(i as u32);
530            }
531        }
532    } else {
533        return Err(Error::invalid_config(format!(
534            "Unsupported column type for '{label_column}'"
535        )));
536    }
537
538    Ok(groups)
539}
540
541/// Balance group indices to a target count via oversampling or undersampling.
542///
543/// Groups smaller than `target` are repeated; groups larger are randomly
544/// subsampled.
545#[cfg(feature = "shuffle")]
546fn balance_group_indices(
547    groups: &std::collections::HashMap<String, Vec<u32>>,
548    target: usize,
549    rng: &mut rand::rngs::StdRng,
550) -> Vec<u32> {
551    use rand::seq::SliceRandom;
552
553    let mut all_indices: Vec<u32> = Vec::new();
554    for indices in groups.values() {
555        match indices.len().cmp(&target) {
556            std::cmp::Ordering::Equal => {
557                all_indices.extend_from_slice(indices);
558            }
559            std::cmp::Ordering::Less => {
560                all_indices.extend_from_slice(indices);
561                let mut extra: Vec<u32> = Vec::with_capacity(target - indices.len());
562                while extra.len() + indices.len() < target {
563                    extra.extend_from_slice(indices);
564                }
565                extra.truncate(target - indices.len());
566                extra.shuffle(rng);
567                all_indices.extend(extra);
568            }
569            std::cmp::Ordering::Greater => {
570                let mut sampled = indices.clone();
571                sampled.shuffle(rng);
572                sampled.truncate(target);
573                all_indices.extend(sampled);
574            }
575        }
576    }
577    all_indices.shuffle(rng);
578    all_indices
579}
580
581/// Resample a classification dataset to address class imbalance.
582///
583/// Given a dataset and a label column, either oversamples minority classes
584/// or undersamples majority classes to produce a balanced dataset.
585#[cfg(feature = "shuffle")]
586pub fn resample(
587    dataset: &ArrowDataset,
588    label_column: &str,
589    strategy: ResampleStrategy,
590    seed: u64,
591) -> Result<ArrowDataset> {
592    use arrow::compute;
593    use rand::SeedableRng;
594
595    let batches: Vec<arrow::array::RecordBatch> = dataset.iter().collect();
596    let batch = if batches.len() == 1 {
597        batches
598            .into_iter()
599            .next()
600            .ok_or_else(|| Error::empty_dataset("empty"))?
601    } else {
602        arrow::compute::concat_batches(&dataset.schema(), &batches).map_err(Error::Arrow)?
603    };
604    let schema = batch.schema();
605    let col_idx = schema
606        .fields()
607        .iter()
608        .position(|f| f.name() == label_column)
609        .ok_or_else(|| {
610            Error::invalid_config(format!("Column '{label_column}' not found in schema"))
611        })?;
612
613    let groups = group_rows_by_label(
614        batch.column(col_idx).as_ref(),
615        batch.num_rows(),
616        label_column,
617    )?;
618
619    if groups.is_empty() {
620        return Err(Error::empty_dataset("No valid labels found for resampling"));
621    }
622
623    let target = match strategy {
624        ResampleStrategy::Oversample => groups.values().map(|v| v.len()).max().unwrap_or(0),
625        ResampleStrategy::Undersample => groups.values().map(|v| v.len()).min().unwrap_or(0),
626    };
627
628    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
629    let all_indices = balance_group_indices(&groups, target, &mut rng);
630
631    let indices_array = arrow::array::UInt32Array::from(all_indices);
632    let columns: Vec<arrow::array::ArrayRef> = (0..batch.num_columns())
633        .map(|i| compute::take(batch.column(i), &indices_array, None))
634        .collect::<std::result::Result<Vec<_>, _>>()
635        .map_err(|e| Error::invalid_config(format!("Arrow take failed: {e}")))?;
636
637    let result_batch = arrow::array::RecordBatch::try_new(schema, columns)
638        .map_err(|e| Error::invalid_config(format!("Failed to create resampled batch: {e}")))?;
639
640    ArrowDataset::from_batch(result_batch)
641}
642
643/// Compute sqrt-inverse class weights for weighted loss.
644///
645/// Returns a vector of weights where `weights[i]` corresponds to class `i`.
646/// Weights are computed as `sqrt(N / (K * count_i))` and normalized to sum to
647/// K.
648///
649/// # Arguments
650/// * `class_counts` - Ordered counts per class (index = class label)
651///
652/// # Returns
653/// Vector of weights, one per class
654pub fn sqrt_inverse_weights(class_counts: &[usize]) -> Vec<f32> {
655    let k = class_counts.len() as f64;
656    let n: f64 = class_counts.iter().sum::<usize>() as f64;
657
658    if k == 0.0 || n == 0.0 {
659        return Vec::new();
660    }
661
662    let raw: Vec<f64> = class_counts
663        .iter()
664        .map(|&c| {
665            if c == 0 {
666                0.0
667            } else {
668                (n / (k * c as f64)).sqrt()
669            }
670        })
671        .collect();
672
673    // Normalize so weights sum to K
674    let sum: f64 = raw.iter().sum();
675    if sum == 0.0 {
676        return vec![1.0; class_counts.len()];
677    }
678
679    raw.iter().map(|&w| (w * k / sum) as f32).collect()
680}
681
682#[cfg(test)]
683mod tests {
684    use std::sync::Arc;
685
686    use arrow::{
687        array::{Int32Array, StringArray},
688        datatypes::{DataType, Field, Schema},
689        record_batch::RecordBatch,
690    };
691
692    use super::*;
693
694    // ========== ImbalanceSeverity tests ==========
695
696    #[test]
697    fn test_severity_from_ratio() {
698        assert_eq!(ImbalanceSeverity::from_ratio(1.0), ImbalanceSeverity::None);
699        assert_eq!(ImbalanceSeverity::from_ratio(1.4), ImbalanceSeverity::None);
700        assert_eq!(ImbalanceSeverity::from_ratio(1.5), ImbalanceSeverity::Low);
701        assert_eq!(ImbalanceSeverity::from_ratio(2.9), ImbalanceSeverity::Low);
702        assert_eq!(
703            ImbalanceSeverity::from_ratio(3.0),
704            ImbalanceSeverity::Moderate
705        );
706        assert_eq!(
707            ImbalanceSeverity::from_ratio(9.9),
708            ImbalanceSeverity::Moderate
709        );
710        assert_eq!(
711            ImbalanceSeverity::from_ratio(10.0),
712            ImbalanceSeverity::Severe
713        );
714        assert_eq!(
715            ImbalanceSeverity::from_ratio(99.0),
716            ImbalanceSeverity::Severe
717        );
718        assert_eq!(
719            ImbalanceSeverity::from_ratio(100.0),
720            ImbalanceSeverity::Extreme
721        );
722        assert_eq!(
723            ImbalanceSeverity::from_ratio(1000.0),
724            ImbalanceSeverity::Extreme
725        );
726    }
727
728    #[test]
729    fn test_severity_is_imbalanced() {
730        assert!(!ImbalanceSeverity::None.is_imbalanced());
731        assert!(ImbalanceSeverity::Low.is_imbalanced());
732        assert!(ImbalanceSeverity::Moderate.is_imbalanced());
733        assert!(ImbalanceSeverity::Severe.is_imbalanced());
734        assert!(ImbalanceSeverity::Extreme.is_imbalanced());
735    }
736
737    #[test]
738    fn test_severity_ordering() {
739        assert!(ImbalanceSeverity::None < ImbalanceSeverity::Low);
740        assert!(ImbalanceSeverity::Low < ImbalanceSeverity::Moderate);
741        assert!(ImbalanceSeverity::Moderate < ImbalanceSeverity::Severe);
742        assert!(ImbalanceSeverity::Severe < ImbalanceSeverity::Extreme);
743    }
744
745    #[test]
746    fn test_severity_description() {
747        assert_eq!(ImbalanceSeverity::None.description(), "Balanced");
748        assert_eq!(
749            ImbalanceSeverity::Extreme.description(),
750            "Extremely imbalanced"
751        );
752    }
753
754    // ========== ImbalanceMetrics tests ==========
755
756    #[test]
757    fn test_metrics_balanced() {
758        let counts: HashMap<String, usize> = [("A".to_string(), 100), ("B".to_string(), 100)]
759            .into_iter()
760            .collect();
761
762        let metrics = ImbalanceMetrics::from_counts(&counts);
763
764        assert!((metrics.imbalance_ratio - 1.0).abs() < 0.01);
765        assert!((metrics.normalized_entropy - 1.0).abs() < 0.01);
766        assert!((metrics.gini - 0.5).abs() < 0.01);
767        assert_eq!(metrics.severity, ImbalanceSeverity::None);
768        assert!(!metrics.is_imbalanced());
769    }
770
771    #[test]
772    fn test_metrics_imbalanced() {
773        let counts: HashMap<String, usize> = [("A".to_string(), 900), ("B".to_string(), 100)]
774            .into_iter()
775            .collect();
776
777        let metrics = ImbalanceMetrics::from_counts(&counts);
778
779        assert!((metrics.imbalance_ratio - 9.0).abs() < 0.01);
780        assert!(metrics.normalized_entropy < 0.8);
781        assert_eq!(metrics.severity, ImbalanceSeverity::Moderate);
782        assert!(metrics.is_imbalanced());
783    }
784
785    #[test]
786    fn test_metrics_severely_imbalanced() {
787        let counts: HashMap<String, usize> = [("A".to_string(), 990), ("B".to_string(), 10)]
788            .into_iter()
789            .collect();
790
791        let metrics = ImbalanceMetrics::from_counts(&counts);
792
793        assert!((metrics.imbalance_ratio - 99.0).abs() < 0.01);
794        assert_eq!(metrics.severity, ImbalanceSeverity::Severe);
795    }
796
797    #[test]
798    fn test_metrics_empty() {
799        let counts: HashMap<String, usize> = HashMap::new();
800        let metrics = ImbalanceMetrics::from_counts(&counts);
801
802        assert!((metrics.imbalance_ratio - 1.0).abs() < 0.01);
803        assert_eq!(metrics.severity, ImbalanceSeverity::None);
804    }
805
806    #[test]
807    fn test_metrics_single_class() {
808        let counts: HashMap<String, usize> = [("A".to_string(), 100)].into_iter().collect();
809
810        let metrics = ImbalanceMetrics::from_counts(&counts);
811
812        // Single class has infinite imbalance ratio (no minority)
813        assert!(metrics.imbalance_ratio.is_infinite() || metrics.imbalance_ratio >= 1.0);
814        assert!((metrics.entropy - 0.0).abs() < 0.01);
815        assert!((metrics.gini - 0.0).abs() < 0.01);
816    }
817
818    #[test]
819    fn test_metrics_multiclass_balanced() {
820        let counts: HashMap<String, usize> = [
821            ("A".to_string(), 100),
822            ("B".to_string(), 100),
823            ("C".to_string(), 100),
824        ]
825        .into_iter()
826        .collect();
827
828        let metrics = ImbalanceMetrics::from_counts(&counts);
829
830        assert!((metrics.imbalance_ratio - 1.0).abs() < 0.01);
831        assert!((metrics.normalized_entropy - 1.0).abs() < 0.01);
832    }
833
834    // ========== ClassDistribution tests ==========
835
836    #[test]
837    fn test_distribution_from_counts() {
838        let counts: HashMap<String, usize> = [("A".to_string(), 75), ("B".to_string(), 25)]
839            .into_iter()
840            .collect();
841
842        let dist = ClassDistribution::from_counts(counts);
843
844        assert_eq!(dist.total, 100);
845        assert_eq!(dist.num_classes, 2);
846        assert_eq!(dist.get_count("A"), 75);
847        assert_eq!(dist.get_count("B"), 25);
848        assert!((dist.get_proportion("A") - 0.75).abs() < 0.01);
849        assert!((dist.get_proportion("B") - 0.25).abs() < 0.01);
850        assert_eq!(dist.majority_class, Some("A".to_string()));
851        assert_eq!(dist.minority_class, Some("B".to_string()));
852    }
853
854    #[test]
855    fn test_distribution_missing_class() {
856        let counts: HashMap<String, usize> = [("A".to_string(), 100)].into_iter().collect();
857        let dist = ClassDistribution::from_counts(counts);
858
859        assert_eq!(dist.get_count("B"), 0);
860        assert!((dist.get_proportion("B") - 0.0).abs() < 0.01);
861    }
862
863    // ========== ImbalanceRecommendation tests ==========
864
865    #[test]
866    fn test_recommendation_display() {
867        let rec = ImbalanceRecommendation::UseStratifiedSplit;
868        assert!(rec.to_string().contains("stratified"));
869    }
870
871    // ========== ImbalanceReport tests ==========
872
873    #[test]
874    fn test_report_balanced() {
875        let counts: HashMap<String, usize> = [("A".to_string(), 100), ("B".to_string(), 100)]
876            .into_iter()
877            .collect();
878        let dist = ClassDistribution::from_counts(counts);
879        let report = ImbalanceReport::from_distribution("label", dist);
880
881        assert!(!report.is_imbalanced());
882        assert_eq!(report.severity(), ImbalanceSeverity::None);
883        assert!(report
884            .recommendations
885            .contains(&ImbalanceRecommendation::NoAction));
886    }
887
888    #[test]
889    fn test_report_imbalanced() {
890        let counts: HashMap<String, usize> = [("A".to_string(), 900), ("B".to_string(), 100)]
891            .into_iter()
892            .collect();
893        let dist = ClassDistribution::from_counts(counts);
894        let report = ImbalanceReport::from_distribution("label", dist);
895
896        assert!(report.is_imbalanced());
897        assert!(report
898            .recommendations
899            .contains(&ImbalanceRecommendation::UseStratifiedSplit));
900        assert!(report
901            .recommendations
902            .contains(&ImbalanceRecommendation::UseAppropriateMetrics));
903    }
904
905    #[test]
906    fn test_report_severely_imbalanced() {
907        let counts: HashMap<String, usize> = [("A".to_string(), 9900), ("B".to_string(), 100)]
908            .into_iter()
909            .collect();
910        let dist = ClassDistribution::from_counts(counts);
911        let report = ImbalanceReport::from_distribution("label", dist);
912
913        assert_eq!(report.severity(), ImbalanceSeverity::Severe);
914        assert!(report
915            .recommendations
916            .contains(&ImbalanceRecommendation::ConsiderSMOTE));
917        assert!(report
918            .recommendations
919            .contains(&ImbalanceRecommendation::CollectMoreData));
920    }
921
922    #[test]
923    fn test_report_extremely_imbalanced() {
924        let counts: HashMap<String, usize> = [("A".to_string(), 10000), ("B".to_string(), 10)]
925            .into_iter()
926            .collect();
927        let dist = ClassDistribution::from_counts(counts);
928        let report = ImbalanceReport::from_distribution("label", dist);
929
930        assert_eq!(report.severity(), ImbalanceSeverity::Extreme);
931        assert!(report
932            .recommendations
933            .contains(&ImbalanceRecommendation::ConsiderAnomalyDetection));
934    }
935
936    // ========== ImbalanceDetector tests ==========
937
938    fn make_string_dataset(labels: Vec<&str>) -> ArrowDataset {
939        let schema = Arc::new(Schema::new(vec![Field::new(
940            "label",
941            DataType::Utf8,
942            false,
943        )]));
944
945        let batch = RecordBatch::try_new(
946            Arc::clone(&schema),
947            vec![Arc::new(StringArray::from(labels))],
948        )
949        .expect("batch");
950
951        ArrowDataset::from_batch(batch).expect("dataset")
952    }
953
954    fn make_int_dataset(labels: Vec<i32>) -> ArrowDataset {
955        let schema = Arc::new(Schema::new(vec![Field::new(
956            "label",
957            DataType::Int32,
958            false,
959        )]));
960
961        let batch = RecordBatch::try_new(
962            Arc::clone(&schema),
963            vec![Arc::new(Int32Array::from(labels))],
964        )
965        .expect("batch");
966
967        ArrowDataset::from_batch(batch).expect("dataset")
968    }
969
970    #[test]
971    fn test_detector_new() {
972        let detector = ImbalanceDetector::new("label");
973        assert_eq!(detector.label_column(), "label");
974    }
975
976    #[test]
977    fn test_detector_analyze_balanced() {
978        let labels: Vec<&str> = (0..100).map(|i| if i < 50 { "A" } else { "B" }).collect();
979        let dataset = make_string_dataset(labels);
980
981        let detector = ImbalanceDetector::new("label");
982        let report = detector.analyze(&dataset).expect("analyze");
983
984        assert!(!report.is_imbalanced());
985        assert_eq!(report.distribution.total, 100);
986        assert_eq!(report.distribution.num_classes, 2);
987    }
988
989    #[test]
990    fn test_detector_analyze_imbalanced() {
991        let mut labels: Vec<&str> = vec!["A"; 90];
992        labels.extend(vec!["B"; 10]);
993        let dataset = make_string_dataset(labels);
994
995        let detector = ImbalanceDetector::new("label");
996        let report = detector.analyze(&dataset).expect("analyze");
997
998        assert!(report.is_imbalanced());
999        assert_eq!(report.distribution.majority_class, Some("A".to_string()));
1000        assert_eq!(report.distribution.minority_class, Some("B".to_string()));
1001    }
1002
1003    #[test]
1004    fn test_detector_analyze_int_labels() {
1005        let labels: Vec<i32> = (0..100).map(|i| if i < 80 { 0 } else { 1 }).collect();
1006        let dataset = make_int_dataset(labels);
1007
1008        let detector = ImbalanceDetector::new("label");
1009        let report = detector.analyze(&dataset).expect("analyze");
1010
1011        assert!(report.is_imbalanced());
1012        assert_eq!(report.distribution.get_count("0"), 80);
1013        assert_eq!(report.distribution.get_count("1"), 20);
1014    }
1015
1016    #[test]
1017    fn test_detector_missing_column() {
1018        let dataset = make_string_dataset(vec!["A", "B", "A"]);
1019
1020        let detector = ImbalanceDetector::new("nonexistent");
1021        let result = detector.analyze(&dataset);
1022
1023        assert!(result.is_err());
1024    }
1025
1026    #[test]
1027    fn test_detector_multiclass() {
1028        let mut labels = vec!["A"; 50];
1029        labels.extend(vec!["B"; 30]);
1030        labels.extend(vec!["C"; 20]);
1031        let dataset = make_string_dataset(labels);
1032
1033        let detector = ImbalanceDetector::new("label");
1034        let report = detector.analyze(&dataset).expect("analyze");
1035
1036        assert_eq!(report.distribution.num_classes, 3);
1037        assert_eq!(report.distribution.majority_class, Some("A".to_string()));
1038        assert_eq!(report.distribution.minority_class, Some("C".to_string()));
1039    }
1040
1041    #[cfg(feature = "shuffle")]
1042    #[test]
1043    fn test_resample_oversample() {
1044        // Create imbalanced dataset: A=10, B=2
1045        let schema = Arc::new(Schema::new(vec![
1046            Field::new("text", DataType::Utf8, false),
1047            Field::new("label", DataType::Int32, false),
1048        ]));
1049        let mut texts = Vec::new();
1050        let mut labels = Vec::new();
1051        for i in 0..10 {
1052            texts.push(format!("text_a_{i}"));
1053            labels.push(0i32);
1054        }
1055        for i in 0..2 {
1056            texts.push(format!("text_b_{i}"));
1057            labels.push(1i32);
1058        }
1059        let batch = RecordBatch::try_new(
1060            schema,
1061            vec![
1062                Arc::new(StringArray::from(texts)),
1063                Arc::new(Int32Array::from(labels)),
1064            ],
1065        )
1066        .unwrap();
1067        let ds = ArrowDataset::from_batch(batch).unwrap();
1068
1069        let result = resample(&ds, "label", ResampleStrategy::Oversample, 42).unwrap();
1070        // Both classes should have 10 samples each = 20 total
1071        assert_eq!(result.len(), 20);
1072    }
1073
1074    #[cfg(feature = "shuffle")]
1075    #[test]
1076    fn test_resample_undersample() {
1077        let schema = Arc::new(Schema::new(vec![
1078            Field::new("text", DataType::Utf8, false),
1079            Field::new("label", DataType::Int32, false),
1080        ]));
1081        let mut texts = Vec::new();
1082        let mut labels = Vec::new();
1083        for i in 0..10 {
1084            texts.push(format!("text_a_{i}"));
1085            labels.push(0i32);
1086        }
1087        for i in 0..2 {
1088            texts.push(format!("text_b_{i}"));
1089            labels.push(1i32);
1090        }
1091        let batch = RecordBatch::try_new(
1092            schema,
1093            vec![
1094                Arc::new(StringArray::from(texts)),
1095                Arc::new(Int32Array::from(labels)),
1096            ],
1097        )
1098        .unwrap();
1099        let ds = ArrowDataset::from_batch(batch).unwrap();
1100
1101        let result = resample(&ds, "label", ResampleStrategy::Undersample, 42).unwrap();
1102        // Both classes should have 2 samples each = 4 total
1103        assert_eq!(result.len(), 4);
1104    }
1105
1106    #[test]
1107    fn test_sqrt_inverse_weights() {
1108        // 5-class SSC-like distribution: [17252, 2402, 2858, 2875, 3920]
1109        let counts = vec![17252, 2402, 2858, 2875, 3920];
1110        let weights = sqrt_inverse_weights(&counts);
1111        assert_eq!(weights.len(), 5);
1112        // Safe class (majority) should have lowest weight
1113        assert!(weights[0] < weights[1]);
1114        // Weights should sum to approximately K=5
1115        let sum: f32 = weights.iter().sum();
1116        assert!((sum - 5.0).abs() < 0.01);
1117    }
1118}