use std::collections::HashMap;
use crate::core::traits::DistanceMetric;
#[derive(Debug, Clone)]
pub struct KnnConfig {
pub n_neighbors: usize,
}
impl KnnConfig {
pub fn new(n_neighbors: usize) -> Self {
Self { n_neighbors }
}
}
#[derive(Debug, Clone)]
pub struct KnnFitted<D: DistanceMetric> {
pub x_train: Vec<Vec<f64>>,
pub y_train: Vec<String>,
pub metric: D,
pub n_neighbors: usize,
}
pub struct Knn;
impl Knn {
pub fn fit<D: DistanceMetric>(
config: &KnnConfig,
x: &[Vec<f64>],
y: &[String],
metric: D,
) -> KnnFitted<D> {
assert!(!x.is_empty(), "Input must have at least one sample");
assert_eq!(x.len(), y.len(), "X and y must have same length");
assert!(config.n_neighbors >= 1, "n_neighbors must be >= 1");
assert!(
config.n_neighbors <= x.len(),
"n_neighbors must not exceed training set size"
);
KnnFitted {
x_train: x.to_vec(),
y_train: y.to_vec(),
metric,
n_neighbors: config.n_neighbors,
}
}
pub fn predict<D: DistanceMetric>(fitted: &KnnFitted<D>, x: &[Vec<f64>]) -> Vec<String> {
x.iter()
.map(|sample| predict_single(sample, fitted))
.collect()
}
pub fn score<D: DistanceMetric>(fitted: &KnnFitted<D>, x: &[Vec<f64>], y: &[String]) -> f64 {
let predictions = Self::predict(fitted, x);
let correct = predictions
.iter()
.zip(y.iter())
.filter(|(p, t)| p == t)
.count();
correct as f64 / y.len() as f64
}
}
fn predict_single<D: DistanceMetric>(sample: &[f64], fitted: &KnnFitted<D>) -> String {
let mut distances: Vec<(f64, &str)> = fitted
.x_train
.iter()
.zip(fitted.y_train.iter())
.map(|(train_sample, label)| (fitted.metric.distance(sample, train_sample), label.as_str()))
.collect();
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let mut votes: HashMap<&str, usize> = HashMap::new();
for &(_, label) in distances.iter().take(fitted.n_neighbors) {
*votes.entry(label).or_insert(0) += 1;
}
votes
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(label, _)| label.to_string())
.unwrap()
}
#[derive(Debug, Clone)]
pub struct EuclideanMetric;
impl DistanceMetric for EuclideanMetric {
fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_knn_basic() {
let config = KnnConfig::new(1);
let x_train = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
let y_train = vec!["A".to_string(), "A".to_string(), "B".to_string()];
let metric = EuclideanMetric;
let fitted = Knn::fit(&config, &x_train, &y_train, metric);
let x_test = vec![vec![0.1, 0.1]];
let pred = Knn::predict(&fitted, &x_test);
assert_eq!(pred[0], "A");
let x_test = vec![vec![1.9, 1.9]];
let pred = Knn::predict(&fitted, &x_test);
assert_eq!(pred[0], "B");
}
#[test]
fn test_knn_k3_voting() {
let config = KnnConfig::new(3);
let x_train = vec![vec![0.0], vec![0.1], vec![0.2], vec![10.0]];
let y_train = vec![
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
];
let metric = EuclideanMetric;
let fitted = Knn::fit(&config, &x_train, &y_train, metric);
let pred = Knn::predict(&fitted, &[vec![0.05]]);
assert_eq!(pred[0], "A");
}
#[test]
fn test_knn_score() {
let config = KnnConfig::new(1);
let x_train = vec![vec![0.0], vec![10.0]];
let y_train = vec!["A".to_string(), "B".to_string()];
let metric = EuclideanMetric;
let fitted = Knn::fit(&config, &x_train, &y_train, metric);
let x_test = vec![vec![0.0], vec![10.0]];
let y_test = vec!["A".to_string(), "B".to_string()];
let score = Knn::score(&fitted, &x_test, &y_test);
assert!((score - 1.0).abs() < 1e-10);
}
}