radiate_core/domain/math/
knn.rs1use crate::diversity::Distance;
2use std::{cmp::Ordering, sync::Arc};
3
4pub struct KnnQueryResult<'a> {
5 pub cluster: &'a [(usize, f32)],
6 pub max_distance: f32,
7 pub min_distance: f32,
8}
9
10impl<'a> KnnQueryResult<'a> {
11 pub fn new(cluster: &'a [(usize, f32)], max_distance: f32, min_distance: f32) -> Self {
12 KnnQueryResult {
13 cluster,
14 max_distance,
15 min_distance,
16 }
17 }
18
19 pub fn average_distance(&self) -> f32 {
20 if self.cluster.is_empty() {
21 0.0
22 } else {
23 let total = self.cluster.iter().map(|(_, dist)| dist).sum::<f32>();
24 total / (self.cluster.len() as f32)
25 }
26 }
27}
28
29pub struct KNN<'a, P> {
34 points: &'a [P],
35 metric: Arc<dyn Distance<P>>,
36 scratch: Vec<(usize, f32)>,
37}
38
39impl<'a, P> KNN<'a, P> {
40 #[inline]
41 pub fn new(points: &'a [P], metric: impl Into<Arc<dyn Distance<P>>>) -> Self {
42 let len = points.len();
43 KNN {
44 points,
45 metric: metric.into(),
46 scratch: Vec::with_capacity(len.saturating_sub(1)),
47 }
48 }
49
50 #[inline]
52 pub fn points(&self) -> &'a [P] {
53 self.points
54 }
55
56 pub fn query_index(&mut self, query_index: usize, k: usize) -> KnnQueryResult<'_> {
65 let len = self.points.len();
66 if len == 0 || k == 0 {
67 self.scratch.clear();
68 return KnnQueryResult::new(&self.scratch, f32::NEG_INFINITY, f32::INFINITY);
69 }
70
71 let points = &self.points[query_index];
72 self.query_point_internal(points, Some(query_index), k)
73 }
74
75 pub fn query_point(&mut self, query: &P, k: usize) -> KnnQueryResult<'_> {
78 if self.points.is_empty() || k == 0 {
79 self.scratch.clear();
80 return KnnQueryResult::new(&self.scratch, f32::NEG_INFINITY, f32::INFINITY);
81 }
82
83 self.query_point_internal(query, None, k)
84 }
85
86 #[inline]
87 fn query_point_internal(
88 &mut self,
89 query: &P,
90 query_index: Option<usize>,
91 k: usize,
92 ) -> KnnQueryResult<'_> {
93 self.scratch.clear();
94
95 let mut min_distance = f32::INFINITY;
96 let mut max_distance = f32::NEG_INFINITY;
97 for (idx, p) in self.points.iter().enumerate() {
98 if let Some(qi) = query_index {
99 if qi == idx {
100 continue;
101 }
102 }
103
104 let dist = self.metric.distance(query, p).max(1e-12);
105 min_distance = min_distance.min(dist);
106 max_distance = max_distance.max(dist);
107 self.scratch.push((idx, dist));
108 }
109
110 let n = self.scratch.len();
111 if n == 0 || k == 0 {
112 self.scratch.clear();
113 return KnnQueryResult::new(&self.scratch, max_distance, min_distance);
114 }
115
116 if k >= n {
117 self.scratch
118 .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
119 return KnnQueryResult::new(&self.scratch, max_distance, min_distance);
120 }
121
122 let (left, _, _) = self.scratch.select_nth_unstable_by(k - 1, |a, b| {
123 a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
124 });
125
126 left.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
127
128 self.scratch.truncate(k);
129
130 KnnQueryResult::new(&self.scratch, max_distance, min_distance)
131 }
132}