imbalanced-sampling 0.1.0

Resampling algorithms for imbalanced datasets in Rust - SMOTE, ADASYN, RandomUnderSampler
Documentation
// imbalanced-sampling/src/random_undersampler.rs
use imbalanced_core::traits::*;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rand::prelude::*;
use std::collections::HashMap;

/// Random Under-sampling strategy
/// 
/// Reduces the majority class by randomly removing samples until
/// the desired class balance is achieved.
#[derive(Debug, Clone)]
pub struct RandomUnderSampler {
    sampling_strategy: SamplingMode,
}

/// Sampling strategy configuration
#[derive(Debug, Clone)]
pub enum SamplingMode {
    /// Balance to match minority class size
    Auto,
    /// Balance to specific ratio (0.0 to 1.0)
    Ratio(f64),
    /// Balance to specific target counts per class
    Targets(HashMap<i32, usize>),
}

/// Configuration for RandomUnderSampler
#[derive(Debug, Clone)]
pub struct RandomUnderSamplerConfig {
    /// Sampling strategy
    pub sampling_strategy: SamplingMode,
    /// Random seed
    pub random_state: Option<u64>,
    /// Whether to replace samples (with replacement)
    pub replacement: bool,
}

impl Default for RandomUnderSamplerConfig {
    fn default() -> Self {
        Self {
            sampling_strategy: SamplingMode::Auto,
            random_state: None,
            replacement: false,
        }
    }
}

impl RandomUnderSampler {
    /// Create a new RandomUnderSampler with auto balancing
    pub fn new() -> Self {
        Self {
            sampling_strategy: SamplingMode::Auto,
        }
    }
    
    /// Create with specific ratio (0.0 to 1.0)
    pub fn with_ratio(ratio: f64) -> Self {
        Self {
            sampling_strategy: SamplingMode::Ratio(ratio.clamp(0.0, 1.0)),
        }
    }
    
    /// Create with specific target counts
    pub fn with_targets(targets: HashMap<i32, usize>) -> Self {
        Self {
            sampling_strategy: SamplingMode::Targets(targets),
        }
    }
    
    /// Calculate target sample counts for each class
    fn calculate_target_counts(
        &self,
        class_counts: &HashMap<i32, usize>,
        _config: &RandomUnderSamplerConfig,
    ) -> Result<HashMap<i32, usize>, ResamplingError> {
        match &self.sampling_strategy {
            SamplingMode::Auto => {
                // Balance to minority class size
                let min_count = *class_counts.values().min().unwrap();
                Ok(class_counts.keys().map(|&class| (class, min_count)).collect())
            },
            SamplingMode::Ratio(ratio) => {
                // Balance to ratio of majority class
                let max_count = *class_counts.values().max().unwrap();
                let target_count = (max_count as f64 * ratio) as usize;
                Ok(class_counts.keys().map(|&class| (class, target_count.min(class_counts[&class]))).collect())
            },
            SamplingMode::Targets(targets) => {
                // Use provided targets, but don't exceed original counts
                let mut result = HashMap::new();
                for (&class, &original_count) in class_counts {
                    let target_count = targets.get(&class).copied().unwrap_or(original_count);
                    if target_count > original_count {
                        return Err(ResamplingError::ConfigError(
                            format!("Target count {} exceeds original count {} for class {}", 
                                   target_count, original_count, class)
                        ));
                    }
                    result.insert(class, target_count);
                }
                Ok(result)
            }
        }
    }
    
    /// Sample indices for a given class
    fn sample_indices(
        &self,
        class_indices: &[usize],
        target_count: usize,
        replacement: bool,
        rng: &mut StdRng,
    ) -> Vec<usize> {
        if target_count >= class_indices.len() {
            return class_indices.to_vec();
        }
        
        if replacement {
            // Sample with replacement
            (0..target_count)
                .map(|_| class_indices[rng.gen_range(0..class_indices.len())])
                .collect()
        } else {
            // Sample without replacement
            let mut indices = class_indices.to_vec();
            indices.shuffle(rng);
            indices.truncate(target_count);
            indices
        }
    }
}

impl Default for RandomUnderSampler {
    fn default() -> Self {
        Self::new()
    }
}

impl ResamplingStrategy for RandomUnderSampler {
    type Input = ();
    type Output = (Array2<f64>, Array1<i32>);
    type Config = RandomUnderSamplerConfig;
    
    fn resample(
        &self,
        x: ArrayView2<f64>,
        y: ArrayView1<i32>,
        config: &Self::Config,
    ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
        if x.nrows() != y.len() {
            return Err(ResamplingError::InvalidInput(
                "Feature matrix and target array must have same number of samples".to_string()
            ));
        }
        
        // Count class frequencies
        let mut class_counts = HashMap::new();
        for &label in y.iter() {
            *class_counts.entry(label).or_insert(0) += 1;
        }
        
        if class_counts.len() < 2 {
            return Err(ResamplingError::InvalidInput(
                "Need at least 2 classes for resampling".to_string()
            ));
        }
        
        // Calculate target counts
        let target_counts = self.calculate_target_counts(&class_counts, config)?;
        
        let mut rng = if let Some(seed) = config.random_state {
            StdRng::seed_from_u64(seed)
        } else {
            StdRng::from_entropy()
        };
        
        // Collect indices for each class
        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
        for (idx, &label) in y.iter().enumerate() {
            class_indices.entry(label).or_default().push(idx);
        }
        
        // Sample indices for each class
        let mut selected_indices = Vec::new();
        for (&class, &target_count) in &target_counts {
            if let Some(indices) = class_indices.get(&class) {
                let sampled = self.sample_indices(indices, target_count, config.replacement, &mut rng);
                selected_indices.extend(sampled);
            }
        }
        
        // Sort indices to maintain some order consistency
        selected_indices.sort_unstable();
        
        let n_samples = selected_indices.len();
        let n_features = x.ncols();
        
        // Create resampled arrays
        let mut resampled_x = Array2::zeros((n_samples, n_features));
        let mut resampled_y = Array1::zeros(n_samples);
        
        for (new_idx, &original_idx) in selected_indices.iter().enumerate() {
            resampled_x.row_mut(new_idx).assign(&x.row(original_idx));
            resampled_y[new_idx] = y[original_idx];
        }
        
        Ok((resampled_x, resampled_y))
    }
    
    fn performance_hints(&self) -> PerformanceHints {
        PerformanceHints::new()
            .with_hint(PerformanceHint::CacheFriendly)
    }
}