use crate::math::distance;
use crate::math::norm;
use rayon::prelude::*;
use std::cmp::Ordering;
use std::collections::HashSet;
#[derive(Debug)]
pub struct KNN {
pub k: i32,
pub x: Vec<Vec<f64>>,
pub y: Vec<i32>,
pub num_labels: usize,
pub distance: Option<distance::Distance>,
pub normalize: Option<norm::Norm>,
}
#[derive(PartialEq, Debug)]
pub struct Point {
pub class: i32,
pub distance: f64,
}
impl PartialOrd for Point {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
impl Eq for Point {}
impl KNN {
pub fn new(
k: i32,
x: Vec<Vec<f64>>,
y: Vec<i32>,
distance: Option<distance::Distance>,
normalize: Option<norm::Norm>,
) -> KNN {
let num_labels = KNN::get_num_labels(&y);
let mut knn = KNN {
k,
x,
y,
num_labels,
distance,
normalize,
};
knn.normalize_data();
knn
}
pub fn get_num_labels(y: &[i32]) -> usize {
let set: HashSet<i32> = y.iter().cloned().collect::<HashSet<_>>();
set.len()
}
pub fn normalize_data(&mut self) {
if let Some(n) = &self.normalize {
self.x
.iter_mut()
.for_each(|xi| norm::normalize_vector(xi, n));
}
}
pub fn data(&self) -> (&Vec<Vec<f64>>, &Vec<i32>) {
(&self.x, &self.y)
}
pub fn calculate_distances(&self, new_point: &[f64]) -> Vec<Point> {
let distance_fn = match self.distance {
Some(distance::Distance::Manhattan) => distance::manhattan_distance,
_ => distance::euclidean_distance,
};
self.x
.par_iter()
.zip(self.y.par_iter())
.map(|(x, y)| Point {
class: *y,
distance: distance_fn(new_point, x),
})
.collect()
}
pub fn predict(&self, x: &[f64]) -> i32 {
let mut norm_x: Vec<f64> = x.to_owned();
if let Some(n) = &self.normalize {
norm::normalize_vector(&mut norm_x, n);
}
let mut points = self.calculate_distances(x);
points.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
let mut predictions = vec![0; self.num_labels];
for i in &points[0..self.k as usize] {
predictions[i.class as usize] += 1;
}
KNN::get_max_value(&predictions)
}
fn get_max_value(predictions: &[i32]) -> i32 {
predictions
.iter()
.enumerate() .max_by_key(|(_, pred)| **pred) .map(|(idx, _)| idx) .unwrap() as i32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_data_test() {
let p: Vec<Vec<f64>> = vec![vec![2.0, 2.0, 2.0]];
let mut knn = KNN::new(5, p, vec![1], None, Some(norm::Norm::L2));
knn.normalize_data();
assert_eq!(
knn.data().0.clone(),
vec![vec![2.0 / f64::from(12).sqrt(); 3]]
);
}
#[test]
fn calculate_distances_test() {
let p: Vec<Vec<f64>> = vec![vec![2.0, 2.0]];
let knn = KNN::new(5, p, vec![1], None, None);
let q = knn.calculate_distances(&(vec![0.0, 0.0] as Vec<f64>));
assert_eq!(q[0].distance, f64::from(8).sqrt());
}
}