Skip to main content

ferrolearn_preprocess/
imputer.rs

1//! Simple imputer: fill missing (NaN) values per feature column.
2//!
3//! [`SimpleImputer`] supports four imputation strategies:
4//! - [`ImputeStrategy::Mean`] — replace NaN with the column mean
5//! - [`ImputeStrategy::Median`] — replace NaN with the column median
6//! - [`ImputeStrategy::MostFrequent`] — replace NaN with the most common value
7//! - [`ImputeStrategy::Constant`] — replace NaN with a fixed constant value
8//!
9//! Fitting ignores NaN values when computing statistics (e.g. the mean is the
10//! mean of all non-NaN values in that column).  Columns that are entirely NaN
11//! at fit time are filled with `F::zero()` under `Mean`/`Median` and with the
12//! most frequent non-NaN value (defaulting to `F::zero()`) under
13//! `MostFrequent`.
14
15use ferrolearn_core::error::FerroError;
16use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
17use ferrolearn_core::traits::{Fit, FitTransform, Transform};
18use ndarray::{Array1, Array2};
19use num_traits::Float;
20
21// ---------------------------------------------------------------------------
22// ImputeStrategy
23// ---------------------------------------------------------------------------
24
25/// The strategy used to compute the fill value for each column.
26#[derive(Debug, Clone, PartialEq)]
27pub enum ImputeStrategy<F> {
28    /// Replace NaN with the column mean (ignoring NaN values).
29    Mean,
30    /// Replace NaN with the column median (ignoring NaN values).
31    Median,
32    /// Replace NaN with the most frequently occurring value in the column.
33    MostFrequent,
34    /// Replace NaN with a fixed constant value.
35    Constant(F),
36}
37
38// ---------------------------------------------------------------------------
39// SimpleImputer (unfitted)
40// ---------------------------------------------------------------------------
41
42/// An unfitted simple imputer.
43///
44/// Calling [`Fit::fit`] computes the per-column fill values according to
45/// the chosen [`ImputeStrategy`] and returns a [`FittedSimpleImputer`] that
46/// can transform new data by replacing NaN values with those fill values.
47///
48/// NaN values are *ignored* when computing statistics during fitting — e.g.
49/// the `Mean` strategy computes the mean of only the non-NaN elements.
50///
51/// # Examples
52///
53/// ```
54/// use ferrolearn_preprocess::imputer::{SimpleImputer, ImputeStrategy};
55/// use ferrolearn_core::traits::{Fit, Transform};
56/// use ndarray::array;
57///
58/// let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
59/// let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
60/// let fitted = imputer.fit(&x, &()).unwrap();
61/// let out = fitted.transform(&x).unwrap();
62/// // NaN in column 1 row 0 is replaced with the mean of column 1 = (4+6)/2 = 5.0
63/// assert!((out[[0, 1]] - 5.0).abs() < 1e-10);
64/// ```
65#[derive(Debug, Clone)]
66pub struct SimpleImputer<F> {
67    strategy: ImputeStrategy<F>,
68}
69
70impl<F: Float + Send + Sync + 'static> SimpleImputer<F> {
71    /// Create a new `SimpleImputer` with the given strategy.
72    #[must_use]
73    pub fn new(strategy: ImputeStrategy<F>) -> Self {
74        Self { strategy }
75    }
76
77    /// Return the imputation strategy.
78    #[must_use]
79    pub fn strategy(&self) -> &ImputeStrategy<F> {
80        &self.strategy
81    }
82}
83
84// ---------------------------------------------------------------------------
85// FittedSimpleImputer
86// ---------------------------------------------------------------------------
87
88/// A fitted simple imputer holding one fill value per feature column.
89///
90/// Created by calling [`Fit::fit`] on a [`SimpleImputer`].
91#[derive(Debug, Clone)]
92pub struct FittedSimpleImputer<F> {
93    /// Per-column fill values learned during fitting.
94    fill_values: Array1<F>,
95}
96
97impl<F: Float + Send + Sync + 'static> FittedSimpleImputer<F> {
98    /// Return the per-column fill values learned during fitting.
99    #[must_use]
100    pub fn fill_values(&self) -> &Array1<F> {
101        &self.fill_values
102    }
103}
104
105// ---------------------------------------------------------------------------
106// Helper: compute median of a non-empty Vec (may contain NaN — caller filters)
107// ---------------------------------------------------------------------------
108
109/// Compute the median of a non-empty slice of finite (non-NaN) values.
110///
111/// Uses a sort-and-interpolate approach.  Panics if the slice is empty.
112fn median_of<F: Float>(values: &mut [F]) -> F {
113    values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114    let n = values.len();
115    if n % 2 == 1 {
116        values[n / 2]
117    } else {
118        let mid = n / 2;
119        (values[mid - 1] + values[mid]) / (F::one() + F::one())
120    }
121}
122
123/// Find the most-frequent value in a non-empty slice of finite values.
124///
125/// Ties are broken by choosing the smallest value.
126fn most_frequent_of<F: Float>(values: &[F]) -> F {
127    // Collect (value, count) by scanning; values are finite so partial_cmp is
128    // total.
129    let mut sorted = values.to_vec();
130    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
131
132    let mut best_val = sorted[0];
133    let mut best_count = 1usize;
134    let mut current_val = sorted[0];
135    let mut current_count = 1usize;
136
137    for &v in &sorted[1..] {
138        if v == current_val {
139            current_count += 1;
140        } else {
141            if current_count > best_count {
142                best_count = current_count;
143                best_val = current_val;
144            }
145            current_val = v;
146            current_count = 1;
147        }
148    }
149    // Final run
150    if current_count > best_count {
151        best_val = current_val;
152    }
153    best_val
154}
155
156// ---------------------------------------------------------------------------
157// Trait implementations
158// ---------------------------------------------------------------------------
159
160impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SimpleImputer<F> {
161    type Fitted = FittedSimpleImputer<F>;
162    type Error = FerroError;
163
164    /// Fit the imputer by computing per-column fill values.
165    ///
166    /// NaN values are excluded from the statistic computation.  Columns that
167    /// are entirely NaN at fit time are filled with `F::zero()` for `Mean` and
168    /// `Median`, and `F::zero()` for `MostFrequent`.
169    ///
170    /// # Errors
171    ///
172    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
173    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSimpleImputer<F>, FerroError> {
174        let n_samples = x.nrows();
175        if n_samples == 0 {
176            return Err(FerroError::InsufficientSamples {
177                required: 1,
178                actual: 0,
179                context: "SimpleImputer::fit".into(),
180            });
181        }
182
183        let n_features = x.ncols();
184        let mut fill_values = Array1::zeros(n_features);
185
186        for j in 0..n_features {
187            let col_vals: Vec<F> = x
188                .column(j)
189                .iter()
190                .copied()
191                .filter(|v| !v.is_nan())
192                .collect();
193
194            let fill = if col_vals.is_empty() {
195                // All-NaN column: fall back to zero
196                F::zero()
197            } else {
198                match &self.strategy {
199                    ImputeStrategy::Mean => {
200                        let n = F::from(col_vals.len()).unwrap_or(F::one());
201                        col_vals.iter().copied().fold(F::zero(), |acc, v| acc + v) / n
202                    }
203                    ImputeStrategy::Median => {
204                        let mut vals = col_vals.clone();
205                        median_of(&mut vals)
206                    }
207                    ImputeStrategy::MostFrequent => most_frequent_of(&col_vals),
208                    ImputeStrategy::Constant(c) => *c,
209                }
210            };
211            fill_values[j] = fill;
212        }
213
214        Ok(FittedSimpleImputer { fill_values })
215    }
216}
217
218impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSimpleImputer<F> {
219    type Output = Array2<F>;
220    type Error = FerroError;
221
222    /// Replace NaN values in each column with the learned fill value.
223    ///
224    /// # Errors
225    ///
226    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
227    /// match the number of features seen during fitting.
228    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
229        let n_features = self.fill_values.len();
230        if x.ncols() != n_features {
231            return Err(FerroError::ShapeMismatch {
232                expected: vec![x.nrows(), n_features],
233                actual: vec![x.nrows(), x.ncols()],
234                context: "FittedSimpleImputer::transform".into(),
235            });
236        }
237
238        let mut out = x.to_owned();
239        for (mut col, &fill) in out.columns_mut().into_iter().zip(self.fill_values.iter()) {
240            for v in col.iter_mut() {
241                if v.is_nan() {
242                    *v = fill;
243                }
244            }
245        }
246        Ok(out)
247    }
248}
249
250/// Implement `Transform` on the unfitted imputer to satisfy the
251/// `FitTransform: Transform` supertrait bound.  Always returns an error.
252impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SimpleImputer<F> {
253    type Output = Array2<F>;
254    type Error = FerroError;
255
256    /// Always returns an error — the imputer must be fitted first.
257    ///
258    /// Use [`Fit::fit`] to produce a [`FittedSimpleImputer`], then call
259    /// [`Transform::transform`] on that.
260    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
261        Err(FerroError::InvalidParameter {
262            name: "SimpleImputer".into(),
263            reason: "imputer must be fitted before calling transform; use fit() first".into(),
264        })
265    }
266}
267
268impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SimpleImputer<F> {
269    type FitError = FerroError;
270
271    /// Fit the imputer on `x` and return the imputed output in one step.
272    ///
273    /// # Errors
274    ///
275    /// Returns an error if fitting fails (e.g. zero rows).
276    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
277        let fitted = self.fit(x, &())?;
278        fitted.transform(x)
279    }
280}
281
282// ---------------------------------------------------------------------------
283// Pipeline integration (f64 specialisation)
284// ---------------------------------------------------------------------------
285
286impl PipelineTransformer for SimpleImputer<f64> {
287    /// Fit the imputer using the pipeline interface.
288    ///
289    /// The `y` argument is ignored; it exists only for API compatibility.
290    ///
291    /// # Errors
292    ///
293    /// Propagates errors from [`Fit::fit`].
294    fn fit_pipeline(
295        &self,
296        x: &Array2<f64>,
297        _y: &Array1<f64>,
298    ) -> Result<Box<dyn FittedPipelineTransformer>, FerroError> {
299        let fitted = self.fit(x, &())?;
300        Ok(Box::new(fitted))
301    }
302}
303
304impl FittedPipelineTransformer for FittedSimpleImputer<f64> {
305    /// Transform data using the pipeline interface.
306    ///
307    /// # Errors
308    ///
309    /// Propagates errors from [`Transform::transform`].
310    fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
311        self.transform(x)
312    }
313}
314
315// ---------------------------------------------------------------------------
316// Tests
317// ---------------------------------------------------------------------------
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use approx::assert_abs_diff_eq;
323    use ndarray::array;
324
325    // ---- Mean strategy -------------------------------------------------------
326
327    #[test]
328    fn test_mean_basic() {
329        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
330        let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
331        let fitted = imputer.fit(&x, &()).unwrap();
332        // Column 0 mean = (1+3+5)/3 = 3.0, column 1 mean = (4+6)/2 = 5.0
333        assert_abs_diff_eq!(fitted.fill_values()[0], 3.0, epsilon = 1e-10);
334        assert_abs_diff_eq!(fitted.fill_values()[1], 5.0, epsilon = 1e-10);
335        let out = fitted.transform(&x).unwrap();
336        assert_abs_diff_eq!(out[[0, 1]], 5.0, epsilon = 1e-10);
337        // Non-NaN values must be untouched
338        assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
339    }
340
341    #[test]
342    fn test_mean_no_nan() {
343        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
344        let x = array![[1.0, 2.0], [3.0, 4.0]];
345        let fitted = imputer.fit(&x, &()).unwrap();
346        let out = fitted.transform(&x).unwrap();
347        // Nothing should change
348        for (a, b) in x.iter().zip(out.iter()) {
349            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
350        }
351    }
352
353    #[test]
354    fn test_mean_multiple_nans_same_column() {
355        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
356        let x = array![[f64::NAN], [f64::NAN], [6.0]];
357        let fitted = imputer.fit(&x, &()).unwrap();
358        assert_abs_diff_eq!(fitted.fill_values()[0], 6.0, epsilon = 1e-10);
359        let out = fitted.transform(&x).unwrap();
360        assert_abs_diff_eq!(out[[0, 0]], 6.0, epsilon = 1e-10);
361        assert_abs_diff_eq!(out[[1, 0]], 6.0, epsilon = 1e-10);
362    }
363
364    #[test]
365    fn test_mean_all_nan_column_fills_zero() {
366        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
367        let x = array![[f64::NAN], [f64::NAN]];
368        let fitted = imputer.fit(&x, &()).unwrap();
369        assert_abs_diff_eq!(fitted.fill_values()[0], 0.0, epsilon = 1e-15);
370        let out = fitted.transform(&x).unwrap();
371        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-15);
372    }
373
374    // ---- Median strategy ----------------------------------------------------
375
376    #[test]
377    fn test_median_odd_count() {
378        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
379        let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
380        let fitted = imputer.fit(&x, &()).unwrap();
381        assert_abs_diff_eq!(fitted.fill_values()[0], 5.0, epsilon = 1e-10);
382    }
383
384    #[test]
385    fn test_median_even_count() {
386        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
387        let x = array![[1.0], [3.0], [5.0], [7.0]];
388        let fitted = imputer.fit(&x, &()).unwrap();
389        // Median of [1,3,5,7] = (3+5)/2 = 4.0
390        assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
391    }
392
393    #[test]
394    fn test_median_with_nan() {
395        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
396        // Column 0: non-NaN values are [2, 4, 6], median = 4
397        let x = array![[2.0], [f64::NAN], [4.0], [6.0]];
398        let fitted = imputer.fit(&x, &()).unwrap();
399        assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
400        let out = fitted.transform(&x).unwrap();
401        assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-10);
402    }
403
404    // ---- MostFrequent strategy ----------------------------------------------
405
406    #[test]
407    fn test_most_frequent_basic() {
408        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
409        let x = array![[1.0], [2.0], [2.0], [3.0]];
410        let fitted = imputer.fit(&x, &()).unwrap();
411        assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
412    }
413
414    #[test]
415    fn test_most_frequent_tie_chooses_smallest() {
416        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
417        // 1.0 and 3.0 each appear twice — smallest wins
418        let x = array![[1.0], [1.0], [3.0], [3.0]];
419        let fitted = imputer.fit(&x, &()).unwrap();
420        assert_abs_diff_eq!(fitted.fill_values()[0], 1.0, epsilon = 1e-10);
421    }
422
423    #[test]
424    fn test_most_frequent_with_nan() {
425        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
426        let x = array![[1.0], [f64::NAN], [2.0], [2.0]];
427        let fitted = imputer.fit(&x, &()).unwrap();
428        assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
429        let out = fitted.transform(&x).unwrap();
430        assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-10);
431    }
432
433    // ---- Constant strategy --------------------------------------------------
434
435    #[test]
436    fn test_constant_strategy() {
437        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(-99.0));
438        let x = array![[1.0, f64::NAN], [f64::NAN, 4.0]];
439        let fitted = imputer.fit(&x, &()).unwrap();
440        assert_abs_diff_eq!(fitted.fill_values()[0], -99.0, epsilon = 1e-15);
441        assert_abs_diff_eq!(fitted.fill_values()[1], -99.0, epsilon = 1e-15);
442        let out = fitted.transform(&x).unwrap();
443        assert_abs_diff_eq!(out[[1, 0]], -99.0, epsilon = 1e-15);
444        assert_abs_diff_eq!(out[[0, 1]], -99.0, epsilon = 1e-15);
445    }
446
447    // ---- Error paths --------------------------------------------------------
448
449    #[test]
450    fn test_fit_zero_rows_error() {
451        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
452        let x: Array2<f64> = Array2::zeros((0, 3));
453        assert!(imputer.fit(&x, &()).is_err());
454    }
455
456    #[test]
457    fn test_transform_shape_mismatch_error() {
458        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
459        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
460        let fitted = imputer.fit(&x_train, &()).unwrap();
461        let x_bad = array![[1.0, 2.0, 3.0]];
462        assert!(fitted.transform(&x_bad).is_err());
463    }
464
465    #[test]
466    fn test_unfitted_transform_error() {
467        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
468        let x = array![[1.0, 2.0]];
469        assert!(imputer.transform(&x).is_err());
470    }
471
472    // ---- fit_transform ------------------------------------------------------
473
474    #[test]
475    fn test_fit_transform_equivalence() {
476        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
477        let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
478        let via_fit_transform = imputer.fit_transform(&x).unwrap();
479        let fitted = imputer.fit(&x, &()).unwrap();
480        let via_separate = fitted.transform(&x).unwrap();
481        for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
482            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
483        }
484    }
485
486    // ---- f32 generic --------------------------------------------------------
487
488    #[test]
489    fn test_f32_imputer() {
490        let imputer = SimpleImputer::<f32>::new(ImputeStrategy::Mean);
491        let x: Array2<f32> = array![[1.0f32, f32::NAN], [3.0, 4.0]];
492        let fitted = imputer.fit(&x, &()).unwrap();
493        let out = fitted.transform(&x).unwrap();
494        assert!((out[[0, 1]] - 4.0f32).abs() < 1e-6);
495    }
496
497    // ---- Pipeline integration -----------------------------------------------
498
499    #[test]
500    fn test_pipeline_integration() {
501        use ferrolearn_core::pipeline::PipelineTransformer;
502
503        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
504        let x = array![[1.0, f64::NAN], [3.0, 4.0]];
505        let y = ndarray::array![0.0, 1.0];
506        let fitted_box = imputer.fit_pipeline(&x, &y).unwrap();
507        let out = fitted_box.transform_pipeline(&x).unwrap();
508        // NaN should be gone
509        assert!(!out[[0, 1]].is_nan());
510    }
511
512    // ---- multiple columns with mixed NaN ------------------------------------
513
514    #[test]
515    fn test_multi_column_mixed_nan() {
516        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
517        let x = array![[f64::NAN, 10.0], [2.0, f64::NAN], [4.0, 30.0], [6.0, 40.0]];
518        let fitted = imputer.fit(&x, &()).unwrap();
519        let out = fitted.transform(&x).unwrap();
520        // Column 0 non-NaN = [2,4,6], median = 4
521        assert_abs_diff_eq!(out[[0, 0]], 4.0, epsilon = 1e-10);
522        // Column 1 non-NaN = [10,30,40], median = 30
523        assert_abs_diff_eq!(out[[1, 1]], 30.0, epsilon = 1e-10);
524    }
525
526    // ---- strategy accessor --------------------------------------------------
527
528    #[test]
529    fn test_strategy_accessor() {
530        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(42.0));
531        assert_eq!(imputer.strategy(), &ImputeStrategy::Constant(42.0));
532    }
533}