use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{Array1, Array2, Axis};
use scirs2_core::random::prelude::*;
use scirs2_core::random::rand_distributions::Distribution;
pub fn imbalanced_classification(
n_majority: usize,
n_minority: usize,
n_features: usize,
imbalance_ratio: f64,
rng: &mut StdRng,
) -> Result<(Array2<f64>, Array1<usize>)> {
if n_majority == 0 {
return Err(DatasetsError::InvalidFormat(
"imbalanced_classification: n_majority must be >= 1".to_string(),
));
}
if n_minority == 0 {
return Err(DatasetsError::InvalidFormat(
"imbalanced_classification: n_minority must be >= 1".to_string(),
));
}
if n_features == 0 {
return Err(DatasetsError::InvalidFormat(
"imbalanced_classification: n_features must be >= 1".to_string(),
));
}
let n_total = n_majority + n_minority;
let normal = scirs2_core::random::Normal::new(0.0_f64, 1.0_f64).map_err(|e| {
DatasetsError::ComputationError(format!("Normal distribution failed: {e}"))
})?;
let mut x = Array2::zeros((n_total, n_features));
let mut y = Array1::zeros(n_total);
for i in 0..n_majority {
for j in 0..n_features {
x[[i, j]] = normal.sample(rng);
}
y[i] = 0;
}
for i in 0..n_minority {
let row = n_majority + i;
for j in 0..n_features {
x[[row, j]] = normal.sample(rng) + imbalance_ratio;
}
y[row] = 1;
}
Ok((x, y))
}
pub fn synthetic_smote(
minority_samples: &Array2<f64>,
k: usize,
n_synthetic: usize,
rng: &mut StdRng,
) -> Result<Array2<f64>> {
let n = minority_samples.nrows();
let p = minority_samples.ncols();
if n < 2 {
return Err(DatasetsError::InvalidFormat(
"synthetic_smote: minority_samples must have at least 2 rows".to_string(),
));
}
if k == 0 {
return Err(DatasetsError::InvalidFormat(
"synthetic_smote: k must be >= 1".to_string(),
));
}
if k >= n {
return Err(DatasetsError::InvalidFormat(format!(
"synthetic_smote: k ({k}) must be < minority_samples.nrows() ({n})"
)));
}
if n_synthetic == 0 {
return Ok(Array2::zeros((0, p)));
}
if p == 0 {
return Err(DatasetsError::InvalidFormat(
"synthetic_smote: minority_samples must have at least 1 column".to_string(),
));
}
let alpha_dist = scirs2_core::random::Uniform::new(0.0_f64, 1.0_f64).map_err(|e| {
DatasetsError::ComputationError(format!("Uniform distribution failed: {e}"))
})?;
let anchor_dist =
scirs2_core::random::Uniform::new(0usize, n).map_err(|e| {
DatasetsError::ComputationError(format!("Uniform index distribution failed: {e}"))
})?;
let neighbour_dist =
scirs2_core::random::Uniform::new(0usize, k).map_err(|e| {
DatasetsError::ComputationError(format!(
"Uniform neighbour index distribution failed: {e}"
))
})?;
let mut out = Array2::zeros((n_synthetic, p));
for s in 0..n_synthetic {
let anchor_idx = anchor_dist.sample(rng);
let anchor = minority_samples.row(anchor_idx);
let mut dists: Vec<(f64, usize)> = (0..n)
.filter(|&i| i != anchor_idx)
.map(|i| {
let row = minority_samples.row(i);
let dist_sq: f64 = anchor
.iter()
.zip(row.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
(dist_sq, i)
})
.collect();
dists.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let nn_pos = neighbour_dist.sample(rng);
let nn_idx = dists[nn_pos].1;
let neighbour = minority_samples.row(nn_idx);
let alpha = alpha_dist.sample(rng);
for j in 0..p {
out[[s, j]] = anchor[j] + alpha * (neighbour[j] - anchor[j]);
}
}
Ok(out)
}
pub fn random_oversample(
x: &Array2<f64>,
y: &Array1<usize>,
target_ratio: f64,
rng: &mut StdRng,
) -> Result<(Array2<f64>, Array1<usize>)> {
validate_xy(x, y, "random_oversample")?;
let target_ratio = target_ratio.max(0.0).min(1.0);
let (maj_indices, min_indices) = split_binary_classes(y, "random_oversample")?;
let n_maj = maj_indices.len();
let n_min = min_indices.len();
let desired_n_min = (n_maj as f64 * target_ratio).round() as usize;
if desired_n_min <= n_min {
return Ok((x.to_owned(), y.to_owned()));
}
let n_extra = desired_n_min - n_min;
let p = x.ncols();
let n_out = x.nrows() + n_extra;
let mut x_out = Array2::zeros((n_out, p));
let mut y_out = Array1::zeros(n_out);
for i in 0..x.nrows() {
for j in 0..p {
x_out[[i, j]] = x[[i, j]];
}
y_out[i] = y[i];
}
let min_dist =
scirs2_core::random::Uniform::new(0usize, n_min).map_err(|e| {
DatasetsError::ComputationError(format!("Uniform distribution failed: {e}"))
})?;
for extra in 0..n_extra {
let src_idx = min_indices[min_dist.sample(rng)];
let row_out = x.nrows() + extra;
for j in 0..p {
x_out[[row_out, j]] = x[[src_idx, j]];
}
y_out[row_out] = 1;
}
Ok((x_out, y_out))
}
pub fn random_undersample(
x: &Array2<f64>,
y: &Array1<usize>,
target_ratio: f64,
rng: &mut StdRng,
) -> Result<(Array2<f64>, Array1<usize>)> {
validate_xy(x, y, "random_undersample")?;
let target_ratio = target_ratio.max(0.0).min(1.0);
let (maj_indices, min_indices) = split_binary_classes(y, "random_undersample")?;
let n_maj = maj_indices.len();
let n_min = min_indices.len();
let desired_n_maj = if target_ratio > 0.0 {
(n_min as f64 / target_ratio).ceil() as usize
} else {
0
};
let keep_n_maj = desired_n_maj.min(n_maj);
if keep_n_maj == n_maj {
return Ok((x.to_owned(), y.to_owned()));
}
let mut shuffled_maj = maj_indices.clone();
{
use scirs2_core::random::SliceRandom;
shuffled_maj.shuffle(rng);
}
shuffled_maj.truncate(keep_n_maj);
let mut keep_indices: Vec<usize> = shuffled_maj;
keep_indices.extend_from_slice(&min_indices);
keep_indices.sort_unstable();
let p = x.ncols();
let n_out = keep_indices.len();
let mut x_out = Array2::zeros((n_out, p));
let mut y_out = Array1::zeros(n_out);
for (out_row, &src) in keep_indices.iter().enumerate() {
for j in 0..p {
x_out[[out_row, j]] = x[[src, j]];
}
y_out[out_row] = y[src];
}
Ok((x_out, y_out))
}
pub fn f1_score(y_true: &Array1<usize>, y_pred: &Array1<usize>) -> Result<f64> {
check_same_len(y_true.len(), y_pred.len(), "f1_score")?;
let mut tp = 0usize;
let mut fp = 0usize;
let mut fn_ = 0usize;
for (&t, &p) in y_true.iter().zip(y_pred.iter()) {
match (t, p) {
(1, 1) => tp += 1,
(0, 1) => fp += 1,
(1, 0) => fn_ += 1,
_ => {}
}
}
let precision = if tp + fp > 0 {
tp as f64 / (tp + fp) as f64
} else {
0.0
};
let recall = if tp + fn_ > 0 {
tp as f64 / (tp + fn_) as f64
} else {
0.0
};
if precision + recall > 0.0 {
Ok(2.0 * precision * recall / (precision + recall))
} else {
Ok(0.0)
}
}
pub fn balanced_accuracy(y_true: &Array1<usize>, y_pred: &Array1<usize>) -> Result<f64> {
check_same_len(y_true.len(), y_pred.len(), "balanced_accuracy")?;
if y_true.is_empty() {
return Err(DatasetsError::InvalidFormat(
"balanced_accuracy: y_true must not be empty".to_string(),
));
}
let mut classes: Vec<usize> = y_true.iter().copied().collect();
classes.sort_unstable();
classes.dedup();
let mut total_recall = 0.0_f64;
for &c in &classes {
let n_true_c: usize = y_true.iter().filter(|&&v| v == c).count();
let n_correct_c: usize = y_true
.iter()
.zip(y_pred.iter())
.filter(|(&t, &p)| t == c && p == c)
.count();
let recall_c = if n_true_c > 0 {
n_correct_c as f64 / n_true_c as f64
} else {
0.0
};
total_recall += recall_c;
}
Ok(total_recall / classes.len() as f64)
}
pub fn confusion_matrix(
y_true: &Array1<usize>,
y_pred: &Array1<usize>,
n_classes: usize,
) -> Result<Array2<usize>> {
check_same_len(y_true.len(), y_pred.len(), "confusion_matrix")?;
if n_classes == 0 {
return Err(DatasetsError::InvalidFormat(
"confusion_matrix: n_classes must be >= 1".to_string(),
));
}
let mut cm = Array2::zeros((n_classes, n_classes));
for (&t, &p) in y_true.iter().zip(y_pred.iter()) {
if t >= n_classes {
return Err(DatasetsError::InvalidFormat(format!(
"confusion_matrix: true label {t} >= n_classes {n_classes}"
)));
}
if p >= n_classes {
return Err(DatasetsError::InvalidFormat(format!(
"confusion_matrix: predicted label {p} >= n_classes {n_classes}"
)));
}
cm[[t, p]] += 1;
}
Ok(cm)
}
pub fn roc_auc(y_true: &Array1<usize>, y_score: &Array1<f64>) -> Result<f64> {
check_same_len(y_true.len(), y_score.len(), "roc_auc")?;
if y_true.is_empty() {
return Err(DatasetsError::InvalidFormat(
"roc_auc: arrays must not be empty".to_string(),
));
}
let n_pos: usize = y_true.iter().filter(|&&v| v == 1).count();
let n_neg: usize = y_true.iter().filter(|&&v| v == 0).count();
if n_pos == 0 || n_neg == 0 {
return Ok(0.5);
}
let mut order: Vec<usize> = (0..y_true.len()).collect();
order.sort_unstable_by(|&a, &b| {
y_score[b]
.partial_cmp(&y_score[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut auc = 0.0_f64;
let mut tp = 0.0_f64;
let mut fp = 0.0_f64;
let mut prev_score = f64::INFINITY;
let mut prev_tp = 0.0_f64;
let mut prev_fp = 0.0_f64;
for &idx in &order {
let score = y_score[idx];
if score != prev_score && prev_score.is_finite() {
let tpr_cur = tp / n_pos as f64;
let fpr_cur = fp / n_neg as f64;
let tpr_prev = prev_tp / n_pos as f64;
let fpr_prev = prev_fp / n_neg as f64;
auc += (fpr_cur - fpr_prev) * (tpr_cur + tpr_prev) / 2.0;
prev_tp = tp;
prev_fp = fp;
}
prev_score = score;
if y_true[idx] == 1 {
tp += 1.0;
} else {
fp += 1.0;
}
}
let tpr_cur = tp / n_pos as f64;
let fpr_cur = fp / n_neg as f64;
let tpr_prev = prev_tp / n_pos as f64;
let fpr_prev = prev_fp / n_neg as f64;
auc += (fpr_cur - fpr_prev) * (tpr_cur + tpr_prev) / 2.0;
Ok(auc.abs())
}
fn validate_xy(x: &Array2<f64>, y: &Array1<usize>, fn_name: &str) -> Result<()> {
if x.nrows() != y.len() {
return Err(DatasetsError::InvalidFormat(format!(
"{fn_name}: x.nrows() ({}) must equal y.len() ({})",
x.nrows(),
y.len()
)));
}
if x.is_empty() {
return Err(DatasetsError::InvalidFormat(format!(
"{fn_name}: x must not be empty"
)));
}
Ok(())
}
fn check_same_len(a: usize, b: usize, fn_name: &str) -> Result<()> {
if a != b {
return Err(DatasetsError::InvalidFormat(format!(
"{fn_name}: arrays must have the same length, got {a} and {b}"
)));
}
Ok(())
}
fn split_binary_classes(
y: &Array1<usize>,
fn_name: &str,
) -> Result<(Vec<usize>, Vec<usize>)> {
let mut class0: Vec<usize> = Vec::new();
let mut class1: Vec<usize> = Vec::new();
for (i, &label) in y.iter().enumerate() {
match label {
0 => class0.push(i),
1 => class1.push(i),
other => {
return Err(DatasetsError::InvalidFormat(format!(
"{fn_name}: label {other} is not 0 or 1"
)))
}
}
}
if class0.is_empty() {
return Err(DatasetsError::InvalidFormat(format!(
"{fn_name}: class 0 is absent"
)));
}
if class1.is_empty() {
return Err(DatasetsError::InvalidFormat(format!(
"{fn_name}: class 1 is absent"
)));
}
if class0.len() >= class1.len() {
Ok((class0, class1))
} else {
Ok((class1, class0))
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn make_rng(seed: u64) -> StdRng {
StdRng::seed_from_u64(seed)
}
#[test]
fn test_imbalanced_shape() {
let mut rng = make_rng(42);
let (x, y) =
imbalanced_classification(900, 100, 4, 2.0, &mut rng).expect("gen failed");
assert_eq!(x.shape(), &[1000, 4]);
assert_eq!(y.len(), 1000);
}
#[test]
fn test_imbalanced_label_counts() {
let mut rng = make_rng(1);
let (_, y) =
imbalanced_classification(80, 20, 3, 1.5, &mut rng).expect("gen failed");
let n0: usize = y.iter().filter(|&&v| v == 0).count();
let n1: usize = y.iter().filter(|&&v| v == 1).count();
assert_eq!(n0, 80);
assert_eq!(n1, 20);
}
#[test]
fn test_imbalanced_error_n_majority_zero() {
let mut rng = make_rng(1);
assert!(imbalanced_classification(0, 10, 3, 1.0, &mut rng).is_err());
}
#[test]
fn test_imbalanced_error_n_minority_zero() {
let mut rng = make_rng(1);
assert!(imbalanced_classification(10, 0, 3, 1.0, &mut rng).is_err());
}
#[test]
fn test_imbalanced_error_n_features_zero() {
let mut rng = make_rng(1);
assert!(imbalanced_classification(10, 5, 0, 1.0, &mut rng).is_err());
}
#[test]
fn test_smote_shape() {
let minority = Array2::from_shape_vec(
(5, 2),
vec![0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.2, 0.8, 0.9, 0.1],
)
.expect("shape");
let mut rng = make_rng(7);
let syn = synthetic_smote(&minority, 2, 20, &mut rng).expect("smote");
assert_eq!(syn.shape(), &[20, 2]);
}
#[test]
fn test_smote_zero_synthetic() {
let minority = Array2::from_shape_vec(
(4, 2),
vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0],
)
.expect("shape");
let mut rng = make_rng(1);
let syn = synthetic_smote(&minority, 2, 0, &mut rng).expect("smote zero");
assert_eq!(syn.shape(), &[0, 2]);
}
#[test]
fn test_smote_error_too_few_samples() {
let minority = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape");
let mut rng = make_rng(1);
assert!(synthetic_smote(&minority, 1, 5, &mut rng).is_err());
}
#[test]
fn test_smote_error_k_too_large() {
let minority = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
.expect("shape");
let mut rng = make_rng(1);
assert!(synthetic_smote(&minority, 3, 5, &mut rng).is_err());
}
#[test]
fn test_smote_interpolation_bounds() {
let minority = Array2::from_shape_vec(
(4, 1),
vec![0.0, 0.25, 0.75, 1.0],
)
.expect("shape");
let mut rng = make_rng(99);
let syn = synthetic_smote(&minority, 2, 100, &mut rng).expect("smote bounds");
for v in syn.iter() {
assert!(
*v >= -1e-9 && *v <= 1.0 + 1e-9,
"SMOTE sample {v} out of [0,1]"
);
}
}
#[test]
fn test_oversample_balance() {
let mut rng = make_rng(1);
let (x, y) =
imbalanced_classification(90, 10, 3, 1.5, &mut rng).expect("gen");
let mut rng2 = make_rng(2);
let (_, y2) = random_oversample(&x, &y, 1.0, &mut rng2).expect("oversample");
let n0: usize = y2.iter().filter(|&&v| v == 0).count();
let n1: usize = y2.iter().filter(|&&v| v == 1).count();
assert_eq!(n0, n1, "oversample should balance classes");
}
#[test]
fn test_oversample_already_balanced() {
let mut rng = make_rng(1);
let (x, y) =
imbalanced_classification(50, 50, 3, 1.0, &mut rng).expect("gen");
let mut rng2 = make_rng(2);
let (x2, y2) = random_oversample(&x, &y, 1.0, &mut rng2).expect("oversample noop");
assert_eq!(x2.nrows(), x.nrows());
assert_eq!(y2.len(), y.len());
}
#[test]
fn test_oversample_error_mismatch() {
let x = Array2::zeros((10, 3));
let y = Array1::zeros(9usize);
let mut rng = make_rng(1);
assert!(random_oversample(&x, &y, 1.0, &mut rng).is_err());
}
#[test]
fn test_undersample_balance() {
let mut rng = make_rng(1);
let (x, y) =
imbalanced_classification(900, 100, 3, 1.5, &mut rng).expect("gen");
let mut rng2 = make_rng(3);
let (_, y2) = random_undersample(&x, &y, 1.0, &mut rng2).expect("undersample");
let n0: usize = y2.iter().filter(|&&v| v == 0).count();
let n1: usize = y2.iter().filter(|&&v| v == 1).count();
assert_eq!(n0, n1, "undersample should balance classes");
}
#[test]
fn test_undersample_preserves_minority() {
let mut rng = make_rng(5);
let (x, y) =
imbalanced_classification(80, 20, 4, 1.5, &mut rng).expect("gen");
let mut rng2 = make_rng(6);
let (_, y2) = random_undersample(&x, &y, 1.0, &mut rng2).expect("undersample");
let n_min: usize = y2.iter().filter(|&&v| v == 1).count();
assert_eq!(n_min, 20, "minority class should be fully preserved");
}
#[test]
fn test_f1_perfect() {
let y = array![0usize, 1, 1, 0, 1];
let f1 = f1_score(&y, &y).expect("perfect f1");
assert!((f1 - 1.0).abs() < 1e-9);
}
#[test]
fn test_f1_zero_precision_recall() {
let y_true = array![1usize, 1, 1];
let y_pred = array![0usize, 0, 0];
let f1 = f1_score(&y_true, &y_pred).expect("zero f1");
assert!((f1 - 0.0).abs() < 1e-9);
}
#[test]
fn test_f1_known_value() {
let y_true = array![1usize, 0, 1, 1, 0];
let y_pred = array![1usize, 0, 0, 1, 1];
let f1 = f1_score(&y_true, &y_pred).expect("known f1");
assert!((f1 - 2.0 / 3.0).abs() < 1e-9, "f1={f1}");
}
#[test]
fn test_f1_error_length_mismatch() {
let a = array![1usize, 0];
let b = array![1usize, 0, 1];
assert!(f1_score(&a, &b).is_err());
}
#[test]
fn test_balanced_accuracy_perfect() {
let y = array![0usize, 1, 0, 1];
let ba = balanced_accuracy(&y, &y).expect("perfect ba");
assert!((ba - 1.0).abs() < 1e-9);
}
#[test]
fn test_balanced_accuracy_known() {
let y_true = array![0usize, 0, 1, 1];
let y_pred = array![0usize, 1, 1, 1];
let ba = balanced_accuracy(&y_true, &y_pred).expect("known ba");
assert!((ba - 0.75).abs() < 1e-9, "ba={ba}");
}
#[test]
fn test_balanced_accuracy_error_length_mismatch() {
let a = array![0usize, 1];
let b = array![0usize];
assert!(balanced_accuracy(&a, &b).is_err());
}
#[test]
fn test_confusion_matrix_binary() {
let y_true = array![0usize, 1, 1, 0, 1];
let y_pred = array![0usize, 1, 0, 0, 1];
let cm = confusion_matrix(&y_true, &y_pred, 2).expect("cm binary");
assert_eq!(cm[[0, 0]], 2); assert_eq!(cm[[1, 1]], 2); assert_eq!(cm[[1, 0]], 1); assert_eq!(cm[[0, 1]], 0); }
#[test]
fn test_confusion_matrix_3class() {
let y_true = array![0usize, 1, 2, 0, 1];
let y_pred = array![0usize, 2, 2, 1, 1];
let cm = confusion_matrix(&y_true, &y_pred, 3).expect("cm 3-class");
assert_eq!(cm[[0, 0]], 1);
assert_eq!(cm[[1, 2]], 1);
assert_eq!(cm[[2, 2]], 1);
}
#[test]
fn test_confusion_matrix_error_out_of_range() {
let y_true = array![0usize, 1, 3]; let y_pred = array![0usize, 1, 2];
assert!(confusion_matrix(&y_true, &y_pred, 3).is_err());
}
#[test]
fn test_roc_auc_perfect() {
let y_true = array![0usize, 0, 1, 1];
let y_score = array![0.1_f64, 0.2, 0.8, 0.9];
let auc = roc_auc(&y_true, &y_score).expect("perfect auc");
assert!((auc - 1.0).abs() < 1e-9, "perfect auc={auc}");
}
#[test]
fn test_roc_auc_random() {
let y_true = array![0usize, 1, 0, 1, 0, 1];
let y_score = array![0.5_f64, 0.5, 0.5, 0.5, 0.5, 0.5];
let auc = roc_auc(&y_true, &y_score).expect("random auc");
assert!((auc - 0.5).abs() < 1e-9, "random auc={auc}");
}
#[test]
fn test_roc_auc_degenerate_no_positive() {
let y_true = array![0usize, 0, 0];
let y_score = array![0.1_f64, 0.5, 0.9];
let auc = roc_auc(&y_true, &y_score).expect("degenerate auc");
assert!((auc - 0.5).abs() < 1e-9);
}
#[test]
fn test_roc_auc_error_empty() {
let y_true: Array1<usize> = Array1::zeros(0);
let y_score: Array1<f64> = Array1::zeros(0);
assert!(roc_auc(&y_true, &y_score).is_err());
}
}