Skip to main content

fdars_core/classification/
knn.rs

1//! k-NN classifier internals.
2
3use crate::error::FdarError;
4use crate::matrix::FdMatrix;
5
6use super::{
7    build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifResult,
8};
9
10/// FPC + k-NN classification.
11///
12/// # Arguments
13/// * `data` — Functional data (n × m)
14/// * `y` — Class labels
15/// * `scalar_covariates` — Optional scalar covariates
16/// * `ncomp` — Number of FPC components
17/// * `k_nn` — Number of nearest neighbors
18///
19/// # Errors
20///
21/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
22/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
23/// Returns [`FdarError::InvalidParameter`] if `k_nn` is zero.
24/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
25/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
26#[must_use = "expensive computation whose result should not be discarded"]
27pub fn fclassif_knn(
28    data: &FdMatrix,
29    y: &[usize],
30    scalar_covariates: Option<&FdMatrix>,
31    ncomp: usize,
32    k_nn: usize,
33) -> Result<ClassifResult, FdarError> {
34    let n = data.nrows();
35    if n == 0 || y.len() != n {
36        return Err(FdarError::InvalidDimension {
37            parameter: "data/y",
38            expected: "n > 0 and y.len() == n".to_string(),
39            actual: format!("n={}, y.len()={}", n, y.len()),
40        });
41    }
42    if ncomp == 0 {
43        return Err(FdarError::InvalidParameter {
44            parameter: "ncomp",
45            message: "must be > 0".to_string(),
46        });
47    }
48    if k_nn == 0 {
49        return Err(FdarError::InvalidParameter {
50            parameter: "k_nn",
51            message: "must be > 0".to_string(),
52        });
53    }
54
55    let (labels, g) = remap_labels(y);
56    if g < 2 {
57        return Err(FdarError::InvalidParameter {
58            parameter: "y",
59            message: format!("need at least 2 classes, got {g}"),
60        });
61    }
62
63    let (features, _mean, _rotation) = build_feature_matrix(data, scalar_covariates, ncomp)?;
64    let d = features.ncols();
65
66    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
67    let accuracy = compute_accuracy(&labels, &predicted);
68    let confusion = confusion_matrix(&labels, &predicted, g);
69
70    Ok(ClassifResult {
71        predicted,
72        probabilities: None,
73        accuracy,
74        confusion,
75        n_classes: g,
76        ncomp: d.min(ncomp),
77    })
78}
79
80/// Leave-one-out k-NN prediction.
81pub(crate) fn knn_predict_loo(
82    features: &FdMatrix,
83    labels: &[usize],
84    g: usize,
85    d: usize,
86    k_nn: usize,
87) -> Vec<usize> {
88    let n = features.nrows();
89    let k_nn = k_nn.min(n - 1);
90
91    (0..n)
92        .map(|i| {
93            let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
94            let mut dists: Vec<(f64, usize)> = (0..n)
95                .filter(|&j| j != i)
96                .map(|j| {
97                    let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
98                    let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
99                    (d_sq, labels[j])
100                })
101                .collect();
102            if k_nn > 0 && k_nn < dists.len() {
103                dists.select_nth_unstable_by(k_nn - 1, |a, b| {
104                    a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
105                });
106            }
107
108            // Majority vote among k nearest
109            let mut votes = vec![0usize; g];
110            for &(_, label) in dists.iter().take(k_nn) {
111                votes[label] += 1;
112            }
113            votes
114                .iter()
115                .enumerate()
116                .max_by_key(|&(_, &v)| v)
117                .map_or(0, |(c, _)| c)
118        })
119        .collect()
120}