Skip to main content

ferrolearn_preprocess/
imputer.rs

1//! Simple imputer: fill missing (NaN) values per feature column.
2//!
3//! [`SimpleImputer`] supports four imputation strategies:
4//! - [`ImputeStrategy::Mean`] — replace NaN with the column mean
5//! - [`ImputeStrategy::Median`] — replace NaN with the column median
6//! - [`ImputeStrategy::MostFrequent`] — replace NaN with the most common value
7//! - [`ImputeStrategy::Constant`] — replace NaN with a fixed constant value
8//!
9//! Fitting ignores NaN values when computing statistics (e.g. the mean is the
10//! mean of all non-NaN values in that column).  Under `Mean`/`Median`/
11//! `MostFrequent`, columns that are entirely NaN at fit time have no observed
12//! value, so — mirroring scikit-learn's default `keep_empty_features=False`
13//! (`sklearn/impute/_base.py:501,510-512,534-537` set `statistics_=nan`;
14//! `:586-603` drop them in `transform`) — they are DROPPED from the transform
15//! output.  Under `Constant`, every column (including all-NaN ones) is filled
16//! with the constant and KEPT (sklearn `:545,583`).
17//!
18//! ## REQ status
19//!
20//! Translation target: scikit-learn 1.5.2 `class SimpleImputer` +
21//! `MissingIndicator` (`sklearn/impute/_base.py:147`). Tracking: #1363. Each REQ
22//! is BINARY — SHIPPED (impl + non-test consumer + tests + green verification)
23//! or NOT-STARTED (with a concrete open blocker).
24//!
25//! | REQ | Scope | Status | Evidence / Blocker |
26//! |-----|-------|--------|--------------------|
27//! | REQ-1 | Per-column fill VALUES on columns with ≥1 observed value (Mean/Median/MostFrequent/Constant) | SHIPPED | [`SimpleImputer`] `fit` — Mean=`np.ma.mean` (`_base.py:498`), Median=`np.ma.median` (`:507`, even=avg-of-two-middle), MostFrequent=scipy mode tie→min (`_most_frequent` `:36-71`), Constant (`:545`); 9 oracle value tests in `tests/divergence_imputer.rs`. Consumer: re-export `lib.rs:136` + `PipelineTransformer` |
28//! | REQ-2 | All-NaN column DROP under Mean/Median/MostFrequent (sklearn default `keep_empty_features=False`) | SHIPPED | `fit` sets `fill_values[j]=NaN` + excludes `j` from `kept_indices`; `transform` projects onto `kept_indices` (mirrors `statistics_=nan` + `X=X[:, valid]` `_base.py:586-603`); `Constant` keeps+fills all (`:583`); 10 oracle tests (column-order, all-dropped, separate matrix, f32) — was DIV-1 #1364, fixed |
29//! | REQ-3 | Error/parameter contracts (n_samples==0, transform ncols, unfitted) | SHIPPED (scoped) | [`SimpleImputer::fit`]/[`FittedSimpleImputer`] `transform`; in-module + divergence error tests |
30//! | REQ-4 | `keep_empty_features` param (True → fill 0 + keep all-NaN cols) | NOT-STARTED | always drops; sklearn `_base.py:583,501` — blocker #1365 |
31//! | REQ-5 | `missing_values` param (non-NaN sentinel / None / str) | NOT-STARTED | NaN-only; sklearn `_base.py:161,288` — blocker #1366 |
32//! | REQ-6 | `add_indicator` + `MissingIndicator` estimator (route parity_op, ABSENT) | NOT-STARTED | needs acto-builder; sklearn `_base.py:205` + `MissingIndicator` — blocker #1367 |
33//! | REQ-7 | `inverse_transform` (requires add_indicator) | NOT-STARTED | sklearn `_base.py:641` — blocker #1368 |
34//! | REQ-8 | `fill_value=None`→0 default + `statistics_` attr name + `copy` param | NOT-STARTED | `Constant(F)` explicit; sklearn `_base.py:425-427,223,288` — blocker #1369 |
35//! | REQ-9 | string/object dtype (most_frequent/constant on non-numeric) | NOT-STARTED | `F: Float` only; sklearn `_base.py:42-52,526` — blocker #1370 |
36//! | REQ-10 | sparse `_sparse_fit` | NOT-STARTED | dense `Array2` only; sklearn `_base.py:444` — blocker #1371 |
37//! | REQ-11 | `get_feature_names_out` + `n_features_in_`/`feature_names_in_` | NOT-STARTED | `_BaseImputer` — blocker #1372 |
38//! | REQ-12 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1373 |
39//! | REQ-13 | ferray substrate | NOT-STARTED | dense `Array2` + `num_traits::Float` only — blocker #1374 |
40
41use ferrolearn_core::error::FerroError;
42use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
43use ferrolearn_core::traits::{Fit, FitTransform, Transform};
44use ndarray::{Array1, Array2};
45use num_traits::Float;
46
47// ---------------------------------------------------------------------------
48// ImputeStrategy
49// ---------------------------------------------------------------------------
50
51/// The strategy used to compute the fill value for each column.
52#[derive(Debug, Clone, PartialEq)]
53pub enum ImputeStrategy<F> {
54    /// Replace NaN with the column mean (ignoring NaN values).
55    Mean,
56    /// Replace NaN with the column median (ignoring NaN values).
57    Median,
58    /// Replace NaN with the most frequently occurring value in the column.
59    MostFrequent,
60    /// Replace NaN with a fixed constant value.
61    Constant(F),
62}
63
64// ---------------------------------------------------------------------------
65// SimpleImputer (unfitted)
66// ---------------------------------------------------------------------------
67
68/// An unfitted simple imputer.
69///
70/// Calling [`Fit::fit`] computes the per-column fill values according to
71/// the chosen [`ImputeStrategy`] and returns a [`FittedSimpleImputer`] that
72/// can transform new data by replacing NaN values with those fill values.
73///
74/// NaN values are *ignored* when computing statistics during fitting — e.g.
75/// the `Mean` strategy computes the mean of only the non-NaN elements.
76///
77/// # Examples
78///
79/// ```
80/// use ferrolearn_preprocess::imputer::{SimpleImputer, ImputeStrategy};
81/// use ferrolearn_core::traits::{Fit, Transform};
82/// use ndarray::array;
83///
84/// let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
85/// let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
86/// let fitted = imputer.fit(&x, &()).unwrap();
87/// let out = fitted.transform(&x).unwrap();
88/// // NaN in column 1 row 0 is replaced with the mean of column 1 = (4+6)/2 = 5.0
89/// assert!((out[[0, 1]] - 5.0).abs() < 1e-10);
90/// ```
91#[derive(Debug, Clone)]
92pub struct SimpleImputer<F> {
93    strategy: ImputeStrategy<F>,
94}
95
96impl<F: Float + Send + Sync + 'static> SimpleImputer<F> {
97    /// Create a new `SimpleImputer` with the given strategy.
98    #[must_use]
99    pub fn new(strategy: ImputeStrategy<F>) -> Self {
100        Self { strategy }
101    }
102
103    /// Return the imputation strategy.
104    #[must_use]
105    pub fn strategy(&self) -> &ImputeStrategy<F> {
106        &self.strategy
107    }
108}
109
110// ---------------------------------------------------------------------------
111// FittedSimpleImputer
112// ---------------------------------------------------------------------------
113
114/// A fitted simple imputer holding one fill value per feature column.
115///
116/// Created by calling [`Fit::fit`] on a [`SimpleImputer`].
117#[derive(Debug, Clone)]
118pub struct FittedSimpleImputer<F> {
119    /// Per-INPUT-column fill values learned during fitting.
120    ///
121    /// One entry per input column, mirroring scikit-learn's `statistics_`:
122    /// holds `F::nan()` for an all-NaN non-constant column that is dropped, and
123    /// the computed fill statistic (or the user constant) otherwise.
124    fill_values: Array1<F>,
125    /// Input-column indices that survive transform, in ascending order.
126    ///
127    /// Under `Mean`/`Median`/`MostFrequent` an all-NaN column has no observed
128    /// value and is excluded (sklearn `keep_empty_features=False`); under
129    /// `Constant` every column is kept.
130    kept_indices: Vec<usize>,
131}
132
133impl<F: Float + Send + Sync + 'static> FittedSimpleImputer<F> {
134    /// Return the per-input-column fill values learned during fitting.
135    ///
136    /// Mirrors scikit-learn's `statistics_`: entries for all-NaN columns that
137    /// are dropped under `Mean`/`Median`/`MostFrequent` are `F::nan()`.
138    #[must_use]
139    pub fn fill_values(&self) -> &Array1<F> {
140        &self.fill_values
141    }
142
143    /// Return the input-column indices that survive `transform`, ascending.
144    #[must_use]
145    pub fn kept_indices(&self) -> &[usize] {
146        &self.kept_indices
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Helper: compute median of a non-empty Vec (may contain NaN — caller filters)
152// ---------------------------------------------------------------------------
153
154/// Sum a slice using numpy's pairwise-summation algorithm, in `F` precision.
155///
156/// scikit-learn's `Mean` strategy computes the per-column mean via
157/// `np.ma.mean(masked_X, axis=0)` (`sklearn/impute/_base.py:498`), whose
158/// reduction is numpy's pairwise summation over the observed values in the input
159/// dtype.  A naive left-to-right fold diverges from this by many ULPs for an
160/// `f32` column.  This mirrors numpy's `pairwise_sum`
161/// (`numpy/_core/src/umath/loops_utils.h.src`): blocks of `len > 128` split in
162/// half (with the split rounded down to a multiple of 8), and the `<= 128` base
163/// case accumulates into 8 partial sums (unrolled by 8) before combining them as
164/// a balanced tree `((r0+r1)+(r2+r3)) + ((r4+r5)+(r6+r7))`.  For `F = f64`
165/// pairwise and sequential agree to f64 ULPs.
166fn pairwise_sum<F: Float>(values: &[F]) -> F {
167    let n = values.len();
168    if n == 0 {
169        return F::zero();
170    }
171    if n < 8 {
172        // Sequential base case for very short runs (numpy does the same).
173        let mut s = values[0];
174        for &v in &values[1..] {
175            s = s + v;
176        }
177        return s;
178    }
179    if n <= 128 {
180        // Eight partial accumulators, unrolled by 8 (numpy's inner block).
181        let mut r = [F::zero(); 8];
182        r.copy_from_slice(&values[..8]);
183        let mut i = 8;
184        while i + 8 <= n {
185            for j in 0..8 {
186                r[j] = r[j] + values[i + j];
187            }
188            i += 8;
189        }
190        // Balanced-tree combine of the eight partials.
191        let mut res = ((r[0] + r[1]) + (r[2] + r[3])) + ((r[4] + r[5]) + (r[6] + r[7]));
192        // Tail elements (n not a multiple of 8) folded sequentially.
193        for &v in &values[i..] {
194            res = res + v;
195        }
196        return res;
197    }
198    // Recursive split; numpy rounds the half-point down to a multiple of 8.
199    let mut half = n / 2;
200    half -= half % 8;
201    pairwise_sum(&values[..half]) + pairwise_sum(&values[half..])
202}
203
204/// Compute the median of a non-empty slice of finite (non-NaN) values.
205///
206/// Uses a sort-and-interpolate approach.  Panics if the slice is empty.
207fn median_of<F: Float>(values: &mut [F]) -> F {
208    values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
209    let n = values.len();
210    if n % 2 == 1 {
211        values[n / 2]
212    } else {
213        let mid = n / 2;
214        (values[mid - 1] + values[mid]) / (F::one() + F::one())
215    }
216}
217
218/// Find the most-frequent value in a non-empty slice of finite values.
219///
220/// Ties are broken by choosing the smallest value.
221fn most_frequent_of<F: Float>(values: &[F]) -> F {
222    // Collect (value, count) by scanning; values are finite so partial_cmp is
223    // total.
224    let mut sorted = values.to_vec();
225    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
226
227    let mut best_val = sorted[0];
228    let mut best_count = 1usize;
229    let mut current_val = sorted[0];
230    let mut current_count = 1usize;
231
232    for &v in &sorted[1..] {
233        if v == current_val {
234            current_count += 1;
235        } else {
236            if current_count > best_count {
237                best_count = current_count;
238                best_val = current_val;
239            }
240            current_val = v;
241            current_count = 1;
242        }
243    }
244    // Final run
245    if current_count > best_count {
246        best_val = current_val;
247    }
248    best_val
249}
250
251// ---------------------------------------------------------------------------
252// Trait implementations
253// ---------------------------------------------------------------------------
254
255impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SimpleImputer<F> {
256    type Fitted = FittedSimpleImputer<F>;
257    type Error = FerroError;
258
259    /// Fit the imputer by computing per-column fill values.
260    ///
261    /// NaN values are excluded from the statistic computation.  Under
262    /// `Mean`/`Median`/`MostFrequent`, a column that is entirely NaN has no
263    /// observed value: its `fill_values` entry is set to `F::nan()` and it is
264    /// excluded from `kept_indices`, so `transform` DROPS it (mirroring
265    /// scikit-learn `keep_empty_features=False`, `sklearn/impute/_base.py:501,
266    /// 510-512,534-537,586-603`).  Under `Constant`, every column is filled
267    /// with the constant and kept (sklearn `:545,583`).
268    ///
269    /// # Errors
270    ///
271    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
272    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSimpleImputer<F>, FerroError> {
273        let n_samples = x.nrows();
274        if n_samples == 0 {
275            return Err(FerroError::InsufficientSamples {
276                required: 1,
277                actual: 0,
278                context: "SimpleImputer::fit".into(),
279            });
280        }
281
282        let n_features = x.ncols();
283        let mut fill_values = Array1::zeros(n_features);
284        let mut kept_indices: Vec<usize> = Vec::with_capacity(n_features);
285
286        for j in 0..n_features {
287            let col_vals: Vec<F> = x
288                .column(j)
289                .iter()
290                .copied()
291                .filter(|v| !v.is_nan())
292                .collect();
293
294            // Constant fills (and keeps) every column, including all-NaN ones
295            // (sklearn `np.full(X.shape[1], fill_value)`, `_base.py:545,583`).
296            if let ImputeStrategy::Constant(c) = &self.strategy {
297                fill_values[j] = *c;
298                kept_indices.push(j);
299                continue;
300            }
301
302            if col_vals.is_empty() {
303                // All-NaN column with no observed value: sklearn sets
304                // `statistics_=nan` and DROPS it (`_base.py:501,510-512,
305                // 534-537,586-603`).
306                fill_values[j] = F::nan();
307                continue;
308            }
309
310            fill_values[j] = match &self.strategy {
311                ImputeStrategy::Mean => {
312                    // sklearn computes the mean via `np.ma.mean(masked_X, axis=0)`
313                    // (`sklearn/impute/_base.py:498`). `np.ma.mean` divides
314                    // `MaskedArray.sum` by the count of observed (non-masked)
315                    // elements; `MaskedArray.sum` does `self.filled(0).sum(axis)`
316                    // (`numpy/ma/core.py:5242,5251`), i.e. it sums the FULL-LENGTH
317                    // column with masked (NaN) entries set to 0, using numpy's
318                    // PAIRWISE summation, then divides by the OBSERVED count. The
319                    // fill rounds to `F` only at the transform assignment into the
320                    // output array (`:625-635`).
321                    //
322                    // numpy's pairwise tree shape depends on the FULL array length
323                    // and element POSITIONS, so summing the full column (NaN->0) is
324                    // NOT bit-equal to summing only the compressed observed values
325                    // when NaN is scattered (the zeros sit at different tree
326                    // positions, shifting f32 partial sums by a few ULPs). Build the
327                    // full-length NaN->0 column and pairwise-sum THAT, then divide by
328                    // the observed count, to be bit-identical to `np.ma.mean`.
329                    //
330                    // For F=f64 pairwise and sequential agree to f64 ULPs, and with
331                    // no NaN the full-length and compressed sums are identical (the
332                    // #2308 no-NaN pin and the f64 oracle tests guard no-regression).
333                    let col_filled: Vec<F> = x
334                        .column(j)
335                        .iter()
336                        .map(|v| if v.is_nan() { F::zero() } else { *v })
337                        .collect();
338                    let n = F::from(col_vals.len()).unwrap_or_else(F::one);
339                    pairwise_sum(&col_filled) / n
340                }
341                ImputeStrategy::Median => {
342                    let mut vals = col_vals.clone();
343                    median_of(&mut vals)
344                }
345                ImputeStrategy::MostFrequent => most_frequent_of(&col_vals),
346                // Constant handled above.
347                ImputeStrategy::Constant(c) => *c,
348            };
349            kept_indices.push(j);
350        }
351
352        Ok(FittedSimpleImputer {
353            fill_values,
354            kept_indices,
355        })
356    }
357}
358
359impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSimpleImputer<F> {
360    type Output = Array2<F>;
361    type Error = FerroError;
362
363    /// Replace NaN values with the learned fill value, projecting onto the
364    /// columns that survived fitting.
365    ///
366    /// The transform input must have the same number of columns as the fit
367    /// input (the full input width, `fill_values.len()`), matching scikit-learn
368    /// which validates against `statistics_.shape[0]` (`_base.py:573-577`).
369    /// The OUTPUT keeps only [`Self::kept_indices`] columns, in ascending
370    /// order — dropping all-NaN columns under `Mean`/`Median`/`MostFrequent`
371    /// (sklearn `X = X[:, valid_statistics_indexes]`, `_base.py:586-603`).
372    ///
373    /// # Errors
374    ///
375    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
376    /// match the number of features seen during fitting.
377    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
378        let n_features = self.fill_values.len();
379        if x.ncols() != n_features {
380            return Err(FerroError::ShapeMismatch {
381                expected: vec![x.nrows(), n_features],
382                actual: vec![x.nrows(), x.ncols()],
383                context: "FittedSimpleImputer::transform".into(),
384            });
385        }
386
387        // Gather the surviving columns (the column-projection pattern used
388        // elsewhere, e.g. select_from_model's `select_columns`), imputing NaN
389        // with each column's learned fill value as we go.
390        let mut out = Array2::zeros((x.nrows(), self.kept_indices.len()));
391        for (out_j, &in_j) in self.kept_indices.iter().enumerate() {
392            let fill = self.fill_values[in_j];
393            for (row, &v) in x.column(in_j).iter().enumerate() {
394                out[[row, out_j]] = if v.is_nan() { fill } else { v };
395            }
396        }
397        Ok(out)
398    }
399}
400
401/// Implement `Transform` on the unfitted imputer to satisfy the
402/// `FitTransform: Transform` supertrait bound.  Always returns an error.
403impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SimpleImputer<F> {
404    type Output = Array2<F>;
405    type Error = FerroError;
406
407    /// Always returns an error — the imputer must be fitted first.
408    ///
409    /// Use [`Fit::fit`] to produce a [`FittedSimpleImputer`], then call
410    /// [`Transform::transform`] on that.
411    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
412        Err(FerroError::InvalidParameter {
413            name: "SimpleImputer".into(),
414            reason: "imputer must be fitted before calling transform; use fit() first".into(),
415        })
416    }
417}
418
419impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SimpleImputer<F> {
420    type FitError = FerroError;
421
422    /// Fit the imputer on `x` and return the imputed output in one step.
423    ///
424    /// # Errors
425    ///
426    /// Returns an error if fitting fails (e.g. zero rows).
427    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
428        let fitted = self.fit(x, &())?;
429        fitted.transform(x)
430    }
431}
432
433// ---------------------------------------------------------------------------
434// Pipeline integration (generic)
435// ---------------------------------------------------------------------------
436
437impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SimpleImputer<F> {
438    /// Fit the imputer using the pipeline interface.
439    ///
440    /// The `y` argument is ignored; it exists only for API compatibility.
441    ///
442    /// # Errors
443    ///
444    /// Propagates errors from [`Fit::fit`].
445    fn fit_pipeline(
446        &self,
447        x: &Array2<F>,
448        _y: &Array1<F>,
449    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
450        let fitted = self.fit(x, &())?;
451        Ok(Box::new(fitted))
452    }
453}
454
455impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedSimpleImputer<F> {
456    /// Transform data using the pipeline interface.
457    ///
458    /// # Errors
459    ///
460    /// Propagates errors from [`Transform::transform`].
461    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
462        self.transform(x)
463    }
464}
465
466// ---------------------------------------------------------------------------
467// Tests
468// ---------------------------------------------------------------------------
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use approx::assert_abs_diff_eq;
474    use ndarray::array;
475
476    // ---- Mean strategy -------------------------------------------------------
477
478    #[test]
479    fn test_mean_basic() {
480        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
481        let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
482        let fitted = imputer.fit(&x, &()).unwrap();
483        // Column 0 mean = (1+3+5)/3 = 3.0, column 1 mean = (4+6)/2 = 5.0
484        assert_abs_diff_eq!(fitted.fill_values()[0], 3.0, epsilon = 1e-10);
485        assert_abs_diff_eq!(fitted.fill_values()[1], 5.0, epsilon = 1e-10);
486        let out = fitted.transform(&x).unwrap();
487        assert_abs_diff_eq!(out[[0, 1]], 5.0, epsilon = 1e-10);
488        // Non-NaN values must be untouched
489        assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
490    }
491
492    #[test]
493    fn test_mean_no_nan() {
494        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
495        let x = array![[1.0, 2.0], [3.0, 4.0]];
496        let fitted = imputer.fit(&x, &()).unwrap();
497        let out = fitted.transform(&x).unwrap();
498        // Nothing should change
499        for (a, b) in x.iter().zip(out.iter()) {
500            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
501        }
502    }
503
504    #[test]
505    fn test_mean_multiple_nans_same_column() {
506        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
507        let x = array![[f64::NAN], [f64::NAN], [6.0]];
508        let fitted = imputer.fit(&x, &()).unwrap();
509        assert_abs_diff_eq!(fitted.fill_values()[0], 6.0, epsilon = 1e-10);
510        let out = fitted.transform(&x).unwrap();
511        assert_abs_diff_eq!(out[[0, 0]], 6.0, epsilon = 1e-10);
512        assert_abs_diff_eq!(out[[1, 0]], 6.0, epsilon = 1e-10);
513    }
514
515    #[test]
516    fn test_mean_all_nan_column_dropped() {
517        // sklearn `keep_empty_features=False` (default): an all-NaN column has
518        // no observed value, so `statistics_=nan` and `transform` DROPS it
519        // (`sklearn/impute/_base.py:586-603`). A single all-NaN input column
520        // therefore yields ZERO output columns.
521        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
522        let x = array![[f64::NAN], [f64::NAN]];
523        let fitted = match imputer.fit(&x, &()) {
524            Ok(f) => f,
525            #[allow(
526                clippy::assertions_on_constants,
527                reason = "error arm fails loudly without panic!/unwrap (anti-pattern gate)"
528            )]
529            Err(e) => {
530                assert!(false, "fit errored: {e}");
531                return;
532            }
533        };
534        // statistics_ entry is NaN (mirrors sklearn `statistics_`).
535        assert!(fitted.fill_values()[0].is_nan());
536        match fitted.transform(&x) {
537            Ok(out) => {
538                assert_eq!(out.ncols(), 0, "all-NaN column dropped -> 0 output columns");
539                assert_eq!(out.nrows(), 2);
540            }
541            #[allow(
542                clippy::assertions_on_constants,
543                reason = "error arm fails loudly without panic!/unwrap (anti-pattern gate)"
544            )]
545            Err(e) => assert!(false, "transform errored: {e}"),
546        }
547    }
548
549    // ---- Median strategy ----------------------------------------------------
550
551    #[test]
552    fn test_median_odd_count() {
553        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
554        let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
555        let fitted = imputer.fit(&x, &()).unwrap();
556        assert_abs_diff_eq!(fitted.fill_values()[0], 5.0, epsilon = 1e-10);
557    }
558
559    #[test]
560    fn test_median_even_count() {
561        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
562        let x = array![[1.0], [3.0], [5.0], [7.0]];
563        let fitted = imputer.fit(&x, &()).unwrap();
564        // Median of [1,3,5,7] = (3+5)/2 = 4.0
565        assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
566    }
567
568    #[test]
569    fn test_median_with_nan() {
570        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
571        // Column 0: non-NaN values are [2, 4, 6], median = 4
572        let x = array![[2.0], [f64::NAN], [4.0], [6.0]];
573        let fitted = imputer.fit(&x, &()).unwrap();
574        assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
575        let out = fitted.transform(&x).unwrap();
576        assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-10);
577    }
578
579    // ---- MostFrequent strategy ----------------------------------------------
580
581    #[test]
582    fn test_most_frequent_basic() {
583        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
584        let x = array![[1.0], [2.0], [2.0], [3.0]];
585        let fitted = imputer.fit(&x, &()).unwrap();
586        assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
587    }
588
589    #[test]
590    fn test_most_frequent_tie_chooses_smallest() {
591        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
592        // 1.0 and 3.0 each appear twice — smallest wins
593        let x = array![[1.0], [1.0], [3.0], [3.0]];
594        let fitted = imputer.fit(&x, &()).unwrap();
595        assert_abs_diff_eq!(fitted.fill_values()[0], 1.0, epsilon = 1e-10);
596    }
597
598    #[test]
599    fn test_most_frequent_with_nan() {
600        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
601        let x = array![[1.0], [f64::NAN], [2.0], [2.0]];
602        let fitted = imputer.fit(&x, &()).unwrap();
603        assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
604        let out = fitted.transform(&x).unwrap();
605        assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-10);
606    }
607
608    // ---- Constant strategy --------------------------------------------------
609
610    #[test]
611    fn test_constant_strategy() {
612        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(-99.0));
613        let x = array![[1.0, f64::NAN], [f64::NAN, 4.0]];
614        let fitted = imputer.fit(&x, &()).unwrap();
615        assert_abs_diff_eq!(fitted.fill_values()[0], -99.0, epsilon = 1e-15);
616        assert_abs_diff_eq!(fitted.fill_values()[1], -99.0, epsilon = 1e-15);
617        let out = fitted.transform(&x).unwrap();
618        assert_abs_diff_eq!(out[[1, 0]], -99.0, epsilon = 1e-15);
619        assert_abs_diff_eq!(out[[0, 1]], -99.0, epsilon = 1e-15);
620    }
621
622    // ---- Error paths --------------------------------------------------------
623
624    #[test]
625    fn test_fit_zero_rows_error() {
626        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
627        let x: Array2<f64> = Array2::zeros((0, 3));
628        assert!(imputer.fit(&x, &()).is_err());
629    }
630
631    #[test]
632    fn test_transform_shape_mismatch_error() {
633        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
634        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
635        let fitted = imputer.fit(&x_train, &()).unwrap();
636        let x_bad = array![[1.0, 2.0, 3.0]];
637        assert!(fitted.transform(&x_bad).is_err());
638    }
639
640    #[test]
641    fn test_unfitted_transform_error() {
642        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
643        let x = array![[1.0, 2.0]];
644        assert!(imputer.transform(&x).is_err());
645    }
646
647    // ---- fit_transform ------------------------------------------------------
648
649    #[test]
650    fn test_fit_transform_equivalence() {
651        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
652        let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
653        let via_fit_transform = imputer.fit_transform(&x).unwrap();
654        let fitted = imputer.fit(&x, &()).unwrap();
655        let via_separate = fitted.transform(&x).unwrap();
656        for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
657            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
658        }
659    }
660
661    // ---- f32 generic --------------------------------------------------------
662
663    #[test]
664    fn test_f32_imputer() {
665        let imputer = SimpleImputer::<f32>::new(ImputeStrategy::Mean);
666        let x: Array2<f32> = array![[1.0f32, f32::NAN], [3.0, 4.0]];
667        let fitted = imputer.fit(&x, &()).unwrap();
668        let out = fitted.transform(&x).unwrap();
669        assert!((out[[0, 1]] - 4.0f32).abs() < 1e-6);
670    }
671
672    // ---- Pipeline integration -----------------------------------------------
673
674    #[test]
675    fn test_pipeline_integration() {
676        use ferrolearn_core::pipeline::PipelineTransformer;
677
678        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
679        let x = array![[1.0, f64::NAN], [3.0, 4.0]];
680        let y = ndarray::array![0.0, 1.0];
681        let fitted_box = imputer.fit_pipeline(&x, &y).unwrap();
682        let out = fitted_box.transform_pipeline(&x).unwrap();
683        // NaN should be gone
684        assert!(!out[[0, 1]].is_nan());
685    }
686
687    // ---- multiple columns with mixed NaN ------------------------------------
688
689    #[test]
690    fn test_multi_column_mixed_nan() {
691        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
692        let x = array![[f64::NAN, 10.0], [2.0, f64::NAN], [4.0, 30.0], [6.0, 40.0]];
693        let fitted = imputer.fit(&x, &()).unwrap();
694        let out = fitted.transform(&x).unwrap();
695        // Column 0 non-NaN = [2,4,6], median = 4
696        assert_abs_diff_eq!(out[[0, 0]], 4.0, epsilon = 1e-10);
697        // Column 1 non-NaN = [10,30,40], median = 30
698        assert_abs_diff_eq!(out[[1, 1]], 30.0, epsilon = 1e-10);
699    }
700
701    // ---- strategy accessor --------------------------------------------------
702
703    #[test]
704    fn test_strategy_accessor() {
705        let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(42.0));
706        assert_eq!(imputer.strategy(), &ImputeStrategy::Constant(42.0));
707    }
708}