use imbalanced_core::traits::*;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rand::prelude::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RandomUnderSampler {
sampling_strategy: SamplingMode,
}
#[derive(Debug, Clone)]
pub enum SamplingMode {
Auto,
Ratio(f64),
Targets(HashMap<i32, usize>),
}
#[derive(Debug, Clone)]
pub struct RandomUnderSamplerConfig {
pub sampling_strategy: SamplingMode,
pub random_state: Option<u64>,
pub replacement: bool,
}
impl Default for RandomUnderSamplerConfig {
fn default() -> Self {
Self {
sampling_strategy: SamplingMode::Auto,
random_state: None,
replacement: false,
}
}
}
impl RandomUnderSampler {
pub fn new() -> Self {
Self {
sampling_strategy: SamplingMode::Auto,
}
}
pub fn with_ratio(ratio: f64) -> Self {
Self {
sampling_strategy: SamplingMode::Ratio(ratio.clamp(0.0, 1.0)),
}
}
pub fn with_targets(targets: HashMap<i32, usize>) -> Self {
Self {
sampling_strategy: SamplingMode::Targets(targets),
}
}
fn calculate_target_counts(
&self,
class_counts: &HashMap<i32, usize>,
_config: &RandomUnderSamplerConfig,
) -> Result<HashMap<i32, usize>, ResamplingError> {
match &self.sampling_strategy {
SamplingMode::Auto => {
let min_count = *class_counts.values().min().unwrap();
Ok(class_counts.keys().map(|&class| (class, min_count)).collect())
},
SamplingMode::Ratio(ratio) => {
let max_count = *class_counts.values().max().unwrap();
let target_count = (max_count as f64 * ratio) as usize;
Ok(class_counts.keys().map(|&class| (class, target_count.min(class_counts[&class]))).collect())
},
SamplingMode::Targets(targets) => {
let mut result = HashMap::new();
for (&class, &original_count) in class_counts {
let target_count = targets.get(&class).copied().unwrap_or(original_count);
if target_count > original_count {
return Err(ResamplingError::ConfigError(
format!("Target count {} exceeds original count {} for class {}",
target_count, original_count, class)
));
}
result.insert(class, target_count);
}
Ok(result)
}
}
}
fn sample_indices(
&self,
class_indices: &[usize],
target_count: usize,
replacement: bool,
rng: &mut StdRng,
) -> Vec<usize> {
if target_count >= class_indices.len() {
return class_indices.to_vec();
}
if replacement {
(0..target_count)
.map(|_| class_indices[rng.gen_range(0..class_indices.len())])
.collect()
} else {
let mut indices = class_indices.to_vec();
indices.shuffle(rng);
indices.truncate(target_count);
indices
}
}
}
impl Default for RandomUnderSampler {
fn default() -> Self {
Self::new()
}
}
impl ResamplingStrategy for RandomUnderSampler {
type Input = ();
type Output = (Array2<f64>, Array1<i32>);
type Config = RandomUnderSamplerConfig;
fn resample(
&self,
x: ArrayView2<f64>,
y: ArrayView1<i32>,
config: &Self::Config,
) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
if x.nrows() != y.len() {
return Err(ResamplingError::InvalidInput(
"Feature matrix and target array must have same number of samples".to_string()
));
}
let mut class_counts = HashMap::new();
for &label in y.iter() {
*class_counts.entry(label).or_insert(0) += 1;
}
if class_counts.len() < 2 {
return Err(ResamplingError::InvalidInput(
"Need at least 2 classes for resampling".to_string()
));
}
let target_counts = self.calculate_target_counts(&class_counts, config)?;
let mut rng = if let Some(seed) = config.random_state {
StdRng::seed_from_u64(seed)
} else {
StdRng::from_entropy()
};
let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
for (idx, &label) in y.iter().enumerate() {
class_indices.entry(label).or_default().push(idx);
}
let mut selected_indices = Vec::new();
for (&class, &target_count) in &target_counts {
if let Some(indices) = class_indices.get(&class) {
let sampled = self.sample_indices(indices, target_count, config.replacement, &mut rng);
selected_indices.extend(sampled);
}
}
selected_indices.sort_unstable();
let n_samples = selected_indices.len();
let n_features = x.ncols();
let mut resampled_x = Array2::zeros((n_samples, n_features));
let mut resampled_y = Array1::zeros(n_samples);
for (new_idx, &original_idx) in selected_indices.iter().enumerate() {
resampled_x.row_mut(new_idx).assign(&x.row(original_idx));
resampled_y[new_idx] = y[original_idx];
}
Ok((resampled_x, resampled_y))
}
fn performance_hints(&self) -> PerformanceHints {
PerformanceHints::new()
.with_hint(PerformanceHint::CacheFriendly)
}
}