fdars_core/classification/
knn.rs1use crate::error::FdarError;
4use crate::matrix::FdMatrix;
5
6use super::{
7 build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifResult,
8};
9
10#[must_use = "expensive computation whose result should not be discarded"]
27pub fn fclassif_knn(
28 data: &FdMatrix,
29 y: &[usize],
30 scalar_covariates: Option<&FdMatrix>,
31 ncomp: usize,
32 k_nn: usize,
33) -> Result<ClassifResult, FdarError> {
34 let n = data.nrows();
35 if n == 0 || y.len() != n {
36 return Err(FdarError::InvalidDimension {
37 parameter: "data/y",
38 expected: "n > 0 and y.len() == n".to_string(),
39 actual: format!("n={}, y.len()={}", n, y.len()),
40 });
41 }
42 if ncomp == 0 {
43 return Err(FdarError::InvalidParameter {
44 parameter: "ncomp",
45 message: "must be > 0".to_string(),
46 });
47 }
48 if k_nn == 0 {
49 return Err(FdarError::InvalidParameter {
50 parameter: "k_nn",
51 message: "must be > 0".to_string(),
52 });
53 }
54
55 let (labels, g) = remap_labels(y);
56 if g < 2 {
57 return Err(FdarError::InvalidParameter {
58 parameter: "y",
59 message: format!("need at least 2 classes, got {g}"),
60 });
61 }
62
63 let (features, _mean, _rotation) = build_feature_matrix(data, scalar_covariates, ncomp)?;
64 let d = features.ncols();
65
66 let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
67 let accuracy = compute_accuracy(&labels, &predicted);
68 let confusion = confusion_matrix(&labels, &predicted, g);
69
70 Ok(ClassifResult {
71 predicted,
72 probabilities: None,
73 accuracy,
74 confusion,
75 n_classes: g,
76 ncomp: d.min(ncomp),
77 })
78}
79
80pub(crate) fn knn_predict_loo(
82 features: &FdMatrix,
83 labels: &[usize],
84 g: usize,
85 d: usize,
86 k_nn: usize,
87) -> Vec<usize> {
88 let n = features.nrows();
89 let k_nn = k_nn.min(n - 1);
90
91 (0..n)
92 .map(|i| {
93 let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
94 let mut dists: Vec<(f64, usize)> = (0..n)
95 .filter(|&j| j != i)
96 .map(|j| {
97 let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
98 let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
99 (d_sq, labels[j])
100 })
101 .collect();
102 if k_nn > 0 && k_nn < dists.len() {
103 dists.select_nth_unstable_by(k_nn - 1, |a, b| {
104 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
105 });
106 }
107
108 let mut votes = vec![0usize; g];
110 for &(_, label) in dists.iter().take(k_nn) {
111 votes[label] += 1;
112 }
113 votes
114 .iter()
115 .enumerate()
116 .max_by_key(|&(_, &v)| v)
117 .map_or(0, |(c, _)| c)
118 })
119 .collect()
120}