imbalanced_sampling/
random_undersampler.rs

1// imbalanced-sampling/src/random_undersampler.rs
2use imbalanced_core::traits::*;
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7/// Random Under-sampling strategy
8/// 
9/// Reduces the majority class by randomly removing samples until
10/// the desired class balance is achieved.
11#[derive(Debug, Clone)]
12pub struct RandomUnderSampler {
13    sampling_strategy: SamplingMode,
14}
15
16/// Sampling strategy configuration
17#[derive(Debug, Clone)]
18pub enum SamplingMode {
19    /// Balance to match minority class size
20    Auto,
21    /// Balance to specific ratio (0.0 to 1.0)
22    Ratio(f64),
23    /// Balance to specific target counts per class
24    Targets(HashMap<i32, usize>),
25}
26
27/// Configuration for RandomUnderSampler
28#[derive(Debug, Clone)]
29pub struct RandomUnderSamplerConfig {
30    /// Sampling strategy
31    pub sampling_strategy: SamplingMode,
32    /// Random seed
33    pub random_state: Option<u64>,
34    /// Whether to replace samples (with replacement)
35    pub replacement: bool,
36}
37
38impl Default for RandomUnderSamplerConfig {
39    fn default() -> Self {
40        Self {
41            sampling_strategy: SamplingMode::Auto,
42            random_state: None,
43            replacement: false,
44        }
45    }
46}
47
48impl RandomUnderSampler {
49    /// Create a new RandomUnderSampler with auto balancing
50    pub fn new() -> Self {
51        Self {
52            sampling_strategy: SamplingMode::Auto,
53        }
54    }
55    
56    /// Create with specific ratio (0.0 to 1.0)
57    pub fn with_ratio(ratio: f64) -> Self {
58        Self {
59            sampling_strategy: SamplingMode::Ratio(ratio.clamp(0.0, 1.0)),
60        }
61    }
62    
63    /// Create with specific target counts
64    pub fn with_targets(targets: HashMap<i32, usize>) -> Self {
65        Self {
66            sampling_strategy: SamplingMode::Targets(targets),
67        }
68    }
69    
70    /// Calculate target sample counts for each class
71    fn calculate_target_counts(
72        &self,
73        class_counts: &HashMap<i32, usize>,
74        _config: &RandomUnderSamplerConfig,
75    ) -> Result<HashMap<i32, usize>, ResamplingError> {
76        match &self.sampling_strategy {
77            SamplingMode::Auto => {
78                // Balance to minority class size
79                let min_count = *class_counts.values().min().unwrap();
80                Ok(class_counts.keys().map(|&class| (class, min_count)).collect())
81            },
82            SamplingMode::Ratio(ratio) => {
83                // Balance to ratio of majority class
84                let max_count = *class_counts.values().max().unwrap();
85                let target_count = (max_count as f64 * ratio) as usize;
86                Ok(class_counts.keys().map(|&class| (class, target_count.min(class_counts[&class]))).collect())
87            },
88            SamplingMode::Targets(targets) => {
89                // Use provided targets, but don't exceed original counts
90                let mut result = HashMap::new();
91                for (&class, &original_count) in class_counts {
92                    let target_count = targets.get(&class).copied().unwrap_or(original_count);
93                    if target_count > original_count {
94                        return Err(ResamplingError::ConfigError(
95                            format!("Target count {} exceeds original count {} for class {}", 
96                                   target_count, original_count, class)
97                        ));
98                    }
99                    result.insert(class, target_count);
100                }
101                Ok(result)
102            }
103        }
104    }
105    
106    /// Sample indices for a given class
107    fn sample_indices(
108        &self,
109        class_indices: &[usize],
110        target_count: usize,
111        replacement: bool,
112        rng: &mut StdRng,
113    ) -> Vec<usize> {
114        if target_count >= class_indices.len() {
115            return class_indices.to_vec();
116        }
117        
118        if replacement {
119            // Sample with replacement
120            (0..target_count)
121                .map(|_| class_indices[rng.gen_range(0..class_indices.len())])
122                .collect()
123        } else {
124            // Sample without replacement
125            let mut indices = class_indices.to_vec();
126            indices.shuffle(rng);
127            indices.truncate(target_count);
128            indices
129        }
130    }
131}
132
133impl Default for RandomUnderSampler {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl ResamplingStrategy for RandomUnderSampler {
140    type Input = ();
141    type Output = (Array2<f64>, Array1<i32>);
142    type Config = RandomUnderSamplerConfig;
143    
144    fn resample(
145        &self,
146        x: ArrayView2<f64>,
147        y: ArrayView1<i32>,
148        config: &Self::Config,
149    ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
150        if x.nrows() != y.len() {
151            return Err(ResamplingError::InvalidInput(
152                "Feature matrix and target array must have same number of samples".to_string()
153            ));
154        }
155        
156        // Count class frequencies
157        let mut class_counts = HashMap::new();
158        for &label in y.iter() {
159            *class_counts.entry(label).or_insert(0) += 1;
160        }
161        
162        if class_counts.len() < 2 {
163            return Err(ResamplingError::InvalidInput(
164                "Need at least 2 classes for resampling".to_string()
165            ));
166        }
167        
168        // Calculate target counts
169        let target_counts = self.calculate_target_counts(&class_counts, config)?;
170        
171        let mut rng = if let Some(seed) = config.random_state {
172            StdRng::seed_from_u64(seed)
173        } else {
174            StdRng::from_entropy()
175        };
176        
177        // Collect indices for each class
178        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
179        for (idx, &label) in y.iter().enumerate() {
180            class_indices.entry(label).or_default().push(idx);
181        }
182        
183        // Sample indices for each class
184        let mut selected_indices = Vec::new();
185        for (&class, &target_count) in &target_counts {
186            if let Some(indices) = class_indices.get(&class) {
187                let sampled = self.sample_indices(indices, target_count, config.replacement, &mut rng);
188                selected_indices.extend(sampled);
189            }
190        }
191        
192        // Sort indices to maintain some order consistency
193        selected_indices.sort_unstable();
194        
195        let n_samples = selected_indices.len();
196        let n_features = x.ncols();
197        
198        // Create resampled arrays
199        let mut resampled_x = Array2::zeros((n_samples, n_features));
200        let mut resampled_y = Array1::zeros(n_samples);
201        
202        for (new_idx, &original_idx) in selected_indices.iter().enumerate() {
203            resampled_x.row_mut(new_idx).assign(&x.row(original_idx));
204            resampled_y[new_idx] = y[original_idx];
205        }
206        
207        Ok((resampled_x, resampled_y))
208    }
209    
210    fn performance_hints(&self) -> PerformanceHints {
211        PerformanceHints::new()
212            .with_hint(PerformanceHint::CacheFriendly)
213    }
214}