use std::collections::HashMap;
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClassWeight {
#[default]
Uniform,
Balanced,
Custom(HashMap<usize, f64>),
}
pub fn compute_sample_weights(targets: &[f64], class_weight: &ClassWeight) -> Vec<f64> {
let n = targets.len();
match class_weight {
ClassWeight::Uniform => vec![1.0; n],
ClassWeight::Balanced => {
let mut counts: HashMap<usize, usize> = HashMap::new();
for &t in targets {
*counts.entry(t as usize).or_insert(0) += 1;
}
let n_classes = counts.len();
let n_f = n as f64;
let class_weights: HashMap<usize, f64> = counts
.iter()
.map(|(&cls, &count)| {
let w = n_f / (n_classes as f64 * count as f64);
(cls, w)
})
.collect();
targets
.iter()
.map(|&t| class_weights.get(&(t as usize)).copied().unwrap_or(1.0))
.collect()
}
ClassWeight::Custom(map) => targets
.iter()
.map(|&t| map.get(&(t as usize)).copied().unwrap_or(1.0))
.collect(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uniform_weights() {
let targets = vec![0.0, 0.0, 1.0, 1.0, 2.0];
let weights = compute_sample_weights(&targets, &ClassWeight::Uniform);
assert_eq!(weights.len(), 5);
assert!(weights.iter().all(|&w| (w - 1.0).abs() < 1e-12));
}
#[test]
fn test_balanced_weights_equal_classes() {
let targets = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let weights = compute_sample_weights(&targets, &ClassWeight::Balanced);
for &w in &weights {
assert!((w - 1.0).abs() < 1e-6, "expected 1.0, got {w}");
}
}
#[test]
fn test_balanced_weights_imbalanced() {
let mut targets = vec![0.0; 90];
targets.extend(vec![1.0; 10]);
let weights = compute_sample_weights(&targets, &ClassWeight::Balanced);
let w0 = weights[0];
let w1 = weights[90];
assert!(
(w0 - 100.0 / 180.0).abs() < 1e-6,
"majority weight: expected {}, got {w0}",
100.0 / 180.0
);
assert!(
(w1 - 5.0).abs() < 1e-6,
"minority weight: expected 5.0, got {w1}"
);
assert!(w1 > w0 * 5.0);
}
#[test]
fn test_custom_weights() {
let mut map = HashMap::new();
map.insert(0, 1.0);
map.insert(1, 10.0);
let targets = vec![0.0, 0.0, 1.0, 1.0];
let weights = compute_sample_weights(&targets, &ClassWeight::Custom(map));
assert!((weights[0] - 1.0).abs() < 1e-12);
assert!((weights[2] - 10.0).abs() < 1e-12);
}
#[test]
fn test_custom_weights_missing_class_defaults_to_one() {
let mut map = HashMap::new();
map.insert(1, 5.0);
let targets = vec![0.0, 1.0];
let weights = compute_sample_weights(&targets, &ClassWeight::Custom(map));
assert!((weights[0] - 1.0).abs() < 1e-12);
assert!((weights[1] - 5.0).abs() < 1e-12);
}
}