imbalanced_sampling/
smote.rs1use imbalanced_core::traits::*;
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct SmoteStrategy {
10 k_neighbors: usize,
11}
12
13#[derive(Debug, Clone)]
15pub struct SmoteConfig {
16 pub k_neighbors: usize,
18 pub random_state: Option<u64>,
20}
21
22impl Default for SmoteConfig {
23 fn default() -> Self {
24 Self {
25 k_neighbors: 5,
26 random_state: None,
27 }
28 }
29}
30
31impl SmoteStrategy {
32 pub fn new(k_neighbors: usize) -> Self {
34 Self { k_neighbors }
35 }
36
37 pub fn default() -> Self {
39 Self::new(5)
40 }
41}
42
43impl ResamplingStrategy for SmoteStrategy {
44 type Input = ();
45 type Output = (Array2<f64>, Array1<i32>);
46 type Config = SmoteConfig;
47
48 fn resample(
49 &self,
50 x: ArrayView2<f64>,
51 y: ArrayView1<i32>,
52 config: &Self::Config,
53 ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
54 if x.nrows() != y.len() {
55 return Err(ResamplingError::InvalidInput(
56 "Feature matrix and target array must have same number of samples".to_string()
57 ));
58 }
59
60 if x.nrows() < self.k_neighbors {
61 return Err(ResamplingError::InsufficientSamples);
62 }
63
64 let mut class_counts = HashMap::new();
66 for &label in y.iter() {
67 *class_counts.entry(label).or_insert(0) += 1;
68 }
69
70 if class_counts.len() < 2 {
71 return Err(ResamplingError::InvalidInput(
72 "Need at least 2 classes for resampling".to_string()
73 ));
74 }
75
76 let max_count = *class_counts.values().max().unwrap();
78
79 let mut synthetic_features = Vec::new();
81 let mut synthetic_labels = Vec::new();
82
83 let mut rng = if let Some(seed) = config.random_state {
84 StdRng::seed_from_u64(seed)
85 } else {
86 StdRng::from_entropy()
87 };
88
89 for (&class_label, &count) in &class_counts {
90 if count < max_count {
91 let n_synthetic = max_count - count;
92
93 let minority_indices: Vec<usize> = y.iter()
95 .enumerate()
96 .filter(|(_, &label)| label == class_label)
97 .map(|(idx, _)| idx)
98 .collect();
99
100 for _ in 0..n_synthetic {
102 let sample_idx = minority_indices[rng.gen_range(0..minority_indices.len())];
104 let sample = x.row(sample_idx);
105
106 let mut distances: Vec<(usize, f64)> = minority_indices.iter()
108 .map(|&idx| {
109 let neighbor = x.row(idx);
110 let dist = sample.iter()
111 .zip(neighbor.iter())
112 .map(|(a, b)| (a - b).powi(2))
113 .sum::<f64>()
114 .sqrt();
115 (idx, dist)
116 })
117 .collect();
118
119 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
121 let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
122
123 if k == 0 {
124 continue; }
126
127 let neighbor_idx = distances[1 + rng.gen_range(0..k)].0; let neighbor = x.row(neighbor_idx);
130
131 let alpha = rng.gen::<f64>(); let synthetic_sample: Vec<f64> = sample.iter()
134 .zip(neighbor.iter())
135 .map(|(s, n)| s + alpha * (n - s))
136 .collect();
137
138 synthetic_features.push(synthetic_sample);
139 synthetic_labels.push(class_label);
140 }
141 }
142 }
143
144 let n_original = x.nrows();
146 let n_synthetic = synthetic_features.len();
147 let n_total = n_original + n_synthetic;
148 let n_features = x.ncols();
149
150 let mut combined_x = Array2::zeros((n_total, n_features));
151 let mut combined_y = Array1::zeros(n_total);
152
153 combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
155 combined_y.slice_mut(s![0..n_original]).assign(&y);
156
157 for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
159 let idx = n_original + i;
160 for (j, &feature) in features.iter().enumerate() {
161 combined_x[[idx, j]] = feature;
162 }
163 combined_y[idx] = *label;
164 }
165
166 Ok((combined_x, combined_y))
167 }
168
169 fn performance_hints(&self) -> PerformanceHints {
170 PerformanceHints::new()
171 .with_hint(PerformanceHint::Parallel)
172 .with_hint(PerformanceHint::CacheFriendly)
173 }
174}
175
176use ndarray::s;