imbalanced_sampling/
adasyn.rs1use imbalanced_core::traits::*;
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
13pub struct AdasynStrategy {
14 k_neighbors: usize,
15 beta: f64, }
17
18#[derive(Debug, Clone)]
20pub struct AdasynConfig {
21 pub k_neighbors: usize,
23 pub beta: f64,
25 pub random_state: Option<u64>,
27}
28
29impl Default for AdasynConfig {
30 fn default() -> Self {
31 Self {
32 k_neighbors: 5,
33 beta: 1.0,
34 random_state: None,
35 }
36 }
37}
38
39impl AdasynStrategy {
40 pub fn new(k_neighbors: usize, beta: f64) -> Self {
42 Self { k_neighbors, beta }
43 }
44
45 pub fn default() -> Self {
47 Self::new(5, 1.0)
48 }
49
50 fn calculate_density_distribution(
52 &self,
53 x: ArrayView2<f64>,
54 y: ArrayView1<i32>,
55 minority_class: i32,
56 _majority_count: usize,
57 ) -> Result<Vec<f64>, ResamplingError> {
58 let minority_indices: Vec<usize> = y.iter()
59 .enumerate()
60 .filter(|(_, &label)| label == minority_class)
61 .map(|(idx, _)| idx)
62 .collect();
63
64 if minority_indices.is_empty() {
65 return Err(ResamplingError::InsufficientSamples);
66 }
67
68 let mut density_ratios = Vec::with_capacity(minority_indices.len());
69
70 for &minority_idx in &minority_indices {
71 let sample = x.row(minority_idx);
72
73 let mut distances: Vec<(usize, f64)> = (0..x.nrows())
75 .map(|idx| {
76 let neighbor = x.row(idx);
77 let dist = sample.iter()
78 .zip(neighbor.iter())
79 .map(|(a, b)| (a - b).powi(2))
80 .sum::<f64>()
81 .sqrt();
82 (idx, dist)
83 })
84 .collect();
85
86 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
88 let k_plus_1 = std::cmp::min(self.k_neighbors + 1, distances.len());
89
90 let majority_neighbors = distances[1..k_plus_1].iter()
92 .filter(|(idx, _)| y[*idx] != minority_class)
93 .count();
94
95 let density_ratio = majority_neighbors as f64 / self.k_neighbors as f64;
97 density_ratios.push(density_ratio);
98 }
99
100 let sum_ratios: f64 = density_ratios.iter().sum();
102 if sum_ratios > 0.0 {
103 for ratio in &mut density_ratios {
104 *ratio /= sum_ratios;
105 }
106 }
107
108 Ok(density_ratios)
109 }
110}
111
112impl ResamplingStrategy for AdasynStrategy {
113 type Input = ();
114 type Output = (Array2<f64>, Array1<i32>);
115 type Config = AdasynConfig;
116
117 fn resample(
118 &self,
119 x: ArrayView2<f64>,
120 y: ArrayView1<i32>,
121 config: &Self::Config,
122 ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
123 if x.nrows() != y.len() {
124 return Err(ResamplingError::InvalidInput(
125 "Feature matrix and target array must have same number of samples".to_string()
126 ));
127 }
128
129 if x.nrows() < self.k_neighbors {
130 return Err(ResamplingError::InsufficientSamples);
131 }
132
133 let mut class_counts = HashMap::new();
135 for &label in y.iter() {
136 *class_counts.entry(label).or_insert(0) += 1;
137 }
138
139 if class_counts.len() < 2 {
140 return Err(ResamplingError::InvalidInput(
141 "Need at least 2 classes for resampling".to_string()
142 ));
143 }
144
145 let max_count = *class_counts.values().max().unwrap();
147 let minority_classes: Vec<_> = class_counts.iter()
148 .filter(|(_, &count)| count < max_count)
149 .map(|(&class, &count)| (class, count))
150 .collect();
151
152 if minority_classes.is_empty() {
153 return Ok((x.to_owned(), y.to_owned()));
155 }
156
157 let mut synthetic_features = Vec::new();
158 let mut synthetic_labels = Vec::new();
159
160 let mut rng = if let Some(seed) = config.random_state {
161 StdRng::seed_from_u64(seed)
162 } else {
163 StdRng::from_entropy()
164 };
165
166 for (minority_class, minority_count) in minority_classes {
168 let desired_samples = ((max_count - minority_count) as f64 * self.beta) as usize;
170
171 if desired_samples == 0 {
172 continue;
173 }
174
175 let minority_indices: Vec<usize> = y.iter()
177 .enumerate()
178 .filter(|(_, &label)| label == minority_class)
179 .map(|(idx, _)| idx)
180 .collect();
181
182 let density_ratios = self.calculate_density_distribution(
184 x, y, minority_class, max_count
185 )?;
186
187 for _ in 0..desired_samples {
189 let cumulative_prob = rng.gen::<f64>();
191 let mut cumulative_sum = 0.0;
192 let mut selected_idx = 0;
193
194 for (i, &ratio) in density_ratios.iter().enumerate() {
195 cumulative_sum += ratio;
196 if cumulative_prob <= cumulative_sum {
197 selected_idx = i;
198 break;
199 }
200 }
201
202 let sample_idx = minority_indices[selected_idx];
203 let sample = x.row(sample_idx);
204
205 let mut distances: Vec<(usize, f64)> = minority_indices.iter()
207 .map(|&idx| {
208 let neighbor = x.row(idx);
209 let dist = sample.iter()
210 .zip(neighbor.iter())
211 .map(|(a, b)| (a - b).powi(2))
212 .sum::<f64>()
213 .sqrt();
214 (idx, dist)
215 })
216 .collect();
217
218 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
220 let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
221
222 if k == 0 {
223 continue;
224 }
225
226 let neighbor_idx = distances[1 + rng.gen_range(0..k)].0;
228 let neighbor = x.row(neighbor_idx);
229
230 let alpha = rng.gen::<f64>();
232 let synthetic_sample: Vec<f64> = sample.iter()
233 .zip(neighbor.iter())
234 .map(|(s, n)| s + alpha * (n - s))
235 .collect();
236
237 synthetic_features.push(synthetic_sample);
238 synthetic_labels.push(minority_class);
239 }
240 }
241
242 let n_original = x.nrows();
244 let n_synthetic = synthetic_features.len();
245 let n_total = n_original + n_synthetic;
246 let n_features = x.ncols();
247
248 let mut combined_x = Array2::zeros((n_total, n_features));
249 let mut combined_y = Array1::zeros(n_total);
250
251 combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
253 combined_y.slice_mut(s![0..n_original]).assign(&y);
254
255 for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
257 let idx = n_original + i;
258 for (j, &feature) in features.iter().enumerate() {
259 combined_x[[idx, j]] = feature;
260 }
261 combined_y[idx] = *label;
262 }
263
264 Ok((combined_x, combined_y))
265 }
266
267 fn performance_hints(&self) -> PerformanceHints {
268 PerformanceHints::new()
269 .with_hint(PerformanceHint::Parallel)
270 .with_hint(PerformanceHint::CacheFriendly)
271 }
272}