use std::collections::BinaryHeap;
use crate::math::{
metrics::DistanceMetric,
neighbors::{neighbor::Neighbor, search::NeighborSearch},
point::Point,
FloatNumber,
};
#[derive(Debug)]
pub struct LinearSearch<'a, T, const N: usize>
where
T: FloatNumber,
{
points: &'a [Point<T, N>],
metric: DistanceMetric,
}
impl<'a, T, const N: usize> LinearSearch<'a, T, N>
where
T: FloatNumber,
{
#[allow(dead_code)]
pub fn build(points: &'a [Point<T, N>], metric: DistanceMetric) -> Self {
Self { points, metric }
}
}
impl<T, const N: usize> NeighborSearch<T, N> for LinearSearch<'_, T, N>
where
T: FloatNumber,
{
fn search(&self, query: &Point<T, N>, k: usize) -> Vec<Neighbor<T>> {
if k == 0 || self.points.is_empty() {
return Vec::new();
}
self.points
.iter()
.enumerate()
.fold(
BinaryHeap::with_capacity(k),
|mut neighbors, (index, point)| {
let distance = self.metric.measure(query, point);
let neighbor = Neighbor::new(index, distance);
neighbors.push(neighbor);
if neighbors.len() > k {
neighbors.pop();
}
neighbors
},
)
.into_sorted_vec()
}
fn search_nearest(&self, query: &Point<T, N>) -> Option<Neighbor<T>> {
if self.points.is_empty() {
return None;
}
self.points
.iter()
.enumerate()
.fold(None, |nearest, (index, point)| {
let distance = self.metric.measure(query, point);
if let Some(best) = nearest {
if distance < best.distance {
Some(Neighbor::new(index, distance))
} else {
Some(best)
}
} else {
Some(Neighbor::new(index, distance))
}
})
}
fn search_radius(&self, query: &Point<T, N>, radius: T) -> Vec<Neighbor<T>> {
if radius < T::zero() || self.points.is_empty() {
return Vec::new();
}
self.points.iter().enumerate().fold(
Vec::with_capacity(self.points.len()),
|mut neighbors, (index, point)| {
let distance = self.metric.measure(query, point);
if distance <= radius {
let neighbor = Neighbor::new(index, distance);
neighbors.push(neighbor);
}
neighbors
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[must_use]
fn sample_points() -> Vec<[f32; 3]> {
vec![
[1.0, 2.0, 3.0], [5.0, 1.0, 2.0], [9.0, 3.0, 4.0], [3.0, 9.0, 1.0], [4.0, 8.0, 3.0], [9.0, 1.0, 1.0], [5.0, 0.0, 0.0], [1.0, 1.0, 1.0], [7.0, 2.0, 2.0], [5.0, 9.0, 1.0], [1.0, 1.0, 9.0], [9.0, 8.0, 7.0], [2.0, 3.0, 4.0], [4.0, 5.0, 4.0], ]
}
#[must_use]
fn empty_points() -> Vec<Point<f32, 3>> {
Vec::new()
}
#[test]
fn test_build() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
assert_eq!(search.points.len(), 14);
assert_eq!(search.metric, DistanceMetric::Euclidean);
}
#[test]
fn test_search() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search(&query, 3);
assert_eq!(neighbors.len(), 3);
assert_eq!(neighbors[0].index, 12);
assert_eq!(neighbors[0].distance, 8.0_f32.sqrt());
assert_eq!(neighbors[1].index, 13);
assert_eq!(neighbors[1].distance, 8.0_f32.sqrt());
assert_eq!(neighbors[2].index, 0);
assert_eq!(neighbors[2].distance, 19.0_f32.sqrt());
}
#[test]
fn test_search_zero() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search(&query, 0);
assert!(neighbors.is_empty());
}
#[test]
fn test_search_empty() {
let points = empty_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search(&query, 3);
assert!(neighbors.is_empty());
}
#[test]
fn test_search_nearest() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let nearest = search.search_nearest(&query).unwrap();
assert_eq!(nearest.index, 12);
assert_eq!(nearest.distance, 8.0_f32.sqrt());
}
#[test]
fn test_search_nearest_empty() {
let points = empty_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let nearest = search.search_nearest(&query);
assert!(nearest.is_none());
}
#[test]
fn test_search_radius() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search_radius(&query, 4.0);
assert_eq!(neighbors.len(), 2);
assert_eq!(neighbors[0].index, 12);
assert_eq!(neighbors[0].distance, 8.0_f32.sqrt());
assert_eq!(neighbors[1].index, 13);
assert_eq!(neighbors[1].distance, 8.0_f32.sqrt());
}
#[test]
fn test_search_radius_zero() {
let points = sample_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search_radius(&query, 0.0);
assert!(neighbors.is_empty());
}
#[test]
fn test_search_radius_empty() {
let points = empty_points();
let search = LinearSearch::build(&points, DistanceMetric::Euclidean);
let query = [2.0, 5.0, 6.0];
let neighbors = search.search_radius(&query, 4.0);
assert!(neighbors.is_empty());
}
}