Skip to main content

engine/
engine.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use common::{
5    DakeraError, DistanceMetric, NamespaceId, PaginationCursor, QueryRequest, QueryResponse,
6    Result, SearchResult,
7};
8use parking_lot::RwLock;
9use storage::VectorStorage;
10
11use crate::filter::evaluate_filter;
12use crate::hnsw::{HnswConfig, HnswIndex};
13use crate::search::brute_force_search;
14
15/// Default vector count threshold above which HNSW is used instead of brute force
16const DEFAULT_ANN_THRESHOLD: usize = 1000;
17
18/// Convert HNSW distance back to similarity score (inverse of hnsw::similarity_to_distance)
19#[inline]
20fn distance_to_similarity(distance: f32, metric: DistanceMetric) -> f32 {
21    match metric {
22        DistanceMetric::Cosine => 1.0 - distance,
23        DistanceMetric::Euclidean => -distance,
24        DistanceMetric::DotProduct => -distance,
25    }
26}
27
28/// Read ANN threshold from environment variable
29fn ann_threshold_from_env() -> usize {
30    std::env::var("DAKERA_ANN_THRESHOLD")
31        .ok()
32        .and_then(|v| v.parse().ok())
33        .unwrap_or(DEFAULT_ANN_THRESHOLD)
34}
35
36/// Main search engine that coordinates storage and search operations
37pub struct SearchEngine<S: VectorStorage + ?Sized> {
38    storage: Arc<S>,
39    /// Cached HNSW indices per namespace for ANN acceleration
40    ann_indices: RwLock<HashMap<String, Arc<HnswIndex>>>,
41    /// Vector count threshold above which HNSW is used
42    ann_threshold: usize,
43}
44
45impl<S: VectorStorage + ?Sized> SearchEngine<S> {
46    pub fn new(storage: Arc<S>) -> Self {
47        Self {
48            storage,
49            ann_indices: RwLock::new(HashMap::new()),
50            ann_threshold: ann_threshold_from_env(),
51        }
52    }
53
54    /// Perform vector search in a namespace
55    pub async fn search(
56        &self,
57        namespace: &NamespaceId,
58        request: &QueryRequest,
59    ) -> Result<QueryResponse> {
60        // Check if namespace exists
61        if !self.storage.namespace_exists(namespace).await? {
62            return Err(DakeraError::NamespaceNotFound(namespace.clone()));
63        }
64
65        // Validate query dimension against namespace dimension
66        if let Some(expected_dim) = self.storage.dimension(namespace).await? {
67            if request.vector.len() != expected_dim {
68                return Err(DakeraError::DimensionMismatch {
69                    expected: expected_dim,
70                    actual: request.vector.len(),
71                });
72            }
73        }
74
75        // Determine if we can use ANN acceleration:
76        // - No filter (HNSW doesn't support metadata filtering)
77        // - No cursor pagination (HNSW returns top-k directly)
78        // - Namespace has enough vectors to justify index overhead
79        let use_ann =
80            request.filter.is_none() && request.cursor.is_none() && self.ann_threshold > 0;
81
82        if use_ann {
83            let count = self.storage.count(namespace).await?;
84            if count > self.ann_threshold {
85                return self.ann_search(namespace, request, count).await;
86            }
87        }
88
89        // Fall through to brute-force search
90        self.brute_force_path(namespace, request).await
91    }
92
93    /// Brute-force search path (original behavior)
94    async fn brute_force_path(
95        &self,
96        namespace: &NamespaceId,
97        request: &QueryRequest,
98    ) -> Result<QueryResponse> {
99        let vectors = self.storage.get_all(namespace).await?;
100
101        let filtered_vectors: Vec<_> = if let Some(ref filter) = request.filter {
102            vectors
103                .into_iter()
104                .filter(|v| evaluate_filter(filter, v.metadata.as_ref()))
105                .collect()
106        } else {
107            vectors
108        };
109
110        let cursor = request
111            .cursor
112            .as_ref()
113            .and_then(|c| PaginationCursor::decode(c));
114
115        tracing::debug!(
116            namespace = %namespace,
117            vector_count = filtered_vectors.len(),
118            top_k = request.top_k,
119            metric = ?request.distance_metric,
120            has_filter = request.filter.is_some(),
121            has_cursor = cursor.is_some(),
122            "Performing brute-force search"
123        );
124
125        let response = brute_force_search(
126            &request.vector,
127            &filtered_vectors,
128            request.top_k,
129            request.distance_metric,
130            request.include_metadata,
131            request.include_vectors,
132            cursor.as_ref(),
133        );
134
135        Ok(response)
136    }
137
138    /// ANN search path using cached HNSW index
139    async fn ann_search(
140        &self,
141        namespace: &NamespaceId,
142        request: &QueryRequest,
143        vector_count: usize,
144    ) -> Result<QueryResponse> {
145        // Get or build the HNSW index for this namespace
146        let index = self
147            .get_or_build_index(namespace, request.distance_metric)
148            .await?;
149
150        tracing::debug!(
151            namespace = %namespace,
152            vector_count = vector_count,
153            top_k = request.top_k,
154            metric = ?request.distance_metric,
155            "Performing ANN search (HNSW)"
156        );
157
158        // Search the HNSW index — returns (VectorId, distance)
159        let hnsw_results = index.search(&request.vector, request.top_k);
160
161        // If we need metadata or vectors, fetch them from storage
162        let need_fetch = request.include_metadata || request.include_vectors;
163        let fetched = if need_fetch && !hnsw_results.is_empty() {
164            let ids: Vec<String> = hnsw_results.iter().map(|(id, _)| id.clone()).collect();
165            let vectors = self.storage.get(namespace, &ids).await?;
166            let map: HashMap<String, _> = vectors.into_iter().map(|v| (v.id.clone(), v)).collect();
167            Some(map)
168        } else {
169            None
170        };
171
172        // Convert HNSW results to SearchResults
173        let results: Vec<SearchResult> = hnsw_results
174            .into_iter()
175            .map(|(id, distance)| {
176                let score = distance_to_similarity(distance, request.distance_metric);
177                let (metadata, vector) = if let Some(ref map) = fetched {
178                    if let Some(v) = map.get(&id) {
179                        (
180                            if request.include_metadata {
181                                v.metadata.clone()
182                            } else {
183                                None
184                            },
185                            if request.include_vectors {
186                                Some(v.values.clone())
187                            } else {
188                                None
189                            },
190                        )
191                    } else {
192                        (None, None)
193                    }
194                } else {
195                    (None, None)
196                };
197                SearchResult {
198                    id,
199                    score,
200                    metadata,
201                    vector,
202                }
203            })
204            .collect();
205
206        Ok(QueryResponse {
207            results,
208            next_cursor: None,
209            has_more: Some(false),
210            search_time_ms: 0, // caller typically overwrites this
211        })
212    }
213
214    /// Get cached HNSW index or build one from storage
215    async fn get_or_build_index(
216        &self,
217        namespace: &NamespaceId,
218        metric: DistanceMetric,
219    ) -> Result<Arc<HnswIndex>> {
220        // Fast path: check read lock
221        {
222            let indices = self.ann_indices.read();
223            if let Some(index) = indices.get(namespace.as_str()) {
224                return Ok(Arc::clone(index));
225            }
226        }
227
228        // Slow path: build index from storage
229        tracing::info!(namespace = %namespace, "Building HNSW index for ANN acceleration");
230        let vectors = self.storage.get_all(namespace).await?;
231
232        let config = HnswConfig::default().with_distance_metric(metric);
233        let index = HnswIndex::with_config(config);
234
235        for v in &vectors {
236            index.insert(v.id.clone(), v.values.clone());
237        }
238
239        let index = Arc::new(index);
240
241        // Cache it
242        {
243            let mut indices = self.ann_indices.write();
244            indices.insert(namespace.clone(), Arc::clone(&index));
245        }
246
247        tracing::info!(
248            namespace = %namespace,
249            vectors = vectors.len(),
250            "HNSW index built and cached"
251        );
252
253        Ok(index)
254    }
255
256    /// Invalidate the cached HNSW index for a namespace (call after upsert/delete)
257    pub fn invalidate_ann_index(&self, namespace: &NamespaceId) {
258        let mut indices = self.ann_indices.write();
259        if indices.remove(namespace.as_str()).is_some() {
260            tracing::debug!(namespace = %namespace, "HNSW index invalidated");
261        }
262    }
263
264    /// Get reference to storage
265    pub fn storage(&self) -> &Arc<S> {
266        &self.storage
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use common::{DistanceMetric, FilterCondition, FilterExpression, FilterValue, Vector};
274    use std::collections::HashMap;
275    use storage::InMemoryStorage;
276
277    async fn setup_engine() -> (SearchEngine<InMemoryStorage>, String) {
278        let storage = Arc::new(InMemoryStorage::new());
279        let engine = SearchEngine::new(storage.clone());
280        let namespace = "test".to_string();
281
282        storage.ensure_namespace(&namespace).await.unwrap();
283        storage
284            .upsert(
285                &namespace,
286                vec![
287                    Vector {
288                        id: "v1".to_string(),
289                        values: vec![1.0, 0.0, 0.0],
290                        metadata: None,
291                        ttl_seconds: None,
292                        expires_at: None,
293                    },
294                    Vector {
295                        id: "v2".to_string(),
296                        values: vec![0.0, 1.0, 0.0],
297                        metadata: None,
298                        ttl_seconds: None,
299                        expires_at: None,
300                    },
301                    Vector {
302                        id: "v3".to_string(),
303                        values: vec![0.707, 0.707, 0.0],
304                        metadata: None,
305                        ttl_seconds: None,
306                        expires_at: None,
307                    },
308                ],
309            )
310            .await
311            .unwrap();
312
313        (engine, namespace)
314    }
315
316    #[tokio::test]
317    async fn test_search_basic() {
318        let (engine, namespace) = setup_engine().await;
319
320        let request = QueryRequest {
321            vector: vec![1.0, 0.0, 0.0],
322            top_k: 2,
323            distance_metric: DistanceMetric::Cosine,
324            include_metadata: true,
325            include_vectors: false,
326            filter: None,
327            cursor: None,
328            consistency: Default::default(),
329            staleness_config: None,
330        };
331
332        let response = engine.search(&namespace, &request).await.unwrap();
333
334        assert_eq!(response.results.len(), 2);
335        assert_eq!(response.results[0].id, "v1"); // Perfect match
336    }
337
338    #[tokio::test]
339    async fn test_search_namespace_not_found() {
340        let storage = Arc::new(InMemoryStorage::new());
341        let engine = SearchEngine::new(storage);
342
343        let request = QueryRequest {
344            vector: vec![1.0, 0.0, 0.0],
345            top_k: 5,
346            distance_metric: DistanceMetric::Cosine,
347            include_metadata: true,
348            include_vectors: false,
349            filter: None,
350            cursor: None,
351            consistency: Default::default(),
352            staleness_config: None,
353        };
354
355        let result = engine.search(&"nonexistent".to_string(), &request).await;
356
357        assert!(matches!(result, Err(DakeraError::NamespaceNotFound(_))));
358    }
359
360    #[tokio::test]
361    async fn test_search_dimension_mismatch() {
362        let (engine, namespace) = setup_engine().await;
363
364        let request = QueryRequest {
365            vector: vec![1.0, 0.0], // Wrong dimension (2 instead of 3)
366            top_k: 5,
367            distance_metric: DistanceMetric::Cosine,
368            include_metadata: true,
369            include_vectors: false,
370            filter: None,
371            cursor: None,
372            consistency: Default::default(),
373            staleness_config: None,
374        };
375
376        let result = engine.search(&namespace, &request).await;
377
378        assert!(matches!(
379            result,
380            Err(DakeraError::DimensionMismatch {
381                expected: 3,
382                actual: 2
383            })
384        ));
385    }
386
387    #[tokio::test]
388    async fn test_search_empty_namespace() {
389        let storage = Arc::new(InMemoryStorage::new());
390        let engine = SearchEngine::new(storage.clone());
391        let namespace = "empty".to_string();
392
393        storage.ensure_namespace(&namespace).await.unwrap();
394
395        let request = QueryRequest {
396            vector: vec![1.0, 0.0, 0.0],
397            top_k: 5,
398            distance_metric: DistanceMetric::Cosine,
399            include_metadata: true,
400            include_vectors: false,
401            filter: None,
402            cursor: None,
403            consistency: Default::default(),
404            staleness_config: None,
405        };
406
407        let response = engine.search(&namespace, &request).await.unwrap();
408
409        assert!(response.results.is_empty());
410    }
411
412    #[tokio::test]
413    async fn test_search_with_filter() {
414        let storage = Arc::new(InMemoryStorage::new());
415        let engine = SearchEngine::new(storage.clone());
416        let namespace = "test".to_string();
417
418        storage.ensure_namespace(&namespace).await.unwrap();
419        storage
420            .upsert(
421                &namespace,
422                vec![
423                    Vector {
424                        id: "v1".to_string(),
425                        values: vec![1.0, 0.0, 0.0],
426                        metadata: Some(
427                            serde_json::json!({"category": "electronics", "price": 100}),
428                        ),
429                        ttl_seconds: None,
430                        expires_at: None,
431                    },
432                    Vector {
433                        id: "v2".to_string(),
434                        values: vec![0.9, 0.1, 0.0],
435                        metadata: Some(serde_json::json!({"category": "books", "price": 20})),
436                        ttl_seconds: None,
437                        expires_at: None,
438                    },
439                    Vector {
440                        id: "v3".to_string(),
441                        values: vec![0.8, 0.2, 0.0],
442                        metadata: Some(serde_json::json!({"category": "electronics", "price": 50})),
443                        ttl_seconds: None,
444                        expires_at: None,
445                    },
446                ],
447            )
448            .await
449            .unwrap();
450
451        // Filter for electronics only
452        let mut field = HashMap::new();
453        field.insert(
454            "category".to_string(),
455            FilterCondition::Eq(FilterValue::String("electronics".to_string())),
456        );
457
458        let request = QueryRequest {
459            vector: vec![1.0, 0.0, 0.0],
460            top_k: 10,
461            distance_metric: DistanceMetric::Cosine,
462            include_metadata: true,
463            include_vectors: false,
464            filter: Some(FilterExpression::Field { field }),
465            cursor: None,
466            consistency: Default::default(),
467            staleness_config: None,
468        };
469
470        let response = engine.search(&namespace, &request).await.unwrap();
471
472        // Should only return v1 and v3 (electronics)
473        assert_eq!(response.results.len(), 2);
474        assert!(response
475            .results
476            .iter()
477            .all(|r| r.id == "v1" || r.id == "v3"));
478    }
479
480    #[tokio::test]
481    async fn test_search_with_numeric_filter() {
482        let storage = Arc::new(InMemoryStorage::new());
483        let engine = SearchEngine::new(storage.clone());
484        let namespace = "test".to_string();
485
486        storage.ensure_namespace(&namespace).await.unwrap();
487        storage
488            .upsert(
489                &namespace,
490                vec![
491                    Vector {
492                        id: "v1".to_string(),
493                        values: vec![1.0, 0.0, 0.0],
494                        metadata: Some(serde_json::json!({"price": 100})),
495                        ttl_seconds: None,
496                        expires_at: None,
497                    },
498                    Vector {
499                        id: "v2".to_string(),
500                        values: vec![0.9, 0.1, 0.0],
501                        metadata: Some(serde_json::json!({"price": 20})),
502                        ttl_seconds: None,
503                        expires_at: None,
504                    },
505                    Vector {
506                        id: "v3".to_string(),
507                        values: vec![0.8, 0.2, 0.0],
508                        metadata: Some(serde_json::json!({"price": 50})),
509                        ttl_seconds: None,
510                        expires_at: None,
511                    },
512                ],
513            )
514            .await
515            .unwrap();
516
517        // Filter for price < 60
518        let mut field = HashMap::new();
519        field.insert(
520            "price".to_string(),
521            FilterCondition::Lt(FilterValue::Number(60.0)),
522        );
523
524        let request = QueryRequest {
525            vector: vec![1.0, 0.0, 0.0],
526            top_k: 10,
527            distance_metric: DistanceMetric::Cosine,
528            include_metadata: true,
529            include_vectors: false,
530            filter: Some(FilterExpression::Field { field }),
531            cursor: None,
532            consistency: Default::default(),
533            staleness_config: None,
534        };
535
536        let response = engine.search(&namespace, &request).await.unwrap();
537
538        // Should only return v2 (20) and v3 (50)
539        assert_eq!(response.results.len(), 2);
540        assert!(response
541            .results
542            .iter()
543            .all(|r| r.id == "v2" || r.id == "v3"));
544    }
545}