1use std::cmp::Ordering;
6use std::collections::BinaryHeap;
7
8use manifoldb_core::EntityId;
9
10use super::{SearchConfig, VectorMatch, VectorOperator};
11use crate::distance::{cosine_distance, dot_product, euclidean_distance, DistanceMetric};
12use crate::error::VectorError;
13use crate::types::Embedding;
14
15pub struct ExactKnn {
47 results: Vec<VectorMatch>,
49 position: usize,
51 dim: usize,
53}
54
55#[derive(Debug)]
57struct MaxHeapEntry {
58 entity_id: EntityId,
59 distance: f32,
60}
61
62impl PartialEq for MaxHeapEntry {
63 fn eq(&self, other: &Self) -> bool {
64 self.distance == other.distance
65 }
66}
67
68impl Eq for MaxHeapEntry {}
69
70impl PartialOrd for MaxHeapEntry {
71 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
72 Some(self.cmp(other))
73 }
74}
75
76impl Ord for MaxHeapEntry {
77 fn cmp(&self, other: &Self) -> Ordering {
78 self.distance.partial_cmp(&other.distance).unwrap_or(Ordering::Equal)
82 }
83}
84
85impl ExactKnn {
86 pub fn new<I>(
99 vectors: I,
100 query: &Embedding,
101 metric: DistanceMetric,
102 config: SearchConfig,
103 ) -> Result<Self, VectorError>
104 where
105 I: IntoIterator<Item = (EntityId, Embedding)>,
106 {
107 let dim = query.dimension();
108 let query_slice = query.as_slice();
109 let k = config.k;
110 let max_distance = config.max_distance;
111
112 let mut heap: BinaryHeap<MaxHeapEntry> =
115 BinaryHeap::with_capacity(k.saturating_add(1).min(1024));
116
117 for (entity_id, embedding) in vectors {
118 if embedding.dimension() != dim {
120 return Err(VectorError::DimensionMismatch {
121 expected: dim,
122 actual: embedding.dimension(),
123 });
124 }
125
126 let distance = compute_distance(query_slice, embedding.as_slice(), metric);
127
128 if let Some(max_dist) = max_distance {
130 if distance > max_dist {
131 continue;
132 }
133 }
134
135 if heap.len() < k {
137 heap.push(MaxHeapEntry { entity_id, distance });
138 } else if let Some(worst) = heap.peek() {
139 if distance < worst.distance {
140 heap.pop();
141 heap.push(MaxHeapEntry { entity_id, distance });
142 }
143 }
144 }
145
146 let mut results: Vec<VectorMatch> =
148 heap.into_iter().map(|e| VectorMatch::new(e.entity_id, e.distance)).collect();
149
150 results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
153
154 Ok(Self { results, position: 0, dim })
155 }
156
157 pub fn k_nearest<I>(
161 vectors: I,
162 query: &Embedding,
163 metric: DistanceMetric,
164 k: usize,
165 ) -> Result<Self, VectorError>
166 where
167 I: IntoIterator<Item = (EntityId, Embedding)>,
168 {
169 Self::new(vectors, query, metric, SearchConfig::k_nearest(k))
170 }
171
172 pub fn within_distance<I>(
176 vectors: I,
177 query: &Embedding,
178 metric: DistanceMetric,
179 max_distance: f32,
180 ) -> Result<Self, VectorError>
181 where
182 I: IntoIterator<Item = (EntityId, Embedding)>,
183 {
184 Self::new(vectors, query, metric, SearchConfig::within_distance(max_distance))
185 }
186
187 pub fn from_slice(
191 vectors: &[(EntityId, Embedding)],
192 query: &Embedding,
193 metric: DistanceMetric,
194 config: SearchConfig,
195 ) -> Result<Self, VectorError> {
196 Self::new(vectors.iter().cloned(), query, metric, config)
197 }
198
199 #[must_use]
201 pub fn len(&self) -> usize {
202 self.results.len()
203 }
204
205 #[must_use]
207 pub fn is_empty(&self) -> bool {
208 self.results.is_empty()
209 }
210
211 #[must_use]
213 pub fn peek(&self) -> Option<&VectorMatch> {
214 self.results.get(self.position)
215 }
216
217 pub fn reset(&mut self) {
219 self.position = 0;
220 }
221
222 #[must_use]
224 pub fn as_slice(&self) -> &[VectorMatch] {
225 &self.results
226 }
227}
228
229impl VectorOperator for ExactKnn {
230 fn next(&mut self) -> Result<Option<VectorMatch>, VectorError> {
231 if self.position < self.results.len() {
232 let result = self.results[self.position];
233 self.position += 1;
234 Ok(Some(result))
235 } else {
236 Ok(None)
237 }
238 }
239
240 fn dimension(&self) -> usize {
241 self.dim
242 }
243}
244
245fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
247 use crate::distance::{chebyshev_distance, manhattan_distance};
248 match metric {
249 DistanceMetric::Euclidean => euclidean_distance(a, b),
250 DistanceMetric::Cosine => cosine_distance(a, b),
251 DistanceMetric::DotProduct => -dot_product(a, b), DistanceMetric::Manhattan => manhattan_distance(a, b),
253 DistanceMetric::Chebyshev => chebyshev_distance(a, b),
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 fn create_test_embedding(dim: usize, value: f32) -> Embedding {
262 Embedding::new(vec![value; dim]).unwrap()
263 }
264
265 fn create_test_vectors(count: usize) -> Vec<(EntityId, Embedding)> {
266 (1..=count).map(|i| (EntityId::new(i as u64), create_test_embedding(4, i as f32))).collect()
267 }
268
269 #[test]
270 fn test_exact_knn_empty() {
271 let query = create_test_embedding(4, 1.0);
272 let vectors: Vec<(EntityId, Embedding)> = vec![];
273
274 let mut knn =
275 ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
276
277 assert!(knn.is_empty());
278 assert!(knn.next().unwrap().is_none());
279 }
280
281 #[test]
282 fn test_exact_knn_single() {
283 let query = create_test_embedding(4, 1.0);
284 let vectors = vec![(EntityId::new(1), create_test_embedding(4, 1.0))];
285
286 let mut knn =
287 ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
288
289 assert_eq!(knn.len(), 1);
290 let result = knn.next().unwrap().unwrap();
291 assert_eq!(result.entity_id, EntityId::new(1));
292 assert!(result.distance < 1e-6);
293 }
294
295 #[test]
296 fn test_exact_knn_k_smaller_than_n() {
297 let query = create_test_embedding(4, 5.0);
298 let vectors = create_test_vectors(10);
299
300 let mut knn =
301 ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 3).unwrap();
302
303 let results = knn.collect_all().unwrap();
304 assert_eq!(results.len(), 3);
305
306 assert!(results[0].distance <= results[1].distance);
308 assert!(results[1].distance <= results[2].distance);
309
310 assert_eq!(results[0].entity_id, EntityId::new(5));
312 }
313
314 #[test]
315 fn test_exact_knn_k_larger_than_n() {
316 let query = create_test_embedding(4, 1.0);
317 let vectors = create_test_vectors(3);
318
319 let knn = ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 10)
320 .unwrap();
321
322 assert_eq!(knn.len(), 3);
323 }
324
325 #[test]
326 fn test_exact_knn_with_max_distance() {
327 let query = create_test_embedding(4, 5.0);
328 let vectors = create_test_vectors(10);
329
330 let mut knn =
331 ExactKnn::within_distance(vectors.into_iter(), &query, DistanceMetric::Euclidean, 2.5)
332 .unwrap();
333
334 let results = knn.collect_all().unwrap();
335 for result in &results {
336 assert!(result.distance <= 2.5);
337 }
338 }
339
340 #[test]
341 fn test_exact_knn_cosine_distance() {
342 let query = Embedding::new(vec![1.0, 0.0, 0.0, 0.0]).unwrap();
343 let vectors = vec![
344 (EntityId::new(1), Embedding::new(vec![1.0, 0.0, 0.0, 0.0]).unwrap()), (EntityId::new(2), Embedding::new(vec![0.0, 1.0, 0.0, 0.0]).unwrap()), (EntityId::new(3), Embedding::new(vec![-1.0, 0.0, 0.0, 0.0]).unwrap()), ];
348
349 let mut knn =
350 ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Cosine, 3).unwrap();
351
352 let results = knn.collect_all().unwrap();
353
354 assert_eq!(results[0].entity_id, EntityId::new(1));
356 assert!(results[0].distance < 1e-6);
357
358 assert_eq!(results[1].entity_id, EntityId::new(2));
360 assert!((results[1].distance - 1.0).abs() < 1e-6);
361
362 assert_eq!(results[2].entity_id, EntityId::new(3));
364 assert!((results[2].distance - 2.0).abs() < 1e-6);
365 }
366
367 #[test]
368 fn test_exact_knn_dot_product() {
369 let query = Embedding::new(vec![1.0, 1.0, 0.0, 0.0]).unwrap();
370 let vectors = vec![
371 (EntityId::new(1), Embedding::new(vec![2.0, 2.0, 0.0, 0.0]).unwrap()), (EntityId::new(2), Embedding::new(vec![1.0, 0.0, 0.0, 0.0]).unwrap()), (EntityId::new(3), Embedding::new(vec![0.0, 0.0, 1.0, 1.0]).unwrap()), ];
375
376 let mut knn =
377 ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::DotProduct, 3)
378 .unwrap();
379
380 let results = knn.collect_all().unwrap();
381
382 assert_eq!(results[0].entity_id, EntityId::new(1));
384 assert!((results[0].distance - (-4.0)).abs() < 1e-6);
385 }
386
387 #[test]
388 fn test_exact_knn_dimension_mismatch() {
389 let query = create_test_embedding(4, 1.0);
390 let vectors = vec![(EntityId::new(1), create_test_embedding(8, 1.0))]; let result = ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5);
393
394 assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
395 }
396
397 #[test]
398 fn test_exact_knn_from_slice() {
399 let query = create_test_embedding(4, 5.0);
400 let vectors = create_test_vectors(10);
401
402 let mut knn = ExactKnn::from_slice(
403 &vectors,
404 &query,
405 DistanceMetric::Euclidean,
406 SearchConfig::k_nearest(3),
407 )
408 .unwrap();
409
410 assert_eq!(knn.len(), 3);
411 assert_eq!(knn.collect_all().unwrap()[0].entity_id, EntityId::new(5));
412 }
413
414 #[test]
415 fn test_exact_knn_peek_and_reset() {
416 let query = create_test_embedding(4, 1.0);
417 let vectors = create_test_vectors(3);
418
419 let mut knn =
420 ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
421
422 let first_id = knn.peek().unwrap().entity_id;
423 assert_eq!(knn.next().unwrap().unwrap().entity_id, first_id);
424
425 while knn.next().unwrap().is_some() {}
427
428 knn.reset();
430 assert_eq!(knn.peek().unwrap().entity_id, first_id);
431 }
432
433 #[test]
434 fn test_exact_knn_as_slice() {
435 let query = create_test_embedding(4, 1.0);
436 let vectors = create_test_vectors(5);
437
438 let knn =
439 ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
440
441 let slice = knn.as_slice();
442 assert_eq!(slice.len(), 5);
443 assert_eq!(slice[0].entity_id, EntityId::new(1));
445 }
446}