fdars_core/classification/
knn.rs1use crate::error::FdarError;
4use crate::matrix::FdMatrix;
5
6use super::{
7 build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifResult,
8};
9
10#[must_use = "expensive computation whose result should not be discarded"]
43pub fn fclassif_knn(
44 data: &FdMatrix,
45 y: &[usize],
46 scalar_covariates: Option<&FdMatrix>,
47 ncomp: usize,
48 k_nn: usize,
49) -> Result<ClassifResult, FdarError> {
50 let n = data.nrows();
51 if n == 0 || y.len() != n {
52 return Err(FdarError::InvalidDimension {
53 parameter: "data/y",
54 expected: "n > 0 and y.len() == n".to_string(),
55 actual: format!("n={}, y.len()={}", n, y.len()),
56 });
57 }
58 if ncomp == 0 {
59 return Err(FdarError::InvalidParameter {
60 parameter: "ncomp",
61 message: "must be > 0".to_string(),
62 });
63 }
64 if k_nn == 0 {
65 return Err(FdarError::InvalidParameter {
66 parameter: "k_nn",
67 message: "must be > 0".to_string(),
68 });
69 }
70
71 let (labels, g) = remap_labels(y);
72 if g < 2 {
73 return Err(FdarError::InvalidParameter {
74 parameter: "y",
75 message: format!("need at least 2 classes, got {g}"),
76 });
77 }
78
79 let (features, _mean, _rotation, _weights) =
80 build_feature_matrix(data, scalar_covariates, ncomp)?;
81 let d = features.ncols();
82
83 let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
84 let accuracy = compute_accuracy(&labels, &predicted);
85 let confusion = confusion_matrix(&labels, &predicted, g);
86
87 Ok(ClassifResult {
88 predicted,
89 probabilities: None,
90 accuracy,
91 confusion,
92 n_classes: g,
93 ncomp: d.min(ncomp),
94 })
95}
96
97#[must_use = "expensive computation whose result should not be discarded"]
110pub fn knn_classify_from_distances(
111 dist_mat: &FdMatrix,
112 y: &[usize],
113 k_nn: usize,
114) -> Result<ClassifResult, FdarError> {
115 let n = dist_mat.nrows();
116 if dist_mat.ncols() != n {
117 return Err(FdarError::InvalidDimension {
118 parameter: "dist_mat",
119 expected: format!("{n} x {n} (square)"),
120 actual: format!("{} x {}", n, dist_mat.ncols()),
121 });
122 }
123 if y.len() != n {
124 return Err(FdarError::InvalidDimension {
125 parameter: "y",
126 expected: format!("{n}"),
127 actual: format!("{}", y.len()),
128 });
129 }
130 if k_nn == 0 {
131 return Err(FdarError::InvalidParameter {
132 parameter: "k_nn",
133 message: "must be > 0".to_string(),
134 });
135 }
136
137 let (labels, g) = remap_labels(y);
138 if g < 2 {
139 return Err(FdarError::InvalidParameter {
140 parameter: "y",
141 message: format!("need at least 2 classes, got {g}"),
142 });
143 }
144
145 let k_nn = k_nn.min(n - 1);
146 let predicted: Vec<usize> = (0..n)
147 .map(|i| {
148 let mut dists: Vec<(f64, usize)> = (0..n)
149 .filter(|&j| j != i)
150 .map(|j| (dist_mat[(i, j)], labels[j]))
151 .collect();
152 if k_nn > 0 && k_nn < dists.len() {
153 dists.select_nth_unstable_by(k_nn - 1, |a, b| {
154 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
155 });
156 }
157 let mut votes = vec![0usize; g];
158 for &(_, label) in dists.iter().take(k_nn) {
159 votes[label] += 1;
160 }
161 votes
162 .iter()
163 .enumerate()
164 .max_by_key(|&(_, &v)| v)
165 .map_or(0, |(c, _)| c)
166 })
167 .collect();
168
169 let accuracy = compute_accuracy(&labels, &predicted);
170 let confusion = confusion_matrix(&labels, &predicted, g);
171
172 Ok(ClassifResult {
173 predicted,
174 probabilities: None,
175 accuracy,
176 confusion,
177 n_classes: g,
178 ncomp: 0,
179 })
180}
181
182pub(crate) fn knn_predict_loo(
184 features: &FdMatrix,
185 labels: &[usize],
186 g: usize,
187 d: usize,
188 k_nn: usize,
189) -> Vec<usize> {
190 let n = features.nrows();
191 let k_nn = k_nn.min(n - 1);
192
193 (0..n)
194 .map(|i| {
195 let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
196 let mut dists: Vec<(f64, usize)> = (0..n)
197 .filter(|&j| j != i)
198 .map(|j| {
199 let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
200 let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
201 (d_sq, labels[j])
202 })
203 .collect();
204 if k_nn > 0 && k_nn < dists.len() {
205 dists.select_nth_unstable_by(k_nn - 1, |a, b| {
206 a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
207 });
208 }
209
210 let mut votes = vec![0usize; g];
212 for &(_, label) in dists.iter().take(k_nn) {
213 votes[label] += 1;
214 }
215 votes
216 .iter()
217 .enumerate()
218 .max_by_key(|&(_, &v)| v)
219 .map_or(0, |(c, _)| c)
220 })
221 .collect()
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::matrix::FdMatrix;
228
229 #[test]
230 fn knn_from_distances_smoke() {
231 let mut dist = FdMatrix::zeros(6, 6);
233 for i in 0..3 {
235 for j in 0..3 {
236 if i != j {
237 dist[(i, j)] = 0.1;
238 }
239 }
240 }
241 for i in 3..6 {
243 for j in 3..6 {
244 if i != j {
245 dist[(i, j)] = 0.1;
246 }
247 }
248 }
249 for i in 0..3 {
251 for j in 3..6 {
252 dist[(i, j)] = 5.0;
253 dist[(j, i)] = 5.0;
254 }
255 }
256
257 let y = vec![0, 0, 0, 1, 1, 1];
258 let result = knn_classify_from_distances(&dist, &y, 3).unwrap();
259 assert_eq!(result.predicted, vec![0, 0, 0, 1, 1, 1]);
260 assert!((result.accuracy - 1.0).abs() < 1e-10);
261 }
262}