imbalanced_sampling/
smote.rs

1// imbalanced-sampling/src/smote.rs
2use imbalanced_core::traits::*;
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7/// SMOTE (Synthetic Minority Over-sampling Technique) implementation
8#[derive(Debug, Clone)]
9pub struct SmoteStrategy {
10    k_neighbors: usize,
11}
12
13/// Configuration for SMOTE
14#[derive(Debug, Clone)]
15pub struct SmoteConfig {
16    /// Number of nearest neighbors to use
17    pub k_neighbors: usize,
18    /// Random seed
19    pub random_state: Option<u64>,
20}
21
22impl Default for SmoteConfig {
23    fn default() -> Self {
24        Self {
25            k_neighbors: 5,
26            random_state: None,
27        }
28    }
29}
30
31impl SmoteStrategy {
32    /// Create a new SMOTE strategy with default k=5 neighbors
33    pub fn new(k_neighbors: usize) -> Self {
34        Self { k_neighbors }
35    }
36    
37    /// Create with default configuration
38    pub fn default() -> Self {
39        Self::new(5)
40    }
41}
42
43impl ResamplingStrategy for SmoteStrategy {
44    type Input = ();
45    type Output = (Array2<f64>, Array1<i32>);
46    type Config = SmoteConfig;
47    
48    fn resample(
49        &self,
50        x: ArrayView2<f64>,
51        y: ArrayView1<i32>,
52        config: &Self::Config,
53    ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
54        if x.nrows() != y.len() {
55            return Err(ResamplingError::InvalidInput(
56                "Feature matrix and target array must have same number of samples".to_string()
57            ));
58        }
59        
60        if x.nrows() < self.k_neighbors {
61            return Err(ResamplingError::InsufficientSamples);
62        }
63        
64        // Count class frequencies
65        let mut class_counts = HashMap::new();
66        for &label in y.iter() {
67            *class_counts.entry(label).or_insert(0) += 1;
68        }
69        
70        if class_counts.len() < 2 {
71            return Err(ResamplingError::InvalidInput(
72                "Need at least 2 classes for resampling".to_string()
73            ));
74        }
75        
76        // Find majority class count
77        let max_count = *class_counts.values().max().unwrap();
78        
79        // Generate synthetic samples for minority classes
80        let mut synthetic_features = Vec::new();
81        let mut synthetic_labels = Vec::new();
82        
83        let mut rng = if let Some(seed) = config.random_state {
84            StdRng::seed_from_u64(seed)
85        } else {
86            StdRng::from_entropy()
87        };
88        
89        for (&class_label, &count) in &class_counts {
90            if count < max_count {
91                let n_synthetic = max_count - count;
92                
93                // Find indices of this minority class
94                let minority_indices: Vec<usize> = y.iter()
95                    .enumerate()
96                    .filter(|(_, &label)| label == class_label)
97                    .map(|(idx, _)| idx)
98                    .collect();
99                
100                // Generate synthetic samples
101                for _ in 0..n_synthetic {
102                    // Randomly select a minority sample
103                    let sample_idx = minority_indices[rng.gen_range(0..minority_indices.len())];
104                    let sample = x.row(sample_idx);
105                    
106                    // Find k nearest neighbors from the same class
107                    let mut distances: Vec<(usize, f64)> = minority_indices.iter()
108                        .map(|&idx| {
109                            let neighbor = x.row(idx);
110                            let dist = sample.iter()
111                                .zip(neighbor.iter())
112                                .map(|(a, b)| (a - b).powi(2))
113                                .sum::<f64>()
114                                .sqrt();
115                            (idx, dist)
116                        })
117                        .collect();
118                    
119                    // Sort by distance and take k nearest (excluding self)
120                    distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
121                    let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
122                    
123                    if k == 0 {
124                        continue; // Skip if no neighbors available
125                    }
126                    
127                    // Select random neighbor from k nearest
128                    let neighbor_idx = distances[1 + rng.gen_range(0..k)].0; // Skip self at index 0
129                    let neighbor = x.row(neighbor_idx);
130                    
131                    // Generate synthetic sample between sample and neighbor
132                    let alpha = rng.gen::<f64>(); // Random value between 0 and 1
133                    let synthetic_sample: Vec<f64> = sample.iter()
134                        .zip(neighbor.iter())
135                        .map(|(s, n)| s + alpha * (n - s))
136                        .collect();
137                    
138                    synthetic_features.push(synthetic_sample);
139                    synthetic_labels.push(class_label);
140                }
141            }
142        }
143        
144        // Combine original and synthetic data
145        let n_original = x.nrows();
146        let n_synthetic = synthetic_features.len();
147        let n_total = n_original + n_synthetic;
148        let n_features = x.ncols();
149        
150        let mut combined_x = Array2::zeros((n_total, n_features));
151        let mut combined_y = Array1::zeros(n_total);
152        
153        // Copy original data
154        combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
155        combined_y.slice_mut(s![0..n_original]).assign(&y);
156        
157        // Add synthetic data
158        for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
159            let idx = n_original + i;
160            for (j, &feature) in features.iter().enumerate() {
161                combined_x[[idx, j]] = feature;
162            }
163            combined_y[idx] = *label;
164        }
165        
166        Ok((combined_x, combined_y))
167    }
168    
169    fn performance_hints(&self) -> PerformanceHints {
170        PerformanceHints::new()
171            .with_hint(PerformanceHint::Parallel)
172            .with_hint(PerformanceHint::CacheFriendly)
173    }
174}
175
176// Need to import slice syntax
177use ndarray::s;