use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::*;
use scirs2_core::random::prelude::*;
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::seq::SliceRandom;
use scirs2_core::random::Uniform;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub enum BalancingStrategy {
RandomOversample,
RandomUndersample,
SMOTE {
k_neighbors: usize,
},
}
#[allow(dead_code)]
pub fn random_oversample(
data: &Array2<f64>,
targets: &Array1<f64>,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)> {
if data.nrows() != targets.len() {
return Err(DatasetsError::InvalidFormat(
"Data rows and targets length must match".to_string(),
));
}
if data.is_empty() || targets.is_empty() {
return Err(DatasetsError::InvalidFormat(
"Data and targets cannot be empty".to_string(),
));
}
let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
for (i, &target) in targets.iter().enumerate() {
let class = target.round() as i64;
class_indices.entry(class).or_default().push(i);
}
let max_class_size = class_indices
.values()
.map(|v| v.len())
.max()
.expect("Operation failed");
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = thread_rng();
StdRng::seed_from_u64(r.next_u64())
}
};
let mut resampled_indices = Vec::new();
for (_, indices) in class_indices {
let class_size = indices.len();
resampled_indices.extend(&indices);
if class_size < max_class_size {
let samples_needed = max_class_size - class_size;
for _ in 0..samples_needed {
let random_idx = rng.sample(Uniform::new(0, class_size).expect("Operation failed"));
resampled_indices.push(indices[random_idx]);
}
}
}
let resampled_data = data.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
let resampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
Ok((resampled_data, resampled_targets))
}
#[allow(dead_code)]
pub fn random_undersample(
data: &Array2<f64>,
targets: &Array1<f64>,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)> {
if data.nrows() != targets.len() {
return Err(DatasetsError::InvalidFormat(
"Data rows and targets length must match".to_string(),
));
}
if data.is_empty() || targets.is_empty() {
return Err(DatasetsError::InvalidFormat(
"Data and targets cannot be empty".to_string(),
));
}
let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
for (i, &target) in targets.iter().enumerate() {
let class = target.round() as i64;
class_indices.entry(class).or_default().push(i);
}
let min_class_size = class_indices
.values()
.map(|v| v.len())
.min()
.expect("Operation failed");
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = thread_rng();
StdRng::seed_from_u64(r.next_u64())
}
};
let mut undersampled_indices = Vec::new();
for (_, mut indices) in class_indices {
if indices.len() > min_class_size {
indices.shuffle(&mut rng);
undersampled_indices.extend(&indices[0..min_class_size]);
} else {
undersampled_indices.extend(&indices);
}
}
let undersampled_data = data.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
let undersampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
Ok((undersampled_data, undersampled_targets))
}
#[allow(dead_code)]
pub fn generate_synthetic_samples(
data: &Array2<f64>,
targets: &Array1<f64>,
target_class: f64,
n_synthetic: usize,
k_neighbors: usize,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)> {
if data.nrows() != targets.len() {
return Err(DatasetsError::InvalidFormat(
"Data rows and targets length must match".to_string(),
));
}
if n_synthetic == 0 {
return Err(DatasetsError::InvalidFormat(
"Number of _synthetic samples must be > 0".to_string(),
));
}
if k_neighbors == 0 {
return Err(DatasetsError::InvalidFormat(
"Number of _neighbors must be > 0".to_string(),
));
}
let class_indices: Vec<usize> = targets
.iter()
.enumerate()
.filter(|(_, &target)| (target - target_class).abs() < 1e-10)
.map(|(i, _)| i)
.collect();
if class_indices.len() < 2 {
return Err(DatasetsError::InvalidFormat(
"Need at least 2 samples of the target _class for _synthetic generation".to_string(),
));
}
if k_neighbors >= class_indices.len() {
return Err(DatasetsError::InvalidFormat(
"k_neighbors must be less than the number of samples in the target _class".to_string(),
));
}
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = thread_rng();
StdRng::seed_from_u64(r.next_u64())
}
};
let n_features = data.ncols();
let mut synthetic_data = Array2::zeros((n_synthetic, n_features));
let synthetic_targets = Array1::from_elem(n_synthetic, target_class);
for i in 0..n_synthetic {
let base_idx = class_indices
[rng.sample(Uniform::new(0, class_indices.len()).expect("Operation failed"))];
let base_sample = data.row(base_idx);
let mut distances: Vec<(usize, f64)> = class_indices
.iter()
.filter(|&&idx| idx != base_idx)
.map(|&idx| {
let neighbor = data.row(idx);
let distance: f64 = base_sample
.iter()
.zip(neighbor.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
(idx, distance)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
let k_nearest = &distances[0..k_neighbors.min(distances.len())];
let neighbor_idx =
k_nearest[rng.sample(Uniform::new(0, k_nearest.len()).expect("Operation failed"))].0;
let neighbor_sample = data.row(neighbor_idx);
let alpha = rng.random_range(0.0..1.0);
for (j, synthetic_feature) in synthetic_data.row_mut(i).iter_mut().enumerate() {
*synthetic_feature = base_sample[j] + alpha * (neighbor_sample[j] - base_sample[j]);
}
}
Ok((synthetic_data, synthetic_targets))
}
#[allow(dead_code)]
pub fn create_balanced_dataset(
data: &Array2<f64>,
targets: &Array1<f64>,
strategy: BalancingStrategy,
random_seed: Option<u64>,
) -> Result<(Array2<f64>, Array1<f64>)> {
match strategy {
BalancingStrategy::RandomOversample => random_oversample(data, targets, random_seed),
BalancingStrategy::RandomUndersample => random_undersample(data, targets, random_seed),
BalancingStrategy::SMOTE { k_neighbors } => {
let mut class_counts: HashMap<i64, usize> = HashMap::new();
for &target in targets.iter() {
let class = target.round() as i64;
*class_counts.entry(class).or_default() += 1;
}
let max_count = *class_counts.values().max().expect("Operation failed");
let mut combined_data = data.clone();
let mut combined_targets = targets.clone();
for (&class, &count) in &class_counts {
if count < max_count {
let samples_needed = max_count - count;
let (synthetic_data, synthetic_targets) = generate_synthetic_samples(
data,
targets,
class as f64,
samples_needed,
k_neighbors,
random_seed,
)?;
combined_data = scirs2_core::ndarray::concatenate(
scirs2_core::ndarray::Axis(0),
&[combined_data.view(), synthetic_data.view()],
)
.map_err(|_| {
DatasetsError::InvalidFormat("Failed to concatenate data".to_string())
})?;
combined_targets = scirs2_core::ndarray::concatenate(
scirs2_core::ndarray::Axis(0),
&[combined_targets.view(), synthetic_targets.view()],
)
.map_err(|_| {
DatasetsError::InvalidFormat("Failed to concatenate targets".to_string())
})?;
}
}
Ok((combined_data, combined_targets))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::Uniform;
#[test]
fn test_random_oversample() {
let data = Array2::from_shape_vec(
(6, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("Test: SMOTE operation failed");
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let (balanced_data, balanced_targets) =
random_oversample(&data, &targets, Some(42)).expect("Operation failed");
let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
assert_eq!(class_0_count, 4); assert_eq!(class_1_count, 4);
assert_eq!(balanced_data.nrows(), 8);
assert_eq!(balanced_targets.len(), 8);
assert_eq!(balanced_data.ncols(), 2);
}
#[test]
fn test_random_oversample_invalid_params() {
let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Operation failed");
let targets = Array1::from(vec![0.0, 1.0]);
assert!(random_oversample(&data, &targets, None).is_err());
let empty_data = Array2::zeros((0, 2));
let empty_targets = Array1::from(vec![]);
assert!(random_oversample(&empty_data, &empty_targets, None).is_err());
}
#[test]
fn test_random_undersample() {
let data = Array2::from_shape_vec(
(6, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("Test: ADASYN operation failed");
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let (balanced_data, balanced_targets) =
random_undersample(&data, &targets, Some(42)).expect("Operation failed");
let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
assert_eq!(class_0_count, 2); assert_eq!(class_1_count, 2);
assert_eq!(balanced_data.nrows(), 4);
assert_eq!(balanced_targets.len(), 4);
assert_eq!(balanced_data.ncols(), 2);
}
#[test]
fn test_random_undersample_invalid_params() {
let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Operation failed");
let targets = Array1::from(vec![0.0, 1.0]);
assert!(random_undersample(&data, &targets, None).is_err());
let empty_data = Array2::zeros((0, 2));
let empty_targets = Array1::from(vec![]);
assert!(random_undersample(&empty_data, &empty_targets, None).is_err());
}
#[test]
fn test_generate_synthetic_samples() {
let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5])
.expect("Operation failed");
let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
let (synthetic_data, synthetic_targets) =
generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42))
.expect("Operation failed");
assert_eq!(synthetic_data.nrows(), 2);
assert_eq!(synthetic_targets.len(), 2);
for &target in synthetic_targets.iter() {
assert_eq!(target, 0.0);
}
assert_eq!(synthetic_data.ncols(), 2);
for i in 0..synthetic_data.nrows() {
for j in 0..synthetic_data.ncols() {
let value = synthetic_data[[i, j]];
assert!((0.5..=2.5).contains(&value)); }
}
}
#[test]
fn test_generate_synthetic_samples_invalid_params() {
let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5])
.expect("Operation failed");
let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
let bad_targets = Array1::from(vec![0.0, 1.0]);
assert!(generate_synthetic_samples(&data, &bad_targets, 0.0, 2, 2, None).is_err());
assert!(generate_synthetic_samples(&data, &targets, 0.0, 0, 2, None).is_err());
assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 0, None).is_err());
assert!(generate_synthetic_samples(&data, &targets, 1.0, 2, 2, None).is_err());
assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 3, None).is_err());
}
#[test]
fn test_create_balanced_dataset_random_oversample() {
let data = Array2::from_shape_vec(
(6, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("Test: undersample operation failed");
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let (balanced_data, balanced_targets) = create_balanced_dataset(
&data,
&targets,
BalancingStrategy::RandomOversample,
Some(42),
)
.expect("Test: undersample operation failed");
let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
assert_eq!(class_0_count, class_1_count);
assert_eq!(balanced_data.nrows(), balanced_targets.len());
}
#[test]
fn test_create_balanced_dataset_random_undersample() {
let data = Array2::from_shape_vec(
(6, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("Test: cluster centroids operation failed");
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let (balanced_data, balanced_targets) = create_balanced_dataset(
&data,
&targets,
BalancingStrategy::RandomUndersample,
Some(42),
)
.expect("Test: cluster centroids operation failed");
let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
assert_eq!(class_0_count, class_1_count);
assert_eq!(balanced_data.nrows(), balanced_targets.len());
}
#[test]
fn test_create_balanced_dataset_smote() {
let data = Array2::from_shape_vec(
(8, 2),
vec![
1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0,
],
)
.expect("Test: edited operation failed");
let targets = Array1::from(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
let (balanced_data, balanced_targets) = create_balanced_dataset(
&data,
&targets,
BalancingStrategy::SMOTE { k_neighbors: 2 },
Some(42),
)
.expect("Test: edited operation failed");
let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
assert_eq!(class_0_count, class_1_count);
assert_eq!(balanced_data.nrows(), balanced_targets.len());
}
#[test]
fn test_balancing_strategy_with_multiple_classes() {
let data = Array2::from_shape_vec((9, 2), (0..18).map(|x| x as f64).collect())
.expect("Operation failed");
let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
let (_over_data, over_targets) = create_balanced_dataset(
&data,
&targets,
BalancingStrategy::RandomOversample,
Some(42),
)
.expect("Test: borderline SMOTE operation failed");
let over_class_0_count = over_targets.iter().filter(|&&x| x == 0.0).count();
let over_class_1_count = over_targets.iter().filter(|&&x| x == 1.0).count();
let over_class_2_count = over_targets.iter().filter(|&&x| x == 2.0).count();
assert_eq!(over_class_0_count, 4);
assert_eq!(over_class_1_count, 4);
assert_eq!(over_class_2_count, 4);
let (_under_data, under_targets) = create_balanced_dataset(
&data,
&targets,
BalancingStrategy::RandomUndersample,
Some(42),
)
.expect("Test: borderline SMOTE operation failed");
let under_class_0_count = under_targets.iter().filter(|&&x| x == 0.0).count();
let under_class_1_count = under_targets.iter().filter(|&&x| x == 1.0).count();
let under_class_2_count = under_targets.iter().filter(|&&x| x == 2.0).count();
assert_eq!(under_class_0_count, 2);
assert_eq!(under_class_1_count, 2);
assert_eq!(under_class_2_count, 2);
}
}