manifoldb_vector/ops/
ann_scan.rs1use manifoldb_storage::StorageEngine;
6
7use super::{SearchConfig, VectorMatch, VectorOperator};
8use crate::error::VectorError;
9use crate::index::{HnswIndex, VectorIndex};
10use crate::types::Embedding;
11
12pub struct AnnScan<'a, E: StorageEngine> {
30 index: &'a HnswIndex<E>,
32 results: Vec<VectorMatch>,
34 position: usize,
36 dim: usize,
38}
39
40impl<'a, E: StorageEngine> AnnScan<'a, E> {
41 pub fn new(
57 index: &'a HnswIndex<E>,
58 query: &Embedding,
59 config: SearchConfig,
60 ) -> Result<Self, VectorError> {
61 let dim = index.dimension()?;
62
63 if query.dimension() != dim {
65 return Err(VectorError::DimensionMismatch {
66 expected: dim,
67 actual: query.dimension(),
68 });
69 }
70
71 let search_results = index.search(query, config.k, config.ef_search)?;
73
74 let results: Vec<VectorMatch> = search_results
76 .into_iter()
77 .filter(|r| match config.max_distance {
78 Some(max_dist) => r.distance <= max_dist,
79 None => true,
80 })
81 .map(VectorMatch::from)
82 .collect();
83
84 Ok(Self { index, results, position: 0, dim })
85 }
86
87 pub fn k_nearest(
101 index: &'a HnswIndex<E>,
102 query: &Embedding,
103 k: usize,
104 ) -> Result<Self, VectorError> {
105 Self::new(index, query, SearchConfig::k_nearest(k))
106 }
107
108 pub fn within_distance(
125 index: &'a HnswIndex<E>,
126 query: &Embedding,
127 max_distance: f32,
128 max_results: usize,
129 ) -> Result<Self, VectorError> {
130 Self::new(index, query, SearchConfig::within_distance(max_distance).with_k(max_results))
131 }
132
133 #[must_use]
135 pub const fn index(&self) -> &'a HnswIndex<E> {
136 self.index
137 }
138
139 #[must_use]
141 pub fn len(&self) -> usize {
142 self.results.len()
143 }
144
145 #[must_use]
147 pub fn is_empty(&self) -> bool {
148 self.results.is_empty()
149 }
150
151 #[must_use]
153 pub fn peek(&self) -> Option<&VectorMatch> {
154 self.results.get(self.position)
155 }
156
157 pub fn reset(&mut self) {
159 self.position = 0;
160 }
161}
162
163impl<E: StorageEngine> VectorOperator for AnnScan<'_, E> {
164 fn next(&mut self) -> Result<Option<VectorMatch>, VectorError> {
165 if self.position < self.results.len() {
166 let result = self.results[self.position];
167 self.position += 1;
168 Ok(Some(result))
169 } else {
170 Ok(None)
171 }
172 }
173
174 fn dimension(&self) -> usize {
175 self.dim
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::distance::DistanceMetric;
183 use crate::index::HnswConfig;
184 use manifoldb_core::EntityId;
185 use manifoldb_storage::backends::RedbEngine;
186
187 fn create_test_embedding(dim: usize, value: f32) -> Embedding {
188 Embedding::new(vec![value; dim]).unwrap()
189 }
190
191 fn create_test_index() -> HnswIndex<RedbEngine> {
192 let engine = RedbEngine::in_memory().unwrap();
193 HnswIndex::new(engine, "test", 4, DistanceMetric::Euclidean, HnswConfig::new(4)).unwrap()
194 }
195
196 #[test]
197 fn test_ann_scan_empty_index() {
198 let index = create_test_index();
199 let query = create_test_embedding(4, 1.0);
200
201 let mut scan = AnnScan::k_nearest(&index, &query, 5).unwrap();
202 assert!(scan.is_empty());
203 assert!(scan.next().unwrap().is_none());
204 }
205
206 #[test]
207 fn test_ann_scan_single_result() {
208 let mut index = create_test_index();
209 let embedding = create_test_embedding(4, 1.0);
210 index.insert(EntityId::new(1), &embedding).unwrap();
211
212 let query = create_test_embedding(4, 1.0);
213 let mut scan = AnnScan::k_nearest(&index, &query, 5).unwrap();
214
215 assert_eq!(scan.len(), 1);
216 let result = scan.next().unwrap().unwrap();
217 assert_eq!(result.entity_id, EntityId::new(1));
218 assert!(result.distance < 1e-6);
219
220 assert!(scan.next().unwrap().is_none());
221 }
222
223 #[test]
224 fn test_ann_scan_multiple_results() {
225 let mut index = create_test_index();
226 for i in 1..=10 {
227 let embedding = create_test_embedding(4, i as f32);
228 index.insert(EntityId::new(i), &embedding).unwrap();
229 }
230
231 let query = create_test_embedding(4, 5.0);
232 let mut scan = AnnScan::k_nearest(&index, &query, 3).unwrap();
233
234 assert_eq!(scan.len(), 3);
235
236 let results = scan.collect_all().unwrap();
237 assert_eq!(results.len(), 3);
238
239 assert!(results[0].distance <= results[1].distance);
241 assert!(results[1].distance <= results[2].distance);
242 }
243
244 #[test]
245 fn test_ann_scan_with_max_distance() {
246 let mut index = create_test_index();
247 for i in 1..=10 {
248 let embedding = create_test_embedding(4, i as f32);
249 index.insert(EntityId::new(i), &embedding).unwrap();
250 }
251
252 let query = create_test_embedding(4, 5.0);
253 let mut scan = AnnScan::within_distance(&index, &query, 2.5, 10).unwrap();
255
256 let results = scan.collect_all().unwrap();
257 for result in &results {
258 assert!(result.distance <= 2.5);
259 }
260 }
261
262 #[test]
263 fn test_ann_scan_dimension_mismatch() {
264 let index = create_test_index();
265 let query = create_test_embedding(8, 1.0); let result = AnnScan::k_nearest(&index, &query, 5);
268 assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
269 }
270
271 #[test]
272 fn test_ann_scan_peek_and_reset() {
273 let mut index = create_test_index();
274 let embedding = create_test_embedding(4, 1.0);
275 index.insert(EntityId::new(1), &embedding).unwrap();
276
277 let query = create_test_embedding(4, 1.0);
278 let mut scan = AnnScan::k_nearest(&index, &query, 5).unwrap();
279
280 let peeked = scan.peek().unwrap();
282 assert_eq!(peeked.entity_id, EntityId::new(1));
283
284 let result = scan.next().unwrap().unwrap();
286 assert_eq!(result.entity_id, EntityId::new(1));
287
288 assert!(scan.next().unwrap().is_none());
290
291 scan.reset();
293 let result = scan.next().unwrap().unwrap();
294 assert_eq!(result.entity_id, EntityId::new(1));
295 }
296
297 #[test]
298 fn test_ann_scan_with_ef_search() {
299 let mut index = create_test_index();
300 for i in 1..=20 {
301 let embedding = create_test_embedding(4, i as f32);
302 index.insert(EntityId::new(i), &embedding).unwrap();
303 }
304
305 let query = create_test_embedding(4, 10.0);
306 let config = SearchConfig::k_nearest(5).with_ef_search(100);
307 let mut scan = AnnScan::new(&index, &query, config).unwrap();
308
309 let results = scan.collect_all().unwrap();
310 assert_eq!(results.len(), 5);
311 }
312}