imbalanced_sampling/
random_undersampler.rs1use imbalanced_core::traits::*;
3use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
4use rand::prelude::*;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
12pub struct RandomUnderSampler {
13 sampling_strategy: SamplingMode,
14}
15
16#[derive(Debug, Clone)]
18pub enum SamplingMode {
19 Auto,
21 Ratio(f64),
23 Targets(HashMap<i32, usize>),
25}
26
27#[derive(Debug, Clone)]
29pub struct RandomUnderSamplerConfig {
30 pub sampling_strategy: SamplingMode,
32 pub random_state: Option<u64>,
34 pub replacement: bool,
36}
37
38impl Default for RandomUnderSamplerConfig {
39 fn default() -> Self {
40 Self {
41 sampling_strategy: SamplingMode::Auto,
42 random_state: None,
43 replacement: false,
44 }
45 }
46}
47
48impl RandomUnderSampler {
49 pub fn new() -> Self {
51 Self {
52 sampling_strategy: SamplingMode::Auto,
53 }
54 }
55
56 pub fn with_ratio(ratio: f64) -> Self {
58 Self {
59 sampling_strategy: SamplingMode::Ratio(ratio.clamp(0.0, 1.0)),
60 }
61 }
62
63 pub fn with_targets(targets: HashMap<i32, usize>) -> Self {
65 Self {
66 sampling_strategy: SamplingMode::Targets(targets),
67 }
68 }
69
70 fn calculate_target_counts(
72 &self,
73 class_counts: &HashMap<i32, usize>,
74 _config: &RandomUnderSamplerConfig,
75 ) -> Result<HashMap<i32, usize>, ResamplingError> {
76 match &self.sampling_strategy {
77 SamplingMode::Auto => {
78 let min_count = *class_counts.values().min().unwrap();
80 Ok(class_counts.keys().map(|&class| (class, min_count)).collect())
81 },
82 SamplingMode::Ratio(ratio) => {
83 let max_count = *class_counts.values().max().unwrap();
85 let target_count = (max_count as f64 * ratio) as usize;
86 Ok(class_counts.keys().map(|&class| (class, target_count.min(class_counts[&class]))).collect())
87 },
88 SamplingMode::Targets(targets) => {
89 let mut result = HashMap::new();
91 for (&class, &original_count) in class_counts {
92 let target_count = targets.get(&class).copied().unwrap_or(original_count);
93 if target_count > original_count {
94 return Err(ResamplingError::ConfigError(
95 format!("Target count {} exceeds original count {} for class {}",
96 target_count, original_count, class)
97 ));
98 }
99 result.insert(class, target_count);
100 }
101 Ok(result)
102 }
103 }
104 }
105
106 fn sample_indices(
108 &self,
109 class_indices: &[usize],
110 target_count: usize,
111 replacement: bool,
112 rng: &mut StdRng,
113 ) -> Vec<usize> {
114 if target_count >= class_indices.len() {
115 return class_indices.to_vec();
116 }
117
118 if replacement {
119 (0..target_count)
121 .map(|_| class_indices[rng.gen_range(0..class_indices.len())])
122 .collect()
123 } else {
124 let mut indices = class_indices.to_vec();
126 indices.shuffle(rng);
127 indices.truncate(target_count);
128 indices
129 }
130 }
131}
132
133impl Default for RandomUnderSampler {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl ResamplingStrategy for RandomUnderSampler {
140 type Input = ();
141 type Output = (Array2<f64>, Array1<i32>);
142 type Config = RandomUnderSamplerConfig;
143
144 fn resample(
145 &self,
146 x: ArrayView2<f64>,
147 y: ArrayView1<i32>,
148 config: &Self::Config,
149 ) -> Result<(Array2<f64>, Array1<i32>), ResamplingError> {
150 if x.nrows() != y.len() {
151 return Err(ResamplingError::InvalidInput(
152 "Feature matrix and target array must have same number of samples".to_string()
153 ));
154 }
155
156 let mut class_counts = HashMap::new();
158 for &label in y.iter() {
159 *class_counts.entry(label).or_insert(0) += 1;
160 }
161
162 if class_counts.len() < 2 {
163 return Err(ResamplingError::InvalidInput(
164 "Need at least 2 classes for resampling".to_string()
165 ));
166 }
167
168 let target_counts = self.calculate_target_counts(&class_counts, config)?;
170
171 let mut rng = if let Some(seed) = config.random_state {
172 StdRng::seed_from_u64(seed)
173 } else {
174 StdRng::from_entropy()
175 };
176
177 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
179 for (idx, &label) in y.iter().enumerate() {
180 class_indices.entry(label).or_default().push(idx);
181 }
182
183 let mut selected_indices = Vec::new();
185 for (&class, &target_count) in &target_counts {
186 if let Some(indices) = class_indices.get(&class) {
187 let sampled = self.sample_indices(indices, target_count, config.replacement, &mut rng);
188 selected_indices.extend(sampled);
189 }
190 }
191
192 selected_indices.sort_unstable();
194
195 let n_samples = selected_indices.len();
196 let n_features = x.ncols();
197
198 let mut resampled_x = Array2::zeros((n_samples, n_features));
200 let mut resampled_y = Array1::zeros(n_samples);
201
202 for (new_idx, &original_idx) in selected_indices.iter().enumerate() {
203 resampled_x.row_mut(new_idx).assign(&x.row(original_idx));
204 resampled_y[new_idx] = y[original_idx];
205 }
206
207 Ok((resampled_x, resampled_y))
208 }
209
210 fn performance_hints(&self) -> PerformanceHints {
211 PerformanceHints::new()
212 .with_hint(PerformanceHint::CacheFriendly)
213 }
214}