use std::collections::BinaryHeap;
use crate::math::{
metrics::DistanceMetric,
neighbors::{neighbor::Neighbor, search::NeighborSearch},
point::Point,
FloatNumber,
};
#[derive(Debug)]
pub struct LinearSearch<'a, T, const N: usize>
where
T: FloatNumber,
{
points: &'a [Point<T, N>],
metric: DistanceMetric,
}
impl<'a, T, const N: usize> LinearSearch<'a, T, N>
where
T: FloatNumber,
{
pub fn build(points: &'a [Point<T, N>], metric: DistanceMetric) -> Self {
Self { points, metric }
}
}
impl<T, const N: usize> NeighborSearch<T, N> for LinearSearch<'_, T, N>
where
T: FloatNumber,
{
#[must_use]
fn search(&self, query: &Point<T, N>, k: usize) -> Vec<Neighbor<T>> {
let mut neighbors = BinaryHeap::with_capacity(k);
for (index, point) in self.points.iter().enumerate() {
let distance = self.metric.measure(query, point);
let neighbor = Neighbor::new(index, distance);
neighbors.push(neighbor);
if neighbors.len() > k {
neighbors.pop();
}
}
neighbors.into_sorted_vec()
}
#[must_use]
fn search_nearest(&self, query: &Point<T, N>) -> Option<Neighbor<T>> {
let mut nearest = Neighbor::new(0, T::infinity());
for (index, other) in self.points.iter().enumerate() {
let distance = self.metric.measure(query, other);
if distance < nearest.distance {
nearest.index = index;
nearest.distance = distance;
}
}
Some(nearest)
}
#[must_use]
fn search_radius(&self, query: &Point<T, N>, radius: T) -> Vec<Neighbor<T>> {
let mut neighbors = Vec::new();
for (index, point) in self.points.iter().enumerate() {
let distance = self.metric.measure(query, point);
if distance <= radius {
let neighbor = Neighbor::new(index, distance);
neighbors.push(neighbor);
}
}
neighbors
}
}
#[cfg(test)]
mod tests {
use super::*;
#[must_use]
fn sample_points() -> Vec<[f32; 3]> {
vec![
[1.0, 2.0, 3.0], [5.0, 1.0, 2.0], [9.0, 3.0, 4.0], [3.0, 9.0, 1.0], [4.0, 8.0, 3.0], [9.0, 1.0, 1.0], [5.0, 0.0, 0.0], [1.0, 1.0, 1.0], [7.0, 2.0, 2.0], [5.0, 9.0, 1.0], [1.0, 1.0, 9.0], [9.0, 8.0, 7.0], [2.0, 3.0, 4.0], [4.0, 5.0, 4.0], ]
}
#[test]
fn test_build() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
assert_eq!(search.points.len(), 14);
assert_eq!(search.metric, DistanceMetric::Euclidean);
}
#[test]
fn test_search() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search(&query, 3);
assert_eq!(neighbors.len(), 3);
assert_eq!(neighbors[0].index, 12);
assert_eq!(neighbors[0].distance, 8.0_f32.sqrt());
assert_eq!(neighbors[1].index, 13);
assert_eq!(neighbors[1].distance, 8.0_f32.sqrt());
assert_eq!(neighbors[2].index, 0);
assert_eq!(neighbors[2].distance, 19.0_f32.sqrt());
}
#[test]
fn test_search_nearest() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let nearest = search.search_nearest(&query).unwrap();
assert_eq!(nearest.index, 12);
assert_eq!(nearest.distance, 8.0_f32.sqrt());
}
#[test]
fn test_search_radius() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search_radius(&query, 4.0);
assert_eq!(neighbors.len(), 2);
assert_eq!(neighbors[0].index, 12);
assert_eq!(neighbors[0].distance, 8.0_f32.sqrt());
assert_eq!(neighbors[1].index, 13);
assert_eq!(neighbors[1].distance, 8.0_f32.sqrt());
}
}