ferrolearn_preprocess/imputer.rs
1//! Simple imputer: fill missing (NaN) values per feature column.
2//!
3//! [`SimpleImputer`] supports four imputation strategies:
4//! - [`ImputeStrategy::Mean`] — replace NaN with the column mean
5//! - [`ImputeStrategy::Median`] — replace NaN with the column median
6//! - [`ImputeStrategy::MostFrequent`] — replace NaN with the most common value
7//! - [`ImputeStrategy::Constant`] — replace NaN with a fixed constant value
8//!
9//! Fitting ignores NaN values when computing statistics (e.g. the mean is the
10//! mean of all non-NaN values in that column). Under `Mean`/`Median`/
11//! `MostFrequent`, columns that are entirely NaN at fit time have no observed
12//! value, so — mirroring scikit-learn's default `keep_empty_features=False`
13//! (`sklearn/impute/_base.py:501,510-512,534-537` set `statistics_=nan`;
14//! `:586-603` drop them in `transform`) — they are DROPPED from the transform
15//! output. Under `Constant`, every column (including all-NaN ones) is filled
16//! with the constant and KEPT (sklearn `:545,583`).
17//!
18//! ## REQ status
19//!
20//! Translation target: scikit-learn 1.5.2 `class SimpleImputer` +
21//! `MissingIndicator` (`sklearn/impute/_base.py:147`). Tracking: #1363. Each REQ
22//! is BINARY — SHIPPED (impl + non-test consumer + tests + green verification)
23//! or NOT-STARTED (with a concrete open blocker).
24//!
25//! | REQ | Scope | Status | Evidence / Blocker |
26//! |-----|-------|--------|--------------------|
27//! | REQ-1 | Per-column fill VALUES on columns with ≥1 observed value (Mean/Median/MostFrequent/Constant) | SHIPPED | [`SimpleImputer`] `fit` — Mean=`np.ma.mean` (`_base.py:498`), Median=`np.ma.median` (`:507`, even=avg-of-two-middle), MostFrequent=scipy mode tie→min (`_most_frequent` `:36-71`), Constant (`:545`); 9 oracle value tests in `tests/divergence_imputer.rs`. Consumer: re-export `lib.rs:136` + `PipelineTransformer` |
28//! | REQ-2 | All-NaN column DROP under Mean/Median/MostFrequent (sklearn default `keep_empty_features=False`) | SHIPPED | `fit` sets `fill_values[j]=NaN` + excludes `j` from `kept_indices`; `transform` projects onto `kept_indices` (mirrors `statistics_=nan` + `X=X[:, valid]` `_base.py:586-603`); `Constant` keeps+fills all (`:583`); 10 oracle tests (column-order, all-dropped, separate matrix, f32) — was DIV-1 #1364, fixed |
29//! | REQ-3 | Error/parameter contracts (n_samples==0, transform ncols, unfitted) | SHIPPED (scoped) | [`SimpleImputer::fit`]/[`FittedSimpleImputer`] `transform`; in-module + divergence error tests |
30//! | REQ-4 | `keep_empty_features` param (True → fill 0 + keep all-NaN cols) | NOT-STARTED | always drops; sklearn `_base.py:583,501` — blocker #1365 |
31//! | REQ-5 | `missing_values` param (non-NaN sentinel / None / str) | NOT-STARTED | NaN-only; sklearn `_base.py:161,288` — blocker #1366 |
32//! | REQ-6 | `add_indicator` + `MissingIndicator` estimator (route parity_op, ABSENT) | NOT-STARTED | needs acto-builder; sklearn `_base.py:205` + `MissingIndicator` — blocker #1367 |
33//! | REQ-7 | `inverse_transform` (requires add_indicator) | NOT-STARTED | sklearn `_base.py:641` — blocker #1368 |
34//! | REQ-8 | `fill_value=None`→0 default + `statistics_` attr name + `copy` param | NOT-STARTED | `Constant(F)` explicit; sklearn `_base.py:425-427,223,288` — blocker #1369 |
35//! | REQ-9 | string/object dtype (most_frequent/constant on non-numeric) | NOT-STARTED | `F: Float` only; sklearn `_base.py:42-52,526` — blocker #1370 |
36//! | REQ-10 | sparse `_sparse_fit` | NOT-STARTED | dense `Array2` only; sklearn `_base.py:444` — blocker #1371 |
37//! | REQ-11 | `get_feature_names_out` + `n_features_in_`/`feature_names_in_` | NOT-STARTED | `_BaseImputer` — blocker #1372 |
38//! | REQ-12 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1373 |
39//! | REQ-13 | ferray substrate | NOT-STARTED | dense `Array2` + `num_traits::Float` only — blocker #1374 |
40
41use ferrolearn_core::error::FerroError;
42use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
43use ferrolearn_core::traits::{Fit, FitTransform, Transform};
44use ndarray::{Array1, Array2};
45use num_traits::Float;
46
47// ---------------------------------------------------------------------------
48// ImputeStrategy
49// ---------------------------------------------------------------------------
50
51/// The strategy used to compute the fill value for each column.
52#[derive(Debug, Clone, PartialEq)]
53pub enum ImputeStrategy<F> {
54 /// Replace NaN with the column mean (ignoring NaN values).
55 Mean,
56 /// Replace NaN with the column median (ignoring NaN values).
57 Median,
58 /// Replace NaN with the most frequently occurring value in the column.
59 MostFrequent,
60 /// Replace NaN with a fixed constant value.
61 Constant(F),
62}
63
64// ---------------------------------------------------------------------------
65// SimpleImputer (unfitted)
66// ---------------------------------------------------------------------------
67
68/// An unfitted simple imputer.
69///
70/// Calling [`Fit::fit`] computes the per-column fill values according to
71/// the chosen [`ImputeStrategy`] and returns a [`FittedSimpleImputer`] that
72/// can transform new data by replacing NaN values with those fill values.
73///
74/// NaN values are *ignored* when computing statistics during fitting — e.g.
75/// the `Mean` strategy computes the mean of only the non-NaN elements.
76///
77/// # Examples
78///
79/// ```
80/// use ferrolearn_preprocess::imputer::{SimpleImputer, ImputeStrategy};
81/// use ferrolearn_core::traits::{Fit, Transform};
82/// use ndarray::array;
83///
84/// let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
85/// let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
86/// let fitted = imputer.fit(&x, &()).unwrap();
87/// let out = fitted.transform(&x).unwrap();
88/// // NaN in column 1 row 0 is replaced with the mean of column 1 = (4+6)/2 = 5.0
89/// assert!((out[[0, 1]] - 5.0).abs() < 1e-10);
90/// ```
91#[derive(Debug, Clone)]
92pub struct SimpleImputer<F> {
93 strategy: ImputeStrategy<F>,
94}
95
96impl<F: Float + Send + Sync + 'static> SimpleImputer<F> {
97 /// Create a new `SimpleImputer` with the given strategy.
98 #[must_use]
99 pub fn new(strategy: ImputeStrategy<F>) -> Self {
100 Self { strategy }
101 }
102
103 /// Return the imputation strategy.
104 #[must_use]
105 pub fn strategy(&self) -> &ImputeStrategy<F> {
106 &self.strategy
107 }
108}
109
110// ---------------------------------------------------------------------------
111// FittedSimpleImputer
112// ---------------------------------------------------------------------------
113
114/// A fitted simple imputer holding one fill value per feature column.
115///
116/// Created by calling [`Fit::fit`] on a [`SimpleImputer`].
117#[derive(Debug, Clone)]
118pub struct FittedSimpleImputer<F> {
119 /// Per-INPUT-column fill values learned during fitting.
120 ///
121 /// One entry per input column, mirroring scikit-learn's `statistics_`:
122 /// holds `F::nan()` for an all-NaN non-constant column that is dropped, and
123 /// the computed fill statistic (or the user constant) otherwise.
124 fill_values: Array1<F>,
125 /// Input-column indices that survive transform, in ascending order.
126 ///
127 /// Under `Mean`/`Median`/`MostFrequent` an all-NaN column has no observed
128 /// value and is excluded (sklearn `keep_empty_features=False`); under
129 /// `Constant` every column is kept.
130 kept_indices: Vec<usize>,
131}
132
133impl<F: Float + Send + Sync + 'static> FittedSimpleImputer<F> {
134 /// Return the per-input-column fill values learned during fitting.
135 ///
136 /// Mirrors scikit-learn's `statistics_`: entries for all-NaN columns that
137 /// are dropped under `Mean`/`Median`/`MostFrequent` are `F::nan()`.
138 #[must_use]
139 pub fn fill_values(&self) -> &Array1<F> {
140 &self.fill_values
141 }
142
143 /// Return the input-column indices that survive `transform`, ascending.
144 #[must_use]
145 pub fn kept_indices(&self) -> &[usize] {
146 &self.kept_indices
147 }
148}
149
150// ---------------------------------------------------------------------------
151// Helper: compute median of a non-empty Vec (may contain NaN — caller filters)
152// ---------------------------------------------------------------------------
153
154/// Sum a slice using numpy's pairwise-summation algorithm, in `F` precision.
155///
156/// scikit-learn's `Mean` strategy computes the per-column mean via
157/// `np.ma.mean(masked_X, axis=0)` (`sklearn/impute/_base.py:498`), whose
158/// reduction is numpy's pairwise summation over the observed values in the input
159/// dtype. A naive left-to-right fold diverges from this by many ULPs for an
160/// `f32` column. This mirrors numpy's `pairwise_sum`
161/// (`numpy/_core/src/umath/loops_utils.h.src`): blocks of `len > 128` split in
162/// half (with the split rounded down to a multiple of 8), and the `<= 128` base
163/// case accumulates into 8 partial sums (unrolled by 8) before combining them as
164/// a balanced tree `((r0+r1)+(r2+r3)) + ((r4+r5)+(r6+r7))`. For `F = f64`
165/// pairwise and sequential agree to f64 ULPs.
166fn pairwise_sum<F: Float>(values: &[F]) -> F {
167 let n = values.len();
168 if n == 0 {
169 return F::zero();
170 }
171 if n < 8 {
172 // Sequential base case for very short runs (numpy does the same).
173 let mut s = values[0];
174 for &v in &values[1..] {
175 s = s + v;
176 }
177 return s;
178 }
179 if n <= 128 {
180 // Eight partial accumulators, unrolled by 8 (numpy's inner block).
181 let mut r = [F::zero(); 8];
182 r.copy_from_slice(&values[..8]);
183 let mut i = 8;
184 while i + 8 <= n {
185 for j in 0..8 {
186 r[j] = r[j] + values[i + j];
187 }
188 i += 8;
189 }
190 // Balanced-tree combine of the eight partials.
191 let mut res = ((r[0] + r[1]) + (r[2] + r[3])) + ((r[4] + r[5]) + (r[6] + r[7]));
192 // Tail elements (n not a multiple of 8) folded sequentially.
193 for &v in &values[i..] {
194 res = res + v;
195 }
196 return res;
197 }
198 // Recursive split; numpy rounds the half-point down to a multiple of 8.
199 let mut half = n / 2;
200 half -= half % 8;
201 pairwise_sum(&values[..half]) + pairwise_sum(&values[half..])
202}
203
204/// Compute the median of a non-empty slice of finite (non-NaN) values.
205///
206/// Uses a sort-and-interpolate approach. Panics if the slice is empty.
207fn median_of<F: Float>(values: &mut [F]) -> F {
208 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
209 let n = values.len();
210 if n % 2 == 1 {
211 values[n / 2]
212 } else {
213 let mid = n / 2;
214 (values[mid - 1] + values[mid]) / (F::one() + F::one())
215 }
216}
217
218/// Find the most-frequent value in a non-empty slice of finite values.
219///
220/// Ties are broken by choosing the smallest value.
221fn most_frequent_of<F: Float>(values: &[F]) -> F {
222 // Collect (value, count) by scanning; values are finite so partial_cmp is
223 // total.
224 let mut sorted = values.to_vec();
225 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
226
227 let mut best_val = sorted[0];
228 let mut best_count = 1usize;
229 let mut current_val = sorted[0];
230 let mut current_count = 1usize;
231
232 for &v in &sorted[1..] {
233 if v == current_val {
234 current_count += 1;
235 } else {
236 if current_count > best_count {
237 best_count = current_count;
238 best_val = current_val;
239 }
240 current_val = v;
241 current_count = 1;
242 }
243 }
244 // Final run
245 if current_count > best_count {
246 best_val = current_val;
247 }
248 best_val
249}
250
251// ---------------------------------------------------------------------------
252// Trait implementations
253// ---------------------------------------------------------------------------
254
255impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SimpleImputer<F> {
256 type Fitted = FittedSimpleImputer<F>;
257 type Error = FerroError;
258
259 /// Fit the imputer by computing per-column fill values.
260 ///
261 /// NaN values are excluded from the statistic computation. Under
262 /// `Mean`/`Median`/`MostFrequent`, a column that is entirely NaN has no
263 /// observed value: its `fill_values` entry is set to `F::nan()` and it is
264 /// excluded from `kept_indices`, so `transform` DROPS it (mirroring
265 /// scikit-learn `keep_empty_features=False`, `sklearn/impute/_base.py:501,
266 /// 510-512,534-537,586-603`). Under `Constant`, every column is filled
267 /// with the constant and kept (sklearn `:545,583`).
268 ///
269 /// # Errors
270 ///
271 /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
272 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSimpleImputer<F>, FerroError> {
273 let n_samples = x.nrows();
274 if n_samples == 0 {
275 return Err(FerroError::InsufficientSamples {
276 required: 1,
277 actual: 0,
278 context: "SimpleImputer::fit".into(),
279 });
280 }
281
282 let n_features = x.ncols();
283 let mut fill_values = Array1::zeros(n_features);
284 let mut kept_indices: Vec<usize> = Vec::with_capacity(n_features);
285
286 for j in 0..n_features {
287 let col_vals: Vec<F> = x
288 .column(j)
289 .iter()
290 .copied()
291 .filter(|v| !v.is_nan())
292 .collect();
293
294 // Constant fills (and keeps) every column, including all-NaN ones
295 // (sklearn `np.full(X.shape[1], fill_value)`, `_base.py:545,583`).
296 if let ImputeStrategy::Constant(c) = &self.strategy {
297 fill_values[j] = *c;
298 kept_indices.push(j);
299 continue;
300 }
301
302 if col_vals.is_empty() {
303 // All-NaN column with no observed value: sklearn sets
304 // `statistics_=nan` and DROPS it (`_base.py:501,510-512,
305 // 534-537,586-603`).
306 fill_values[j] = F::nan();
307 continue;
308 }
309
310 fill_values[j] = match &self.strategy {
311 ImputeStrategy::Mean => {
312 // sklearn computes the mean via `np.ma.mean(masked_X, axis=0)`
313 // (`sklearn/impute/_base.py:498`). `np.ma.mean` divides
314 // `MaskedArray.sum` by the count of observed (non-masked)
315 // elements; `MaskedArray.sum` does `self.filled(0).sum(axis)`
316 // (`numpy/ma/core.py:5242,5251`), i.e. it sums the FULL-LENGTH
317 // column with masked (NaN) entries set to 0, using numpy's
318 // PAIRWISE summation, then divides by the OBSERVED count. The
319 // fill rounds to `F` only at the transform assignment into the
320 // output array (`:625-635`).
321 //
322 // numpy's pairwise tree shape depends on the FULL array length
323 // and element POSITIONS, so summing the full column (NaN->0) is
324 // NOT bit-equal to summing only the compressed observed values
325 // when NaN is scattered (the zeros sit at different tree
326 // positions, shifting f32 partial sums by a few ULPs). Build the
327 // full-length NaN->0 column and pairwise-sum THAT, then divide by
328 // the observed count, to be bit-identical to `np.ma.mean`.
329 //
330 // For F=f64 pairwise and sequential agree to f64 ULPs, and with
331 // no NaN the full-length and compressed sums are identical (the
332 // #2308 no-NaN pin and the f64 oracle tests guard no-regression).
333 let col_filled: Vec<F> = x
334 .column(j)
335 .iter()
336 .map(|v| if v.is_nan() { F::zero() } else { *v })
337 .collect();
338 let n = F::from(col_vals.len()).unwrap_or_else(F::one);
339 pairwise_sum(&col_filled) / n
340 }
341 ImputeStrategy::Median => {
342 let mut vals = col_vals.clone();
343 median_of(&mut vals)
344 }
345 ImputeStrategy::MostFrequent => most_frequent_of(&col_vals),
346 // Constant handled above.
347 ImputeStrategy::Constant(c) => *c,
348 };
349 kept_indices.push(j);
350 }
351
352 Ok(FittedSimpleImputer {
353 fill_values,
354 kept_indices,
355 })
356 }
357}
358
359impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSimpleImputer<F> {
360 type Output = Array2<F>;
361 type Error = FerroError;
362
363 /// Replace NaN values with the learned fill value, projecting onto the
364 /// columns that survived fitting.
365 ///
366 /// The transform input must have the same number of columns as the fit
367 /// input (the full input width, `fill_values.len()`), matching scikit-learn
368 /// which validates against `statistics_.shape[0]` (`_base.py:573-577`).
369 /// The OUTPUT keeps only [`Self::kept_indices`] columns, in ascending
370 /// order — dropping all-NaN columns under `Mean`/`Median`/`MostFrequent`
371 /// (sklearn `X = X[:, valid_statistics_indexes]`, `_base.py:586-603`).
372 ///
373 /// # Errors
374 ///
375 /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
376 /// match the number of features seen during fitting.
377 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
378 let n_features = self.fill_values.len();
379 if x.ncols() != n_features {
380 return Err(FerroError::ShapeMismatch {
381 expected: vec![x.nrows(), n_features],
382 actual: vec![x.nrows(), x.ncols()],
383 context: "FittedSimpleImputer::transform".into(),
384 });
385 }
386
387 // Gather the surviving columns (the column-projection pattern used
388 // elsewhere, e.g. select_from_model's `select_columns`), imputing NaN
389 // with each column's learned fill value as we go.
390 let mut out = Array2::zeros((x.nrows(), self.kept_indices.len()));
391 for (out_j, &in_j) in self.kept_indices.iter().enumerate() {
392 let fill = self.fill_values[in_j];
393 for (row, &v) in x.column(in_j).iter().enumerate() {
394 out[[row, out_j]] = if v.is_nan() { fill } else { v };
395 }
396 }
397 Ok(out)
398 }
399}
400
401/// Implement `Transform` on the unfitted imputer to satisfy the
402/// `FitTransform: Transform` supertrait bound. Always returns an error.
403impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SimpleImputer<F> {
404 type Output = Array2<F>;
405 type Error = FerroError;
406
407 /// Always returns an error — the imputer must be fitted first.
408 ///
409 /// Use [`Fit::fit`] to produce a [`FittedSimpleImputer`], then call
410 /// [`Transform::transform`] on that.
411 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
412 Err(FerroError::InvalidParameter {
413 name: "SimpleImputer".into(),
414 reason: "imputer must be fitted before calling transform; use fit() first".into(),
415 })
416 }
417}
418
419impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SimpleImputer<F> {
420 type FitError = FerroError;
421
422 /// Fit the imputer on `x` and return the imputed output in one step.
423 ///
424 /// # Errors
425 ///
426 /// Returns an error if fitting fails (e.g. zero rows).
427 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
428 let fitted = self.fit(x, &())?;
429 fitted.transform(x)
430 }
431}
432
433// ---------------------------------------------------------------------------
434// Pipeline integration (generic)
435// ---------------------------------------------------------------------------
436
437impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for SimpleImputer<F> {
438 /// Fit the imputer using the pipeline interface.
439 ///
440 /// The `y` argument is ignored; it exists only for API compatibility.
441 ///
442 /// # Errors
443 ///
444 /// Propagates errors from [`Fit::fit`].
445 fn fit_pipeline(
446 &self,
447 x: &Array2<F>,
448 _y: &Array1<F>,
449 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
450 let fitted = self.fit(x, &())?;
451 Ok(Box::new(fitted))
452 }
453}
454
455impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedSimpleImputer<F> {
456 /// Transform data using the pipeline interface.
457 ///
458 /// # Errors
459 ///
460 /// Propagates errors from [`Transform::transform`].
461 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
462 self.transform(x)
463 }
464}
465
466// ---------------------------------------------------------------------------
467// Tests
468// ---------------------------------------------------------------------------
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use approx::assert_abs_diff_eq;
474 use ndarray::array;
475
476 // ---- Mean strategy -------------------------------------------------------
477
478 #[test]
479 fn test_mean_basic() {
480 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
481 let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
482 let fitted = imputer.fit(&x, &()).unwrap();
483 // Column 0 mean = (1+3+5)/3 = 3.0, column 1 mean = (4+6)/2 = 5.0
484 assert_abs_diff_eq!(fitted.fill_values()[0], 3.0, epsilon = 1e-10);
485 assert_abs_diff_eq!(fitted.fill_values()[1], 5.0, epsilon = 1e-10);
486 let out = fitted.transform(&x).unwrap();
487 assert_abs_diff_eq!(out[[0, 1]], 5.0, epsilon = 1e-10);
488 // Non-NaN values must be untouched
489 assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
490 }
491
492 #[test]
493 fn test_mean_no_nan() {
494 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
495 let x = array![[1.0, 2.0], [3.0, 4.0]];
496 let fitted = imputer.fit(&x, &()).unwrap();
497 let out = fitted.transform(&x).unwrap();
498 // Nothing should change
499 for (a, b) in x.iter().zip(out.iter()) {
500 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
501 }
502 }
503
504 #[test]
505 fn test_mean_multiple_nans_same_column() {
506 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
507 let x = array![[f64::NAN], [f64::NAN], [6.0]];
508 let fitted = imputer.fit(&x, &()).unwrap();
509 assert_abs_diff_eq!(fitted.fill_values()[0], 6.0, epsilon = 1e-10);
510 let out = fitted.transform(&x).unwrap();
511 assert_abs_diff_eq!(out[[0, 0]], 6.0, epsilon = 1e-10);
512 assert_abs_diff_eq!(out[[1, 0]], 6.0, epsilon = 1e-10);
513 }
514
515 #[test]
516 fn test_mean_all_nan_column_dropped() {
517 // sklearn `keep_empty_features=False` (default): an all-NaN column has
518 // no observed value, so `statistics_=nan` and `transform` DROPS it
519 // (`sklearn/impute/_base.py:586-603`). A single all-NaN input column
520 // therefore yields ZERO output columns.
521 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
522 let x = array![[f64::NAN], [f64::NAN]];
523 let fitted = match imputer.fit(&x, &()) {
524 Ok(f) => f,
525 #[allow(
526 clippy::assertions_on_constants,
527 reason = "error arm fails loudly without panic!/unwrap (anti-pattern gate)"
528 )]
529 Err(e) => {
530 assert!(false, "fit errored: {e}");
531 return;
532 }
533 };
534 // statistics_ entry is NaN (mirrors sklearn `statistics_`).
535 assert!(fitted.fill_values()[0].is_nan());
536 match fitted.transform(&x) {
537 Ok(out) => {
538 assert_eq!(out.ncols(), 0, "all-NaN column dropped -> 0 output columns");
539 assert_eq!(out.nrows(), 2);
540 }
541 #[allow(
542 clippy::assertions_on_constants,
543 reason = "error arm fails loudly without panic!/unwrap (anti-pattern gate)"
544 )]
545 Err(e) => assert!(false, "transform errored: {e}"),
546 }
547 }
548
549 // ---- Median strategy ----------------------------------------------------
550
551 #[test]
552 fn test_median_odd_count() {
553 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
554 let x = array![[1.0], [3.0], [5.0], [7.0], [9.0]];
555 let fitted = imputer.fit(&x, &()).unwrap();
556 assert_abs_diff_eq!(fitted.fill_values()[0], 5.0, epsilon = 1e-10);
557 }
558
559 #[test]
560 fn test_median_even_count() {
561 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
562 let x = array![[1.0], [3.0], [5.0], [7.0]];
563 let fitted = imputer.fit(&x, &()).unwrap();
564 // Median of [1,3,5,7] = (3+5)/2 = 4.0
565 assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
566 }
567
568 #[test]
569 fn test_median_with_nan() {
570 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
571 // Column 0: non-NaN values are [2, 4, 6], median = 4
572 let x = array![[2.0], [f64::NAN], [4.0], [6.0]];
573 let fitted = imputer.fit(&x, &()).unwrap();
574 assert_abs_diff_eq!(fitted.fill_values()[0], 4.0, epsilon = 1e-10);
575 let out = fitted.transform(&x).unwrap();
576 assert_abs_diff_eq!(out[[1, 0]], 4.0, epsilon = 1e-10);
577 }
578
579 // ---- MostFrequent strategy ----------------------------------------------
580
581 #[test]
582 fn test_most_frequent_basic() {
583 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
584 let x = array![[1.0], [2.0], [2.0], [3.0]];
585 let fitted = imputer.fit(&x, &()).unwrap();
586 assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
587 }
588
589 #[test]
590 fn test_most_frequent_tie_chooses_smallest() {
591 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
592 // 1.0 and 3.0 each appear twice — smallest wins
593 let x = array![[1.0], [1.0], [3.0], [3.0]];
594 let fitted = imputer.fit(&x, &()).unwrap();
595 assert_abs_diff_eq!(fitted.fill_values()[0], 1.0, epsilon = 1e-10);
596 }
597
598 #[test]
599 fn test_most_frequent_with_nan() {
600 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::MostFrequent);
601 let x = array![[1.0], [f64::NAN], [2.0], [2.0]];
602 let fitted = imputer.fit(&x, &()).unwrap();
603 assert_abs_diff_eq!(fitted.fill_values()[0], 2.0, epsilon = 1e-10);
604 let out = fitted.transform(&x).unwrap();
605 assert_abs_diff_eq!(out[[1, 0]], 2.0, epsilon = 1e-10);
606 }
607
608 // ---- Constant strategy --------------------------------------------------
609
610 #[test]
611 fn test_constant_strategy() {
612 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(-99.0));
613 let x = array![[1.0, f64::NAN], [f64::NAN, 4.0]];
614 let fitted = imputer.fit(&x, &()).unwrap();
615 assert_abs_diff_eq!(fitted.fill_values()[0], -99.0, epsilon = 1e-15);
616 assert_abs_diff_eq!(fitted.fill_values()[1], -99.0, epsilon = 1e-15);
617 let out = fitted.transform(&x).unwrap();
618 assert_abs_diff_eq!(out[[1, 0]], -99.0, epsilon = 1e-15);
619 assert_abs_diff_eq!(out[[0, 1]], -99.0, epsilon = 1e-15);
620 }
621
622 // ---- Error paths --------------------------------------------------------
623
624 #[test]
625 fn test_fit_zero_rows_error() {
626 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
627 let x: Array2<f64> = Array2::zeros((0, 3));
628 assert!(imputer.fit(&x, &()).is_err());
629 }
630
631 #[test]
632 fn test_transform_shape_mismatch_error() {
633 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
634 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
635 let fitted = imputer.fit(&x_train, &()).unwrap();
636 let x_bad = array![[1.0, 2.0, 3.0]];
637 assert!(fitted.transform(&x_bad).is_err());
638 }
639
640 #[test]
641 fn test_unfitted_transform_error() {
642 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
643 let x = array![[1.0, 2.0]];
644 assert!(imputer.transform(&x).is_err());
645 }
646
647 // ---- fit_transform ------------------------------------------------------
648
649 #[test]
650 fn test_fit_transform_equivalence() {
651 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
652 let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
653 let via_fit_transform = imputer.fit_transform(&x).unwrap();
654 let fitted = imputer.fit(&x, &()).unwrap();
655 let via_separate = fitted.transform(&x).unwrap();
656 for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
657 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
658 }
659 }
660
661 // ---- f32 generic --------------------------------------------------------
662
663 #[test]
664 fn test_f32_imputer() {
665 let imputer = SimpleImputer::<f32>::new(ImputeStrategy::Mean);
666 let x: Array2<f32> = array![[1.0f32, f32::NAN], [3.0, 4.0]];
667 let fitted = imputer.fit(&x, &()).unwrap();
668 let out = fitted.transform(&x).unwrap();
669 assert!((out[[0, 1]] - 4.0f32).abs() < 1e-6);
670 }
671
672 // ---- Pipeline integration -----------------------------------------------
673
674 #[test]
675 fn test_pipeline_integration() {
676 use ferrolearn_core::pipeline::PipelineTransformer;
677
678 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Mean);
679 let x = array![[1.0, f64::NAN], [3.0, 4.0]];
680 let y = ndarray::array![0.0, 1.0];
681 let fitted_box = imputer.fit_pipeline(&x, &y).unwrap();
682 let out = fitted_box.transform_pipeline(&x).unwrap();
683 // NaN should be gone
684 assert!(!out[[0, 1]].is_nan());
685 }
686
687 // ---- multiple columns with mixed NaN ------------------------------------
688
689 #[test]
690 fn test_multi_column_mixed_nan() {
691 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Median);
692 let x = array![[f64::NAN, 10.0], [2.0, f64::NAN], [4.0, 30.0], [6.0, 40.0]];
693 let fitted = imputer.fit(&x, &()).unwrap();
694 let out = fitted.transform(&x).unwrap();
695 // Column 0 non-NaN = [2,4,6], median = 4
696 assert_abs_diff_eq!(out[[0, 0]], 4.0, epsilon = 1e-10);
697 // Column 1 non-NaN = [10,30,40], median = 30
698 assert_abs_diff_eq!(out[[1, 1]], 30.0, epsilon = 1e-10);
699 }
700
701 // ---- strategy accessor --------------------------------------------------
702
703 #[test]
704 fn test_strategy_accessor() {
705 let imputer = SimpleImputer::<f64>::new(ImputeStrategy::Constant(42.0));
706 assert_eq!(imputer.strategy(), &ImputeStrategy::Constant(42.0));
707 }
708}