use crate::linalg::Matrix;
use super::weights::{EPSILON, MIN_NEIGHBORS_QUADRATIC};
#[inline]
pub fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(
a.len(),
b.len(),
"Points must have the same dimensionality"
);
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
pub fn find_nearest_neighbors(query: &[f64], data: &Matrix, k: usize) -> Vec<usize> {
let n = data.rows;
let p = data.cols;
assert!(k > 0, "k must be positive");
assert!(k <= n, "k cannot exceed number of data points");
assert_eq!(
query.len(),
p,
"Query dimension must match number of predictors"
);
let mut distances: Vec<(usize, f64)> = Vec::with_capacity(n);
for i in 0..n {
let mut point = Vec::with_capacity(p);
for j in 0..p {
point.push(data.get(i, j));
}
let dist = euclidean_distance(query, &point);
distances.push((i, dist));
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.iter().take(k).map(|(idx, _)| *idx).collect()
}
#[inline]
pub fn compute_neighborhood_size(n: usize, span: f64, degree: usize) -> usize {
let k = (span * n as f64 + EPSILON).floor() as usize;
let min_k = if degree == 2 {
MIN_NEIGHBORS_QUADRATIC
} else {
2
};
k.max(min_k).min(n)
}
pub fn compute_bandwidth(query: &[f64], data: &Matrix, span: f64, degree: usize) -> (f64, Vec<usize>) {
let n = data.rows;
assert!(n > 0, "Data matrix must have at least one row");
assert!(
span > 0.0 && span <= 1.0,
"Span must be in (0, 1], got {}",
span
);
let k = compute_neighborhood_size(n, span, degree);
let neighbors = find_nearest_neighbors(query, data, k);
let kth_idx = neighbors[k - 1];
let p = data.cols;
let mut kth_point = Vec::with_capacity(p);
for j in 0..p {
kth_point.push(data.get(kth_idx, j));
}
let bandwidth = euclidean_distance(query, &kth_point);
(bandwidth, neighbors)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean_distance_1d() {
let a = vec![0.0];
let b = vec![3.0];
assert_eq!(euclidean_distance(&a, &b), 3.0);
}
#[test]
fn test_euclidean_distance_2d() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 1e-10);
}
#[test]
fn test_euclidean_distance_same_point() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
assert_eq!(euclidean_distance(&a, &b), 0.0);
}
#[test]
fn test_find_nearest_neighbors_simple() {
let data = vec![0.0, 2.0, 4.0, 6.0, 8.0];
let x = Matrix::new(5, 1, data);
let query = vec![3.0];
let neighbors = find_nearest_neighbors(&query, &x, 2);
assert_eq!(neighbors.len(), 2);
assert!(neighbors.contains(&1)); assert!(neighbors.contains(&2)); }
#[test]
fn test_find_nearest_neighbors_2d() {
let data = vec![
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0, ];
let x = Matrix::new(4, 2, data);
let query = vec![0.1, 0.1];
let neighbors = find_nearest_neighbors(&query, &x, 2);
assert_eq!(neighbors.len(), 2);
assert!(neighbors.contains(&0)); }
#[test]
fn test_compute_neighborhood_size() {
assert_eq!(compute_neighborhood_size(10, 0.5, 1), 5);
assert_eq!(compute_neighborhood_size(10, 0.1, 1), 2);
assert_eq!(compute_neighborhood_size(10, 0.1, 2), 3);
assert_eq!(compute_neighborhood_size(10, 1.0, 1), 10);
}
#[test]
fn test_compute_bandwidth_span() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let x = Matrix::new(10, 1, data);
let query = vec![5.0];
let (bw, neighbors) = compute_bandwidth(&query, &x, 0.5, 1);
assert_eq!(neighbors.len(), 5);
assert!(bw > 0.0);
}
#[test]
fn test_compute_bandwidth_small_span() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let x = Matrix::new(10, 1, data);
let query = vec![5.0];
let (bw, neighbors) = compute_bandwidth(&query, &x, 0.1, 1);
assert_eq!(neighbors.len(), 2);
assert!(bw > 0.0);
}
#[test]
fn test_compute_bandwidth_full_span() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let x = Matrix::new(5, 1, data);
let query = vec![2.0];
let (_bw, neighbors) = compute_bandwidth(&query, &x, 1.0, 1);
assert_eq!(neighbors.len(), 5);
assert!(neighbors.contains(&0));
assert!(neighbors.contains(&1));
assert!(neighbors.contains(&2));
assert!(neighbors.contains(&3));
assert!(neighbors.contains(&4));
}
#[test]
fn test_neighbors_sorted_by_distance() {
let data = vec![0.0, 10.0, 20.0, 30.0, 40.0];
let x = Matrix::new(5, 1, data);
let query = vec![25.0];
let neighbors = find_nearest_neighbors(&query, &x, 3);
let p0 = vec![x.get(neighbors[0], 0)];
let d0 = euclidean_distance(&query, &p0);
let p2 = vec![x.get(neighbors[2], 0)];
let d2 = euclidean_distance(&query, &p2);
assert!(d0 <= d2);
}
}