Skip to main content

fdars_core/classification/
knn.rs

1//! k-NN classifier internals.
2
3use crate::error::FdarError;
4use crate::matrix::FdMatrix;
5
6use super::{
7    build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifResult,
8};
9
10/// FPC + k-NN classification.
11///
12/// # Arguments
13/// * `data` — Functional data (n × m)
14/// * `y` — Class labels
15/// * `scalar_covariates` — Optional scalar covariates
16/// * `ncomp` — Number of FPC components
17/// * `k_nn` — Number of nearest neighbors
18///
19/// # Errors
20///
21/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
22/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
23/// Returns [`FdarError::InvalidParameter`] if `k_nn` is zero.
24/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
25/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
26///
27/// # Examples
28///
29/// ```
30/// use fdars_core::matrix::FdMatrix;
31/// use fdars_core::classification::knn::fclassif_knn;
32///
33/// let data = FdMatrix::from_column_major(
34///     (0..100).map(|i| (i as f64 * 0.1).sin()).collect(),
35///     10, 10,
36/// ).unwrap();
37/// let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
38/// let result = fclassif_knn(&data, &y, None, 3, 3).unwrap();
39/// assert_eq!(result.predicted.len(), 10);
40/// assert_eq!(result.n_classes, 2);
41/// ```
42#[must_use = "expensive computation whose result should not be discarded"]
43pub fn fclassif_knn(
44    data: &FdMatrix,
45    y: &[usize],
46    scalar_covariates: Option<&FdMatrix>,
47    ncomp: usize,
48    k_nn: usize,
49) -> Result<ClassifResult, FdarError> {
50    let n = data.nrows();
51    if n == 0 || y.len() != n {
52        return Err(FdarError::InvalidDimension {
53            parameter: "data/y",
54            expected: "n > 0 and y.len() == n".to_string(),
55            actual: format!("n={}, y.len()={}", n, y.len()),
56        });
57    }
58    if ncomp == 0 {
59        return Err(FdarError::InvalidParameter {
60            parameter: "ncomp",
61            message: "must be > 0".to_string(),
62        });
63    }
64    if k_nn == 0 {
65        return Err(FdarError::InvalidParameter {
66            parameter: "k_nn",
67            message: "must be > 0".to_string(),
68        });
69    }
70
71    let (labels, g) = remap_labels(y);
72    if g < 2 {
73        return Err(FdarError::InvalidParameter {
74            parameter: "y",
75            message: format!("need at least 2 classes, got {g}"),
76        });
77    }
78
79    let (features, _mean, _rotation) = build_feature_matrix(data, scalar_covariates, ncomp)?;
80    let d = features.ncols();
81
82    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
83    let accuracy = compute_accuracy(&labels, &predicted);
84    let confusion = confusion_matrix(&labels, &predicted, g);
85
86    Ok(ClassifResult {
87        predicted,
88        probabilities: None,
89        accuracy,
90        confusion,
91        n_classes: g,
92        ncomp: d.min(ncomp),
93    })
94}
95
96/// Leave-one-out k-NN prediction.
97pub(crate) fn knn_predict_loo(
98    features: &FdMatrix,
99    labels: &[usize],
100    g: usize,
101    d: usize,
102    k_nn: usize,
103) -> Vec<usize> {
104    let n = features.nrows();
105    let k_nn = k_nn.min(n - 1);
106
107    (0..n)
108        .map(|i| {
109            let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
110            let mut dists: Vec<(f64, usize)> = (0..n)
111                .filter(|&j| j != i)
112                .map(|j| {
113                    let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
114                    let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
115                    (d_sq, labels[j])
116                })
117                .collect();
118            if k_nn > 0 && k_nn < dists.len() {
119                dists.select_nth_unstable_by(k_nn - 1, |a, b| {
120                    a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
121                });
122            }
123
124            // Majority vote among k nearest
125            let mut votes = vec![0usize; g];
126            for &(_, label) in dists.iter().take(k_nn) {
127                votes[label] += 1;
128            }
129            votes
130                .iter()
131                .enumerate()
132                .max_by_key(|&(_, &v)| v)
133                .map_or(0, |(c, _)| c)
134        })
135        .collect()
136}