use imbalanced_core::traits::*;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
use rand::prelude::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AdasynStrategy {
k_neighbors: usize,
beta: f64, }
#[derive(Debug, Clone)]
pub struct AdasynConfig {
pub k_neighbors: usize,
pub beta: f64,
pub random_state: Option<u64>,
}
impl Default for AdasynConfig {
fn default() -> Self {
Self {
k_neighbors: 5,
beta: 1.0,
random_state: None,
}
}
}
impl AdasynStrategy {
pub fn new(k_neighbors: usize, beta: f64) -> Self {
Self { k_neighbors, beta }
}
pub fn default() -> Self {
Self::new(5, 1.0)
}
fn calculate_density_distribution(
&self,
x: ArrayView2<f64>,
y: ArrayView1<i32>,
minority_class: i32,
_majority_count: usize,
) -> Result<Vec<f64>, ResamplingError> {
let minority_indices: Vec<usize> = y.iter()
.enumerate()
.filter(|(_, &label)| label == minority_class)
.map(|(idx, _)| idx)
.collect();
if minority_indices.is_empty() {
return Err(ResamplingError::InsufficientSamples);
}
let mut density_ratios = Vec::with_capacity(minority_indices.len());
for &minority_idx in &minority_indices {
let sample = x.row(minority_idx);
let mut distances: Vec<(usize, f64)> = (0..x.nrows())
.map(|idx| {
let neighbor = x.row(idx);
let dist = sample.iter()
.zip(neighbor.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
(idx, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let k_plus_1 = std::cmp::min(self.k_neighbors + 1, distances.len());
let majority_neighbors = distances[1..k_plus_1].iter()
.filter(|(idx, _)| y[*idx] != minority_class)
.count();
let density_ratio = majority_neighbors as f64 / self.k_neighbors as f64;
density_ratios.push(density_ratio);
}
let sum_ratios: f64 = density_ratios.iter().sum();
if sum_ratios > 0.0 {
for ratio in &mut density_ratios {
*ratio /= sum_ratios;
}
}
Ok(density_ratios)
}
}
impl ResamplingStrategy for AdasynStrategy {
type Input = ();
type Output = (Array2<f64>, Array1<i32>);
type Config = AdasynConfig;
fn resample(
&self,
x: ArrayView2<f64>,
y: ArrayView1<i32>,
config: &Self::Config,
) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
if x.nrows() != y.len() {
return Err(ResamplingError::InvalidInput(
"Feature matrix and target array must have same number of samples".to_string()
));
}
if x.nrows() < self.k_neighbors {
return Err(ResamplingError::InsufficientSamples);
}
let mut class_counts = HashMap::new();
for &label in y.iter() {
*class_counts.entry(label).or_insert(0) += 1;
}
if class_counts.len() < 2 {
return Err(ResamplingError::InvalidInput(
"Need at least 2 classes for resampling".to_string()
));
}
let max_count = *class_counts.values().max().unwrap();
let minority_classes: Vec<_> = class_counts.iter()
.filter(|(_, &count)| count < max_count)
.map(|(&class, &count)| (class, count))
.collect();
if minority_classes.is_empty() {
return Ok((x.to_owned(), y.to_owned()));
}
let mut synthetic_features = Vec::new();
let mut synthetic_labels = Vec::new();
let mut rng = if let Some(seed) = config.random_state {
StdRng::seed_from_u64(seed)
} else {
StdRng::from_entropy()
};
for (minority_class, minority_count) in minority_classes {
let desired_samples = ((max_count - minority_count) as f64 * self.beta) as usize;
if desired_samples == 0 {
continue;
}
let minority_indices: Vec<usize> = y.iter()
.enumerate()
.filter(|(_, &label)| label == minority_class)
.map(|(idx, _)| idx)
.collect();
let density_ratios = self.calculate_density_distribution(
x, y, minority_class, max_count
)?;
for _ in 0..desired_samples {
let cumulative_prob = rng.gen::<f64>();
let mut cumulative_sum = 0.0;
let mut selected_idx = 0;
for (i, &ratio) in density_ratios.iter().enumerate() {
cumulative_sum += ratio;
if cumulative_prob <= cumulative_sum {
selected_idx = i;
break;
}
}
let sample_idx = minority_indices[selected_idx];
let sample = x.row(sample_idx);
let mut distances: Vec<(usize, f64)> = minority_indices.iter()
.map(|&idx| {
let neighbor = x.row(idx);
let dist = sample.iter()
.zip(neighbor.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
(idx, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
if k == 0 {
continue;
}
let neighbor_idx = distances[1 + rng.gen_range(0..k)].0;
let neighbor = x.row(neighbor_idx);
let alpha = rng.gen::<f64>();
let synthetic_sample: Vec<f64> = sample.iter()
.zip(neighbor.iter())
.map(|(s, n)| s + alpha * (n - s))
.collect();
synthetic_features.push(synthetic_sample);
synthetic_labels.push(minority_class);
}
}
let n_original = x.nrows();
let n_synthetic = synthetic_features.len();
let n_total = n_original + n_synthetic;
let n_features = x.ncols();
let mut combined_x = Array2::zeros((n_total, n_features));
let mut combined_y = Array1::zeros(n_total);
combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
combined_y.slice_mut(s![0..n_original]).assign(&y);
for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
let idx = n_original + i;
for (j, &feature) in features.iter().enumerate() {
combined_x[[idx, j]] = feature;
}
combined_y[idx] = *label;
}
Ok((combined_x, combined_y))
}
fn performance_hints(&self) -> PerformanceHints {
PerformanceHints::new()
.with_hint(PerformanceHint::Parallel)
.with_hint(PerformanceHint::CacheFriendly)
}
}