use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::metrics::distance::{Distance, Distances};
use crate::neighbors::KNNWeightFunction;
use crate::numbers::basenum::Number;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNClassifierParameters<T: Number, D: Distance<Vec<T>>> {
#[cfg_attr(feature = "serde", serde(default))]
pub distance: D,
#[cfg_attr(feature = "serde", serde(default))]
pub algorithm: KNNAlgorithmName,
#[cfg_attr(feature = "serde", serde(default))]
pub weight: KNNWeightFunction,
#[cfg_attr(feature = "serde", serde(default))]
pub k: usize,
#[cfg_attr(feature = "serde", serde(default))]
t: PhantomData<T>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KNNClassifier<
TX: Number,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
D: Distance<Vec<TX>>,
> {
classes: Option<Vec<TY>>,
y: Option<Vec<usize>>,
knn_algorithm: Option<KNNAlgorithm<TX, D>>,
weight: Option<KNNWeightFunction>,
k: Option<usize>,
_phantom_tx: PhantomData<TX>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
KNNClassifier<TX, TY, X, Y, D>
{
fn classes(&self) -> &Vec<TY> {
self.classes.as_ref().unwrap()
}
fn y(&self) -> &Vec<usize> {
self.y.as_ref().unwrap()
}
fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> {
self.knn_algorithm.as_ref().unwrap()
}
fn weight(&self) -> &KNNWeightFunction {
self.weight.as_ref().unwrap()
}
fn k(&self) -> usize {
self.k.unwrap()
}
}
impl<T: Number, D: Distance<Vec<T>>> KNNClassifierParameters<T, D> {
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn with_distance<DD: Distance<Vec<T>>>(
self,
distance: DD,
) -> KNNClassifierParameters<T, DD> {
KNNClassifierParameters {
distance,
algorithm: self.algorithm,
weight: self.weight,
k: self.k,
t: PhantomData,
}
}
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
self.algorithm = algorithm;
self
}
pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self {
self.weight = weight;
self
}
}
impl<T: Number> Default for KNNClassifierParameters<T, Euclidian<T>> {
fn default() -> Self {
KNNClassifierParameters {
distance: Distances::euclidian(),
algorithm: KNNAlgorithmName::default(),
weight: KNNWeightFunction::default(),
k: 3,
t: PhantomData,
}
}
}
impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> PartialEq
for KNNClassifier<TX, TY, X, Y, D>
{
fn eq(&self, other: &Self) -> bool {
if self.classes().len() != other.classes().len()
|| self.k() != other.k()
|| self.y().len() != other.y().len()
{
false
} else {
for i in 0..self.classes().len() {
if self.classes()[i] != other.classes()[i] {
return false;
}
}
for i in 0..self.y().len() {
if self.y().get(i) != other.y().get(i) {
return false;
}
}
true
}
}
}
impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
SupervisedEstimator<X, Y, KNNClassifierParameters<TX, D>> for KNNClassifier<TX, TY, X, Y, D>
{
fn new() -> Self {
Self {
classes: Option::None,
y: Option::None,
knn_algorithm: Option::None,
weight: Option::None,
k: Option::None,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: KNNClassifierParameters<TX, D>) -> Result<Self, Failed> {
KNNClassifier::fit(x, y, parameters)
}
}
impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
Predictor<X, Y> for KNNClassifier<TX, TY, X, Y, D>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
KNNClassifier<TX, TY, X, Y, D>
{
pub fn fit(
x: &X,
y: &Y,
parameters: KNNClassifierParameters<TX, D>,
) -> Result<KNNClassifier<TX, TY, X, Y, D>, Failed> {
let y_n = y.shape();
let (x_n, _) = x.shape();
let data = x
.row_iter()
.map(|row| row.iterator(0).copied().collect())
.collect();
let mut yi: Vec<usize> = vec![0; y_n];
let classes = y.unique();
for (i, yi_i) in yi.iter_mut().enumerate().take(y_n) {
let yc = *y.get(i);
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
}
if x_n != y_n {
return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{x_n}], |y|=[{y_n}]"
)));
}
if parameters.k <= 1 {
return Err(Failed::fit(&format!(
"k should be > 1, k=[{}]",
parameters.k
)));
}
Ok(KNNClassifier {
classes: Some(classes),
y: Some(yi),
k: Some(parameters.k),
knn_algorithm: Some(parameters.algorithm.fit(data, parameters.distance)?),
weight: Some(parameters.weight),
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let mut row_vec = vec![TX::zero(); x.shape().1];
for (i, row) in x.row_iter().enumerate() {
row.iterator(0)
.zip(row_vec.iter_mut())
.for_each(|(&s, v)| *v = s);
result.set(i, self.classes()[self.predict_for_row(&row_vec)?]);
}
Ok(result)
}
fn predict_proba_for_row(&self, row: &Vec<TX>) -> Result<Vec<f64>, Failed> {
let search_result = self.knn_algorithm().find(row, self.k())?;
let weights = self
.weight()
.calc_weights(search_result.iter().map(|v| v.1).collect());
let w_sum: f64 = weights.iter().copied().sum();
if w_sum == 0.0 {
return Err(Failed::because(
FailedError::PredictFailed,
"Sum of weights is zero; cannot compute probabilities",
));
}
let mut class_votes = vec![0.0; self.classes().len()];
for (r, w) in search_result.iter().zip(weights.iter()) {
class_votes[self.y()[r.0]] += *w;
}
let inv_sum = 1.0 / w_sum;
for v in &mut class_votes {
*v *= inv_sum;
}
Ok(class_votes)
}
fn predict_for_row(&self, row: &Vec<TX>) -> Result<usize, Failed> {
let proba = self.predict_proba_for_row(row)?;
let mut max_idx = 0;
let mut max_val = proba[0];
for (i, &val) in proba.iter().enumerate().skip(1) {
if val > max_val {
max_val = val;
max_idx = i;
}
}
Ok(max_idx) }
pub fn predict_proba(&self, x: &X) -> Result<Vec<Vec<f64>>, Failed> {
let mut result = Vec::with_capacity(x.shape().0);
let mut row_vec = vec![TX::zero(); x.shape().1];
for row in x.row_iter() {
row.iterator(0)
.zip(row_vec.iter_mut())
.for_each(|(&s, v)| *v = s);
result.push(self.predict_proba_for_row(&row_vec)?);
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
fn assert_vec_f64_eq(a: &[f64], b: &[f64], tol: f64, msg: &str) {
assert_eq!(a.len(), b.len(), "{}: length mismatch", msg);
for (i, (va, vb)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(va - vb).abs() < tol,
"{}: index {} differs: {} vs {}",
msg,
i,
va,
vb
);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn knn_fit_predict() {
let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
.unwrap();
let y = vec![2, 2, 2, 3, 3];
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, y_hat.len());
assert_eq!(y, y_hat);
}
#[test]
fn knn_fit_predict_weighted() {
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]).unwrap();
let y = vec![2, 2, 2, 3, 3];
let knn = KNNClassifier::fit(
&x,
&y,
KNNClassifierParameters::default()
.with_k(5)
.with_algorithm(KNNAlgorithmName::LinearSearch)
.with_weight(KNNWeightFunction::Distance),
)
.unwrap();
let y_hat = knn
.predict(&DenseMatrix::from_2d_array(&[&[4.1]]).unwrap())
.unwrap();
assert_eq!(vec![3], y_hat);
}
#[test]
fn knn_predict_proba_valid() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[2., 3.],
&[3., 4.], &[8., 9.],
&[9., 10.],
&[10., 11.], ])
.unwrap();
let y = vec![0, 0, 0, 1, 1, 1];
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
let proba = knn.predict_proba(&x).unwrap();
for (i, p) in proba.iter().enumerate() {
assert!(
(p.iter().sum::<f64>() - 1.0).abs() < 1e-10,
"Sample {}: probabilities don't sum to 1",
i
);
for &prob in p {
assert!(
prob >= 0.0 && prob <= 1.0,
"Sample {}: probability {} out of range",
i,
prob
);
}
}
}
#[test]
fn knn_predict_consistent_with_proba() {
let x = DenseMatrix::from_2d_array(&[
&[1., 1.],
&[2., 2.],
&[3., 3.],
&[8., 8.],
&[9., 9.],
&[10., 10.],
])
.unwrap();
let y = vec![10, 10, 10, 20, 20, 20];
let knn = KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(3)).unwrap();
let test = DenseMatrix::from_2d_array(&[&[2.5, 2.5]]).unwrap();
let pred_class = knn.predict(&test).unwrap();
let pred_proba = knn.predict_proba(&test).unwrap();
let max_proba_idx = pred_proba[0]
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
assert_eq!(
knn.classes()[max_proba_idx],
pred_class[0],
"predict() and predict_proba() disagree on class"
);
}
#[test]
fn knn_predict_proba_linear_vs_cover_tree() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[2., 2.],
&[3., 3.],
&[8., 8.],
&[9., 9.],
&[10., 10.],
])
.unwrap();
let y = vec![0, 0, 0, 1, 1, 1];
let test = DenseMatrix::from_2d_array(&[&[2.5, 2.5], &[9.5, 9.5]]).unwrap();
let knn_linear = KNNClassifier::fit(
&x,
&y,
KNNClassifierParameters::default()
.with_algorithm(KNNAlgorithmName::LinearSearch)
.with_k(3),
)
.unwrap();
let knn_cover = KNNClassifier::fit(
&x,
&y,
KNNClassifierParameters::default()
.with_algorithm(KNNAlgorithmName::CoverTree)
.with_k(3),
)
.unwrap();
let proba_linear = knn_linear.predict_proba(&test).unwrap();
let proba_cover = knn_cover.predict_proba(&test).unwrap();
for (i, (pl, pc)) in proba_linear.iter().zip(proba_cover.iter()).enumerate() {
assert_vec_f64_eq(
pl,
pc,
1e-10,
&format!("Sample {} probability vectors differ", i),
);
}
}
#[test]
fn knn_predict_proba_zero_weights_error() {
let x = DenseMatrix::from_2d_array(&[&[1., 1.], &[1., 1.], &[1., 1.]]).unwrap();
let y = vec![0, 1, 2];
let knn = KNNClassifier::fit(
&x,
&y,
KNNClassifierParameters::default()
.with_k(3)
.with_weight(KNNWeightFunction::Distance),
)
.unwrap();
let test = DenseMatrix::from_2d_array(&[&[1., 1.]]).unwrap();
let result = knn.predict_proba(&test);
match result {
Ok(proba) => {
assert_eq!(proba.len(), 1);
assert!((proba[0].iter().sum::<f64>() - 1.0).abs() < 1e-10);
}
Err(e) => {
let err_msg = format!("{:?}", e);
assert!(
err_msg.contains("weight") || err_msg.contains("zero"),
"Error message should mention weights or zero sum: {}",
err_msg
);
}
}
}
#[test]
fn knn_predict_proba_weight_functions_differ() {
let x = DenseMatrix::from_2d_array(&[
&[1., 1.], &[2., 2.], &[10., 10.], ])
.unwrap();
let y = vec![0, 0, 1];
let test = DenseMatrix::from_2d_array(&[&[1.5, 1.5]]).unwrap();
let knn_uniform = KNNClassifier::fit(
&x,
&y,
KNNClassifierParameters::default()
.with_k(3)
.with_weight(KNNWeightFunction::Uniform),
)
.unwrap();
let knn_distance = KNNClassifier::fit(
&x,
&y,
KNNClassifierParameters::default()
.with_k(3)
.with_weight(KNNWeightFunction::Distance),
)
.unwrap();
let proba_uniform = knn_uniform.predict_proba(&test).unwrap();
let proba_distance = knn_distance.predict_proba(&test).unwrap();
let mut differs = false;
for (vu, vd) in proba_uniform[0].iter().zip(proba_distance[0].iter()) {
if (vu - vd).abs() > 1e-10 {
differs = true;
break;
}
}
assert!(
differs,
"Uniform and Distance weights should produce different probabilities"
);
}
#[test]
fn knn_predict_proba_extreme_k_values() {
let x =
DenseMatrix::from_2d_array(&[&[1., 1.], &[2., 2.], &[3., 3.], &[8., 8.], &[9., 9.]])
.unwrap();
let y = vec![0, 0, 1, 1, 1];
let test = DenseMatrix::from_2d_array(&[&[2.5, 2.5]]).unwrap();
let knn_kn =
KNNClassifier::fit(&x, &y, KNNClassifierParameters::default().with_k(5)).unwrap();
let proba_kn = knn_kn.predict_proba(&test).unwrap();
let max_prob = proba_kn[0].iter().copied().fold(0.0, f64::max);
assert!(
max_prob < 1.0 - 1e-10,
"k=n with mixed classes should not give probability 1.0"
);
}
#[test]
fn knn_predict_proba_multiclass() {
let x = DenseMatrix::from_2d_array(&[
&[1., 1.],
&[1.5, 1.5], &[4., 4.],
&[4.5, 4.5], &[8., 8.],
&[8.5, 8.5], ])
.unwrap();
let y = vec![10, 10, 20, 20, 30, 30];
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
let test = DenseMatrix::from_2d_array(&[&[4.2, 4.2]]).unwrap();
let proba = knn.predict_proba(&test).unwrap();
assert_eq!(proba[0].len(), 3, "Should have 3 class probabilities");
assert!((proba[0].iter().sum::<f64>() - 1.0).abs() < 1e-10);
let max_idx = proba[0]
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
assert_eq!(knn.classes()[max_idx], 20);
}
#[test]
fn knn_predict_proba_batch() {
let x = DenseMatrix::from_2d_array(&[
&[1., 1.],
&[2., 2.],
&[3., 3.],
&[8., 8.],
&[9., 9.],
&[10., 10.],
])
.unwrap();
let y = vec![0, 0, 0, 1, 1, 1];
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
let test = DenseMatrix::from_2d_array(&[
&[1.5, 1.5], &[9.5, 9.5], &[5., 5.], ])
.unwrap();
let proba = knn.predict_proba(&test).unwrap();
assert_eq!(proba.len(), 3, "Should return probabilities for 3 samples");
for p in &proba {
assert_eq!(p.len(), 2); assert!((p.iter().sum::<f64>() - 1.0).abs() < 1e-10);
}
assert!(
proba[0][0] > proba[0][1],
"First sample should favor class 0"
);
assert!(
proba[1][1] > proba[1][0],
"Second sample should favor class 1"
);
}
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
.unwrap();
let y = vec![2, 2, 2, 3, 3];
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
assert_eq!(knn, deserialized_knn);
}
}