Skip to main content

fdars_core/classification/
knn.rs

1//! k-NN classifier internals.
2
3use crate::error::FdarError;
4use crate::matrix::FdMatrix;
5
6use super::{
7    build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifResult,
8};
9
10/// FPC + k-NN classification.
11///
12/// # Arguments
13/// * `data` — Functional data (n × m)
14/// * `y` — Class labels
15/// * `scalar_covariates` — Optional scalar covariates
16/// * `ncomp` — Number of FPC components
17/// * `k_nn` — Number of nearest neighbors
18///
19/// # Errors
20///
21/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or `y.len() != n`.
22/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero.
23/// Returns [`FdarError::InvalidParameter`] if `k_nn` is zero.
24/// Returns [`FdarError::InvalidParameter`] if `y` contains fewer than 2 distinct classes.
25/// Returns [`FdarError::ComputationFailed`] if the SVD decomposition in FPCA fails.
26///
27/// # Examples
28///
29/// ```
30/// use fdars_core::matrix::FdMatrix;
31/// use fdars_core::classification::knn::fclassif_knn;
32///
33/// let data = FdMatrix::from_column_major(
34///     (0..100).map(|i| (i as f64 * 0.1).sin()).collect(),
35///     10, 10,
36/// ).unwrap();
37/// let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
38/// let result = fclassif_knn(&data, &y, None, 3, 3).unwrap();
39/// assert_eq!(result.predicted.len(), 10);
40/// assert_eq!(result.n_classes, 2);
41/// ```
42#[must_use = "expensive computation whose result should not be discarded"]
43pub fn fclassif_knn(
44    data: &FdMatrix,
45    y: &[usize],
46    scalar_covariates: Option<&FdMatrix>,
47    ncomp: usize,
48    k_nn: usize,
49) -> Result<ClassifResult, FdarError> {
50    let n = data.nrows();
51    if n == 0 || y.len() != n {
52        return Err(FdarError::InvalidDimension {
53            parameter: "data/y",
54            expected: "n > 0 and y.len() == n".to_string(),
55            actual: format!("n={}, y.len()={}", n, y.len()),
56        });
57    }
58    if ncomp == 0 {
59        return Err(FdarError::InvalidParameter {
60            parameter: "ncomp",
61            message: "must be > 0".to_string(),
62        });
63    }
64    if k_nn == 0 {
65        return Err(FdarError::InvalidParameter {
66            parameter: "k_nn",
67            message: "must be > 0".to_string(),
68        });
69    }
70
71    let (labels, g) = remap_labels(y);
72    if g < 2 {
73        return Err(FdarError::InvalidParameter {
74            parameter: "y",
75            message: format!("need at least 2 classes, got {g}"),
76        });
77    }
78
79    let (features, _mean, _rotation, _weights) =
80        build_feature_matrix(data, scalar_covariates, ncomp)?;
81    let d = features.ncols();
82
83    let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
84    let accuracy = compute_accuracy(&labels, &predicted);
85    let confusion = confusion_matrix(&labels, &predicted, g);
86
87    Ok(ClassifResult {
88        predicted,
89        probabilities: None,
90        accuracy,
91        confusion,
92        n_classes: g,
93        ncomp: d.min(ncomp),
94    })
95}
96
97/// k-NN classification from a precomputed distance matrix.
98///
99/// Works with **any** distance matrix (elastic, DTW, Lp, or custom).
100/// Labels are 0-indexed class indices.
101///
102/// # Arguments
103/// * `dist_mat` — Symmetric n × n distance matrix
104/// * `y` — Class labels (length n, 0-indexed)
105/// * `k_nn` — Number of nearest neighbors
106///
107/// # Errors
108/// Returns errors if `dist_mat` is not square, `y.len() != n`, `k_nn == 0`, or fewer than 2 classes.
109#[must_use = "expensive computation whose result should not be discarded"]
110pub fn knn_classify_from_distances(
111    dist_mat: &FdMatrix,
112    y: &[usize],
113    k_nn: usize,
114) -> Result<ClassifResult, FdarError> {
115    let n = dist_mat.nrows();
116    if dist_mat.ncols() != n {
117        return Err(FdarError::InvalidDimension {
118            parameter: "dist_mat",
119            expected: format!("{n} x {n} (square)"),
120            actual: format!("{} x {}", n, dist_mat.ncols()),
121        });
122    }
123    if y.len() != n {
124        return Err(FdarError::InvalidDimension {
125            parameter: "y",
126            expected: format!("{n}"),
127            actual: format!("{}", y.len()),
128        });
129    }
130    if k_nn == 0 {
131        return Err(FdarError::InvalidParameter {
132            parameter: "k_nn",
133            message: "must be > 0".to_string(),
134        });
135    }
136
137    let (labels, g) = remap_labels(y);
138    if g < 2 {
139        return Err(FdarError::InvalidParameter {
140            parameter: "y",
141            message: format!("need at least 2 classes, got {g}"),
142        });
143    }
144
145    let k_nn = k_nn.min(n - 1);
146    let predicted: Vec<usize> = (0..n)
147        .map(|i| {
148            let mut dists: Vec<(f64, usize)> = (0..n)
149                .filter(|&j| j != i)
150                .map(|j| (dist_mat[(i, j)], labels[j]))
151                .collect();
152            if k_nn > 0 && k_nn < dists.len() {
153                dists.select_nth_unstable_by(k_nn - 1, |a, b| {
154                    a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
155                });
156            }
157            let mut votes = vec![0usize; g];
158            for &(_, label) in dists.iter().take(k_nn) {
159                votes[label] += 1;
160            }
161            votes
162                .iter()
163                .enumerate()
164                .max_by_key(|&(_, &v)| v)
165                .map_or(0, |(c, _)| c)
166        })
167        .collect();
168
169    let accuracy = compute_accuracy(&labels, &predicted);
170    let confusion = confusion_matrix(&labels, &predicted, g);
171
172    Ok(ClassifResult {
173        predicted,
174        probabilities: None,
175        accuracy,
176        confusion,
177        n_classes: g,
178        ncomp: 0,
179    })
180}
181
182/// Leave-one-out k-NN prediction.
183pub(crate) fn knn_predict_loo(
184    features: &FdMatrix,
185    labels: &[usize],
186    g: usize,
187    d: usize,
188    k_nn: usize,
189) -> Vec<usize> {
190    let n = features.nrows();
191    let k_nn = k_nn.min(n - 1);
192
193    (0..n)
194        .map(|i| {
195            let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
196            let mut dists: Vec<(f64, usize)> = (0..n)
197                .filter(|&j| j != i)
198                .map(|j| {
199                    let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
200                    let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
201                    (d_sq, labels[j])
202                })
203                .collect();
204            if k_nn > 0 && k_nn < dists.len() {
205                dists.select_nth_unstable_by(k_nn - 1, |a, b| {
206                    a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
207                });
208            }
209
210            // Majority vote among k nearest
211            let mut votes = vec![0usize; g];
212            for &(_, label) in dists.iter().take(k_nn) {
213                votes[label] += 1;
214            }
215            votes
216                .iter()
217                .enumerate()
218                .max_by_key(|&(_, &v)| v)
219                .map_or(0, |(c, _)| c)
220        })
221        .collect()
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::matrix::FdMatrix;
228
229    #[test]
230    fn knn_from_distances_smoke() {
231        // 6 points: 3 in class 0 (close together), 3 in class 1 (close together)
232        let mut dist = FdMatrix::zeros(6, 6);
233        // Within class 0 (indices 0,1,2): small distances
234        for i in 0..3 {
235            for j in 0..3 {
236                if i != j {
237                    dist[(i, j)] = 0.1;
238                }
239            }
240        }
241        // Within class 1 (indices 3,4,5): small distances
242        for i in 3..6 {
243            for j in 3..6 {
244                if i != j {
245                    dist[(i, j)] = 0.1;
246                }
247            }
248        }
249        // Between classes: large distances
250        for i in 0..3 {
251            for j in 3..6 {
252                dist[(i, j)] = 5.0;
253                dist[(j, i)] = 5.0;
254            }
255        }
256
257        let y = vec![0, 0, 0, 1, 1, 1];
258        let result = knn_classify_from_distances(&dist, &y, 3).unwrap();
259        assert_eq!(result.predicted, vec![0, 0, 0, 1, 1, 1]);
260        assert!((result.accuracy - 1.0).abs() < 1e-10);
261    }
262}