use imbalanced_core::traits::*;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rand::prelude::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SmoteStrategy {
k_neighbors: usize,
}
#[derive(Debug, Clone)]
pub struct SmoteConfig {
pub k_neighbors: usize,
pub random_state: Option<u64>,
}
impl Default for SmoteConfig {
fn default() -> Self {
Self {
k_neighbors: 5,
random_state: None,
}
}
}
impl SmoteStrategy {
pub fn new(k_neighbors: usize) -> Self {
Self { k_neighbors }
}
pub fn default() -> Self {
Self::new(5)
}
}
impl ResamplingStrategy for SmoteStrategy {
type Input = ();
type Output = (Array2<f64>, Array1<i32>);
type Config = SmoteConfig;
fn resample(
&self,
x: ArrayView2<f64>,
y: ArrayView1<i32>,
config: &Self::Config,
) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
if x.nrows() != y.len() {
return Err(ResamplingError::InvalidInput(
"Feature matrix and target array must have same number of samples".to_string()
));
}
if x.nrows() < self.k_neighbors {
return Err(ResamplingError::InsufficientSamples);
}
let mut class_counts = HashMap::new();
for &label in y.iter() {
*class_counts.entry(label).or_insert(0) += 1;
}
if class_counts.len() < 2 {
return Err(ResamplingError::InvalidInput(
"Need at least 2 classes for resampling".to_string()
));
}
let max_count = *class_counts.values().max().unwrap();
let mut synthetic_features = Vec::new();
let mut synthetic_labels = Vec::new();
let mut rng = if let Some(seed) = config.random_state {
StdRng::seed_from_u64(seed)
} else {
StdRng::from_entropy()
};
for (&class_label, &count) in &class_counts {
if count < max_count {
let n_synthetic = max_count - count;
let minority_indices: Vec<usize> = y.iter()
.enumerate()
.filter(|(_, &label)| label == class_label)
.map(|(idx, _)| idx)
.collect();
for _ in 0..n_synthetic {
let sample_idx = minority_indices[rng.gen_range(0..minority_indices.len())];
let sample = x.row(sample_idx);
let mut distances: Vec<(usize, f64)> = minority_indices.iter()
.map(|&idx| {
let neighbor = x.row(idx);
let dist = sample.iter()
.zip(neighbor.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
(idx, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let k = std::cmp::min(self.k_neighbors, distances.len() - 1);
if k == 0 {
continue; }
let neighbor_idx = distances[1 + rng.gen_range(0..k)].0; let neighbor = x.row(neighbor_idx);
let alpha = rng.gen::<f64>(); let synthetic_sample: Vec<f64> = sample.iter()
.zip(neighbor.iter())
.map(|(s, n)| s + alpha * (n - s))
.collect();
synthetic_features.push(synthetic_sample);
synthetic_labels.push(class_label);
}
}
}
let n_original = x.nrows();
let n_synthetic = synthetic_features.len();
let n_total = n_original + n_synthetic;
let n_features = x.ncols();
let mut combined_x = Array2::zeros((n_total, n_features));
let mut combined_y = Array1::zeros(n_total);
combined_x.slice_mut(s![0..n_original, ..]).assign(&x);
combined_y.slice_mut(s![0..n_original]).assign(&y);
for (i, (features, label)) in synthetic_features.iter().zip(synthetic_labels.iter()).enumerate() {
let idx = n_original + i;
for (j, &feature) in features.iter().enumerate() {
combined_x[[idx, j]] = feature;
}
combined_y[idx] = *label;
}
Ok((combined_x, combined_y))
}
fn performance_hints(&self) -> PerformanceHints {
PerformanceHints::new()
.with_hint(PerformanceHint::Parallel)
.with_hint(PerformanceHint::CacheFriendly)
}
}
use ndarray::s;