fdars_core/classification/
knn.rs1use crate::error::FdarError;
4use crate::matrix::FdMatrix;
5
6use super::{
7 build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifResult,
8};
9
10#[must_use = "expensive computation whose result should not be discarded"]
43pub fn fclassif_knn(
44 data: &FdMatrix,
45 y: &[usize],
46 scalar_covariates: Option<&FdMatrix>,
47 ncomp: usize,
48 k_nn: usize,
49) -> Result<ClassifResult, FdarError> {
50 let n = data.nrows();
51 if n == 0 || y.len() != n {
52 return Err(FdarError::InvalidDimension {
53 parameter: "data/y",
54 expected: "n > 0 and y.len() == n".to_string(),
55 actual: format!("n={}, y.len()={}", n, y.len()),
56 });
57 }
58 if ncomp == 0 {
59 return Err(FdarError::InvalidParameter {
60 parameter: "ncomp",
61 message: "must be > 0".to_string(),
62 });
63 }
64 if k_nn == 0 {
65 return Err(FdarError::InvalidParameter {
66 parameter: "k_nn",
67 message: "must be > 0".to_string(),
68 });
69 }
70
71 let (labels, g) = remap_labels(y);
72 if g < 2 {
73 return Err(FdarError::InvalidParameter {
74 parameter: "y",
75 message: format!("need at least 2 classes, got {g}"),
76 });
77 }
78
79 let (features, _mean, _rotation) = build_feature_matrix(data, scalar_covariates, ncomp)?;
80 let d = features.ncols();
81
82 let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
83 let accuracy = compute_accuracy(&labels, &predicted);
84 let confusion = confusion_matrix(&labels, &predicted, g);
85
86 Ok(ClassifResult {
87 predicted,
88 probabilities: None,
89 accuracy,
90 confusion,
91 n_classes: g,
92 ncomp: d.min(ncomp),
93 })
94}
95
96pub(crate) fn knn_predict_loo(
98 features: &FdMatrix,
99 labels: &[usize],
100 g: usize,
101 d: usize,
102 k_nn: usize,
103) -> Vec<usize> {
104 let n = features.nrows();
105 let k_nn = k_nn.min(n - 1);
106
107 (0..n)
108 .map(|i| {
109 let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
110 let mut dists: Vec<(f64, usize)> = (0..n)
111 .filter(|&j| j != i)
112 .map(|j| {
113 let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
114 let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
115 (d_sq, labels[j])
116 })
117 .collect();
118 if k_nn > 0 && k_nn < dists.len() {
119 dists.select_nth_unstable_by(k_nn - 1, |a, b| {
120 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
121 });
122 }
123
124 let mut votes = vec![0usize; g];
126 for &(_, label) in dists.iter().take(k_nn) {
127 votes[label] += 1;
128 }
129 votes
130 .iter()
131 .enumerate()
132 .max_by_key(|&(_, &v)| v)
133 .map_or(0, |(c, _)| c)
134 })
135 .collect()
136}