1use crate::core::{ImputationError, ImputationResult, Imputer};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
8use scirs2_core::random::Random;
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Transform, Untrained},
12 types::Float,
13};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
44pub struct SimpleImputer<S = Untrained> {
45 state: S,
46 missing_values: f64,
47 strategy: String,
48 fill_value: Option<f64>,
49 copy: bool,
50}
51
52#[derive(Debug, Clone)]
54pub struct SimpleImputerTrained {
55 statistics: Array1<f64>,
56 valid_values: Vec<Vec<f64>>,
57}
58
59impl SimpleImputer<Untrained> {
60 pub fn new() -> Self {
62 Self {
63 state: Untrained,
64 missing_values: f64::NAN,
65 strategy: "mean".to_string(),
66 fill_value: None,
67 copy: true,
68 }
69 }
70
71 pub fn missing_values(mut self, missing_values: f64) -> Self {
73 self.missing_values = missing_values;
74 self
75 }
76
77 pub fn strategy(mut self, strategy: String) -> Self {
79 self.strategy = strategy;
80 self
81 }
82
83 pub fn fill_value(mut self, fill_value: Option<f64>) -> Self {
85 self.fill_value = fill_value;
86 self
87 }
88
89 pub fn copy(mut self, copy: bool) -> Self {
91 self.copy = copy;
92 self
93 }
94
95 fn is_missing(&self, value: f64) -> bool {
96 if self.missing_values.is_nan() {
97 value.is_nan()
98 } else {
99 (value - self.missing_values).abs() < f64::EPSILON
100 }
101 }
102}
103
104impl Default for SimpleImputer<Untrained> {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl Estimator for SimpleImputer<Untrained> {
111 type Config = ();
112 type Error = SklearsError;
113 type Float = Float;
114
115 fn config(&self) -> &Self::Config {
116 &()
117 }
118}
119
120impl Fit<ArrayView2<'_, Float>, ()> for SimpleImputer<Untrained> {
121 type Fitted = SimpleImputer<SimpleImputerTrained>;
122
123 #[allow(non_snake_case)]
124 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
125 let X = X.mapv(|x| x);
126 let (_, n_features) = X.dim();
127 let mut statistics = Vec::new();
128 let mut all_valid_values = Vec::new();
129
130 for feature_idx in 0..n_features {
131 let column = X.column(feature_idx);
132 let valid_values: Vec<f64> = column
133 .iter()
134 .filter(|&&x| !self.is_missing(x))
135 .cloned()
136 .collect();
137
138 if valid_values.is_empty() {
139 return Err(SklearsError::InvalidInput(format!(
140 "All values are missing in feature {feature_idx}"
141 )));
142 }
143
144 let statistic = match self.strategy.as_str() {
145 "mean" => {
146 let sum: f64 = valid_values.iter().sum();
147 sum / valid_values.len() as f64
148 }
149 "median" => {
150 let mut sorted_values = valid_values.clone();
151 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
152 let len = sorted_values.len();
153 if len % 2 == 0 {
154 (sorted_values[len / 2 - 1] + sorted_values[len / 2]) / 2.0
155 } else {
156 sorted_values[len / 2]
157 }
158 }
159 "most_frequent" => {
160 let mut counts = HashMap::new();
161 for &value in &valid_values {
162 *counts.entry(value.to_bits()).or_insert(0) += 1;
163 }
164 let most_frequent_bits = counts
165 .into_iter()
166 .max_by_key(|&(_, count)| count)
167 .unwrap()
168 .0;
169 f64::from_bits(most_frequent_bits)
170 }
171 "constant" => self.fill_value.unwrap_or(0.0),
172 "forward_fill" | "backward_fill" => {
173 let sum: f64 = valid_values.iter().sum();
176 sum / valid_values.len() as f64
177 }
178 "random_sampling" => {
179 let sum: f64 = valid_values.iter().sum();
182 sum / valid_values.len() as f64
183 }
184 _ => {
185 return Err(SklearsError::InvalidInput(format!(
186 "Unknown strategy: {}",
187 self.strategy
188 )));
189 }
190 };
191
192 statistics.push(statistic);
193 all_valid_values.push(valid_values.clone());
194 }
195
196 Ok(SimpleImputer {
197 state: SimpleImputerTrained {
198 statistics: Array1::from(statistics),
199 valid_values: all_valid_values,
200 },
201 missing_values: self.missing_values,
202 strategy: self.strategy,
203 fill_value: self.fill_value,
204 copy: self.copy,
205 })
206 }
207}
208
209impl Transform<ArrayView2<'_, Float>, Array2<Float>> for SimpleImputer<SimpleImputerTrained> {
210 #[allow(non_snake_case)]
211 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
212 let X = X.mapv(|x| x);
213 let (n_samples, n_features) = X.dim();
214
215 if n_features != self.state.statistics.len() {
216 return Err(SklearsError::InvalidInput(format!(
217 "Number of features {} does not match training features {}",
218 n_features,
219 self.state.statistics.len()
220 )));
221 }
222
223 let mut X_imputed = if self.copy { X.clone() } else { X };
224
225 match self.strategy.as_str() {
226 "forward_fill" => {
227 for feature_idx in 0..n_features {
228 let mut last_valid = None;
229 for sample_idx in 0..n_samples {
230 let value = X_imputed[[sample_idx, feature_idx]];
231 if self.is_missing(value) {
232 if let Some(fill_value) = last_valid {
233 X_imputed[[sample_idx, feature_idx]] = fill_value;
234 } else {
235 X_imputed[[sample_idx, feature_idx]] =
237 self.state.statistics[feature_idx];
238 }
239 } else {
240 last_valid = Some(value);
241 }
242 }
243 }
244 }
245 "backward_fill" => {
246 for feature_idx in 0..n_features {
247 let mut next_valid = None;
248 for sample_idx in (0..n_samples).rev() {
249 let value = X_imputed[[sample_idx, feature_idx]];
250 if self.is_missing(value) {
251 if let Some(fill_value) = next_valid {
252 X_imputed[[sample_idx, feature_idx]] = fill_value;
253 } else {
254 X_imputed[[sample_idx, feature_idx]] =
256 self.state.statistics[feature_idx];
257 }
258 } else {
259 next_valid = Some(value);
260 }
261 }
262 }
263 }
264 "random_sampling" => {
265 let mut rng = Random::default();
266 for feature_idx in 0..n_features {
267 let valid_values = &self.state.valid_values[feature_idx];
268 if !valid_values.is_empty() {
269 for sample_idx in 0..n_samples {
270 if self.is_missing(X_imputed[[sample_idx, feature_idx]]) {
271 let random_idx = rng.gen_range(0..valid_values.len());
272 let random_value = &valid_values[random_idx];
273 X_imputed[[sample_idx, feature_idx]] = *random_value;
274 }
275 }
276 }
277 }
278 }
279 _ => {
280 for feature_idx in 0..n_features {
282 let fill_value = self.state.statistics[feature_idx];
283 for sample_idx in 0..n_samples {
284 if self.is_missing(X_imputed[[sample_idx, feature_idx]]) {
285 X_imputed[[sample_idx, feature_idx]] = fill_value;
286 }
287 }
288 }
289 }
290 }
291
292 Ok(X_imputed.mapv(|x| x as Float))
293 }
294}
295
296impl SimpleImputer<SimpleImputerTrained> {
297 fn is_missing(&self, value: f64) -> bool {
298 if self.missing_values.is_nan() {
299 value.is_nan()
300 } else {
301 (value - self.missing_values).abs() < f64::EPSILON
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
331pub struct MissingIndicator<S = Untrained> {
332 state: S,
333 missing_values: f64,
334 features: String,
335 sparse: bool,
336 error_on_new: bool,
337}
338
339#[derive(Debug, Clone)]
341pub struct MissingIndicatorTrained {
342 features_: Vec<usize>,
343 n_features_in_: usize,
344}
345
346impl MissingIndicator<Untrained> {
347 pub fn new() -> Self {
349 Self {
350 state: Untrained,
351 missing_values: f64::NAN,
352 features: "missing-only".to_string(),
353 sparse: false,
354 error_on_new: true,
355 }
356 }
357
358 pub fn missing_values(mut self, missing_values: f64) -> Self {
360 self.missing_values = missing_values;
361 self
362 }
363
364 pub fn features(mut self, features: String) -> Self {
366 self.features = features;
367 self
368 }
369
370 pub fn sparse(mut self, sparse: bool) -> Self {
372 self.sparse = sparse;
373 self
374 }
375
376 pub fn error_on_new(mut self, error_on_new: bool) -> Self {
378 self.error_on_new = error_on_new;
379 self
380 }
381
382 fn is_missing(&self, value: f64) -> bool {
383 if self.missing_values.is_nan() {
384 value.is_nan()
385 } else {
386 (value - self.missing_values).abs() < f64::EPSILON
387 }
388 }
389}
390
391impl Default for MissingIndicator<Untrained> {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397impl Estimator for MissingIndicator<Untrained> {
398 type Config = ();
399 type Error = SklearsError;
400 type Float = Float;
401
402 fn config(&self) -> &Self::Config {
403 &()
404 }
405}
406
407impl Fit<ArrayView2<'_, Float>, ()> for MissingIndicator<Untrained> {
408 type Fitted = MissingIndicator<MissingIndicatorTrained>;
409
410 #[allow(non_snake_case)]
411 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
412 let X = X.mapv(|x| x);
413 let (_, n_features) = X.dim();
414
415 let features_ = match self.features.as_str() {
416 "missing-only" => {
417 let mut selected_features = Vec::new();
419 for feature_idx in 0..n_features {
420 let column = X.column(feature_idx);
421 if column.iter().any(|&x| self.is_missing(x)) {
422 selected_features.push(feature_idx);
423 }
424 }
425 selected_features
426 }
427 "all" => (0..n_features).collect(),
428 _ => {
429 return Err(SklearsError::InvalidInput(format!(
430 "Unknown features option: {}",
431 self.features
432 )));
433 }
434 };
435
436 Ok(MissingIndicator {
437 state: MissingIndicatorTrained {
438 features_,
439 n_features_in_: n_features,
440 },
441 missing_values: self.missing_values,
442 features: self.features,
443 sparse: self.sparse,
444 error_on_new: self.error_on_new,
445 })
446 }
447}
448
449impl Transform<ArrayView2<'_, Float>, Array2<Float>> for MissingIndicator<MissingIndicatorTrained> {
450 #[allow(non_snake_case)]
451 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
452 let X = X.mapv(|x| x);
453 let (n_samples, n_features) = X.dim();
454
455 if n_features != self.state.n_features_in_ {
456 return Err(SklearsError::InvalidInput(format!(
457 "Number of features {} does not match training features {}",
458 n_features, self.state.n_features_in_
459 )));
460 }
461
462 if self.error_on_new {
463 for feature_idx in 0..n_features {
465 if !self.state.features_.contains(&feature_idx) {
466 let column = X.column(feature_idx);
467 if column.iter().any(|&x| self.is_missing(x)) {
468 return Err(SklearsError::InvalidInput(format!(
469 "Feature {} has missing values but was not seen during fit",
470 feature_idx
471 )));
472 }
473 }
474 }
475 }
476
477 let n_indicator_features = self.state.features_.len();
478 let mut indicators = Array2::<f64>::zeros((n_samples, n_indicator_features));
479
480 for (indicator_idx, &feature_idx) in self.state.features_.iter().enumerate() {
481 let column = X.column(feature_idx);
482 for (sample_idx, &value) in column.iter().enumerate() {
483 if self.is_missing(value) {
484 indicators[[sample_idx, indicator_idx]] = 1.0;
485 }
486 }
487 }
488
489 Ok(indicators.mapv(|x| x as Float))
490 }
491}
492
493impl MissingIndicator<MissingIndicatorTrained> {
494 fn is_missing(&self, value: f64) -> bool {
495 if self.missing_values.is_nan() {
496 value.is_nan()
497 } else {
498 (value - self.missing_values).abs() < f64::EPSILON
499 }
500 }
501}
502
503impl Imputer for SimpleImputer<Untrained> {
505 #[allow(non_snake_case)]
506 fn fit_transform(
507 &self,
508 X: &scirs2_core::ndarray::ArrayView2<f64>,
509 ) -> ImputationResult<scirs2_core::ndarray::Array2<f64>> {
510 let X_float = X.mapv(|x| x as Float);
512 let X_view = X_float.view();
513
514 let fitted = self.clone().fit(&X_view, &()).map_err(|e| {
516 ImputationError::ProcessingError(format!("Failed to fit imputer: {}", e))
517 })?;
518
519 let result = fitted.transform(&X_view).map_err(|e| {
520 ImputationError::ProcessingError(format!("Failed to transform data: {}", e))
521 })?;
522
523 Ok(result.mapv(|x| x))
525 }
526}