imbalanced-sampling 0.1.0

Resampling algorithms for imbalanced datasets in Rust - SMOTE, ADASYN, RandomUnderSampler
Documentation
// imbalanced-sampling/src/adasyn.rs
use imbalanced_core::traits::*;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
use rand::prelude::*;
use std::collections::HashMap;

/// ADASYN (Adaptive Synthetic Sampling) implementation
/// 
/// ADASYN improves upon SMOTE by focusing synthetic sample generation
/// on minority class examples that are harder to learn, determined by
/// their k-nearest neighbor density distribution.
#[derive(Debug, Clone)]
pub struct AdasynStrategy {
    k_neighbors: usize,
    beta: f64, // Balance level after generation (0.0 to 1.0)
}

/// Configuration for ADASYN
#[derive(Debug, Clone)]
pub struct AdasynConfig {
    /// Number of nearest neighbors to use
    pub k_neighbors: usize,
    /// Balance level (0.0 = no generation, 1.0 = perfect balance)
    pub beta: f64,
    /// Random seed
    pub random_state: Option<u64>,
}

impl Default for AdasynConfig {
    fn default() -> Self {
        Self {
            k_neighbors: 5,
            beta: 1.0,
            random_state: None,
        }
    }
}

impl AdasynStrategy {
    /// Create a new ADASYN strategy
    pub fn new(k_neighbors: usize, beta: f64) -> Self {
        Self { k_neighbors, beta }
    }
    
    /// Create with default configuration
    pub fn default() -> Self {
        Self::new(5, 1.0)
    }
    
    /// Calculate density distribution for minority samples
    fn calculate_density_distribution(
        &self,
        x: ArrayView2<f64>,
        y: ArrayView1<i32>,
        minority_class: i32,
        _majority_count: usize,
    ) -> Result<Vec<f64>, ResamplingError> {
        let minority_indices: Vec<usize> = y.iter()
            .enumerate()
            .filter(|(_, &label)| label == minority_class)
            .map(|(idx, _)| idx)
            .collect();
            
        if minority_indices.is_empty() {
            return Err(ResamplingError::InsufficientSamples);
        }
        
        let mut density_ratios = Vec::with_capacity(minority_indices.len());
        
        for &minority_idx in &minority_indices {
            let sample = x.row(minority_idx);
            
            // Find k+1 nearest neighbors (including self)
            let mut distances: Vec<(usize, f64)> = (0..x.nrows())
                .map(|idx| {
                    let neighbor = x.row(idx);
                    let dist = sample.iter()
                        .zip(neighbor.iter())
                        .map(|(a, b)| (a - b).powi(2))
                        .sum::<f64>()
                        .sqrt();
                    (idx, dist)
                })
                .collect();
            
            // Sort by distance and take k+1 nearest
            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
            let k_plus_1 = std::cmp::min(self.k_neighbors + 1, distances.len());
            
            // Count majority class samples among k nearest neighbors (excluding self)
            let majority_neighbors = distances[1..k_plus_1].iter()
                .filter(|(idx, _)| y[*idx] != minority_class)
                .count();
            
            // Density ratio: proportion of majority class neighbors
            let density_ratio = majority_neighbors as f64 / self.k_neighbors as f64;
            density_ratios.push(density_ratio);
        }
        
        // Normalize density ratios
        let sum_ratios: f64 = density_ratios.iter().sum();
        if sum_ratios > 0.0 {
            for ratio in &mut density_ratios {
                *ratio /= sum_ratios;
            }
        }
        
        Ok(density_ratios)
    }
}

impl ResamplingStrategy for AdasynStrategy {
    type Input = ();
    type Output = (Array2<f64>, Array1<i32>);
    type Config = AdasynConfig;
    
