Skip to main content

ferrolearn_preprocess/
knn_imputer.rs

1//! KNN imputer: fill missing (NaN) values using K-nearest neighbors.
2//!
3//! [`KNNImputer`] replaces each missing value by computing the weighted average
4//! of the corresponding feature from the `k` nearest non-missing neighbors.
5//! Distance is computed only on the features that are non-missing in both the
6//! query row and the candidate row (partial Euclidean distance).
7//!
8//! # Weighting
9//!
10//! - [`KNNWeights::Uniform`] — all neighbors contribute equally.
11//! - [`KNNWeights::Distance`] — neighbors are weighted by the inverse of their
12//!   distance (closer neighbors contribute more).
13
14use ferrolearn_core::error::FerroError;
15use ferrolearn_core::traits::{Fit, FitTransform, Transform};
16use ndarray::Array2;
17use num_traits::Float;
18
19// ---------------------------------------------------------------------------
20// KNNWeights
21// ---------------------------------------------------------------------------
22
23/// Weighting strategy for k-nearest neighbor imputation.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum KNNWeights {
26    /// All neighbors contribute equally.
27    Uniform,
28    /// Neighbors contribute proportionally to the inverse of their distance.
29    Distance,
30}
31
32// ---------------------------------------------------------------------------
33// KNNImputer (unfitted)
34// ---------------------------------------------------------------------------
35
36/// An unfitted KNN imputer.
37///
38/// Calling [`Fit::fit`] stores the training data and returns a
39/// [`FittedKNNImputer`] that can impute missing values in new data.
40///
41/// # Parameters
42///
43/// - `n_neighbors` — number of nearest neighbors to use (default 5).
44/// - `weights` — how to weight neighbor contributions (default `Uniform`).
45///
46/// # Examples
47///
48/// ```
49/// use ferrolearn_preprocess::knn_imputer::{KNNImputer, KNNWeights};
50/// use ferrolearn_core::traits::{Fit, Transform};
51/// use ndarray::array;
52///
53/// let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
54/// let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, f64::NAN]];
55/// let fitted = imputer.fit(&x, &()).unwrap();
56/// let out = fitted.transform(&x).unwrap();
57/// assert!(!out[[2, 1]].is_nan());
58/// ```
59#[must_use]
60#[derive(Debug, Clone)]
61pub struct KNNImputer<F> {
62    /// Number of nearest neighbors to use.
63    n_neighbors: usize,
64    /// Weighting strategy.
65    weights: KNNWeights,
66    _marker: std::marker::PhantomData<F>,
67}
68
69impl<F: Float + Send + Sync + 'static> KNNImputer<F> {
70    /// Create a new `KNNImputer` with the given parameters.
71    pub fn new(n_neighbors: usize, weights: KNNWeights) -> Self {
72        Self {
73            n_neighbors,
74            weights,
75            _marker: std::marker::PhantomData,
76        }
77    }
78
79    /// Return the number of neighbors.
80    #[must_use]
81    pub fn n_neighbors(&self) -> usize {
82        self.n_neighbors
83    }
84
85    /// Return the weighting strategy.
86    #[must_use]
87    pub fn weights(&self) -> KNNWeights {
88        self.weights
89    }
90}
91
92impl<F: Float + Send + Sync + 'static> Default for KNNImputer<F> {
93    fn default() -> Self {
94        Self::new(5, KNNWeights::Uniform)
95    }
96}
97
98// ---------------------------------------------------------------------------
99// FittedKNNImputer
100// ---------------------------------------------------------------------------
101
102/// A fitted KNN imputer holding the training data used for neighbor lookup.
103///
104/// Created by calling [`Fit::fit`] on a [`KNNImputer`].
105#[derive(Debug, Clone)]
106pub struct FittedKNNImputer<F> {
107    /// The training data (used for neighbor lookup).
108    train_data: Array2<F>,
109    /// Number of neighbors.
110    n_neighbors: usize,
111    /// Weighting strategy.
112    weights: KNNWeights,
113}
114
115impl<F: Float + Send + Sync + 'static> FittedKNNImputer<F> {
116    /// Return the number of training samples.
117    #[must_use]
118    pub fn n_train_samples(&self) -> usize {
119        self.train_data.nrows()
120    }
121}
122
123// ---------------------------------------------------------------------------
124// Helpers
125// ---------------------------------------------------------------------------
126
127/// Compute partial Euclidean distance between two rows, using only features
128/// that are non-missing in both rows.
129///
130/// Returns `(distance, n_valid)`. If no valid features exist, returns
131/// `(F::infinity(), 0)`.
132fn partial_euclidean_distance<F: Float>(row_a: &[F], row_b: &[F]) -> (F, usize) {
133    let mut sum_sq = F::zero();
134    let mut n_valid = 0usize;
135    for (&a, &b) in row_a.iter().zip(row_b.iter()) {
136        if !a.is_nan() && !b.is_nan() {
137            let d = a - b;
138            sum_sq = sum_sq + d * d;
139            n_valid += 1;
140        }
141    }
142    if n_valid == 0 {
143        (F::infinity(), 0)
144    } else {
145        // Scale distance to account for missing dimensions:
146        // d_full = d_partial * sqrt(n_total / n_valid)
147        // But we keep it simple here: just use sqrt(sum_sq)
148        (sum_sq.sqrt(), n_valid)
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Trait implementations
154// ---------------------------------------------------------------------------
155
156impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KNNImputer<F> {
157    type Fitted = FittedKNNImputer<F>;
158    type Error = FerroError;
159
160    /// Fit the imputer by storing the training data.
161    ///
162    /// # Errors
163    ///
164    /// - [`FerroError::InsufficientSamples`] if the input has zero rows.
165    /// - [`FerroError::InvalidParameter`] if `n_neighbors` is zero or exceeds
166    ///   the number of samples.
167    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKNNImputer<F>, FerroError> {
168        let n_samples = x.nrows();
169        if n_samples == 0 {
170            return Err(FerroError::InsufficientSamples {
171                required: 1,
172                actual: 0,
173                context: "KNNImputer::fit".into(),
174            });
175        }
176        if self.n_neighbors == 0 {
177            return Err(FerroError::InvalidParameter {
178                name: "n_neighbors".into(),
179                reason: "n_neighbors must be at least 1".into(),
180            });
181        }
182        if self.n_neighbors > n_samples {
183            return Err(FerroError::InvalidParameter {
184                name: "n_neighbors".into(),
185                reason: format!(
186                    "n_neighbors ({}) exceeds the number of training samples ({})",
187                    self.n_neighbors, n_samples
188                ),
189            });
190        }
191
192        Ok(FittedKNNImputer {
193            train_data: x.to_owned(),
194            n_neighbors: self.n_neighbors,
195            weights: self.weights,
196        })
197    }
198}
199
200impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKNNImputer<F> {
201    type Output = Array2<F>;
202    type Error = FerroError;
203
204    /// Impute missing values in `x` using the k-nearest neighbors from the
205    /// training data.
206    ///
207    /// For each missing value `x[i, j]`, the method finds the `k` nearest
208    /// training rows (based on partial Euclidean distance over non-missing
209    /// features) that also have a non-missing value at feature `j`, then
210    /// computes a (optionally distance-weighted) average.
211    ///
212    /// # Errors
213    ///
214    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
215    /// match the training data.
216    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
217        let n_features = self.train_data.ncols();
218        if x.ncols() != n_features {
219            return Err(FerroError::ShapeMismatch {
220                expected: vec![x.nrows(), n_features],
221                actual: vec![x.nrows(), x.ncols()],
222                context: "FittedKNNImputer::transform".into(),
223            });
224        }
225
226        let mut out = x.to_owned();
227        let n_train = self.train_data.nrows();
228
229        for i in 0..out.nrows() {
230            // Check if this row has any missing values
231            let row_slice: Vec<F> = out.row(i).to_vec();
232            let has_missing = row_slice.iter().any(|v| v.is_nan());
233            if !has_missing {
234                continue;
235            }
236
237            // Compute distances to all training rows
238            let mut dists: Vec<(usize, F)> = Vec::with_capacity(n_train);
239            for t in 0..n_train {
240                let train_row: Vec<F> = self.train_data.row(t).to_vec();
241                let (d, n_valid) = partial_euclidean_distance(&row_slice, &train_row);
242                if n_valid > 0 {
243                    dists.push((t, d));
244                }
245            }
246            // Sort by distance
247            dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
248
249            // For each missing feature, impute from the k nearest neighbors
250            // that have a non-missing value for that feature
251            for j in 0..n_features {
252                if !row_slice[j].is_nan() {
253                    continue;
254                }
255
256                // Collect up to k neighbors that have a valid value at feature j
257                let mut neighbor_vals: Vec<(F, F)> = Vec::new(); // (value, distance)
258                for &(t_idx, dist) in &dists {
259                    let val = self.train_data[[t_idx, j]];
260                    if !val.is_nan() {
261                        neighbor_vals.push((val, dist));
262                        if neighbor_vals.len() >= self.n_neighbors {
263                            break;
264                        }
265                    }
266                }
267
268                if neighbor_vals.is_empty() {
269                    // No valid neighbors found — leave as NaN or fill with zero
270                    out[[i, j]] = F::zero();
271                    continue;
272                }
273
274                let imputed = match self.weights {
275                    KNNWeights::Uniform => {
276                        let sum = neighbor_vals
277                            .iter()
278                            .map(|&(v, _)| v)
279                            .fold(F::zero(), |acc, v| acc + v);
280                        sum / F::from(neighbor_vals.len()).unwrap_or(F::one())
281                    }
282                    KNNWeights::Distance => {
283                        // Inverse distance weighting
284                        let mut weight_sum = F::zero();
285                        let mut val_sum = F::zero();
286                        let epsilon = F::from(1e-12).unwrap_or(F::min_positive_value());
287                        for &(val, dist) in &neighbor_vals {
288                            let w = if dist <= epsilon {
289                                // Exact match — give very high weight
290                                F::from(1e12).unwrap_or(F::max_value())
291                            } else {
292                                F::one() / dist
293                            };
294                            weight_sum = weight_sum + w;
295                            val_sum = val_sum + w * val;
296                        }
297                        if weight_sum > F::zero() {
298                            val_sum / weight_sum
299                        } else {
300                            neighbor_vals[0].0
301                        }
302                    }
303                };
304
305                out[[i, j]] = imputed;
306            }
307        }
308
309        Ok(out)
310    }
311}
312
313/// Implement `Transform` on the unfitted imputer to satisfy the
314/// `FitTransform: Transform` supertrait bound.
315impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for KNNImputer<F> {
316    type Output = Array2<F>;
317    type Error = FerroError;
318
319    /// Always returns an error — the imputer must be fitted first.
320    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
321        Err(FerroError::InvalidParameter {
322            name: "KNNImputer".into(),
323            reason: "imputer must be fitted before calling transform; use fit() first".into(),
324        })
325    }
326}
327
328impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for KNNImputer<F> {
329    type FitError = FerroError;
330
331    /// Fit the imputer on `x` and return the imputed output in one step.
332    ///
333    /// # Errors
334    ///
335    /// Returns an error if fitting fails.
336    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
337        let fitted = self.fit(x, &())?;
338        fitted.transform(x)
339    }
340}
341
342// ---------------------------------------------------------------------------
343// Tests
344// ---------------------------------------------------------------------------
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use approx::assert_abs_diff_eq;
350    use ndarray::array;
351
352    #[test]
353    fn test_knn_imputer_uniform_basic() {
354        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
355        // Row 2 has NaN in column 1
356        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, f64::NAN]];
357        let fitted = imputer.fit(&x, &()).unwrap();
358        let out = fitted.transform(&x).unwrap();
359        // Missing value at [2,1]: nearest 2 neighbors are rows 0 and 1
360        // with values 2.0 and 4.0 → mean = 3.0
361        assert_abs_diff_eq!(out[[2, 1]], 3.0, epsilon = 1e-10);
362        // Non-missing values unchanged
363        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
364        assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
365    }
366
367    #[test]
368    fn test_knn_imputer_distance_weighted() {
369        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Distance);
370        // Rows 0 and 1 have known feature 1; row 2 is missing feature 1
371        // Row 2 feature 0 = 4.0, row 0 feature 0 = 1.0, row 1 feature 0 = 3.0
372        // Distance to row 0: |4 - 1| = 3.0
373        // Distance to row 1: |4 - 3| = 1.0
374        // Weighted: (2.0 * 1/3 + 6.0 * 1/1) / (1/3 + 1/1) = (0.667 + 6.0) / 1.333 ≈ 5.0
375        let x = array![[1.0, 2.0], [3.0, 6.0], [4.0, f64::NAN]];
376        let fitted = imputer.fit(&x, &()).unwrap();
377        let out = fitted.transform(&x).unwrap();
378        // w0 = 1/3, w1 = 1/1
379        let w0 = 1.0 / 3.0;
380        let w1 = 1.0 / 1.0;
381        let expected = (2.0 * w0 + 6.0 * w1) / (w0 + w1);
382        assert_abs_diff_eq!(out[[2, 1]], expected, epsilon = 1e-10);
383    }
384
385    #[test]
386    fn test_knn_imputer_no_missing() {
387        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
388        let x = array![[1.0, 2.0], [3.0, 4.0]];
389        let fitted = imputer.fit(&x, &()).unwrap();
390        let out = fitted.transform(&x).unwrap();
391        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
392        assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
393    }
394
395    #[test]
396    fn test_knn_imputer_multiple_missing() {
397        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
398        let x = array![
399            [1.0, 10.0, 100.0],
400            [2.0, 20.0, 200.0],
401            [3.0, f64::NAN, f64::NAN]
402        ];
403        let fitted = imputer.fit(&x, &()).unwrap();
404        let out = fitted.transform(&x).unwrap();
405        // All imputed values should be finite
406        assert!(!out[[2, 1]].is_nan());
407        assert!(!out[[2, 2]].is_nan());
408    }
409
410    #[test]
411    fn test_knn_imputer_fit_transform() {
412        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
413        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, f64::NAN]];
414        let out = imputer.fit_transform(&x).unwrap();
415        assert!(!out[[2, 1]].is_nan());
416    }
417
418    #[test]
419    fn test_knn_imputer_zero_rows_error() {
420        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
421        let x: Array2<f64> = Array2::zeros((0, 3));
422        assert!(imputer.fit(&x, &()).is_err());
423    }
424
425    #[test]
426    fn test_knn_imputer_zero_neighbors_error() {
427        let imputer = KNNImputer::<f64>::new(0, KNNWeights::Uniform);
428        let x = array![[1.0, 2.0]];
429        assert!(imputer.fit(&x, &()).is_err());
430    }
431
432    #[test]
433    fn test_knn_imputer_too_many_neighbors_error() {
434        let imputer = KNNImputer::<f64>::new(10, KNNWeights::Uniform);
435        let x = array![[1.0, 2.0], [3.0, 4.0]];
436        assert!(imputer.fit(&x, &()).is_err());
437    }
438
439    #[test]
440    fn test_knn_imputer_shape_mismatch_error() {
441        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
442        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
443        let fitted = imputer.fit(&x_train, &()).unwrap();
444        let x_bad = array![[1.0, 2.0, 3.0]];
445        assert!(fitted.transform(&x_bad).is_err());
446    }
447
448    #[test]
449    fn test_knn_imputer_unfitted_transform_error() {
450        let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
451        let x = array![[1.0, 2.0]];
452        assert!(imputer.transform(&x).is_err());
453    }
454
455    #[test]
456    fn test_knn_imputer_default() {
457        let imputer = KNNImputer::<f64>::default();
458        assert_eq!(imputer.n_neighbors(), 5);
459        assert_eq!(imputer.weights(), KNNWeights::Uniform);
460    }
461
462    #[test]
463    fn test_knn_imputer_single_neighbor() {
464        let imputer = KNNImputer::<f64>::new(1, KNNWeights::Uniform);
465        // Row 0 is closest to row 2 (distance on col 0 = |5 - 4| = 1)
466        let x = array![[1.0, 10.0], [4.0, 40.0], [5.0, f64::NAN]];
467        let fitted = imputer.fit(&x, &()).unwrap();
468        let out = fitted.transform(&x).unwrap();
469        // Nearest neighbor to row 2 by col 0: row 1 (dist = 1) vs row 0 (dist = 4)
470        assert_abs_diff_eq!(out[[2, 1]], 40.0, epsilon = 1e-10);
471    }
472
473    #[test]
474    fn test_knn_imputer_f32() {
475        let imputer = KNNImputer::<f32>::new(2, KNNWeights::Uniform);
476        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, f32::NAN]];
477        let fitted = imputer.fit(&x, &()).unwrap();
478        let out = fitted.transform(&x).unwrap();
479        assert!(!out[[2, 1]].is_nan());
480    }
481}