imbalanced_sampling/
adasyn.rs

1// imbalanced-sampling/src/adasyn.rs
2use imbalanced_core::traits::*;
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7/// ADASYN (Adaptive Synthetic Sampling) implementation
8/// 
9/// ADASYN improves upon SMOTE by focusing synthetic sample generation
10/// on minority class examples that are harder to learn, determined by
11/// their k-nearest neighbor density distribution.
12#[derive(Debug, Clone)]
13pub struct AdasynStrategy {
14    k_neighbors: usize,
15    beta: f64, // Balance level after generation (0.0 to 1.0)
16}
17
18/// Configuration for ADASYN
19#[derive(Debug, Clone)]
20pub struct AdasynConfig {
21    /// Number of nearest neighbors to use
22    pub k_neighbors: usize,
23    /// Balance level (0.0 = no generation, 1.0 = perfect balance)
24    pub beta: f64,
25    /// Random seed
26    pub random_state: Option<u64>,
27}
28
29impl Default for AdasynConfig {
30    fn default() -> Self {
31        Self {
32            k_neighbors: 5,
33            beta: 1.0,
34            random_state: None,
35        }
36    }
37}
38
39impl AdasynStrategy {
40    /// Create a new ADASYN strategy
41    pub fn new(k_neighbors: usize, beta: f64) -> Self {
42        Self { k_neighbors, beta }
43    }
44    
45    /// Create with default configuration
46    pub fn default() -> Self {
47        Self::new(5, 1.0)
48    }
49    
50    /// Calculate density distribution for minority samples
51    fn calculate_density_distribution(
52        &self,
53        x: ArrayView2<f64>,
54        y: ArrayView1<i32>,
55        minority_class: i32,
56        _majority_count: usize,
57    ) -> Result<Vec<f64>, ResamplingError> {
58        let minority_indices: Vec<usize> = y.iter()
59            .enumerate()
60            .filter(|(_, &label)| label == minority_class)
61            .map(|(idx, _)| idx)
62            .collect();
63            
64        if minority_indices.is_empty() {
65            return Err(ResamplingError::InsufficientSamples);
66        }
67        
68        let mut density_ratios = Vec::with_capacity(minority_indices.len());
69        
70        for &minority_idx in &minority_indices {
71            let sample = x.row(minority_idx);
72            
73            // Find k+1 nearest neighbors (including self)
74            let mut distances: Vec<(usize, f64)> = (0..x.nrows())
75                .map(|idx| {
76                    let neighbor = x.row(idx);
77                    let dist = sample.iter()
78                        .zip(neighbor.iter())
79                        .map(|(a, b)| (a - b).powi(2))
80                        .sum::<f64>()
81                        .sqrt();
82                    (idx, dist)
83                })
84                .collect();
85            
86            // Sort by distance and take k+1 nearest
87            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
88            let k_plus_1 = std::cmp::min(self.k_neighbors + 1, distances.len());
89            
90            // Count majority class samples among k nearest neighbors (excluding self)
91            let majority_neighbors = distances[1..k_plus_1].iter()
92                .filter(|(idx, _)| y[*idx] != minority_class)
93                .count();
94            
95            // Density ratio: proportion of majority class neighbors
96            let density_ratio = majority_neighbors as f64 / self.k_neighbors as f64;
97            density_ratios.push(density_ratio);
98        }
99        
100        // Normalize density ratios
101        let sum_ratios: f64 = density_ratios.iter().sum();
102        if sum_ratios > 0.0 {
103            for ratio in &mut density_ratios {
104                *ratio /= sum_ratios;
105            }
106        }
107        
108        Ok(density_ratios)
109    }
110}
111
112impl ResamplingStrategy for AdasynStrategy {
113    type Input = ();
114    type Output = (Array2<f64>, Array1<i32>);
115    type Config = AdasynConfig;
116    
117    fn resample(
118        &self,
119        x: ArrayView2<f64>,
120        y: ArrayView1<i32>,
121        config: &Self::Config,
122    ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
123        if x.nrows() != y.len() {
124            return Err(ResamplingError::InvalidInput(
125                "Feature matrix and target array must have same number of samples".to_string()
126            ));
127        }
128        
129        if x.nrows() < self.k_neighbors {
130            return Err(ResamplingError::InsufficientSamples);
131        }
132        
133        // Count class frequencies
134        let mut class_counts = HashMap::new();
135        for &label in y.iter() {
136            *class_counts.entry(label).or_insert(0) += 1;
137        }
138        
139        if class_counts.len() < 2 {
140            return Err(ResamplingError::InvalidInput(
141                "Need at least 2 classes for resampling".to_string()
142            ));
143        }
144        
145        // Find majority and minority classes
146        let max_count = *class_counts.values().max().unwrap();
147        let minority_classes: Vec<_> = class_counts.iter()
148            .filter(|(_, &count)| count < max_count)
149            .map(|(&class, &count)| (class, count))
150            .collect();
151        
152        if minority_classes.is_empty() {
153            // Dataset is already balanced
154            return Ok((x.to_owned(), y.to_owned()));
155        }
156        
157        let mut synthetic_features = Vec::new();
158        let mut synthetic_labels = Vec::new();
159        
160        let mut rng = if let Some(seed) = config.random_state {
161            StdRng::seed_from_u64(seed)
162        } else {
163            StdRng::from_entropy()
164        };
165        
166        // Process each minority class
167        for (minority_class, minority_count) in minority_classes {
168            // Calculate number of synthetic samples to generate
169            let desired_samples = ((max_count - minority_count) as f64 * self.beta) as usize;
170            
171            if desired_samples == 0 {
172                continue;
173            }
174            
175            // Get minority class indices
176            let minority_indices: Vec<usize> = y.iter()
177                .enumerate()
178                .filter(|(_, &label)| label == minority_class)
179                .map(|(idx, _)| idx)
180                .collect();
181            
182            // Calculate density distribution
183            let density_ratios = self.calculate_density_distribution(
184                x, y, minority_class, max_count
185            )?;
186            
187            // Generate synthetic samples based on density distribution
188            for _ in 0..desired_samples {
189                // Select minority sample based on density distribution
190                let cumulative_prob = rng.gen::<f64>();
191                let mut cumulative_sum = 0.0;
192                let mut selected_idx = 0;
193                
194                for (i, &ratio) in density_ratios.iter().enumerate() {
195                    cumulative_sum += ratio;
196                    if cumulative_prob <= cumulative_sum {
197                        selected_idx = i;
198                        break;
199                    }
200                }
201                
202                let sample_idx = minority_indices[selected_idx];
203                let sample = x.row(sample_idx);
204                
205                // Find k nearest neighbors from same class
206                let mut distances: Vec<(usize, f64)> = minority_indices.iter()
207                    .map(|&idx| {
208                        let neighbor = x.row(idx);
209                        let dist = sample.iter()
210                            .zip(neighbor.iter())
211                            .map(|(a, b)| (a - b).powi(2))
212                            .sum::<f64>()
213                            .sqrt();
214                        (idx, dist)
215                    })
216                    .collect();
217                
218                // Sort by distance and take k nearest (excluding self)
219                distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
220                let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
221                
222                if k == 0 {
223                    continue;
224                }
225                
226                // Select random neighbor from k nearest
227                let neighbor_idx = distances[1 + rng.gen_range(0..k)].0;
228                let neighbor = x.row(neighbor_idx);
229                
230                // Generate synthetic sample
231                let alpha = rng.gen::<f64>();
232                let synthetic_sample: Vec<f64> = sample.iter()
233                    .zip(neighbor.iter())
234                    .map(|(s, n)| s + alpha * (n - s))
235                    .collect();
236                
237                synthetic_features.push(synthetic_sample);
238                synthetic_labels.push(minority_class);
239            }
240        }
241        
242        // Combine original and synthetic data
243        let n_original = x.nrows();
244        let n_synthetic = synthetic_features.len();
245        let n_total = n_original + n_synthetic;
246        let n_features = x.ncols();
247        
248        let mut combined_x = Array2::zeros((n_total, n_features));
249        let mut combined_y = Array1::zeros(n_total);
250        
251        // Copy original data
252        combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
253        combined_y.slice_mut(s![0..n_original]).assign(&y);
254        
255        // Add synthetic data
256        for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
257            let idx = n_original + i;
258            for (j, &feature) in features.iter().enumerate() {
259                combined_x[[idx, j]] = feature;
260            }
261            combined_y[idx] = *label;
262        }
263        
264        Ok((combined_x, combined_y))
265    }
266    
267    fn performance_hints(&self) -> PerformanceHints {
268        PerformanceHints::new()
269            .with_hint(PerformanceHint::Parallel)
270            .with_hint(PerformanceHint::CacheFriendly)
271    }
272}