1use ferrolearn_core::error::FerroError;
16use ferrolearn_core::traits::{Fit, FitTransform, Transform};
17use ndarray::{Array1, Array2};
18use num_traits::Float;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum InitialStrategy {
27 Mean,
29 Median,
31}
32
33#[must_use]
64#[derive(Debug, Clone)]
65pub struct IterativeImputer<F> {
66 max_iter: usize,
68 tol: F,
70 initial_strategy: InitialStrategy,
72}
73
74impl<F: Float + Send + Sync + 'static> IterativeImputer<F> {
75 pub fn new(max_iter: usize, tol: F, initial_strategy: InitialStrategy) -> Self {
77 Self {
78 max_iter,
79 tol,
80 initial_strategy,
81 }
82 }
83
84 #[must_use]
86 pub fn max_iter(&self) -> usize {
87 self.max_iter
88 }
89
90 #[must_use]
92 pub fn tol(&self) -> F {
93 self.tol
94 }
95
96 #[must_use]
98 pub fn initial_strategy(&self) -> InitialStrategy {
99 self.initial_strategy
100 }
101}
102
103impl<F: Float + Send + Sync + 'static> Default for IterativeImputer<F> {
104 fn default() -> Self {
105 Self::new(
106 10,
107 F::from(1e-3).unwrap_or(F::epsilon()),
108 InitialStrategy::Mean,
109 )
110 }
111}
112
113#[derive(Debug, Clone)]
122pub struct FittedIterativeImputer<F> {
123 initial_fill: Array1<F>,
125 feature_models: Vec<Option<FeatureModel<F>>>,
129 missing_features: Vec<usize>,
131 n_iter: usize,
133 max_iter: usize,
135 tol: F,
137 initial_strategy: InitialStrategy,
139}
140
141#[derive(Debug, Clone)]
143struct FeatureModel<F> {
144 coefficients: Array1<F>,
146 intercept: F,
148}
149
150impl<F: Float + Send + Sync + 'static> FittedIterativeImputer<F> {
151 #[must_use]
153 pub fn n_iter(&self) -> usize {
154 self.n_iter
155 }
156
157 #[must_use]
159 pub fn initial_fill(&self) -> &Array1<F> {
160 &self.initial_fill
161 }
162
163 #[must_use]
165 pub fn initial_strategy(&self) -> InitialStrategy {
166 self.initial_strategy
167 }
168}
169
170fn column_means_nan<F: Float>(x: &Array2<F>) -> Array1<F> {
176 let n_features = x.ncols();
177 let mut means = Array1::zeros(n_features);
178 for j in 0..n_features {
179 let col = x.column(j);
180 let mut sum = F::zero();
181 let mut count = 0usize;
182 for &v in col.iter() {
183 if !v.is_nan() {
184 sum = sum + v;
185 count += 1;
186 }
187 }
188 means[j] = if count > 0 {
189 sum / F::from(count).unwrap_or(F::one())
190 } else {
191 F::zero()
192 };
193 }
194 means
195}
196
197fn column_medians_nan<F: Float>(x: &Array2<F>) -> Array1<F> {
199 let n_features = x.ncols();
200 let mut medians = Array1::zeros(n_features);
201 for j in 0..n_features {
202 let col = x.column(j);
203 let mut vals: Vec<F> = col.iter().copied().filter(|v| !v.is_nan()).collect();
204 if vals.is_empty() {
205 medians[j] = F::zero();
206 } else {
207 vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
208 let n = vals.len();
209 medians[j] = if n % 2 == 1 {
210 vals[n / 2]
211 } else {
212 (vals[n / 2 - 1] + vals[n / 2]) / (F::one() + F::one())
213 };
214 }
215 }
216 medians
217}
218
219fn initial_fill<F: Float>(x: &Array2<F>, fill: &Array1<F>) -> Array2<F> {
221 let mut out = x.to_owned();
222 for (mut col, &f) in out.columns_mut().into_iter().zip(fill.iter()) {
223 for v in col.iter_mut() {
224 if v.is_nan() {
225 *v = f;
226 }
227 }
228 }
229 out
230}
231
232fn ridge_fit<F: Float>(x: &Array2<F>, y: &Array1<F>, alpha: F) -> Option<(Array1<F>, F)> {
237 let n_samples = x.nrows();
238 let n_features = x.ncols();
239
240 if n_samples == 0 || n_features == 0 {
241 return None;
242 }
243
244 let y_mean =
246 y.iter().copied().fold(F::zero(), |a, v| a + v) / F::from(n_samples).unwrap_or(F::one());
247
248 let mut x_means = Array1::zeros(n_features);
250 for j in 0..n_features {
251 x_means[j] = x.column(j).iter().copied().fold(F::zero(), |a, v| a + v)
252 / F::from(n_samples).unwrap_or(F::one());
253 }
254
255 let mut xtx = Array2::zeros((n_features, n_features));
257 for i in 0..n_features {
258 for j in 0..n_features {
259 let mut s = F::zero();
260 for k in 0..n_samples {
261 s = s + (x[[k, i]] - x_means[i]) * (x[[k, j]] - x_means[j]);
262 }
263 xtx[[i, j]] = s;
264 }
265 xtx[[i, i]] = xtx[[i, i]] + alpha;
266 }
267
268 let mut xty = Array1::zeros(n_features);
270 for i in 0..n_features {
271 let mut s = F::zero();
272 for k in 0..n_samples {
273 s = s + (x[[k, i]] - x_means[i]) * (y[k] - y_mean);
274 }
275 xty[i] = s;
276 }
277
278 let beta = solve_linear_system(&xtx, &xty)?;
280
281 let mut intercept = y_mean;
283 for j in 0..n_features {
284 intercept = intercept - beta[j] * x_means[j];
285 }
286
287 Some((beta, intercept))
288}
289
290fn solve_linear_system<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Option<Array1<F>> {
292 let n = a.nrows();
293 if n != a.ncols() || n != b.len() {
294 return None;
295 }
296 if n == 0 {
297 return Some(Array1::zeros(0));
298 }
299
300 let mut aug = Array2::zeros((n, n + 1));
302 for i in 0..n {
303 for j in 0..n {
304 aug[[i, j]] = a[[i, j]];
305 }
306 aug[[i, n]] = b[i];
307 }
308
309 for col in 0..n {
311 let mut max_row = col;
313 let mut max_val = aug[[col, col]].abs();
314 for row in (col + 1)..n {
315 let val = aug[[row, col]].abs();
316 if val > max_val {
317 max_val = val;
318 max_row = row;
319 }
320 }
321
322 if max_val < F::from(1e-15).unwrap_or(F::min_positive_value()) {
323 return None; }
325
326 if max_row != col {
328 for j in 0..=n {
329 let tmp = aug[[col, j]];
330 aug[[col, j]] = aug[[max_row, j]];
331 aug[[max_row, j]] = tmp;
332 }
333 }
334
335 let pivot = aug[[col, col]];
337 for row in (col + 1)..n {
338 let factor = aug[[row, col]] / pivot;
339 for j in col..=n {
340 let val = aug[[col, j]];
341 aug[[row, j]] = aug[[row, j]] - factor * val;
342 }
343 }
344 }
345
346 let mut x = Array1::zeros(n);
348 for i in (0..n).rev() {
349 let mut sum = aug[[i, n]];
350 for j in (i + 1)..n {
351 sum = sum - aug[[i, j]] * x[j];
352 }
353 let diag = aug[[i, i]];
354 if diag.abs() < F::from(1e-15).unwrap_or(F::min_positive_value()) {
355 return None;
356 }
357 x[i] = sum / diag;
358 }
359
360 Some(x)
361}
362
363fn ridge_predict<F: Float>(x: &Array2<F>, coefficients: &Array1<F>, intercept: F) -> Array1<F> {
365 let n_samples = x.nrows();
366 let mut y = Array1::zeros(n_samples);
367 for i in 0..n_samples {
368 let mut val = intercept;
369 for j in 0..x.ncols() {
370 val = val + coefficients[j] * x[[i, j]];
371 }
372 y[i] = val;
373 }
374 y
375}
376
377impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for IterativeImputer<F> {
382 type Fitted = FittedIterativeImputer<F>;
383 type Error = FerroError;
384
385 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedIterativeImputer<F>, FerroError> {
392 let n_samples = x.nrows();
393 if n_samples == 0 {
394 return Err(FerroError::InsufficientSamples {
395 required: 1,
396 actual: 0,
397 context: "IterativeImputer::fit".into(),
398 });
399 }
400 if self.max_iter == 0 {
401 return Err(FerroError::InvalidParameter {
402 name: "max_iter".into(),
403 reason: "max_iter must be at least 1".into(),
404 });
405 }
406
407 let n_features = x.ncols();
408
409 let fill_values = match self.initial_strategy {
411 InitialStrategy::Mean => column_means_nan(x),
412 InitialStrategy::Median => column_medians_nan(x),
413 };
414
415 let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
417 let mut missing_features = Vec::new();
418 for j in 0..n_features {
419 let mut has_missing = false;
420 for i in 0..n_samples {
421 if x[[i, j]].is_nan() {
422 missing_mask[[i, j]] = true;
423 has_missing = true;
424 }
425 }
426 if has_missing {
427 missing_features.push(j);
428 }
429 }
430
431 let mut imputed = initial_fill(x, &fill_values);
433
434 let alpha = F::one(); let mut n_iter = 0usize;
437 let mut feature_models: Vec<Option<FeatureModel<F>>> =
438 (0..n_features).map(|_| None).collect();
439
440 for iter_idx in 0..self.max_iter {
441 n_iter = iter_idx + 1;
442 let prev_imputed = imputed.clone();
443
444 for &j in &missing_features {
445 let predictor_cols: Vec<usize> = (0..n_features).filter(|&k| k != j).collect();
448 let n_predictors = predictor_cols.len();
449
450 let non_missing_rows: Vec<usize> =
452 (0..n_samples).filter(|&i| !missing_mask[[i, j]]).collect();
453
454 if non_missing_rows.is_empty() || n_predictors == 0 {
455 continue;
456 }
457
458 let n_train = non_missing_rows.len();
460 let mut x_train = Array2::zeros((n_train, n_predictors));
461 let mut y_train = Array1::zeros(n_train);
462 for (row_idx, &i) in non_missing_rows.iter().enumerate() {
463 for (col_idx, &k) in predictor_cols.iter().enumerate() {
464 x_train[[row_idx, col_idx]] = imputed[[i, k]];
465 }
466 y_train[row_idx] = imputed[[i, j]];
467 }
468
469 if let Some((coefficients, intercept)) = ridge_fit(&x_train, &y_train, alpha) {
471 let missing_rows: Vec<usize> =
473 (0..n_samples).filter(|&i| missing_mask[[i, j]]).collect();
474
475 if !missing_rows.is_empty() {
476 let n_missing = missing_rows.len();
477 let mut x_missing = Array2::zeros((n_missing, n_predictors));
478 for (row_idx, &i) in missing_rows.iter().enumerate() {
479 for (col_idx, &k) in predictor_cols.iter().enumerate() {
480 x_missing[[row_idx, col_idx]] = imputed[[i, k]];
481 }
482 }
483
484 let predictions = ridge_predict(&x_missing, &coefficients, intercept);
485 for (row_idx, &i) in missing_rows.iter().enumerate() {
486 imputed[[i, j]] = predictions[row_idx];
487 }
488 }
489
490 feature_models[j] = Some(FeatureModel {
491 coefficients,
492 intercept,
493 });
494 }
495 }
496
497 let mut total_change = F::zero();
499 let mut total_value = F::zero();
500 for &j in &missing_features {
501 for i in 0..n_samples {
502 if missing_mask[[i, j]] {
503 let diff = imputed[[i, j]] - prev_imputed[[i, j]];
504 total_change = total_change + diff * diff;
505 total_value = total_value + imputed[[i, j]] * imputed[[i, j]];
506 }
507 }
508 }
509
510 if total_value > F::zero() {
511 let relative_change = (total_change / total_value).sqrt();
512 if relative_change < self.tol {
513 break;
514 }
515 } else if total_change < self.tol * self.tol {
516 break;
517 }
518 }
519
520 Ok(FittedIterativeImputer {
521 initial_fill: fill_values,
522 feature_models,
523 missing_features,
524 n_iter,
525 max_iter: self.max_iter,
526 tol: self.tol,
527 initial_strategy: self.initial_strategy,
528 })
529 }
530}
531
532impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedIterativeImputer<F> {
533 type Output = Array2<F>;
534 type Error = FerroError;
535
536 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
543 let n_features = self.initial_fill.len();
544 if x.ncols() != n_features {
545 return Err(FerroError::ShapeMismatch {
546 expected: vec![x.nrows(), n_features],
547 actual: vec![x.nrows(), x.ncols()],
548 context: "FittedIterativeImputer::transform".into(),
549 });
550 }
551
552 let n_samples = x.nrows();
553
554 let mut imputed = initial_fill(x, &self.initial_fill);
556
557 let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
559 for j in 0..n_features {
560 for i in 0..n_samples {
561 if x[[i, j]].is_nan() {
562 missing_mask[[i, j]] = true;
563 }
564 }
565 }
566
567 let alpha = F::one();
569 for _iter in 0..self.max_iter {
570 let prev = imputed.clone();
571
572 for &j in &self.missing_features {
573 let predictor_cols: Vec<usize> = (0..n_features).filter(|&k| k != j).collect();
574 let n_predictors = predictor_cols.len();
575
576 if n_predictors == 0 {
577 continue;
578 }
579
580 let model = if let Some(ref m) = self.feature_models[j] {
582 Some((m.coefficients.clone(), m.intercept))
583 } else {
584 let non_missing_rows: Vec<usize> =
586 (0..n_samples).filter(|&i| !missing_mask[[i, j]]).collect();
587 if non_missing_rows.is_empty() {
588 None
589 } else {
590 let n_train = non_missing_rows.len();
591 let mut x_train = Array2::zeros((n_train, n_predictors));
592 let mut y_train = Array1::zeros(n_train);
593 for (row_idx, &i) in non_missing_rows.iter().enumerate() {
594 for (col_idx, &k) in predictor_cols.iter().enumerate() {
595 x_train[[row_idx, col_idx]] = imputed[[i, k]];
596 }
597 y_train[row_idx] = imputed[[i, j]];
598 }
599 ridge_fit(&x_train, &y_train, alpha)
600 }
601 };
602
603 if let Some((coefficients, intercept)) = model {
604 let missing_rows: Vec<usize> =
605 (0..n_samples).filter(|&i| missing_mask[[i, j]]).collect();
606 if !missing_rows.is_empty() {
607 let n_missing = missing_rows.len();
608 let mut x_missing = Array2::zeros((n_missing, n_predictors));
609 for (row_idx, &i) in missing_rows.iter().enumerate() {
610 for (col_idx, &k) in predictor_cols.iter().enumerate() {
611 x_missing[[row_idx, col_idx]] = imputed[[i, k]];
612 }
613 }
614 let predictions = ridge_predict(&x_missing, &coefficients, intercept);
615 for (row_idx, &i) in missing_rows.iter().enumerate() {
616 imputed[[i, j]] = predictions[row_idx];
617 }
618 }
619 }
620 }
621
622 let mut total_change = F::zero();
624 let mut total_value = F::zero();
625 for &j in &self.missing_features {
626 for i in 0..n_samples {
627 if missing_mask[[i, j]] {
628 let diff = imputed[[i, j]] - prev[[i, j]];
629 total_change = total_change + diff * diff;
630 total_value = total_value + imputed[[i, j]] * imputed[[i, j]];
631 }
632 }
633 }
634 if total_value > F::zero() {
635 let relative_change = (total_change / total_value).sqrt();
636 if relative_change < self.tol {
637 break;
638 }
639 } else if total_change < self.tol * self.tol {
640 break;
641 }
642 }
643
644 Ok(imputed)
645 }
646}
647
648impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for IterativeImputer<F> {
651 type Output = Array2<F>;
652 type Error = FerroError;
653
654 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
656 Err(FerroError::InvalidParameter {
657 name: "IterativeImputer".into(),
658 reason: "imputer must be fitted before calling transform; use fit() first".into(),
659 })
660 }
661}
662
663impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for IterativeImputer<F> {
664 type FitError = FerroError;
665
666 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
672 let fitted = self.fit(x, &())?;
673 fitted.transform(x)
674 }
675}
676
677#[cfg(test)]
682mod tests {
683 use super::*;
684 use ndarray::array;
685
686 #[test]
687 fn test_iterative_imputer_basic() {
688 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
689 let x = array![[1.0, 2.0], [3.0, f64::NAN], [f64::NAN, 6.0]];
690 let fitted = imputer.fit(&x, &()).unwrap();
691 let out = fitted.transform(&x).unwrap();
692 for v in out.iter() {
694 assert!(!v.is_nan(), "Output contains NaN");
695 }
696 }
697
698 #[test]
699 fn test_iterative_imputer_no_missing() {
700 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
701 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
702 let fitted = imputer.fit(&x, &()).unwrap();
703 let out = fitted.transform(&x).unwrap();
704 for (a, b) in x.iter().zip(out.iter()) {
705 assert!((a - b).abs() < 1e-10);
706 }
707 }
708
709 #[test]
710 fn test_iterative_imputer_convergence() {
711 let imputer = IterativeImputer::<f64>::new(100, 1e-6, InitialStrategy::Mean);
712 let x = array![
714 [1.0, 2.0],
715 [2.0, 4.0],
716 [3.0, 6.0],
717 [4.0, f64::NAN],
718 [f64::NAN, 10.0]
719 ];
720 let fitted = imputer.fit(&x, &()).unwrap();
721 let out = fitted.transform(&x).unwrap();
722 assert!(
725 (out[[3, 1]] - 8.0).abs() < 2.0,
726 "Expected ~8.0, got {}",
727 out[[3, 1]]
728 );
729 assert!(
731 (out[[4, 0]] - 5.0).abs() < 2.0,
732 "Expected ~5.0, got {}",
733 out[[4, 0]]
734 );
735 }
736
737 #[test]
738 fn test_iterative_imputer_median_strategy() {
739 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Median);
740 let x = array![[1.0, 10.0], [2.0, 20.0], [3.0, f64::NAN]];
741 let fitted = imputer.fit(&x, &()).unwrap();
742 let out = fitted.transform(&x).unwrap();
743 assert!(!out[[2, 1]].is_nan());
744 }
745
746 #[test]
747 fn test_iterative_imputer_fit_transform() {
748 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
749 let x = array![[1.0, 2.0], [3.0, f64::NAN], [f64::NAN, 6.0]];
750 let out = imputer.fit_transform(&x).unwrap();
751 for v in out.iter() {
752 assert!(!v.is_nan());
753 }
754 }
755
756 #[test]
757 fn test_iterative_imputer_zero_rows_error() {
758 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
759 let x: Array2<f64> = Array2::zeros((0, 3));
760 assert!(imputer.fit(&x, &()).is_err());
761 }
762
763 #[test]
764 fn test_iterative_imputer_zero_max_iter_error() {
765 let imputer = IterativeImputer::<f64>::new(0, 1e-3, InitialStrategy::Mean);
766 let x = array![[1.0, 2.0]];
767 assert!(imputer.fit(&x, &()).is_err());
768 }
769
770 #[test]
771 fn test_iterative_imputer_shape_mismatch_error() {
772 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
773 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
774 let fitted = imputer.fit(&x_train, &()).unwrap();
775 let x_bad = array![[1.0, 2.0, 3.0]];
776 assert!(fitted.transform(&x_bad).is_err());
777 }
778
779 #[test]
780 fn test_iterative_imputer_unfitted_transform_error() {
781 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
782 let x = array![[1.0, 2.0]];
783 assert!(imputer.transform(&x).is_err());
784 }
785
786 #[test]
787 fn test_iterative_imputer_default() {
788 let imputer = IterativeImputer::<f64>::default();
789 assert_eq!(imputer.max_iter(), 10);
790 assert_eq!(imputer.initial_strategy(), InitialStrategy::Mean);
791 }
792
793 #[test]
794 fn test_iterative_imputer_n_iter_accessor() {
795 let imputer = IterativeImputer::<f64>::new(10, 1e-3, InitialStrategy::Mean);
796 let x = array![[1.0, 2.0], [3.0, f64::NAN]];
797 let fitted = imputer.fit(&x, &()).unwrap();
798 assert!(fitted.n_iter() > 0);
799 assert!(fitted.n_iter() <= 10);
800 }
801
802 #[test]
803 fn test_iterative_imputer_f32() {
804 let imputer = IterativeImputer::<f32>::new(10, 1e-3, InitialStrategy::Mean);
805 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, f32::NAN]];
806 let fitted = imputer.fit(&x, &()).unwrap();
807 let out = fitted.transform(&x).unwrap();
808 assert!(!out[[1, 1]].is_nan());
809 }
810}