Skip to main content

ferrolearn_preprocess/
iterative_imputer.rs

1//! Iterative imputer: fill missing values by modeling each feature as a function
2//! of all other features.
3//!
4//! [`IterativeImputer`] performs round-robin imputation: for each feature with
5//! missing values, it fits a simple Ridge regression on the non-missing rows
6//! using the other features as predictors, then predicts the missing values.
7//! This process is repeated for `max_iter` iterations or until convergence.
8//!
9//! # Initial Imputation
10//!
11//! Before the iterative process begins, missing values are filled using a simple
12//! strategy (mean by default). This initial imputation provides a starting point
13//! for the regression models.
14
15use ferrolearn_core::error::FerroError;
16use ferrolearn_core::traits::{Fit, FitTransform, Transform};
17use ndarray::{Array1, Array2};
18use num_traits::Float;
19
20// ---------------------------------------------------------------------------
21// InitialStrategy
22// ---------------------------------------------------------------------------
23
24/// Strategy for the initial imputation before iterative refinement.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum InitialStrategy {
27    /// Replace NaN with the column mean.
28    Mean,
29    /// Replace NaN with the column median.
30    Median,
31}
32
33// ---------------------------------------------------------------------------
34// IterativeImputer (unfitted)
35// ---------------------------------------------------------------------------
36
37/// An unfitted iterative imputer.
38///
39/// Calling [`Fit::fit`] learns the imputation model and returns a
40/// [`FittedIterativeImputer`] that can impute missing values in new data.
41///
42/// # Parameters
43///
44/// - `max_iter` — maximum number of imputation rounds (default 10).
45/// - `tol` — convergence tolerance on the total change in imputed values
46///   (default 1e-3).
47/// - `initial_strategy` — strategy for the initial fill (default `Mean`).
48///
49/// # Examples
50///
51/// ```
52/// use ferrolearn_preprocess::iterative_imputer::{IterativeImputer, InitialStrategy};
53/// use ferrolearn_core::traits::{Fit, Transform};
54/// use ndarray::array;
55///
56/// let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
57/// let x = array![[1.0, 2.0], [3.0, f64::NAN], [f64::NAN, 6.0]];
58/// let fitted = imputer.fit(&x, &()).unwrap();
59/// let out = fitted.transform(&x).unwrap();
60/// assert!(!out[[1, 1]].is_nan());
61/// assert!(!out[[2, 0]].is_nan());
62/// ```
63#[must_use]
64#[derive(Debug, Clone)]
65pub struct IterativeImputer<F> {
66    /// Maximum number of imputation rounds.
67    max_iter: usize,
68    /// Convergence tolerance.
69    tol: F,
70    /// Initial imputation strategy.
71    initial_strategy: InitialStrategy,
72}
73
74impl<F: Float + Send + Sync + 'static> IterativeImputer<F> {
75    /// Create a new `IterativeImputer` with the given parameters.
76    pub fn new(max_iter: usize, tol: F, initial_strategy: InitialStrategy) -> Self {
77        Self {
78            max_iter,
79            tol,
80            initial_strategy,
81        }
82    }
83
84    /// Return the maximum number of iterations.
85    #[must_use]
86    pub fn max_iter(&self) -> usize {
87        self.max_iter
88    }
89
90    /// Return the convergence tolerance.
91    #[must_use]
92    pub fn tol(&self) -> F {
93        self.tol
94    }
95
96    /// Return the initial imputation strategy.
97    #[must_use]
98    pub fn initial_strategy(&self) -> InitialStrategy {
99        self.initial_strategy
100    }
101}
102
103impl<F: Float + Send + Sync + 'static> Default for IterativeImputer<F> {
104    fn default() -> Self {
105        Self::new(
106            10,
107            F::from(1e-3).unwrap_or(F::epsilon()),
108            InitialStrategy::Mean,
109        )
110    }
111}
112
113// ---------------------------------------------------------------------------
114// FittedIterativeImputer
115// ---------------------------------------------------------------------------
116
117/// A fitted iterative imputer that stores per-feature Ridge regression
118/// coefficients learned during fitting.
119///
120/// Created by calling [`Fit::fit`] on an [`IterativeImputer`].
121#[derive(Debug, Clone)]
122pub struct FittedIterativeImputer<F> {
123    /// Per-feature initial fill values (used for initial imputation of transform data).
124    initial_fill: Array1<F>,
125    /// Per-feature Ridge coefficients: `coefs[j]` is the coefficient vector
126    /// for predicting feature `j` from the other features.
127    /// Only stored for features that had missing values during training.
128    feature_models: Vec<Option<FeatureModel<F>>>,
129    /// Indices of features that had missing values during training.
130    missing_features: Vec<usize>,
131    /// Number of iterations that were performed during fitting.
132    n_iter: usize,
133    /// Maximum iterations.
134    max_iter: usize,
135    /// Convergence tolerance.
136    tol: F,
137    /// Initial strategy.
138    initial_strategy: InitialStrategy,
139}
140
141/// Ridge regression model for a single feature.
142#[derive(Debug, Clone)]
143struct FeatureModel<F> {
144    /// Coefficients (one per predictor feature).
145    coefficients: Array1<F>,
146    /// Intercept.
147    intercept: F,
148}
149
150impl<F: Float + Send + Sync + 'static> FittedIterativeImputer<F> {
151    /// Return the number of iterations performed during fitting.
152    #[must_use]
153    pub fn n_iter(&self) -> usize {
154        self.n_iter
155    }
156
157    /// Return the initial fill values per feature.
158    #[must_use]
159    pub fn initial_fill(&self) -> &Array1<F> {
160        &self.initial_fill
161    }
162
163    /// Return the initial imputation strategy used during fitting.
164    #[must_use]
165    pub fn initial_strategy(&self) -> InitialStrategy {
166        self.initial_strategy
167    }
168}
169
170// ---------------------------------------------------------------------------
171// Helpers
172// ---------------------------------------------------------------------------
173
174/// Compute column means, ignoring NaN values.
175fn column_means_nan<F: Float>(x: &Array2<F>) -> Array1<F> {
176    let n_features = x.ncols();
177    let mut means = Array1::zeros(n_features);
178    for j in 0..n_features {
179        let col = x.column(j);
180        let mut sum = F::zero();
181        let mut count = 0usize;
182        for &v in col.iter() {
183            if !v.is_nan() {
184                sum = sum + v;
185                count += 1;
186            }
187        }
188        means[j] = if count > 0 {
189            sum / F::from(count).unwrap_or(F::one())
190        } else {
191            F::zero()
192        };
193    }
194    means
195}
196
197/// Compute column medians, ignoring NaN values.
198fn column_medians_nan<F: Float>(x: &Array2<F>) -> Array1<F> {
199    let n_features = x.ncols();
200    let mut medians = Array1::zeros(n_features);
201    for j in 0..n_features {
202        let col = x.column(j);
203        let mut vals: Vec<F> = col.iter().copied().filter(|v| !v.is_nan()).collect();
204        if vals.is_empty() {
205            medians[j] = F::zero();
206        } else {
207            vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
208            let n = vals.len();
209            medians[j] = if n % 2 == 1 {
210                vals[n / 2]
211            } else {
212                (vals[n / 2 - 1] + vals[n / 2]) / (F::one() + F::one())
213            };
214        }
215    }
216    medians
217}
218
219/// Fill NaN values in a matrix with the given fill values.
220fn initial_fill<F: Float>(x: &Array2<F>, fill: &Array1<F>) -> Array2<F> {
221    let mut out = x.to_owned();
222    for (mut col, &f) in out.columns_mut().into_iter().zip(fill.iter()) {
223        for v in col.iter_mut() {
224            if v.is_nan() {
225                *v = f;
226            }
227        }
228    }
229    out
230}
231
232/// Fit a simple Ridge regression: y = X * beta + intercept.
233/// Uses the closed-form solution: beta = (X^T X + alpha * I)^{-1} X^T y.
234///
235/// For simplicity we solve this using a small linear system solver.
236fn ridge_fit<F: Float>(x: &Array2<F>, y: &Array1<F>, alpha: F) -> Option<(Array1<F>, F)> {
237    let n_samples = x.nrows();
238    let n_features = x.ncols();
239
240    if n_samples == 0 || n_features == 0 {
241        return None;
242    }
243
244    // Center y
245    let y_mean =
246        y.iter().copied().fold(F::zero(), |a, v| a + v) / F::from(n_samples).unwrap_or(F::one());
247
248    // Center X
249    let mut x_means = Array1::zeros(n_features);
250    for j in 0..n_features {
251        x_means[j] = x.column(j).iter().copied().fold(F::zero(), |a, v| a + v)
252            / F::from(n_samples).unwrap_or(F::one());
253    }
254
255    // Compute X^T X + alpha * I (n_features x n_features)
256    let mut xtx = Array2::zeros((n_features, n_features));
257    for i in 0..n_features {
258        for j in 0..n_features {
259            let mut s = F::zero();
260            for k in 0..n_samples {
261                s = s + (x[[k, i]] - x_means[i]) * (x[[k, j]] - x_means[j]);
262            }
263            xtx[[i, j]] = s;
264        }
265        xtx[[i, i]] = xtx[[i, i]] + alpha;
266    }
267
268    // Compute X^T y (n_features)
269    let mut xty = Array1::zeros(n_features);
270    for i in 0..n_features {
271        let mut s = F::zero();
272        for k in 0..n_samples {
273            s = s + (x[[k, i]] - x_means[i]) * (y[k] - y_mean);
274        }
275        xty[i] = s;
276    }
277
278    // Solve xtx * beta = xty using Cholesky-like approach (simple Gaussian elimination)
279    let beta = solve_linear_system(&xtx, &xty)?;
280
281    // Compute intercept
282    let mut intercept = y_mean;
283    for j in 0..n_features {
284        intercept = intercept - beta[j] * x_means[j];
285    }
286
287    Some((beta, intercept))
288}
289
290/// Solve A * x = b using Gaussian elimination with partial pivoting.
291fn solve_linear_system<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Option<Array1<F>> {
292    let n = a.nrows();
293    if n != a.ncols() || n != b.len() {
294        return None;
295    }
296    if n == 0 {
297        return Some(Array1::zeros(0));
298    }
299
300    // Augmented matrix
301    let mut aug = Array2::zeros((n, n + 1));
302    for i in 0..n {
303        for j in 0..n {
304            aug[[i, j]] = a[[i, j]];
305        }
306        aug[[i, n]] = b[i];
307    }
308
309    // Forward elimination with partial pivoting
310    for col in 0..n {
311        // Find pivot
312        let mut max_row = col;
313        let mut max_val = aug[[col, col]].abs();
314        for row in (col + 1)..n {
315            let val = aug[[row, col]].abs();
316            if val > max_val {
317                max_val = val;
318                max_row = row;
319            }
320        }
321
322        if max_val < F::from(1e-15).unwrap_or(F::min_positive_value()) {
323            return None; // Singular matrix
324        }
325
326        // Swap rows
327        if max_row != col {
328            for j in 0..=n {
329                let tmp = aug[[col, j]];
330                aug[[col, j]] = aug[[max_row, j]];
331                aug[[max_row, j]] = tmp;
332            }
333        }
334
335        // Eliminate below
336        let pivot = aug[[col, col]];
337        for row in (col + 1)..n {
338            let factor = aug[[row, col]] / pivot;
339            for j in col..=n {
340                let val = aug[[col, j]];
341                aug[[row, j]] = aug[[row, j]] - factor * val;
342            }
343        }
344    }
345
346    // Back substitution
347    let mut x = Array1::zeros(n);
348    for i in (0..n).rev() {
349        let mut sum = aug[[i, n]];
350        for j in (i + 1)..n {
351            sum = sum - aug[[i, j]] * x[j];
352        }
353        let diag = aug[[i, i]];
354        if diag.abs() < F::from(1e-15).unwrap_or(F::min_positive_value()) {
355            return None;
356        }
357        x[i] = sum / diag;
358    }
359
360    Some(x)
361}
362
363/// Predict using a Ridge model.
364fn ridge_predict<F: Float>(x: &Array2<F>, coefficients: &Array1<F>, intercept: F) -> Array1<F> {
365    let n_samples = x.nrows();
366    let mut y = Array1::zeros(n_samples);
367    for i in 0..n_samples {
368        let mut val = intercept;
369        for j in 0..x.ncols() {
370            val = val + coefficients[j] * x[[i, j]];
371        }
372        y[i] = val;
373    }
374    y
375}
376
377// ---------------------------------------------------------------------------
378// Trait implementations
379// ---------------------------------------------------------------------------
380
381impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for IterativeImputer<F> {
382    type Fitted = FittedIterativeImputer<F>;
383    type Error = FerroError;
384
385    /// Fit the iterative imputer by performing round-robin Ridge regression.
386    ///
387    /// # Errors
388    ///
389    /// - [`FerroError::InsufficientSamples`] if the input has zero rows.
390    /// - [`FerroError::InvalidParameter`] if `max_iter` is zero.
391    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedIterativeImputer<F>, FerroError> {
392        let n_samples = x.nrows();
393        if n_samples == 0 {
394            return Err(FerroError::InsufficientSamples {
395                required: 1,
396                actual: 0,
397                context: "IterativeImputer::fit".into(),
398            });
399        }
400        if self.max_iter == 0 {
401            return Err(FerroError::InvalidParameter {
402                name: "max_iter".into(),
403                reason: "max_iter must be at least 1".into(),
404            });
405        }
406
407        let n_features = x.ncols();
408
409        // Compute initial fill values
410        let fill_values = match self.initial_strategy {
411            InitialStrategy::Mean => column_means_nan(x),
412            InitialStrategy::Median => column_medians_nan(x),
413        };
414
415        // Create mask of missing values
416        let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
417        let mut missing_features = Vec::new();
418        for j in 0..n_features {
419            let mut has_missing = false;
420            for i in 0..n_samples {
421                if x[[i, j]].is_nan() {
422                    missing_mask[[i, j]] = true;
423                    has_missing = true;
424                }
425            }
426            if has_missing {
427                missing_features.push(j);
428            }
429        }
430
431        // Initial imputation
432        let mut imputed = initial_fill(x, &fill_values);
433
434        // Iterative refinement
435        let alpha = F::one(); // Ridge alpha
436        let mut n_iter = 0usize;
437        let mut feature_models: Vec<Option<FeatureModel<F>>> =
438            (0..n_features).map(|_| None).collect();
439
440        for iter_idx in 0..self.max_iter {
441            n_iter = iter_idx + 1;
442            let prev_imputed = imputed.clone();
443
444            for &j in &missing_features {
445                // Build predictor matrix (all features except j) and target (feature j)
446                // Only use rows where feature j is NOT missing
447                let predictor_cols: Vec<usize> = (0..n_features).filter(|&k| k != j).collect();
448                let n_predictors = predictor_cols.len();
449
450                // Collect non-missing rows for feature j
451                let non_missing_rows: Vec<usize> =
452                    (0..n_samples).filter(|&i| !missing_mask[[i, j]]).collect();
453
454                if non_missing_rows.is_empty() || n_predictors == 0 {
455                    continue;
456                }
457
458                // Build X_train and y_train
459                let n_train = non_missing_rows.len();
460                let mut x_train = Array2::zeros((n_train, n_predictors));
461                let mut y_train = Array1::zeros(n_train);
462                for (row_idx, &i) in non_missing_rows.iter().enumerate() {
463                    for (col_idx, &k) in predictor_cols.iter().enumerate() {
464                        x_train[[row_idx, col_idx]] = imputed[[i, k]];
465                    }
466                    y_train[row_idx] = imputed[[i, j]];
467                }
468
469                // Fit Ridge regression
470                if let Some((coefficients, intercept)) = ridge_fit(&x_train, &y_train, alpha) {
471                    // Predict for missing rows
472                    let missing_rows: Vec<usize> =
473                        (0..n_samples).filter(|&i| missing_mask[[i, j]]).collect();
474
475                    if !missing_rows.is_empty() {
476                        let n_missing = missing_rows.len();
477                        let mut x_missing = Array2::zeros((n_missing, n_predictors));
478                        for (row_idx, &i) in missing_rows.iter().enumerate() {
479                            for (col_idx, &k) in predictor_cols.iter().enumerate() {
480                                x_missing[[row_idx, col_idx]] = imputed[[i, k]];
481                            }
482                        }
483
484                        let predictions = ridge_predict(&x_missing, &coefficients, intercept);
485                        for (row_idx, &i) in missing_rows.iter().enumerate() {
486                            imputed[[i, j]] = predictions[row_idx];
487                        }
488                    }
489
490                    feature_models[j] = Some(FeatureModel {
491                        coefficients,
492                        intercept,
493                    });
494                }
495            }
496
497            // Check convergence
498            let mut total_change = F::zero();
499            let mut total_value = F::zero();
500            for &j in &missing_features {
501                for i in 0..n_samples {
502                    if missing_mask[[i, j]] {
503                        let diff = imputed[[i, j]] - prev_imputed[[i, j]];
504                        total_change = total_change + diff * diff;
505                        total_value = total_value + imputed[[i, j]] * imputed[[i, j]];
506                    }
507                }
508            }
509
510            if total_value > F::zero() {
511                let relative_change = (total_change / total_value).sqrt();
512                if relative_change < self.tol {
513                    break;
514                }
515            } else if total_change < self.tol * self.tol {
516                break;
517            }
518        }
519
520        Ok(FittedIterativeImputer {
521            initial_fill: fill_values,
522            feature_models,
523            missing_features,
524            n_iter,
525            max_iter: self.max_iter,
526            tol: self.tol,
527            initial_strategy: self.initial_strategy,
528        })
529    }
530}
531
532impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedIterativeImputer<F> {
533    type Output = Array2<F>;
534    type Error = FerroError;
535
536    /// Impute missing values in `x` using the learned feature models.
537    ///
538    /// # Errors
539    ///
540    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
541    /// match the training data.
542    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
543        let n_features = self.initial_fill.len();
544        if x.ncols() != n_features {
545            return Err(FerroError::ShapeMismatch {
546                expected: vec![x.nrows(), n_features],
547                actual: vec![x.nrows(), x.ncols()],
548                context: "FittedIterativeImputer::transform".into(),
549            });
550        }
551
552        let n_samples = x.nrows();
553
554        // Initial imputation
555        let mut imputed = initial_fill(x, &self.initial_fill);
556
557        // Create missing mask
558        let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
559        for j in 0..n_features {
560            for i in 0..n_samples {
561                if x[[i, j]].is_nan() {
562                    missing_mask[[i, j]] = true;
563                }
564            }
565        }
566
567        // Apply iterative imputation using learned models
568        let alpha = F::one();
569        for _iter in 0..self.max_iter {
570            let prev = imputed.clone();
571
572            for &j in &self.missing_features {
573                let predictor_cols: Vec<usize> = (0..n_features).filter(|&k| k != j).collect();
574                let n_predictors = predictor_cols.len();
575
576                if n_predictors == 0 {
577                    continue;
578                }
579
580                // Use the stored model if available, otherwise re-fit on non-missing data
581                let model = if let Some(ref m) = self.feature_models[j] {
582                    Some((m.coefficients.clone(), m.intercept))
583                } else {
584                    // Fallback: fit on non-missing rows of transform data
585                    let non_missing_rows: Vec<usize> =
586                        (0..n_samples).filter(|&i| !missing_mask[[i, j]]).collect();
587                    if non_missing_rows.is_empty() {
588                        None
589                    } else {
590                        let n_train = non_missing_rows.len();
591                        let mut x_train = Array2::zeros((n_train, n_predictors));
592                        let mut y_train = Array1::zeros(n_train);
593                        for (row_idx, &i) in non_missing_rows.iter().enumerate() {
594                            for (col_idx, &k) in predictor_cols.iter().enumerate() {
595                                x_train[[row_idx, col_idx]] = imputed[[i, k]];
596                            }
597                            y_train[row_idx] = imputed[[i, j]];
598                        }
599                        ridge_fit(&x_train, &y_train, alpha)
600                    }
601                };
602
603                if let Some((coefficients, intercept)) = model {
604                    let missing_rows: Vec<usize> =
605                        (0..n_samples).filter(|&i| missing_mask[[i, j]]).collect();
606                    if !missing_rows.is_empty() {
607                        let n_missing = missing_rows.len();
608                        let mut x_missing = Array2::zeros((n_missing, n_predictors));
609                        for (row_idx, &i) in missing_rows.iter().enumerate() {
610                            for (col_idx, &k) in predictor_cols.iter().enumerate() {
611                                x_missing[[row_idx, col_idx]] = imputed[[i, k]];
612                            }
613                        }
614                        let predictions = ridge_predict(&x_missing, &coefficients, intercept);
615                        for (row_idx, &i) in missing_rows.iter().enumerate() {
616                            imputed[[i, j]] = predictions[row_idx];
617                        }
618                    }
619                }
620            }
621
622            // Check convergence
623            let mut total_change = F::zero();
624            let mut total_value = F::zero();
625            for &j in &self.missing_features {
626                for i in 0..n_samples {
627                    if missing_mask[[i, j]] {
628                        let diff = imputed[[i, j]] - prev[[i, j]];
629                        total_change = total_change + diff * diff;
630                        total_value = total_value + imputed[[i, j]] * imputed[[i, j]];
631                    }
632                }
633            }
634            if total_value > F::zero() {
635                let relative_change = (total_change / total_value).sqrt();
636                if relative_change < self.tol {
637                    break;
638                }
639            } else if total_change < self.tol * self.tol {
640                break;
641            }
642        }
643
644        Ok(imputed)
645    }
646}
647
648/// Implement `Transform` on the unfitted imputer to satisfy the
649/// `FitTransform: Transform` supertrait bound.
650impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for IterativeImputer<F> {
651    type Output = Array2<F>;
652    type Error = FerroError;
653
654    /// Always returns an error — the imputer must be fitted first.
655    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
656        Err(FerroError::InvalidParameter {
657            name: "IterativeImputer".into(),
658            reason: "imputer must be fitted before calling transform; use fit() first".into(),
659        })
660    }
661}
662
663impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for IterativeImputer<F> {
664    type FitError = FerroError;
665
666    /// Fit the imputer on `x` and return the imputed output in one step.
667    ///
668    /// # Errors
669    ///
670    /// Returns an error if fitting fails.
671    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
672        let fitted = self.fit(x, &())?;
673        fitted.transform(x)
674    }
675}
676
677// ---------------------------------------------------------------------------
678// Tests
679// ---------------------------------------------------------------------------
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use ndarray::array;
685
686    #[test]
687    fn test_iterative_imputer_basic() {
688        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
689        let x = array![[1.0, 2.0], [3.0, f64::NAN], [f64::NAN, 6.0]];
690        let fitted = imputer.fit(&x, &()).unwrap();
691        let out = fitted.transform(&x).unwrap();
692        // All values should be non-NaN
693        for v in out.iter() {
694            assert!(!v.is_nan(), "Output contains NaN");
695        }
696    }
697
698    #[test]
699    fn test_iterative_imputer_no_missing() {
700        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
701        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
702        let fitted = imputer.fit(&x, &()).unwrap();
703        let out = fitted.transform(&x).unwrap();
704        for (a, b) in x.iter().zip(out.iter()) {
705            assert!((a - b).abs() < 1e-10);
706        }
707    }
708
709    #[test]
710    fn test_iterative_imputer_convergence() {
711        let imputer = IterativeImputer::<f64>::new(100, 1e-6, InitialStrategy::Mean);
712        // Correlated features: feature 1 ≈ 2 * feature 0
713        let x = array![
714            [1.0, 2.0],
715            [2.0, 4.0],
716            [3.0, 6.0],
717            [4.0, f64::NAN],
718            [f64::NAN, 10.0]
719        ];
720        let fitted = imputer.fit(&x, &()).unwrap();
721        let out = fitted.transform(&x).unwrap();
722        // Check that imputed values are reasonable
723        // Feature 1 of row 3 should be close to 8.0 (2 * 4.0)
724        assert!(
725            (out[[3, 1]] - 8.0).abs() < 2.0,
726            "Expected ~8.0, got {}",
727            out[[3, 1]]
728        );
729        // Feature 0 of row 4 should be close to 5.0 (10.0 / 2)
730        assert!(
731            (out[[4, 0]] - 5.0).abs() < 2.0,
732            "Expected ~5.0, got {}",
733            out[[4, 0]]
734        );
735    }
736
737    #[test]
738    fn test_iterative_imputer_median_strategy() {
739        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Median);
740        let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, f64::NAN]];
741        let fitted = imputer.fit(&x, &()).unwrap();
742        let out = fitted.transform(&x).unwrap();
743        assert!(!out[[2, 1]].is_nan());
744    }
745
746    #[test]
747    fn test_iterative_imputer_fit_transform() {
748        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
749        let x = array![[1.0, 2.0], [3.0, f64::NAN], [f64::NAN, 6.0]];
750        let out = imputer.fit_transform(&x).unwrap();
751        for v in out.iter() {
752            assert!(!v.is_nan());
753        }
754    }
755
756    #[test]
757    fn test_iterative_imputer_zero_rows_error() {
758        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
759        let x: Array2<f64> = Array2::zeros((0, 3));
760        assert!(imputer.fit(&x, &()).is_err());
761    }
762
763    #[test]
764    fn test_iterative_imputer_zero_max_iter_error() {
765        let imputer = IterativeImputer::<f64>::new(0, 1e-3, InitialStrategy::Mean);
766        let x = array![[1.0, 2.0]];
767        assert!(imputer.fit(&x, &()).is_err());
768    }
769
770    #[test]
771    fn test_iterative_imputer_shape_mismatch_error() {
772        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
773        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
774        let fitted = imputer.fit(&x_train, &()).unwrap();
775        let x_bad = array![[1.0, 2.0, 3.0]];
776        assert!(fitted.transform(&x_bad).is_err());
777    }
778
779    #[test]
780    fn test_iterative_imputer_unfitted_transform_error() {
781        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
782        let x = array![[1.0, 2.0]];
783        assert!(imputer.transform(&x).is_err());
784    }
785
786    #[test]
787    fn test_iterative_imputer_default() {
788        let imputer = IterativeImputer::<f64>::default();
789        assert_eq!(imputer.max_iter(), 10);
790        assert_eq!(imputer.initial_strategy(), InitialStrategy::Mean);
791    }
792
793    #[test]
794    fn test_iterative_imputer_n_iter_accessor() {
795        let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
796        let x = array![[1.0, 2.0], [3.0, f64::NAN]];
797        let fitted = imputer.fit(&x, &()).unwrap();
798        assert!(fitted.n_iter() > 0);
799        assert!(fitted.n_iter() <= 10);
800    }
801
802    #[test]
803    fn test_iterative_imputer_f32() {
804        let imputer = IterativeImputer::<f32>::new(10, 1e-3, InitialStrategy::Mean);
805        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, f32::NAN]];
806        let fitted = imputer.fit(&x, &()).unwrap();
807        let out = fitted.transform(&x).unwrap();
808        assert!(!out[[1, 1]].is_nan());
809    }
810}