Skip to main content

radiate_core/domain/math/
knn.rs

1use crate::diversity::Distance;
2use std::{cmp::Ordering, sync::Arc};
3
4pub struct KnnQueryResult<'a> {
5    pub cluster: &'a [(usize, f32)],
6    pub max_distance: f32,
7    pub min_distance: f32,
8}
9
10impl<'a> KnnQueryResult<'a> {
11    pub fn new(cluster: &'a [(usize, f32)], max_distance: f32, min_distance: f32) -> Self {
12        KnnQueryResult {
13            cluster,
14            max_distance,
15            min_distance,
16        }
17    }
18
19    pub fn average_distance(&self) -> f32 {
20        if self.cluster.is_empty() {
21            0.0
22        } else {
23            let total = self.cluster.iter().map(|(_, dist)| dist).sum::<f32>();
24            total / (self.cluster.len() as f32)
25        }
26    }
27}
28
29/// Brute-force KNN index over a slice of points.
30///
31/// - `P`: point type, must implement `KnnPoint`
32/// - `M`: distance metric
33pub struct KNN<'a, P> {
34    points: &'a [P],
35    metric: Arc<dyn Distance<P>>,
36    scratch: Vec<(usize, f32)>,
37}
38
39impl<'a, P> KNN<'a, P> {
40    #[inline]
41    pub fn new(points: &'a [P], metric: impl Into<Arc<dyn Distance<P>>>) -> Self {
42        let len = points.len();
43        KNN {
44            points,
45            metric: metric.into(),
46            scratch: Vec::with_capacity(len.saturating_sub(1)),
47        }
48    }
49
50    /// Returns a reference to the underlying points.
51    #[inline]
52    pub fn points(&self) -> &'a [P] {
53        self.points
54    }
55
56    /// Query the k nearest neighbors of the point at `query_index`.
57    ///
58    /// - `k`: number of neighbors to return
59    /// - `exclude_self`: if true, will skip the point at `query_index`
60    ///
61    /// Returns a slice of `(index, distance)` sorted by increasing distance.
62    /// The slice is backed by an internal scratch buffer and is invalidated
63    /// by the next query.
64    pub fn query_index(&mut self, query_index: usize, k: usize) -> KnnQueryResult<'_> {
65        let len = self.points.len();
66        if len == 0 || k == 0 {
67            self.scratch.clear();
68            return KnnQueryResult::new(&self.scratch, f32::NEG_INFINITY, f32::INFINITY);
69        }
70
71        let points = &self.points[query_index];
72        self.query_point_internal(points, Some(query_index), k)
73    }
74
75    /// Query the k nearest neighbors of an arbitrary query point (not
76    /// necessarily in the index).
77    pub fn query_point(&mut self, query: &P, k: usize) -> KnnQueryResult<'_> {
78        if self.points.is_empty() || k == 0 {
79            self.scratch.clear();
80            return KnnQueryResult::new(&self.scratch, f32::NEG_INFINITY, f32::INFINITY);
81        }
82
83        self.query_point_internal(query, None, k)
84    }
85
86    #[inline]
87    fn query_point_internal(
88        &mut self,
89        query: &P,
90        query_index: Option<usize>,
91        k: usize,
92    ) -> KnnQueryResult<'_> {
93        self.scratch.clear();
94
95        let mut min_distance = f32::INFINITY;
96        let mut max_distance = f32::NEG_INFINITY;
97        for (idx, p) in self.points.iter().enumerate() {
98            if let Some(qi) = query_index {
99                if qi == idx {
100                    continue;
101                }
102            }
103
104            let dist = self.metric.distance(query, p).max(1e-12);
105            min_distance = min_distance.min(dist);
106            max_distance = max_distance.max(dist);
107            self.scratch.push((idx, dist));
108        }
109
110        let n = self.scratch.len();
111        if n == 0 || k == 0 {
112            self.scratch.clear();
113            return KnnQueryResult::new(&self.scratch, max_distance, min_distance);
114        }
115
116        if k >= n {
117            self.scratch
118                .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
119            return KnnQueryResult::new(&self.scratch, max_distance, min_distance);
120        }
121
122        let (left, _, _) = self.scratch.select_nth_unstable_by(k - 1, |a, b| {
123            a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
124        });
125
126        left.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
127
128        self.scratch.truncate(k);
129
130        KnnQueryResult::new(&self.scratch, max_distance, min_distance)
131    }
132}