1use ferrolearn_core::error::FerroError;
16use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
17use ferrolearn_core::traits::{Fit, FitTransform, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21#[derive(Debug, Clone, PartialEq)]
27pub enum ImputeStrategy<F> {
28 Mean,
30 Median,
32 MostFrequent,
34 Constant(F),
36}
37
38#[derive(Debug, Clone)]
66pub struct SimpleImputer<F> {
67 strategy: ImputeStrategy<F>,
68}
69
70impl<F: Float + Send + Sync + 'static> SimpleImputer<F> {
71 #[must_use]
73 pub fn new(strategy: ImputeStrategy<F>) -> Self {
74 Self { strategy }
75 }
76
77 #[must_use]
79 pub fn strategy(&self) -> &ImputeStrategy<F> {
80 &self.strategy
81 }
82}
83
84#[derive(Debug, Clone)]
92pub struct FittedSimpleImputer<F> {
93 fill_values: Array1<F>,
95}
96
97impl<F: Float + Send + Sync + 'static> FittedSimpleImputer<F> {
98 #[must_use]
100 pub fn fill_values(&self) -> &Array1<F> {
101 &self.fill_values
102 }
103}
104
105fn median_of<F: Float>(values: &mut [F]) -> F {
113 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114 let n = values.len();
115 if n % 2 == 1 {
116 values[n / 2]
117 } else {
118 let mid = n / 2;
119 (values[mid - 1] + values[mid]) / (F::one() + F::one())
120 }
121}
122
123fn most_frequent_of<F: Float>(values: &[F]) -> F {
127 let mut sorted = values.to_vec();
130 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
131
132 let mut best_val = sorted[0];
133 let mut best_count = 1usize;
134 let mut current_val = sorted[0];
135 let mut current_count = 1usize;
136
137 for &v in &sorted[1..] {
138 if v == current_val {
139 current_count += 1;
140 } else {
141 if current_count > best_count {
142 best_count = current_count;
143 best_val = current_val;
144 }
145 current_val = v;
146 current_count = 1;
147 }
148 }
149 if current_count > best_count {
151 best_val = current_val;
152 }
153 best_val
154}
155
156impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SimpleImputer<F> {
161 type Fitted = FittedSimpleImputer<F>;
162 type Error = FerroError;
163
164 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSimpleImputer<F>, FerroError> {
174 let n_samples = x.nrows();
175 if n_samples == 0 {
176 return Err(FerroError::InsufficientSamples {
177 required: 1,
178 actual: 0,
179 context: "SimpleImputer::fit".into(),
180 });
181 }
182
183 let n_features = x.ncols();
184 let mut fill_values = Array1::zeros(n_features);
185
186 for j in 0..n_features {
187 let col_vals: Vec<F> = x
188 .column(j)
189 .iter()
190 .copied()
191 .filter(|v| !v.is_nan())
192 .collect();
193
194 let fill = if col_vals.is_empty() {
195 F::zero()
197 } else {
198 match &self.strategy {
199 ImputeStrategy::Mean => {
200 let n = F::from(col_vals.len()).unwrap_or(F::one());
201 col_vals.iter().copied().fold(F::zero(), |acc, v| acc + v) / n
202 }
203 ImputeStrategy::Median => {
204 let mut vals = col_vals.clone();
205 median_of(&mut vals)
206 }
207 ImputeStrategy::MostFrequent => most_frequent_of(&col_vals),
208 ImputeStrategy::Constant(c) => *c,
209 }
210 };
211 fill_values[j] = fill;
212 }
213
214 Ok(FittedSimpleImputer { fill_values })
215 }
216}
217
218impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSimpleImputer<F> {
219 type Output = Array2<F>;
220 type Error = FerroError;
221
222 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
229 let n_features = self.fill_values.len();
230 if x.ncols() != n_features {
231 return Err(FerroError::ShapeMismatch {
232 expected: vec![x.nrows(), n_features],
233 actual: vec![x.nrows(), x.ncols()],
234 context: "FittedSimpleImputer::transform".into(),
235 });
236 }
237
238 let mut out = x.to_owned();
239 for (mut col, &fill) in out.columns_mut().into_iter().zip(self.fill_values.iter()) {
240 for v in col.iter_mut() {
241 if v.is_nan() {
242 *v = fill;
243 }
244 }
245 }
246 Ok(out)
247 }
248}
249
250impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SimpleImputer<F> {
253 type Output = Array2<F>;
254 type Error = FerroError;
255
256 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
261 Err(FerroError::InvalidParameter {
262 name: "SimpleImputer".into(),
263 reason: "imputer must be fitted before calling transform; use fit() first".into(),
264 })
265 }
266}
267
268impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SimpleImputer<F> {
269 type FitError = FerroError;
270
271 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
277 let fitted = self.fit(x, &())?;
278 fitted.transform(x)
279 }
280}
281
282impl PipelineTransformer for SimpleImputer<f64> {
287 fn fit_pipeline(
295 &self,
296 x: &Array2<f64>,
297 _y: &Array1<f64>,
298 ) -> Result<Box<dyn FittedPipelineTransformer>, FerroError> {
299 let fitted = self.fit(x, &())?;
300 Ok(Box::new(fitted))
301 }
302}
303
304impl FittedPipelineTransformer for FittedSimpleImputer<f64> {
305 fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
311 self.transform(x)
312 }
313}
314
315#[cfg(test)]
320mod tests {
321 use super::*;
322 use approx::assert_abs_diff_eq;
323 use ndarray::array;
324
325 #[test]
328 fn test_mean_basic() {
329 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
330 let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
331 let fitted = imputer.fit(&x, &()).unwrap();
332 assert_abs_diff_eq!(fitted.fill_values()[0], 3.0, epsilon = 1e-10);
334 assert_abs_diff_eq!(fitted.fill_values()[1], 5.0, epsilon = 1e-10);
335 let out = fitted.transform(&x).unwrap();
336 assert_abs_diff_eq!(out[[0, 1]], 5.0, epsilon = 1e-10);
337 assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
339 }
340
341 #[test]
342 fn test_mean_no_nan() {
343 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
344 let x = array![[1.0, 2.0], [3.0, 4.0]];
345 let fitted = imputer.fit(&x, &()).unwrap();
346 let out = fitted.transform(&x).unwrap();
347 for (a, b) in x.iter().zip(out.iter()) {
349 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
350 }
351 }
352
353 #[test]
354 fn test_mean_multiple_nans_same_column() {
355 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
356 let x = array![[f64::NAN], [f64::NAN], [6.0]];
357 let fitted = imputer.fit(&x, &()).unwrap();
358 assert_abs_diff_eq!(fitted.fill_values()[0], 6.0, epsilon = 1e-10);
359 let out = fitted.transform(&x).unwrap();
360 assert_abs_diff_eq!(out[[0, 0]], 6.0, epsilon = 1e-10);
361 assert_abs_diff_eq!(out[[1, 0]], 6.0, epsilon = 1e-10);
362 }
363
364 #[test]
365 fn test_mean_all_nan_column_fills_zero() {
366 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
367 let x = array![[f64::NAN], [f64::NAN]];
368 let fitted = imputer.fit(&x, &()).unwrap();
369 assert_abs_diff_eq!(fitted.fill_values()[0], 0.0, epsilon = 1e-15);
370 let out = fitted.transform(&x).unwrap();
371 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-15);
372 }
373
374 #[test]
377 fn test_median_odd_count() {
378 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
379 let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
380 let fitted = imputer.fit(&x, &()).unwrap();
381 assert_abs_diff_eq!(fitted.fill_values()[0], 5.0, epsilon = 1e-10);
382 }
383
384 #[test]
385 fn test_median_even_count() {
386 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
387 let x = array![[1.0], [3.0], [5.0], [7.0]];
388 let fitted = imputer.fit(&x, &()).unwrap();
389 assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
391 }
392
393 #[test]
394 fn test_median_with_nan() {
395 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
396 let x = array![[2.0], [f64::NAN], [4.0], [6.0]];
398 let fitted = imputer.fit(&x, &()).unwrap();
399 assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
400 let out = fitted.transform(&x).unwrap();
401 assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-10);
402 }
403
404 #[test]
407 fn test_most_frequent_basic() {
408 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
409 let x = array![[1.0], [2.0], [2.0], [3.0]];
410 let fitted = imputer.fit(&x, &()).unwrap();
411 assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
412 }
413
414 #[test]
415 fn test_most_frequent_tie_chooses_smallest() {
416 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
417 let x = array![[1.0], [1.0], [3.0], [3.0]];
419 let fitted = imputer.fit(&x, &()).unwrap();
420 assert_abs_diff_eq!(fitted.fill_values()[0], 1.0, epsilon = 1e-10);
421 }
422
423 #[test]
424 fn test_most_frequent_with_nan() {
425 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
426 let x = array![[1.0], [f64::NAN], [2.0], [2.0]];
427 let fitted = imputer.fit(&x, &()).unwrap();
428 assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
429 let out = fitted.transform(&x).unwrap();
430 assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-10);
431 }
432
433 #[test]
436 fn test_constant_strategy() {
437 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(-99.0));
438 let x = array![[1.0, f64::NAN], [f64::NAN, 4.0]];
439 let fitted = imputer.fit(&x, &()).unwrap();
440 assert_abs_diff_eq!(fitted.fill_values()[0], -99.0, epsilon = 1e-15);
441 assert_abs_diff_eq!(fitted.fill_values()[1], -99.0, epsilon = 1e-15);
442 let out = fitted.transform(&x).unwrap();
443 assert_abs_diff_eq!(out[[1, 0]], -99.0, epsilon = 1e-15);
444 assert_abs_diff_eq!(out[[0, 1]], -99.0, epsilon = 1e-15);
445 }
446
447 #[test]
450 fn test_fit_zero_rows_error() {
451 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
452 let x: Array2<f64> = Array2::zeros((0, 3));
453 assert!(imputer.fit(&x, &()).is_err());
454 }
455
456 #[test]
457 fn test_transform_shape_mismatch_error() {
458 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
459 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
460 let fitted = imputer.fit(&x_train, &()).unwrap();
461 let x_bad = array![[1.0, 2.0, 3.0]];
462 assert!(fitted.transform(&x_bad).is_err());
463 }
464
465 #[test]
466 fn test_unfitted_transform_error() {
467 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
468 let x = array![[1.0, 2.0]];
469 assert!(imputer.transform(&x).is_err());
470 }
471
472 #[test]
475 fn test_fit_transform_equivalence() {
476 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
477 let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
478 let via_fit_transform = imputer.fit_transform(&x).unwrap();
479 let fitted = imputer.fit(&x, &()).unwrap();
480 let via_separate = fitted.transform(&x).unwrap();
481 for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
482 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
483 }
484 }
485
486 #[test]
489 fn test_f32_imputer() {
490 let imputer = SimpleImputer::<f32>::new(ImputeStrategy::Mean);
491 let x: Array2<f32> = array![[1.0f32, f32::NAN], [3.0, 4.0]];
492 let fitted = imputer.fit(&x, &()).unwrap();
493 let out = fitted.transform(&x).unwrap();
494 assert!((out[[0, 1]] - 4.0f32).abs() < 1e-6);
495 }
496
497 #[test]
500 fn test_pipeline_integration() {
501 use ferrolearn_core::pipeline::PipelineTransformer;
502
503 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
504 let x = array![[1.0, f64::NAN], [3.0, 4.0]];
505 let y = ndarray::array![0.0, 1.0];
506 let fitted_box = imputer.fit_pipeline(&x, &y).unwrap();
507 let out = fitted_box.transform_pipeline(&x).unwrap();
508 assert!(!out[[0, 1]].is_nan());
510 }
511
512 #[test]
515 fn test_multi_column_mixed_nan() {
516 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
517 let x = array![[f64::NAN, 10.0], [2.0, f64::NAN], [4.0, 30.0], [6.0, 40.0]];
518 let fitted = imputer.fit(&x, &()).unwrap();
519 let out = fitted.transform(&x).unwrap();
520 assert_abs_diff_eq!(out[[0, 0]], 4.0, epsilon = 1e-10);
522 assert_abs_diff_eq!(out[[1, 1]], 30.0, epsilon = 1e-10);
524 }
525
526 #[test]
529 fn test_strategy_accessor() {
530 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(42.0));
531 assert_eq!(imputer.strategy(), &ImputeStrategy::Constant(42.0));
532 }
533}