    fn resample(
        &self,
        x: ArrayView2<f64>,
        y: ArrayView1<i32>,
        config: &Self::Config,
    ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
        if x.nrows() != y.len() {
            return Err(ResamplingError::InvalidInput(
                "Feature matrix and target array must have same number of samples".to_string()
            ));
        }
        
        if x.nrows() < self.k_neighbors {
            return Err(ResamplingError::InsufficientSamples);
        }
        
        // Count class frequencies
        let mut class_counts = HashMap::new();
        for &label in y.iter() {
            *class_counts.entry(label).or_insert(0) += 1;
        }
        
        if class_counts.len() < 2 {
            return Err(ResamplingError::InvalidInput(
                "Need at least 2 classes for resampling".to_string()
            ));
        }
        
        // Find majority and minority classes
        let max_count = *class_counts.values().max().unwrap();
        let minority_classes: Vec<_> = class_counts.iter()
            .filter(|(_, &count)| count < max_count)
            .map(|(&class, &count)| (class, count))
            .collect();
        
        if minority_classes.is_empty() {
            // Dataset is already balanced
            return Ok((x.to_owned(), y.to_owned()));
        }
        
        let mut synthetic_features = Vec::new();
        let mut synthetic_labels = Vec::new();
        
        let mut rng = if let Some(seed) = config.random_state {
            StdRng::seed_from_u64(seed)
        } else {
            StdRng::from_entropy()
        };
        
        // Process each minority class
        for (minority_class, minority_count) in minority_classes {
            // Calculate number of synthetic samples to generate
            let desired_samples = ((max_count - minority_count) as f64 * self.beta) as usize;
            
            if desired_samples == 0 {
                continue;
            }
            
            // Get minority class indices
            let minority_indices: Vec<usize> = y.iter()
                .enumerate()
                .filter(|(_, &label)| label == minority_class)
                .map(|(idx, _)| idx)
                .collect();
            
            // Calculate density distribution
            let density_ratios = self.calculate_density_distribution(
                x, y, minority_class, max_count
            )?;
            
            // Generate synthetic samples based on density distribution
            for _ in 0..desired_samples {
                // Select minority sample based on density distribution
                let cumulative_prob = rng.gen::<f64>();
                let mut cumulative_sum = 0.0;
                let mut selected_idx = 0;
                
                for (i, &ratio) in density_ratios.iter().enumerate() {
                    cumulative_sum += ratio;
                    if cumulative_prob <= cumulative_sum {
                        selected_idx = i;
                        break;
                    }
                }
                
                let sample_idx = minority_indices[selected_idx];
                let sample = x.row(sample_idx);
                
                // Find k nearest neighbors from same class
                let mut distances: Vec<(usize, f64)> = minority_indices.iter()
                    .map(|&idx| {
                        let neighbor = x.row(idx);
                        let dist = sample.iter()
                            .zip(neighbor.iter())
                            .map(|(a, b)| (a - b).powi(2))
                            .sum::<f64>()
                            .sqrt();
                        (idx, dist)
                    })
                    .collect();
                
                // Sort by distance and take k nearest (excluding self)
                distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
                let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
                
                if k == 0 {
                    continue;
                }
                
                // Select random neighbor from k nearest
                let neighbor_idx = distances[1 + rng.gen_range(0..k)].0;
                let neighbor = x.row(neighbor_idx);
                
                // Generate synthetic sample
                let alpha = rng.gen::<f64>();
                let synthetic_sample: Vec<f64> = sample.iter()
                    .zip(neighbor.iter())
                    .map(|(s, n)| s + alpha * (n - s))
                    .collect();
                
                synthetic_features.push(synthetic_sample);
                synthetic_labels.push(minority_class);
            }
        }
        
        // Combine original and synthetic data
        let n_original = x.nrows();
        let n_synthetic = synthetic_features.len();
        let n_total = n_original + n_synthetic;
        let n_features = x.ncols();
        
        let mut combined_x = Array2::zeros((n_total, n_features));
        let mut combined_y = Array1::zeros(n_total);
        
        // Copy original data
        combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
        combined_y.slice_mut(s![0..n_original]).assign(&y);
        
        // Add synthetic data
        for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
            let idx = n_original + i;
            for (j, &feature) in features.iter().enumerate() {
                combined_x[[idx, j]] = feature;
            }
            combined_y[idx] = *label;
        }
        
        Ok((combined_x, combined_y))
    }
    
    fn performance_hints(&self) -> PerformanceHints {
        PerformanceHints::new()
            .with_hint(PerformanceHint::Parallel)
            .with_hint(PerformanceHint::CacheFriendly)
    }
}