manifoldb_vector/ops/
exact_knn.rs

1//! Exact K-Nearest Neighbors operator.
2//!
3//! Performs brute force k-NN search by computing distances to all vectors.
4
5use std::cmp::Ordering;
6use std::collections::BinaryHeap;
7
8use manifoldb_core::EntityId;
9
10use super::{SearchConfig, VectorMatch, VectorOperator};
11use crate::distance::{cosine_distance, dot_product, euclidean_distance, DistanceMetric};
12use crate::error::VectorError;
13use crate::types::Embedding;
14
15/// Exact k-NN search operator using brute force.
16///
17/// Computes distances to all provided vectors and returns the K nearest.
18/// This is useful for:
19/// - Small datasets where HNSW overhead isn't justified
20/// - Validating HNSW results
21/// - Post-filtering a small candidate set from graph traversal
22///
23/// # Complexity
24///
25/// O(n * d) where n is the number of vectors and d is the dimension.
26/// For large datasets, use [`AnnScan`](super::AnnScan) instead.
27///
28/// # Example
29///
30/// ```ignore
31/// use manifoldb_vector::ops::{ExactKnn, VectorOperator, SearchConfig};
32/// use manifoldb_vector::distance::DistanceMetric;
33///
34/// let vectors = vec![
35///     (EntityId::new(1), embedding1),
36///     (EntityId::new(2), embedding2),
37/// ];
38///
39/// let config = SearchConfig::k_nearest(5);
40/// let mut knn = ExactKnn::new(vectors, query, DistanceMetric::Cosine, config)?;
41///
42/// while let Some(m) = knn.next()? {
43///     println!("Entity {:?} at distance {}", m.entity_id, m.distance);
44/// }
45/// ```
46pub struct ExactKnn {
47    /// Pre-computed and sorted results.
48    results: Vec<VectorMatch>,
49    /// Current position in results.
50    position: usize,
51    /// Dimension of the vectors.
52    dim: usize,
53}
54
55/// Wrapper for max-heap comparison (we want smallest distances first).
56#[derive(Debug)]
57struct MaxHeapEntry {
58    entity_id: EntityId,
59    distance: f32,
60}
61
62impl PartialEq for MaxHeapEntry {
63    fn eq(&self, other: &Self) -> bool {
64        self.distance == other.distance
65    }
66}
67
68impl Eq for MaxHeapEntry {}
69
70impl PartialOrd for MaxHeapEntry {
71    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
72        Some(self.cmp(other))
73    }
74}
75
76impl Ord for MaxHeapEntry {
77    fn cmp(&self, other: &Self) -> Ordering {
78        // Max-heap: larger distances should come first (to be popped)
79        // NaN values are treated as equal to maintain a total ordering for the heap.
80        // In practice, NaN distances should not occur from valid distance calculations.
81        self.distance.partial_cmp(&other.distance).unwrap_or(Ordering::Equal)
82    }
83}
84
85impl ExactKnn {
86    /// Create a new exact k-NN search operator.
87    ///
88    /// # Arguments
89    ///
90    /// * `vectors` - Iterator of (`entity_id`, embedding) pairs to search
91    /// * `query` - The query embedding
92    /// * `metric` - Distance metric to use
93    /// * `config` - Search configuration
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if any vector has a dimension mismatch with the query.
98    pub fn new<I>(
99        vectors: I,
100        query: &Embedding,
101        metric: DistanceMetric,
102        config: SearchConfig,
103    ) -> Result<Self, VectorError>
104    where
105        I: IntoIterator<Item = (EntityId, Embedding)>,
106    {
107        let dim = query.dimension();
108        let query_slice = query.as_slice();
109        let k = config.k;
110        let max_distance = config.max_distance;
111
112        // Use a max-heap to keep track of k smallest distances
113        // Use saturating_add to avoid overflow when k is usize::MAX
114        let mut heap: BinaryHeap<MaxHeapEntry> =
115            BinaryHeap::with_capacity(k.saturating_add(1).min(1024));
116
117        for (entity_id, embedding) in vectors {
118            // Validate dimension
119            if embedding.dimension() != dim {
120                return Err(VectorError::DimensionMismatch {
121                    expected: dim,
122                    actual: embedding.dimension(),
123                });
124            }
125
126            let distance = compute_distance(query_slice, embedding.as_slice(), metric);
127
128            // Skip if exceeds max_distance
129            if let Some(max_dist) = max_distance {
130                if distance > max_dist {
131                    continue;
132                }
133            }
134
135            // Add to heap if within k or better than worst
136            if heap.len() < k {
137                heap.push(MaxHeapEntry { entity_id, distance });
138            } else if let Some(worst) = heap.peek() {
139                if distance < worst.distance {
140                    heap.pop();
141                    heap.push(MaxHeapEntry { entity_id, distance });
142                }
143            }
144        }
145
146        // Convert heap to sorted vec
147        let mut results: Vec<VectorMatch> =
148            heap.into_iter().map(|e| VectorMatch::new(e.entity_id, e.distance)).collect();
149
150        // Sort by distance (ascending). NaN distances are treated as equal to maintain
151        // a stable sort order; in practice NaN should not occur from valid calculations.
152        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal));
153
154        Ok(Self { results, position: 0, dim })
155    }
156
157    /// Create a k-nearest search from an iterator of vectors.
158    ///
159    /// Convenience method for common case.
160    pub fn k_nearest<I>(
161        vectors: I,
162        query: &Embedding,
163        metric: DistanceMetric,
164        k: usize,
165    ) -> Result<Self, VectorError>
166    where
167        I: IntoIterator<Item = (EntityId, Embedding)>,
168    {
169        Self::new(vectors, query, metric, SearchConfig::k_nearest(k))
170    }
171
172    /// Create a within-distance search from an iterator of vectors.
173    ///
174    /// Returns all vectors within the specified distance threshold.
175    pub fn within_distance<I>(
176        vectors: I,
177        query: &Embedding,
178        metric: DistanceMetric,
179        max_distance: f32,
180    ) -> Result<Self, VectorError>
181    where
182        I: IntoIterator<Item = (EntityId, Embedding)>,
183    {
184        Self::new(vectors, query, metric, SearchConfig::within_distance(max_distance))
185    }
186
187    /// Create from a slice of vectors (borrows and clones).
188    ///
189    /// Useful when you have a reference to existing data.
190    pub fn from_slice(
191        vectors: &[(EntityId, Embedding)],
192        query: &Embedding,
193        metric: DistanceMetric,
194        config: SearchConfig,
195    ) -> Result<Self, VectorError> {
196        Self::new(vectors.iter().cloned(), query, metric, config)
197    }
198
199    /// Get the number of results found.
200    #[must_use]
201    pub fn len(&self) -> usize {
202        self.results.len()
203    }
204
205    /// Check if no results were found.
206    #[must_use]
207    pub fn is_empty(&self) -> bool {
208        self.results.is_empty()
209    }
210
211    /// Peek at the next result without consuming it.
212    #[must_use]
213    pub fn peek(&self) -> Option<&VectorMatch> {
214        self.results.get(self.position)
215    }
216
217    /// Reset the iterator to the beginning.
218    pub fn reset(&mut self) {
219        self.position = 0;
220    }
221
222    /// Get all results as a slice.
223    #[must_use]
224    pub fn as_slice(&self) -> &[VectorMatch] {
225        &self.results
226    }
227}
228
229impl VectorOperator for ExactKnn {
230    fn next(&mut self) -> Result<Option<VectorMatch>, VectorError> {
231        if self.position < self.results.len() {
232            let result = self.results[self.position];
233            self.position += 1;
234            Ok(Some(result))
235        } else {
236            Ok(None)
237        }
238    }
239
240    fn dimension(&self) -> usize {
241        self.dim
242    }
243}
244
245/// Compute distance between two vectors using the specified metric.
246fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
247    use crate::distance::{chebyshev_distance, manhattan_distance};
248    match metric {
249        DistanceMetric::Euclidean => euclidean_distance(a, b),
250        DistanceMetric::Cosine => cosine_distance(a, b),
251        DistanceMetric::DotProduct => -dot_product(a, b), // Negate for min-distance
252        DistanceMetric::Manhattan => manhattan_distance(a, b),
253        DistanceMetric::Chebyshev => chebyshev_distance(a, b),
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    fn create_test_embedding(dim: usize, value: f32) -> Embedding {
262        Embedding::new(vec![value; dim]).unwrap()
263    }
264
265    fn create_test_vectors(count: usize) -> Vec<(EntityId, Embedding)> {
266        (1..=count).map(|i| (EntityId::new(i as u64), create_test_embedding(4, i as f32))).collect()
267    }
268
269    #[test]
270    fn test_exact_knn_empty() {
271        let query = create_test_embedding(4, 1.0);
272        let vectors: Vec<(EntityId, Embedding)> = vec![];
273
274        let mut knn =
275            ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
276
277        assert!(knn.is_empty());
278        assert!(knn.next().unwrap().is_none());
279    }
280
281    #[test]
282    fn test_exact_knn_single() {
283        let query = create_test_embedding(4, 1.0);
284        let vectors = vec![(EntityId::new(1), create_test_embedding(4, 1.0))];
285
286        let mut knn =
287            ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
288
289        assert_eq!(knn.len(), 1);
290        let result = knn.next().unwrap().unwrap();
291        assert_eq!(result.entity_id, EntityId::new(1));
292        assert!(result.distance < 1e-6);
293    }
294
295    #[test]
296    fn test_exact_knn_k_smaller_than_n() {
297        let query = create_test_embedding(4, 5.0);
298        let vectors = create_test_vectors(10);
299
300        let mut knn =
301            ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 3).unwrap();
302
303        let results = knn.collect_all().unwrap();
304        assert_eq!(results.len(), 3);
305
306        // Results should be sorted by distance
307        assert!(results[0].distance <= results[1].distance);
308        assert!(results[1].distance <= results[2].distance);
309
310        // Closest should be entity 5 (same value as query)
311        assert_eq!(results[0].entity_id, EntityId::new(5));
312    }
313
314    #[test]
315    fn test_exact_knn_k_larger_than_n() {
316        let query = create_test_embedding(4, 1.0);
317        let vectors = create_test_vectors(3);
318
319        let knn = ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 10)
320            .unwrap();
321
322        assert_eq!(knn.len(), 3);
323    }
324
325    #[test]
326    fn test_exact_knn_with_max_distance() {
327        let query = create_test_embedding(4, 5.0);
328        let vectors = create_test_vectors(10);
329
330        let mut knn =
331            ExactKnn::within_distance(vectors.into_iter(), &query, DistanceMetric::Euclidean, 2.5)
332                .unwrap();
333
334        let results = knn.collect_all().unwrap();
335        for result in &results {
336            assert!(result.distance <= 2.5);
337        }
338    }
339
340    #[test]
341    fn test_exact_knn_cosine_distance() {
342        let query = Embedding::new(vec![1.0, 0.0, 0.0, 0.0]).unwrap();
343        let vectors = vec![
344            (EntityId::new(1), Embedding::new(vec![1.0, 0.0, 0.0, 0.0]).unwrap()), // Same direction
345            (EntityId::new(2), Embedding::new(vec![0.0, 1.0, 0.0, 0.0]).unwrap()), // Orthogonal
346            (EntityId::new(3), Embedding::new(vec![-1.0, 0.0, 0.0, 0.0]).unwrap()), // Opposite
347        ];
348
349        let mut knn =
350            ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Cosine, 3).unwrap();
351
352        let results = knn.collect_all().unwrap();
353
354        // Entity 1 should be closest (cosine distance = 0)
355        assert_eq!(results[0].entity_id, EntityId::new(1));
356        assert!(results[0].distance < 1e-6);
357
358        // Entity 2 should be next (cosine distance = 1)
359        assert_eq!(results[1].entity_id, EntityId::new(2));
360        assert!((results[1].distance - 1.0).abs() < 1e-6);
361
362        // Entity 3 should be furthest (cosine distance = 2)
363        assert_eq!(results[2].entity_id, EntityId::new(3));
364        assert!((results[2].distance - 2.0).abs() < 1e-6);
365    }
366
367    #[test]
368    fn test_exact_knn_dot_product() {
369        let query = Embedding::new(vec![1.0, 1.0, 0.0, 0.0]).unwrap();
370        let vectors = vec![
371            (EntityId::new(1), Embedding::new(vec![2.0, 2.0, 0.0, 0.0]).unwrap()), // Dot = 4
372            (EntityId::new(2), Embedding::new(vec![1.0, 0.0, 0.0, 0.0]).unwrap()), // Dot = 1
373            (EntityId::new(3), Embedding::new(vec![0.0, 0.0, 1.0, 1.0]).unwrap()), // Dot = 0
374        ];
375
376        let mut knn =
377            ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::DotProduct, 3)
378                .unwrap();
379
380        let results = knn.collect_all().unwrap();
381
382        // Entity 1 should be closest (highest dot product = -4 distance)
383        assert_eq!(results[0].entity_id, EntityId::new(1));
384        assert!((results[0].distance - (-4.0)).abs() < 1e-6);
385    }
386
387    #[test]
388    fn test_exact_knn_dimension_mismatch() {
389        let query = create_test_embedding(4, 1.0);
390        let vectors = vec![(EntityId::new(1), create_test_embedding(8, 1.0))]; // Wrong dimension
391
392        let result = ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5);
393
394        assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
395    }
396
397    #[test]
398    fn test_exact_knn_from_slice() {
399        let query = create_test_embedding(4, 5.0);
400        let vectors = create_test_vectors(10);
401
402        let mut knn = ExactKnn::from_slice(
403            &vectors,
404            &query,
405            DistanceMetric::Euclidean,
406            SearchConfig::k_nearest(3),
407        )
408        .unwrap();
409
410        assert_eq!(knn.len(), 3);
411        assert_eq!(knn.collect_all().unwrap()[0].entity_id, EntityId::new(5));
412    }
413
414    #[test]
415    fn test_exact_knn_peek_and_reset() {
416        let query = create_test_embedding(4, 1.0);
417        let vectors = create_test_vectors(3);
418
419        let mut knn =
420            ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
421
422        let first_id = knn.peek().unwrap().entity_id;
423        assert_eq!(knn.next().unwrap().unwrap().entity_id, first_id);
424
425        // Exhaust
426        while knn.next().unwrap().is_some() {}
427
428        // Reset and check first again
429        knn.reset();
430        assert_eq!(knn.peek().unwrap().entity_id, first_id);
431    }
432
433    #[test]
434    fn test_exact_knn_as_slice() {
435        let query = create_test_embedding(4, 1.0);
436        let vectors = create_test_vectors(5);
437
438        let knn =
439            ExactKnn::k_nearest(vectors.into_iter(), &query, DistanceMetric::Euclidean, 5).unwrap();
440
441        let slice = knn.as_slice();
442        assert_eq!(slice.len(), 5);
443        // First should be closest to query value 1.0
444        assert_eq!(slice[0].entity_id, EntityId::new(1));
445    }
446}