use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct KNNClassifier<L: Eq + Hash + Clone> {
pub k: usize,
pub features: Vec<Vec<f64>>,
pub labels: Vec<L>,
_label_marker: PhantomData<L>,
}
impl<L: Eq + Hash + Clone> KNNClassifier<L> {
pub fn new(k: usize, features: Vec<Vec<f64>>, labels: Vec<L>) -> Self {
assert!(k > 0, "k must be > 0");
let n = features.len();
assert!(n > 0, "features cannot be empty");
assert_eq!(n, labels.len(), "features and labels must have same length");
for f in &features {
assert!(
!f.is_empty(),
"All feature vectors must have at least one dimension"
);
}
Self {
k,
features,
labels,
_label_marker: PhantomData,
}
}
pub fn predict(&self, point: &[f64]) -> L {
assert!(
!self.features.is_empty(),
"No training data in the classifier"
);
let neighbors = self.find_k_nearest(point);
self.majority_vote(neighbors)
}
pub fn predict_batch(&self, points: &[Vec<f64>]) -> Vec<L> {
points.iter().map(|p| self.predict(p)).collect()
}
fn find_k_nearest(&self, point: &[f64]) -> Vec<usize> {
let mut dists: Vec<(f64, usize)> = self
.features
.iter()
.enumerate()
.map(|(i, f)| (euclidean_distance_sq(f, point), i))
.collect();
dists.sort_by(|(d1, _), (d2, _)| d1.partial_cmp(d2).unwrap());
dists.iter().take(self.k).map(|&(_, i)| i).collect()
}
fn majority_vote(&self, neighbor_indices: Vec<usize>) -> L {
let mut counts = HashMap::<L, usize>::new();
for idx in neighbor_indices {
let label = &self.labels[idx];
*counts.entry(label.clone()).or_insert(0) += 1;
}
counts
.into_iter()
.max_by_key(|(_label, count)| *count)
.unwrap()
.0
}
}
fn euclidean_distance_sq(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y) * (x - y))
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_knn() {
let features = vec![
vec![1.0, 2.0],
vec![2.0, 3.0],
vec![2.5, 2.7],
vec![10.0, 10.0],
];
let labels = vec!["A", "A", "B", "B"];
let knn = KNNClassifier::new(3, features, labels);
let pred1 = knn.predict(&[2.1, 2.9]);
assert_eq!(pred1, "A");
let pred2 = knn.predict(&[9.5, 9.7]);
assert_eq!(pred2, "B");
}
#[test]
fn test_empty_features_panic() {
let features: Vec<Vec<f64>> = vec![];
let labels: Vec<&str> = vec![];
let result = std::panic::catch_unwind(|| {
KNNClassifier::new(3, features, labels);
});
assert!(result.is_err());
}
#[test]
fn test_predict_no_data_panic() {
let knn = KNNClassifier::<&str> {
k: 3,
features: vec![],
labels: vec![],
_label_marker: PhantomData,
};
let result = std::panic::catch_unwind(|| {
knn.predict(&[1.0, 2.0]);
});
assert!(result.is_err());
}
}