use crate::error::FdarError;
use crate::matrix::FdMatrix;
use super::{
build_feature_matrix, compute_accuracy, confusion_matrix, remap_labels, ClassifResult,
};
#[must_use = "expensive computation whose result should not be discarded"]
pub fn fclassif_knn(
data: &FdMatrix,
y: &[usize],
scalar_covariates: Option<&FdMatrix>,
ncomp: usize,
k_nn: usize,
) -> Result<ClassifResult, FdarError> {
let n = data.nrows();
if n == 0 || y.len() != n {
return Err(FdarError::InvalidDimension {
parameter: "data/y",
expected: "n > 0 and y.len() == n".to_string(),
actual: format!("n={}, y.len()={}", n, y.len()),
});
}
if ncomp == 0 {
return Err(FdarError::InvalidParameter {
parameter: "ncomp",
message: "must be > 0".to_string(),
});
}
if k_nn == 0 {
return Err(FdarError::InvalidParameter {
parameter: "k_nn",
message: "must be > 0".to_string(),
});
}
let (labels, g) = remap_labels(y);
if g < 2 {
return Err(FdarError::InvalidParameter {
parameter: "y",
message: format!("need at least 2 classes, got {g}"),
});
}
let (features, _mean, _rotation, _weights) =
build_feature_matrix(data, scalar_covariates, ncomp)?;
let d = features.ncols();
let predicted = knn_predict_loo(&features, &labels, g, d, k_nn);
let accuracy = compute_accuracy(&labels, &predicted);
let confusion = confusion_matrix(&labels, &predicted, g);
Ok(ClassifResult {
predicted,
probabilities: None,
accuracy,
confusion,
n_classes: g,
ncomp: d.min(ncomp),
})
}
pub(crate) fn knn_predict_loo(
features: &FdMatrix,
labels: &[usize],
g: usize,
d: usize,
k_nn: usize,
) -> Vec<usize> {
let n = features.nrows();
let k_nn = k_nn.min(n - 1);
(0..n)
.map(|i| {
let xi: Vec<f64> = (0..d).map(|j| features[(i, j)]).collect();
let mut dists: Vec<(f64, usize)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
let xj: Vec<f64> = (0..d).map(|jj| features[(j, jj)]).collect();
let d_sq: f64 = xi.iter().zip(&xj).map(|(&a, &b)| (a - b).powi(2)).sum();
(d_sq, labels[j])
})
.collect();
if k_nn > 0 && k_nn < dists.len() {
dists.select_nth_unstable_by(k_nn - 1, |a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
}
let mut votes = vec![0usize; g];
for &(_, label) in dists.iter().take(k_nn) {
votes[label] += 1;
}
votes
.iter()
.enumerate()
.max_by_key(|&(_, &v)| v)
.map_or(0, |(c, _)| c)
})
.collect()
}