manifoldb_vector/ops/
ann_scan.rs

1//! Approximate Nearest Neighbor scan operator.
2//!
3//! Uses the HNSW index for fast approximate k-NN search.
4
5use manifoldb_storage::StorageEngine;
6
7use super::{SearchConfig, VectorMatch, VectorOperator};
8use crate::error::VectorError;
9use crate::index::{HnswIndex, VectorIndex};
10use crate::types::Embedding;
11
12/// Approximate nearest neighbor search operator.
13///
14/// Uses an HNSW index to efficiently find the K nearest neighbors to a query
15/// vector. The search is approximate, trading some accuracy for speed.
16///
17/// # Example
18///
19/// ```ignore
20/// use manifoldb_vector::ops::{AnnScan, VectorOperator, SearchConfig};
21///
22/// let config = SearchConfig::k_nearest(10).with_ef_search(50);
23/// let mut scan = AnnScan::new(&index, query, config)?;
24///
25/// while let Some(m) = scan.next()? {
26///     println!("Entity {:?} at distance {}", m.entity_id, m.distance);
27/// }
28/// ```
29pub struct AnnScan<'a, E: StorageEngine> {
30    /// Reference to the HNSW index.
31    index: &'a HnswIndex<E>,
32    /// Pre-computed search results (buffered).
33    results: Vec<VectorMatch>,
34    /// Current position in the results.
35    position: usize,
36    /// The dimension of the index.
37    dim: usize,
38}
39
40impl<'a, E: StorageEngine> AnnScan<'a, E> {
41    /// Create a new ANN scan operator.
42    ///
43    /// Performs the HNSW search immediately and buffers the results for
44    /// iteration. This is because HNSW search is most efficient when
45    /// performed as a single operation.
46    ///
47    /// # Arguments
48    ///
49    /// * `index` - Reference to the HNSW index
50    /// * `query` - The query embedding
51    /// * `config` - Search configuration (k, `max_distance`, `ef_search`)
52    ///
53    /// # Errors
54    ///
55    /// Returns an error if the query dimension doesn't match the index.
56    pub fn new(
57        index: &'a HnswIndex<E>,
58        query: &Embedding,
59        config: SearchConfig,
60    ) -> Result<Self, VectorError> {
61        let dim = index.dimension()?;
62
63        // Validate dimension
64        if query.dimension() != dim {
65            return Err(VectorError::DimensionMismatch {
66                expected: dim,
67                actual: query.dimension(),
68            });
69        }
70
71        // Perform the search
72        let search_results = index.search(query, config.k, config.ef_search)?;
73
74        // Convert and filter by max_distance if specified
75        let results: Vec<VectorMatch> = search_results
76            .into_iter()
77            .filter(|r| match config.max_distance {
78                Some(max_dist) => r.distance <= max_dist,
79                None => true,
80            })
81            .map(VectorMatch::from)
82            .collect();
83
84        Ok(Self { index, results, position: 0, dim })
85    }
86
87    /// Create an ANN scan with simple k-nearest configuration.
88    ///
89    /// Convenience method for common case of finding K nearest neighbors.
90    ///
91    /// # Arguments
92    ///
93    /// * `index` - Reference to the HNSW index
94    /// * `query` - The query embedding
95    /// * `k` - Number of nearest neighbors to find
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if the query dimension doesn't match the index.
100    pub fn k_nearest(
101        index: &'a HnswIndex<E>,
102        query: &Embedding,
103        k: usize,
104    ) -> Result<Self, VectorError> {
105        Self::new(index, query, SearchConfig::k_nearest(k))
106    }
107
108    /// Create an ANN scan to find vectors within a distance threshold.
109    ///
110    /// Note: HNSW is optimized for k-NN search. For within-distance queries,
111    /// this searches for a large k and filters by distance. Consider using
112    /// `ExactKnn` for precise distance-based filtering on small sets.
113    ///
114    /// # Arguments
115    ///
116    /// * `index` - Reference to the HNSW index
117    /// * `query` - The query embedding
118    /// * `max_distance` - Maximum distance threshold
119    /// * `max_results` - Maximum number of results to return
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if the query dimension doesn't match the index.
124    pub fn within_distance(
125        index: &'a HnswIndex<E>,
126        query: &Embedding,
127        max_distance: f32,
128        max_results: usize,
129    ) -> Result<Self, VectorError> {
130        Self::new(index, query, SearchConfig::within_distance(max_distance).with_k(max_results))
131    }
132
133    /// Get the underlying index.
134    #[must_use]
135    pub const fn index(&self) -> &'a HnswIndex<E> {
136        self.index
137    }
138
139    /// Get the number of results found.
140    #[must_use]
141    pub fn len(&self) -> usize {
142        self.results.len()
143    }
144
145    /// Check if no results were found.
146    #[must_use]
147    pub fn is_empty(&self) -> bool {
148        self.results.is_empty()
149    }
150
151    /// Peek at the next result without consuming it.
152    #[must_use]
153    pub fn peek(&self) -> Option<&VectorMatch> {
154        self.results.get(self.position)
155    }
156
157    /// Reset the iterator to the beginning.
158    pub fn reset(&mut self) {
159        self.position = 0;
160    }
161}
162
163impl<E: StorageEngine> VectorOperator for AnnScan<'_, E> {
164    fn next(&mut self) -> Result<Option<VectorMatch>, VectorError> {
165        if self.position < self.results.len() {
166            let result = self.results[self.position];
167            self.position += 1;
168            Ok(Some(result))
169        } else {
170            Ok(None)
171        }
172    }
173
174    fn dimension(&self) -> usize {
175        self.dim
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::distance::DistanceMetric;
183    use crate::index::HnswConfig;
184    use manifoldb_core::EntityId;
185    use manifoldb_storage::backends::RedbEngine;
186
187    fn create_test_embedding(dim: usize, value: f32) -> Embedding {
188        Embedding::new(vec![value; dim]).unwrap()
189    }
190
191    fn create_test_index() -> HnswIndex<RedbEngine> {
192        let engine = RedbEngine::in_memory().unwrap();
193        HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, HnswConfig::new(4)).unwrap()
194    }
195
196    #[test]
197    fn test_ann_scan_empty_index() {
198        let index = create_test_index();
199        let query = create_test_embedding(4, 1.0);
200
201        let mut scan = AnnScan::k_nearest(&index, &query, 5).unwrap();
202        assert!(scan.is_empty());
203        assert!(scan.next().unwrap().is_none());
204    }
205
206    #[test]
207    fn test_ann_scan_single_result() {
208        let mut index = create_test_index();
209        let embedding = create_test_embedding(4, 1.0);
210        index.insert(EntityId::new(1), &embedding).unwrap();
211
212        let query = create_test_embedding(4, 1.0);
213        let mut scan = AnnScan::k_nearest(&index, &query, 5).unwrap();
214
215        assert_eq!(scan.len(), 1);
216        let result = scan.next().unwrap().unwrap();
217        assert_eq!(result.entity_id, EntityId::new(1));
218        assert!(result.distance < 1e-6);
219
220        assert!(scan.next().unwrap().is_none());
221    }
222
223    #[test]
224    fn test_ann_scan_multiple_results() {
225        let mut index = create_test_index();
226        for i in 1..=10 {
227            let embedding = create_test_embedding(4, i as f32);
228            index.insert(EntityId::new(i), &embedding).unwrap();
229        }
230
231        let query = create_test_embedding(4, 5.0);
232        let mut scan = AnnScan::k_nearest(&index, &query, 3).unwrap();
233
234        assert_eq!(scan.len(), 3);
235
236        let results = scan.collect_all().unwrap();
237        assert_eq!(results.len(), 3);
238
239        // Results should be sorted by distance
240        assert!(results[0].distance <= results[1].distance);
241        assert!(results[1].distance <= results[2].distance);
242    }
243
244    #[test]
245    fn test_ann_scan_with_max_distance() {
246        let mut index = create_test_index();
247        for i in 1..=10 {
248            let embedding = create_test_embedding(4, i as f32);
249            index.insert(EntityId::new(i), &embedding).unwrap();
250        }
251
252        let query = create_test_embedding(4, 5.0);
253        // Only get results within distance 2.0 (entities 4, 5, 6)
254        let mut scan = AnnScan::within_distance(&index, &query, 2.5, 10).unwrap();
255
256        let results = scan.collect_all().unwrap();
257        for result in &results {
258            assert!(result.distance <= 2.5);
259        }
260    }
261
262    #[test]
263    fn test_ann_scan_dimension_mismatch() {
264        let index = create_test_index();
265        let query = create_test_embedding(8, 1.0); // Wrong dimension
266
267        let result = AnnScan::k_nearest(&index, &query, 5);
268        assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
269    }
270
271    #[test]
272    fn test_ann_scan_peek_and_reset() {
273        let mut index = create_test_index();
274        let embedding = create_test_embedding(4, 1.0);
275        index.insert(EntityId::new(1), &embedding).unwrap();
276
277        let query = create_test_embedding(4, 1.0);
278        let mut scan = AnnScan::k_nearest(&index, &query, 5).unwrap();
279
280        // Peek shouldn't consume
281        let peeked = scan.peek().unwrap();
282        assert_eq!(peeked.entity_id, EntityId::new(1));
283
284        // Next should return the same
285        let result = scan.next().unwrap().unwrap();
286        assert_eq!(result.entity_id, EntityId::new(1));
287
288        // Now exhausted
289        assert!(scan.next().unwrap().is_none());
290
291        // Reset and iterate again
292        scan.reset();
293        let result = scan.next().unwrap().unwrap();
294        assert_eq!(result.entity_id, EntityId::new(1));
295    }
296
297    #[test]
298    fn test_ann_scan_with_ef_search() {
299        let mut index = create_test_index();
300        for i in 1..=20 {
301            let embedding = create_test_embedding(4, i as f32);
302            index.insert(EntityId::new(i), &embedding).unwrap();
303        }
304
305        let query = create_test_embedding(4, 10.0);
306        let config = SearchConfig::k_nearest(5).with_ef_search(100);
307        let mut scan = AnnScan::new(&index, &query, config).unwrap();
308
309        let results = scan.collect_all().unwrap();
310        assert_eq!(results.len(), 5);
311    }
312}