Skip to main content

engine/
engine.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use common::{
5    DakeraError, DistanceMetric, NamespaceId, PaginationCursor, QueryRequest, QueryResponse,
6    Result, SearchResult,
7};
8use parking_lot::RwLock;
9use storage::VectorStorage;
10
11use crate::filter::evaluate_filter;
12use crate::hnsw::{HnswConfig, HnswIndex};
13use crate::search::brute_force_search;
14
15/// Default vector count threshold above which HNSW is used instead of brute force
16const DEFAULT_ANN_THRESHOLD: usize = 1000;
17
18/// Over-fetch multiplier applied to top_k when filtering post-HNSW.
19/// ANN returns more candidates than needed so metadata filtering has sufficient
20/// results to choose from before truncating to the caller's top_k.
21const ANN_FILTER_OVERFETCH_FACTOR: usize = 4;
22
23/// Convert HNSW distance back to similarity score (inverse of hnsw::similarity_to_distance)
24#[inline]
25fn distance_to_similarity(distance: f32, metric: DistanceMetric) -> f32 {
26    match metric {
27        DistanceMetric::Cosine => 1.0 - distance,
28        DistanceMetric::Euclidean => -distance,
29        DistanceMetric::DotProduct => -distance,
30    }
31}
32
33/// Read ANN threshold from environment variable
34fn ann_threshold_from_env() -> usize {
35    std::env::var("DAKERA_ANN_THRESHOLD")
36        .ok()
37        .and_then(|v| v.parse().ok())
38        .unwrap_or(DEFAULT_ANN_THRESHOLD)
39}
40
41/// Main search engine that coordinates storage and search operations
42pub struct SearchEngine<S: VectorStorage + ?Sized> {
43    storage: Arc<S>,
44    /// Cached HNSW indices per namespace for ANN acceleration
45    ann_indices: RwLock<HashMap<String, Arc<HnswIndex>>>,
46    /// Vector count threshold above which HNSW is used
47    ann_threshold: usize,
48}
49
50impl<S: VectorStorage + ?Sized> SearchEngine<S> {
51    pub fn new(storage: Arc<S>) -> Self {
52        Self {
53            storage,
54            ann_indices: RwLock::new(HashMap::new()),
55            ann_threshold: ann_threshold_from_env(),
56        }
57    }
58
59    /// Perform vector search in a namespace
60    pub async fn search(
61        &self,
62        namespace: &NamespaceId,
63        request: &QueryRequest,
64    ) -> Result<QueryResponse> {
65        // Check if namespace exists
66        if !self.storage.namespace_exists(namespace).await? {
67            return Err(DakeraError::NamespaceNotFound(namespace.clone()));
68        }
69
70        // Validate query dimension against namespace dimension
71        if let Some(expected_dim) = self.storage.dimension(namespace).await? {
72            if request.vector.len() != expected_dim {
73                return Err(DakeraError::DimensionMismatch {
74                    expected: expected_dim,
75                    actual: request.vector.len(),
76                });
77            }
78        }
79
80        // Determine if we can use ANN acceleration:
81        // - No cursor pagination (HNSW returns top-k directly)
82        // - Namespace has enough vectors to justify index overhead
83        // Note: filtered queries are supported via post-filter ANN (over-fetch from HNSW then
84        // apply evaluate_filter on candidates). This restores O(log N) behaviour for the
85        // filtered queries that dominate production traffic (agent_id, tags, min_importance).
86        let use_ann = request.cursor.is_none() && self.ann_threshold > 0;
87
88        if use_ann {
89            let count = self.storage.count(namespace).await?;
90            if count > self.ann_threshold {
91                return self.ann_search(namespace, request, count).await;
92            }
93        }
94
95        // Fall through to brute-force search
96        self.brute_force_path(namespace, request).await
97    }
98
99    /// Brute-force search path (original behavior)
100    async fn brute_force_path(
101        &self,
102        namespace: &NamespaceId,
103        request: &QueryRequest,
104    ) -> Result<QueryResponse> {
105        let vectors = self.storage.get_all(namespace).await?;
106
107        let filtered_vectors: Vec<_> = if let Some(ref filter) = request.filter {
108            vectors
109                .into_iter()
110                .filter(|v| evaluate_filter(filter, v.metadata.as_ref()))
111                .collect()
112        } else {
113            vectors
114        };
115
116        let cursor = request
117            .cursor
118            .as_ref()
119            .and_then(|c| PaginationCursor::decode(c));
120
121        tracing::debug!(
122            namespace = %namespace,
123            vector_count = filtered_vectors.len(),
124            top_k = request.top_k,
125            metric = ?request.distance_metric,
126            has_filter = request.filter.is_some(),
127            has_cursor = cursor.is_some(),
128            "Performing brute-force search"
129        );
130
131        let response = brute_force_search(
132            &request.vector,
133            &filtered_vectors,
134            request.top_k,
135            request.distance_metric,
136            request.include_metadata,
137            request.include_vectors,
138            cursor.as_ref(),
139        );
140
141        Ok(response)
142    }
143
144    /// ANN search path using cached HNSW index
145    async fn ann_search(
146        &self,
147        namespace: &NamespaceId,
148        request: &QueryRequest,
149        vector_count: usize,
150    ) -> Result<QueryResponse> {
151        // Get or build the HNSW index for this namespace
152        let index = self
153            .get_or_build_index(namespace, request.distance_metric)
154            .await?;
155
156        let has_filter = request.filter.is_some();
157
158        // Over-fetch when a filter is present: HNSW may return candidates that fail the
159        // filter, so we ask for more than top_k and truncate after filtering.
160        let hnsw_top_k = if has_filter {
161            request.top_k.saturating_mul(ANN_FILTER_OVERFETCH_FACTOR)
162        } else {
163            request.top_k
164        };
165
166        tracing::debug!(
167            namespace = %namespace,
168            vector_count = vector_count,
169            top_k = request.top_k,
170            hnsw_top_k,
171            has_filter,
172            metric = ?request.distance_metric,
173            "Performing ANN search (HNSW)"
174        );
175
176        // Search the HNSW index — returns (VectorId, distance) sorted by score
177        let hnsw_results = index.search(&request.vector, hnsw_top_k);
178
179        // Fetch metadata when caller wants it or when we need to evaluate the filter
180        let need_fetch = request.include_metadata || request.include_vectors || has_filter;
181        let fetched = if need_fetch && !hnsw_results.is_empty() {
182            let ids: Vec<String> = hnsw_results.iter().map(|(id, _)| id.clone()).collect();
183            let vectors = self.storage.get(namespace, &ids).await?;
184            let map: HashMap<String, _> = vectors.into_iter().map(|v| (v.id.clone(), v)).collect();
185            Some(map)
186        } else {
187            None
188        };
189
190        // Convert HNSW results to SearchResults, applying post-filter when present
191        let mut results: Vec<SearchResult> = hnsw_results
192            .into_iter()
193            .filter_map(|(id, distance)| {
194                let score = distance_to_similarity(distance, request.distance_metric);
195                let entry = fetched.as_ref().and_then(|map| map.get(&id));
196
197                // Post-filter: drop candidates that do not match the filter expression
198                if let Some(ref filter) = request.filter {
199                    let metadata = entry.and_then(|v| v.metadata.as_ref());
200                    if !evaluate_filter(filter, metadata) {
201                        return None;
202                    }
203                }
204
205                let (metadata, vector) = if let Some(v) = entry {
206                    (
207                        if request.include_metadata {
208                            v.metadata.clone()
209                        } else {
210                            None
211                        },
212                        if request.include_vectors {
213                            Some(v.values.clone())
214                        } else {
215                            None
216                        },
217                    )
218                } else {
219                    (None, None)
220                };
221                Some(SearchResult {
222                    id,
223                    score,
224                    metadata,
225                    vector,
226                })
227            })
228            .collect();
229
230        // Truncate to the caller's requested top_k (over-fetch may have returned more)
231        results.truncate(request.top_k);
232
233        Ok(QueryResponse {
234            results,
235            next_cursor: None,
236            has_more: Some(false),
237            search_time_ms: 0, // caller typically overwrites this
238        })
239    }
240
241    /// Get cached HNSW index or build one from storage
242    async fn get_or_build_index(
243        &self,
244        namespace: &NamespaceId,
245        metric: DistanceMetric,
246    ) -> Result<Arc<HnswIndex>> {
247        // Fast path: check read lock
248        {
249            let indices = self.ann_indices.read();
250            if let Some(index) = indices.get(namespace.as_str()) {
251                return Ok(Arc::clone(index));
252            }
253        }
254
255        // Slow path: build index from storage
256        tracing::info!(namespace = %namespace, "Building HNSW index for ANN acceleration");
257        let vectors = self.storage.get_all(namespace).await?;
258
259        let config = HnswConfig::default().with_distance_metric(metric);
260        let index = HnswIndex::with_config(config);
261
262        for v in &vectors {
263            index.insert(v.id.clone(), v.values.clone());
264        }
265
266        let index = Arc::new(index);
267
268        // Cache it
269        {
270            let mut indices = self.ann_indices.write();
271            indices.insert(namespace.clone(), Arc::clone(&index));
272        }
273
274        tracing::info!(
275            namespace = %namespace,
276            vectors = vectors.len(),
277            "HNSW index built and cached"
278        );
279
280        Ok(index)
281    }
282
283    /// Invalidate the cached HNSW index for a namespace (call after upsert/delete)
284    pub fn invalidate_ann_index(&self, namespace: &NamespaceId) {
285        let mut indices = self.ann_indices.write();
286        if indices.remove(namespace.as_str()).is_some() {
287            tracing::debug!(namespace = %namespace, "HNSW index invalidated");
288        }
289    }
290
291    /// Get reference to storage
292    pub fn storage(&self) -> &Arc<S> {
293        &self.storage
294    }
295
296    /// Create engine with an explicit ANN threshold (for tests only)
297    #[cfg(test)]
298    pub fn new_with_threshold(storage: Arc<S>, ann_threshold: usize) -> Self {
299        Self {
300            storage,
301            ann_indices: RwLock::new(HashMap::new()),
302            ann_threshold,
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use common::{DistanceMetric, FilterCondition, FilterExpression, FilterValue, Vector};
311    use std::collections::HashMap;
312    use storage::InMemoryStorage;
313
314    async fn setup_engine() -> (SearchEngine<InMemoryStorage>, String) {
315        let storage = Arc::new(InMemoryStorage::new());
316        let engine = SearchEngine::new(storage.clone());
317        let namespace = "test".to_string();
318
319        storage.ensure_namespace(&namespace).await.unwrap();
320        storage
321            .upsert(
322                &namespace,
323                vec![
324                    Vector {
325                        id: "v1".to_string(),
326                        values: vec![1.0, 0.0, 0.0],
327                        metadata: None,
328                        ttl_seconds: None,
329                        expires_at: None,
330                    },
331                    Vector {
332                        id: "v2".to_string(),
333                        values: vec![0.0, 1.0, 0.0],
334                        metadata: None,
335                        ttl_seconds: None,
336                        expires_at: None,
337                    },
338                    Vector {
339                        id: "v3".to_string(),
340                        values: vec![0.707, 0.707, 0.0],
341                        metadata: None,
342                        ttl_seconds: None,
343                        expires_at: None,
344                    },
345                ],
346            )
347            .await
348            .unwrap();
349
350        (engine, namespace)
351    }
352
353    #[tokio::test]
354    async fn test_search_basic() {
355        let (engine, namespace) = setup_engine().await;
356
357        let request = QueryRequest {
358            vector: vec![1.0, 0.0, 0.0],
359            top_k: 2,
360            distance_metric: DistanceMetric::Cosine,
361            include_metadata: true,
362            include_vectors: false,
363            filter: None,
364            cursor: None,
365            consistency: Default::default(),
366            staleness_config: None,
367        };
368
369        let response = engine.search(&namespace, &request).await.unwrap();
370
371        assert_eq!(response.results.len(), 2);
372        assert_eq!(response.results[0].id, "v1"); // Perfect match
373    }
374
375    #[tokio::test]
376    async fn test_search_namespace_not_found() {
377        let storage = Arc::new(InMemoryStorage::new());
378        let engine = SearchEngine::new(storage);
379
380        let request = QueryRequest {
381            vector: vec![1.0, 0.0, 0.0],
382            top_k: 5,
383            distance_metric: DistanceMetric::Cosine,
384            include_metadata: true,
385            include_vectors: false,
386            filter: None,
387            cursor: None,
388            consistency: Default::default(),
389            staleness_config: None,
390        };
391
392        let result = engine.search(&"nonexistent".to_string(), &request).await;
393
394        assert!(matches!(result, Err(DakeraError::NamespaceNotFound(_))));
395    }
396
397    #[tokio::test]
398    async fn test_search_dimension_mismatch() {
399        let (engine, namespace) = setup_engine().await;
400
401        let request = QueryRequest {
402            vector: vec![1.0, 0.0], // Wrong dimension (2 instead of 3)
403            top_k: 5,
404            distance_metric: DistanceMetric::Cosine,
405            include_metadata: true,
406            include_vectors: false,
407            filter: None,
408            cursor: None,
409            consistency: Default::default(),
410            staleness_config: None,
411        };
412
413        let result = engine.search(&namespace, &request).await;
414
415        assert!(matches!(
416            result,
417            Err(DakeraError::DimensionMismatch {
418                expected: 3,
419                actual: 2
420            })
421        ));
422    }
423
424    #[tokio::test]
425    async fn test_search_empty_namespace() {
426        let storage = Arc::new(InMemoryStorage::new());
427        let engine = SearchEngine::new(storage.clone());
428        let namespace = "empty".to_string();
429
430        storage.ensure_namespace(&namespace).await.unwrap();
431
432        let request = QueryRequest {
433            vector: vec![1.0, 0.0, 0.0],
434            top_k: 5,
435            distance_metric: DistanceMetric::Cosine,
436            include_metadata: true,
437            include_vectors: false,
438            filter: None,
439            cursor: None,
440            consistency: Default::default(),
441            staleness_config: None,
442        };
443
444        let response = engine.search(&namespace, &request).await.unwrap();
445
446        assert!(response.results.is_empty());
447    }
448
449    #[tokio::test]
450    async fn test_search_with_filter() {
451        let storage = Arc::new(InMemoryStorage::new());
452        let engine = SearchEngine::new(storage.clone());
453        let namespace = "test".to_string();
454
455        storage.ensure_namespace(&namespace).await.unwrap();
456        storage
457            .upsert(
458                &namespace,
459                vec![
460                    Vector {
461                        id: "v1".to_string(),
462                        values: vec![1.0, 0.0, 0.0],
463                        metadata: Some(
464                            serde_json::json!({"category": "electronics", "price": 100}),
465                        ),
466                        ttl_seconds: None,
467                        expires_at: None,
468                    },
469                    Vector {
470                        id: "v2".to_string(),
471                        values: vec![0.9, 0.1, 0.0],
472                        metadata: Some(serde_json::json!({"category": "books", "price": 20})),
473                        ttl_seconds: None,
474                        expires_at: None,
475                    },
476                    Vector {
477                        id: "v3".to_string(),
478                        values: vec![0.8, 0.2, 0.0],
479                        metadata: Some(serde_json::json!({"category": "electronics", "price": 50})),
480                        ttl_seconds: None,
481                        expires_at: None,
482                    },
483                ],
484            )
485            .await
486            .unwrap();
487
488        // Filter for electronics only
489        let mut field = HashMap::new();
490        field.insert(
491            "category".to_string(),
492            FilterCondition::Eq(FilterValue::String("electronics".to_string())),
493        );
494
495        let request = QueryRequest {
496            vector: vec![1.0, 0.0, 0.0],
497            top_k: 10,
498            distance_metric: DistanceMetric::Cosine,
499            include_metadata: true,
500            include_vectors: false,
501            filter: Some(FilterExpression::Field { field }),
502            cursor: None,
503            consistency: Default::default(),
504            staleness_config: None,
505        };
506
507        let response = engine.search(&namespace, &request).await.unwrap();
508
509        // Should only return v1 and v3 (electronics)
510        assert_eq!(response.results.len(), 2);
511        assert!(response
512            .results
513            .iter()
514            .all(|r| r.id == "v1" || r.id == "v3"));
515    }
516
517    #[tokio::test]
518    async fn test_search_with_numeric_filter() {
519        let storage = Arc::new(InMemoryStorage::new());
520        let engine = SearchEngine::new(storage.clone());
521        let namespace = "test".to_string();
522
523        storage.ensure_namespace(&namespace).await.unwrap();
524        storage
525            .upsert(
526                &namespace,
527                vec![
528                    Vector {
529                        id: "v1".to_string(),
530                        values: vec![1.0, 0.0, 0.0],
531                        metadata: Some(serde_json::json!({"price": 100})),
532                        ttl_seconds: None,
533                        expires_at: None,
534                    },
535                    Vector {
536                        id: "v2".to_string(),
537                        values: vec![0.9, 0.1, 0.0],
538                        metadata: Some(serde_json::json!({"price": 20})),
539                        ttl_seconds: None,
540                        expires_at: None,
541                    },
542                    Vector {
543                        id: "v3".to_string(),
544                        values: vec![0.8, 0.2, 0.0],
545                        metadata: Some(serde_json::json!({"price": 50})),
546                        ttl_seconds: None,
547                        expires_at: None,
548                    },
549                ],
550            )
551            .await
552            .unwrap();
553
554        // Filter for price < 60
555        let mut field = HashMap::new();
556        field.insert(
557            "price".to_string(),
558            FilterCondition::Lt(FilterValue::Number(60.0)),
559        );
560
561        let request = QueryRequest {
562            vector: vec![1.0, 0.0, 0.0],
563            top_k: 10,
564            distance_metric: DistanceMetric::Cosine,
565            include_metadata: true,
566            include_vectors: false,
567            filter: Some(FilterExpression::Field { field }),
568            cursor: None,
569            consistency: Default::default(),
570            staleness_config: None,
571        };
572
573        let response = engine.search(&namespace, &request).await.unwrap();
574
575        // Should only return v2 (20) and v3 (50)
576        assert_eq!(response.results.len(), 2);
577        assert!(response
578            .results
579            .iter()
580            .all(|r| r.id == "v2" || r.id == "v3"));
581    }
582
583    /// Verify that HNSW (ANN) path correctly applies post-filter when the namespace
584    /// exceeds the threshold.  Uses a low threshold (2) so a 3-vector namespace
585    /// exercises the ANN code path.
586    #[tokio::test]
587    async fn test_ann_search_with_filter() {
588        let storage = Arc::new(InMemoryStorage::new());
589        // Threshold of 2 forces HNSW for a 3-vector namespace
590        let engine = SearchEngine::new_with_threshold(storage.clone(), 2);
591        let namespace = "test_ann_filter".to_string();
592
593        storage.ensure_namespace(&namespace).await.unwrap();
594        storage
595            .upsert(
596                &namespace,
597                vec![
598                    Vector {
599                        id: "v1".to_string(),
600                        values: vec![1.0, 0.0, 0.0],
601                        metadata: Some(serde_json::json!({"category": "electronics"})),
602                        ttl_seconds: None,
603                        expires_at: None,
604                    },
605                    Vector {
606                        id: "v2".to_string(),
607                        values: vec![0.9, 0.1, 0.0],
608                        metadata: Some(serde_json::json!({"category": "books"})),
609                        ttl_seconds: None,
610                        expires_at: None,
611                    },
612                    Vector {
613                        id: "v3".to_string(),
614                        values: vec![0.8, 0.2, 0.0],
615                        metadata: Some(serde_json::json!({"category": "electronics"})),
616                        ttl_seconds: None,
617                        expires_at: None,
618                    },
619                ],
620            )
621            .await
622            .unwrap();
623
624        let mut field = HashMap::new();
625        field.insert(
626            "category".to_string(),
627            FilterCondition::Eq(FilterValue::String("electronics".to_string())),
628        );
629
630        let request = QueryRequest {
631            vector: vec![1.0, 0.0, 0.0],
632            top_k: 10,
633            distance_metric: DistanceMetric::Cosine,
634            include_metadata: true,
635            include_vectors: false,
636            filter: Some(FilterExpression::Field { field }),
637            cursor: None,
638            consistency: Default::default(),
639            staleness_config: None,
640        };
641
642        let response = engine.search(&namespace, &request).await.unwrap();
643
644        // ANN + post-filter: only electronics results (v1, v3)
645        assert_eq!(response.results.len(), 2);
646        assert!(response
647            .results
648            .iter()
649            .all(|r| r.id == "v1" || r.id == "v3"));
650        // Best match (v1 = [1,0,0]) should come first
651        assert_eq!(response.results[0].id, "v1");
652    }
653}