1use ferrolearn_core::error::FerroError;
15use ferrolearn_core::traits::{Fit, FitTransform, Transform};
16use ndarray::Array2;
17use num_traits::Float;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum KNNWeights {
26 Uniform,
28 Distance,
30}
31
32#[must_use]
60#[derive(Debug, Clone)]
61pub struct KNNImputer<F> {
62 n_neighbors: usize,
64 weights: KNNWeights,
66 _marker: std::marker::PhantomData<F>,
67}
68
69impl<F: Float + Send + Sync + 'static> KNNImputer<F> {
70 pub fn new(n_neighbors: usize, weights: KNNWeights) -> Self {
72 Self {
73 n_neighbors,
74 weights,
75 _marker: std::marker::PhantomData,
76 }
77 }
78
79 #[must_use]
81 pub fn n_neighbors(&self) -> usize {
82 self.n_neighbors
83 }
84
85 #[must_use]
87 pub fn weights(&self) -> KNNWeights {
88 self.weights
89 }
90}
91
92impl<F: Float + Send + Sync + 'static> Default for KNNImputer<F> {
93 fn default() -> Self {
94 Self::new(5, KNNWeights::Uniform)
95 }
96}
97
98#[derive(Debug, Clone)]
106pub struct FittedKNNImputer<F> {
107 train_data: Array2<F>,
109 n_neighbors: usize,
111 weights: KNNWeights,
113}
114
115impl<F: Float + Send + Sync + 'static> FittedKNNImputer<F> {
116 #[must_use]
118 pub fn n_train_samples(&self) -> usize {
119 self.train_data.nrows()
120 }
121}
122
123fn partial_euclidean_distance<F: Float>(row_a: &[F], row_b: &[F]) -> (F, usize) {
133 let mut sum_sq = F::zero();
134 let mut n_valid = 0usize;
135 for (&a, &b) in row_a.iter().zip(row_b.iter()) {
136 if !a.is_nan() && !b.is_nan() {
137 let d = a - b;
138 sum_sq = sum_sq + d * d;
139 n_valid += 1;
140 }
141 }
142 if n_valid == 0 {
143 (F::infinity(), 0)
144 } else {
145 (sum_sq.sqrt(), n_valid)
149 }
150}
151
152impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for KNNImputer<F> {
157 type Fitted = FittedKNNImputer<F>;
158 type Error = FerroError;
159
160 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedKNNImputer<F>, FerroError> {
168 let n_samples = x.nrows();
169 if n_samples == 0 {
170 return Err(FerroError::InsufficientSamples {
171 required: 1,
172 actual: 0,
173 context: "KNNImputer::fit".into(),
174 });
175 }
176 if self.n_neighbors == 0 {
177 return Err(FerroError::InvalidParameter {
178 name: "n_neighbors".into(),
179 reason: "n_neighbors must be at least 1".into(),
180 });
181 }
182 if self.n_neighbors > n_samples {
183 return Err(FerroError::InvalidParameter {
184 name: "n_neighbors".into(),
185 reason: format!(
186 "n_neighbors ({}) exceeds the number of training samples ({})",
187 self.n_neighbors, n_samples
188 ),
189 });
190 }
191
192 Ok(FittedKNNImputer {
193 train_data: x.to_owned(),
194 n_neighbors: self.n_neighbors,
195 weights: self.weights,
196 })
197 }
198}
199
200impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedKNNImputer<F> {
201 type Output = Array2<F>;
202 type Error = FerroError;
203
204 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
217 let n_features = self.train_data.ncols();
218 if x.ncols() != n_features {
219 return Err(FerroError::ShapeMismatch {
220 expected: vec![x.nrows(), n_features],
221 actual: vec![x.nrows(), x.ncols()],
222 context: "FittedKNNImputer::transform".into(),
223 });
224 }
225
226 let mut out = x.to_owned();
227 let n_train = self.train_data.nrows();
228
229 for i in 0..out.nrows() {
230 let row_slice: Vec<F> = out.row(i).to_vec();
232 let has_missing = row_slice.iter().any(|v| v.is_nan());
233 if !has_missing {
234 continue;
235 }
236
237 let mut dists: Vec<(usize, F)> = Vec::with_capacity(n_train);
239 for t in 0..n_train {
240 let train_row: Vec<F> = self.train_data.row(t).to_vec();
241 let (d, n_valid) = partial_euclidean_distance(&row_slice, &train_row);
242 if n_valid > 0 {
243 dists.push((t, d));
244 }
245 }
246 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
248
249 for j in 0..n_features {
252 if !row_slice[j].is_nan() {
253 continue;
254 }
255
256 let mut neighbor_vals: Vec<(F, F)> = Vec::new(); for &(t_idx, dist) in &dists {
259 let val = self.train_data[[t_idx, j]];
260 if !val.is_nan() {
261 neighbor_vals.push((val, dist));
262 if neighbor_vals.len() >= self.n_neighbors {
263 break;
264 }
265 }
266 }
267
268 if neighbor_vals.is_empty() {
269 out[[i, j]] = F::zero();
271 continue;
272 }
273
274 let imputed = match self.weights {
275 KNNWeights::Uniform => {
276 let sum = neighbor_vals
277 .iter()
278 .map(|&(v, _)| v)
279 .fold(F::zero(), |acc, v| acc + v);
280 sum / F::from(neighbor_vals.len()).unwrap_or(F::one())
281 }
282 KNNWeights::Distance => {
283 let mut weight_sum = F::zero();
285 let mut val_sum = F::zero();
286 let epsilon = F::from(1e-12).unwrap_or(F::min_positive_value());
287 for &(val, dist) in &neighbor_vals {
288 let w = if dist <= epsilon {
289 F::from(1e12).unwrap_or(F::max_value())
291 } else {
292 F::one() / dist
293 };
294 weight_sum = weight_sum + w;
295 val_sum = val_sum + w * val;
296 }
297 if weight_sum > F::zero() {
298 val_sum / weight_sum
299 } else {
300 neighbor_vals[0].0
301 }
302 }
303 };
304
305 out[[i, j]] = imputed;
306 }
307 }
308
309 Ok(out)
310 }
311}
312
313impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for KNNImputer<F> {
316 type Output = Array2<F>;
317 type Error = FerroError;
318
319 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
321 Err(FerroError::InvalidParameter {
322 name: "KNNImputer".into(),
323 reason: "imputer must be fitted before calling transform; use fit() first".into(),
324 })
325 }
326}
327
328impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for KNNImputer<F> {
329 type FitError = FerroError;
330
331 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
337 let fitted = self.fit(x, &())?;
338 fitted.transform(x)
339 }
340}
341
342#[cfg(test)]
347mod tests {
348 use super::*;
349 use approx::assert_abs_diff_eq;
350 use ndarray::array;
351
352 #[test]
353 fn test_knn_imputer_uniform_basic() {
354 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
355 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, f64::NAN]];
357 let fitted = imputer.fit(&x, &()).unwrap();
358 let out = fitted.transform(&x).unwrap();
359 assert_abs_diff_eq!(out[[2, 1]], 3.0, epsilon = 1e-10);
362 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
364 assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
365 }
366
367 #[test]
368 fn test_knn_imputer_distance_weighted() {
369 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Distance);
370 let x = array![[1.0, 2.0], [3.0, 6.0], [4.0, f64::NAN]];
376 let fitted = imputer.fit(&x, &()).unwrap();
377 let out = fitted.transform(&x).unwrap();
378 let w0 = 1.0 / 3.0;
380 let w1 = 1.0 / 1.0;
381 let expected = (2.0 * w0 + 6.0 * w1) / (w0 + w1);
382 assert_abs_diff_eq!(out[[2, 1]], expected, epsilon = 1e-10);
383 }
384
385 #[test]
386 fn test_knn_imputer_no_missing() {
387 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
388 let x = array![[1.0, 2.0], [3.0, 4.0]];
389 let fitted = imputer.fit(&x, &()).unwrap();
390 let out = fitted.transform(&x).unwrap();
391 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
392 assert_abs_diff_eq!(out[[1, 1]], 4.0, epsilon = 1e-10);
393 }
394
395 #[test]
396 fn test_knn_imputer_multiple_missing() {
397 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
398 let x = array![
399 [1.0, 10.0, 100.0],
400 [2.0, 20.0, 200.0],
401 [3.0, f64::NAN, f64::NAN]
402 ];
403 let fitted = imputer.fit(&x, &()).unwrap();
404 let out = fitted.transform(&x).unwrap();
405 assert!(!out[[2, 1]].is_nan());
407 assert!(!out[[2, 2]].is_nan());
408 }
409
410 #[test]
411 fn test_knn_imputer_fit_transform() {
412 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
413 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, f64::NAN]];
414 let out = imputer.fit_transform(&x).unwrap();
415 assert!(!out[[2, 1]].is_nan());
416 }
417
418 #[test]
419 fn test_knn_imputer_zero_rows_error() {
420 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
421 let x: Array2<f64> = Array2::zeros((0, 3));
422 assert!(imputer.fit(&x, &()).is_err());
423 }
424
425 #[test]
426 fn test_knn_imputer_zero_neighbors_error() {
427 let imputer = KNNImputer::<f64>::new(0, KNNWeights::Uniform);
428 let x = array![[1.0, 2.0]];
429 assert!(imputer.fit(&x, &()).is_err());
430 }
431
432 #[test]
433 fn test_knn_imputer_too_many_neighbors_error() {
434 let imputer = KNNImputer::<f64>::new(10, KNNWeights::Uniform);
435 let x = array![[1.0, 2.0], [3.0, 4.0]];
436 assert!(imputer.fit(&x, &()).is_err());
437 }
438
439 #[test]
440 fn test_knn_imputer_shape_mismatch_error() {
441 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
442 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
443 let fitted = imputer.fit(&x_train, &()).unwrap();
444 let x_bad = array![[1.0, 2.0, 3.0]];
445 assert!(fitted.transform(&x_bad).is_err());
446 }
447
448 #[test]
449 fn test_knn_imputer_unfitted_transform_error() {
450 let imputer = KNNImputer::<f64>::new(2, KNNWeights::Uniform);
451 let x = array![[1.0, 2.0]];
452 assert!(imputer.transform(&x).is_err());
453 }
454
455 #[test]
456 fn test_knn_imputer_default() {
457 let imputer = KNNImputer::<f64>::default();
458 assert_eq!(imputer.n_neighbors(), 5);
459 assert_eq!(imputer.weights(), KNNWeights::Uniform);
460 }
461
462 #[test]
463 fn test_knn_imputer_single_neighbor() {
464 let imputer = KNNImputer::<f64>::new(1, KNNWeights::Uniform);
465 let x = array![[1.0, 10.0], [4.0, 40.0], [5.0, f64::NAN]];
467 let fitted = imputer.fit(&x, &()).unwrap();
468 let out = fitted.transform(&x).unwrap();
469 assert_abs_diff_eq!(out[[2, 1]], 40.0, epsilon = 1e-10);
471 }
472
473 #[test]
474 fn test_knn_imputer_f32() {
475 let imputer = KNNImputer::<f32>::new(2, KNNWeights::Uniform);
476 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, f32::NAN]];
477 let fitted = imputer.fit(&x, &()).unwrap();
478 let out = fitted.transform(&x).unwrap();
479 assert!(!out[[2, 1]].is_nan());
480 }
481}