imbalanced-sampling 0.1.0

Resampling algorithms for imbalanced datasets in Rust - SMOTE, ADASYN, RandomUnderSampler
Documentation
// imbalanced-sampling/src/smote.rs
use imbalanced_core::traits::*;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rand::prelude::*;
use std::collections::HashMap;

/// SMOTE (Synthetic Minority Over-sampling Technique) implementation
#[derive(Debug, Clone)]
pub struct SmoteStrategy {
    k_neighbors: usize,
}

/// Configuration for SMOTE
#[derive(Debug, Clone)]
pub struct SmoteConfig {
    /// Number of nearest neighbors to use
    pub k_neighbors: usize,
    /// Random seed
    pub random_state: Option<u64>,
}

impl Default for SmoteConfig {
    fn default() -> Self {
        Self {
            k_neighbors: 5,
            random_state: None,
        }
    }
}

impl SmoteStrategy {
    /// Create a new SMOTE strategy with default k=5 neighbors
    pub fn new(k_neighbors: usize) -> Self {
        Self { k_neighbors }
    }
    
    /// Create with default configuration
    pub fn default() -> Self {
        Self::new(5)
    }
}

impl ResamplingStrategy for SmoteStrategy {
    type Input = ();
    type Output = (Array2<f64>, Array1<i32>);
    type Config = SmoteConfig;
    
    fn resample(
        &self,
        x: ArrayView2<f64>,
        y: ArrayView1<i32>,
        config: &Self::Config,
    ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
        if x.nrows() != y.len() {
            return Err(ResamplingError::InvalidInput(
                "Feature matrix and target array must have same number of samples".to_string()
            ));
        }
        
        if x.nrows() < self.k_neighbors {
            return Err(ResamplingError::InsufficientSamples);
        }
        
        // Count class frequencies
        let mut class_counts = HashMap::new();
        for &label in y.iter() {
            *class_counts.entry(label).or_insert(0) += 1;
        }
        
        if class_counts.len() < 2 {
            return Err(ResamplingError::InvalidInput(
                "Need at least 2 classes for resampling".to_string()
            ));
        }
        
        // Find majority class count
        let max_count = *class_counts.values().max().unwrap();
        
        // Generate synthetic samples for minority classes
        let mut synthetic_features = Vec::new();
        let mut synthetic_labels = Vec::new();
        
        let mut rng = if let Some(seed) = config.random_state {
            StdRng::seed_from_u64(seed)
        } else {
            StdRng::from_entropy()
        };
        
        for (&class_label, &count) in &class_counts {
            if count < max_count {
                let n_synthetic = max_count - count;
                
                // Find indices of this minority class
                let minority_indices: Vec<usize> = y.iter()
                    .enumerate()
                    .filter(|(_, &label)| label == class_label)
                    .map(|(idx, _)| idx)
                    .collect();
                
                // Generate synthetic samples
                for _ in 0..n_synthetic {
                    // Randomly select a minority sample
                    let sample_idx = minority_indices[rng.gen_range(0..minority_indices.len())];
                    let sample = x.row(sample_idx);
                    
                    // Find k nearest neighbors from the same class
                    let mut distances: Vec<(usize, f64)> = minority_indices.iter()
                        .map(|&idx| {
                            let neighbor = x.row(idx);
                            let dist = sample.iter()
                                .zip(neighbor.iter())
                                .map(|(a, b)| (a - b).powi(2))
                                .sum::<f64>()
                                .sqrt();
                            (idx, dist)
                        })
                        .collect();
                    
                    // Sort by distance and take k nearest (excluding self)
                    distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
                    let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
                    
                    if k == 0 {
                        continue; // Skip if no neighbors available
                    }
                    
                    // Select random neighbor from k nearest
                    let neighbor_idx = distances[1 + rng.gen_range(0..k)].0; // Skip self at index 0
                    let neighbor = x.row(neighbor_idx);
                    
                    // Generate synthetic sample between sample and neighbor
                    let alpha = rng.gen::<f64>(); // Random value between 0 and 1
                    let synthetic_sample: Vec<f64> = sample.iter()
                        .zip(neighbor.iter())
                        .map(|(s, n)| s + alpha * (n - s))
                        .collect();
                    
                    synthetic_features.push(synthetic_sample);
                    synthetic_labels.push(class_label);
                }
            }
        }
        
        // Combine original and synthetic data
        let n_original = x.nrows();
        let n_synthetic = synthetic_features.len();
        let n_total = n_original + n_synthetic;
        let n_features = x.ncols();
        
        let mut combined_x = Array2::zeros((n_total, n_features));
        let mut combined_y = Array1::zeros(n_total);
        
        // Copy original data
        combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
        combined_y.slice_mut(s![0..n_original]).assign(&y);
        
        // Add synthetic data
        for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
            let idx = n_original + i;
            for (j, &feature) in features.iter().enumerate() {
                combined_x[[idx, j]] = feature;
            }
            combined_y[idx] = *label;
        }
        
        Ok((combined_x, combined_y))
    }
    
    fn performance_hints(&self) -> PerformanceHints {
        PerformanceHints::new()
            .with_hint(PerformanceHint::Parallel)
            .with_hint(PerformanceHint::CacheFriendly)
    }
}

// Need to import slice syntax
use ndarray::s;