1use std::collections::HashMap;
2use std::sync::Arc;
3
4use common::{
5 DakeraError, DistanceMetric, NamespaceId, PaginationCursor, QueryRequest, QueryResponse,
6 Result, SearchResult,
7};
8use parking_lot::RwLock;
9use storage::VectorStorage;
10
11use crate::filter::evaluate_filter;
12use crate::hnsw::{HnswConfig, HnswIndex};
13use crate::search::brute_force_search;
14
15const DEFAULT_ANN_THRESHOLD: usize = 1000;
17
18const ANN_FILTER_OVERFETCH_FACTOR: usize = 4;
22
23#[inline]
25fn distance_to_similarity(distance: f32, metric: DistanceMetric) -> f32 {
26 match metric {
27 DistanceMetric::Cosine => 1.0 - distance,
28 DistanceMetric::Euclidean => -distance,
29 DistanceMetric::DotProduct => -distance,
30 }
31}
32
33fn ann_threshold_from_env() -> usize {
35 std::env::var("DAKERA_ANN_THRESHOLD")
36 .ok()
37 .and_then(|v| v.parse().ok())
38 .unwrap_or(DEFAULT_ANN_THRESHOLD)
39}
40
41pub struct SearchEngine<S: VectorStorage + ?Sized> {
43 storage: Arc<S>,
44 ann_indices: RwLock<HashMap<String, Arc<HnswIndex>>>,
46 ann_threshold: usize,
48}
49
50impl<S: VectorStorage + ?Sized> SearchEngine<S> {
51 pub fn new(storage: Arc<S>) -> Self {
52 Self {
53 storage,
54 ann_indices: RwLock::new(HashMap::new()),
55 ann_threshold: ann_threshold_from_env(),
56 }
57 }
58
59 pub async fn search(
61 &self,
62 namespace: &NamespaceId,
63 request: &QueryRequest,
64 ) -> Result<QueryResponse> {
65 if !self.storage.namespace_exists(namespace).await? {
67 return Err(DakeraError::NamespaceNotFound(namespace.clone()));
68 }
69
70 if let Some(expected_dim) = self.storage.dimension(namespace).await? {
72 if request.vector.len() != expected_dim {
73 return Err(DakeraError::DimensionMismatch {
74 expected: expected_dim,
75 actual: request.vector.len(),
76 });
77 }
78 }
79
80 let use_ann = request.cursor.is_none() && self.ann_threshold > 0;
87
88 if use_ann {
89 let count = self.storage.count(namespace).await?;
90 if count > self.ann_threshold {
91 return self.ann_search(namespace, request, count).await;
92 }
93 }
94
95 self.brute_force_path(namespace, request).await
97 }
98
99 async fn brute_force_path(
101 &self,
102 namespace: &NamespaceId,
103 request: &QueryRequest,
104 ) -> Result<QueryResponse> {
105 let vectors = self.storage.get_all(namespace).await?;
106
107 let filtered_vectors: Vec<_> = if let Some(ref filter) = request.filter {
108 vectors
109 .into_iter()
110 .filter(|v| evaluate_filter(filter, v.metadata.as_ref()))
111 .collect()
112 } else {
113 vectors
114 };
115
116 let cursor = request
117 .cursor
118 .as_ref()
119 .and_then(|c| PaginationCursor::decode(c));
120
121 tracing::debug!(
122 namespace = %namespace,
123 vector_count = filtered_vectors.len(),
124 top_k = request.top_k,
125 metric = ?request.distance_metric,
126 has_filter = request.filter.is_some(),
127 has_cursor = cursor.is_some(),
128 "Performing brute-force search"
129 );
130
131 let response = brute_force_search(
132 &request.vector,
133 &filtered_vectors,
134 request.top_k,
135 request.distance_metric,
136 request.include_metadata,
137 request.include_vectors,
138 cursor.as_ref(),
139 );
140
141 Ok(response)
142 }
143
144 async fn ann_search(
146 &self,
147 namespace: &NamespaceId,
148 request: &QueryRequest,
149 vector_count: usize,
150 ) -> Result<QueryResponse> {
151 let index = self
153 .get_or_build_index(namespace, request.distance_metric)
154 .await?;
155
156 let has_filter = request.filter.is_some();
157
158 let hnsw_top_k = if has_filter {
161 request.top_k.saturating_mul(ANN_FILTER_OVERFETCH_FACTOR)
162 } else {
163 request.top_k
164 };
165
166 tracing::debug!(
167 namespace = %namespace,
168 vector_count = vector_count,
169 top_k = request.top_k,
170 hnsw_top_k,
171 has_filter,
172 metric = ?request.distance_metric,
173 "Performing ANN search (HNSW)"
174 );
175
176 let hnsw_results = index.search(&request.vector, hnsw_top_k);
178
179 let need_fetch = request.include_metadata || request.include_vectors || has_filter;
181 let fetched = if need_fetch && !hnsw_results.is_empty() {
182 let ids: Vec<String> = hnsw_results.iter().map(|(id, _)| id.clone()).collect();
183 let vectors = self.storage.get(namespace, &ids).await?;
184 let map: HashMap<String, _> = vectors.into_iter().map(|v| (v.id.clone(), v)).collect();
185 Some(map)
186 } else {
187 None
188 };
189
190 let mut results: Vec<SearchResult> = hnsw_results
192 .into_iter()
193 .filter_map(|(id, distance)| {
194 let score = distance_to_similarity(distance, request.distance_metric);
195 let entry = fetched.as_ref().and_then(|map| map.get(&id));
196
197 if let Some(ref filter) = request.filter {
199 let metadata = entry.and_then(|v| v.metadata.as_ref());
200 if !evaluate_filter(filter, metadata) {
201 return None;
202 }
203 }
204
205 let (metadata, vector) = if let Some(v) = entry {
206 (
207 if request.include_metadata {
208 v.metadata.clone()
209 } else {
210 None
211 },
212 if request.include_vectors {
213 Some(v.values.clone())
214 } else {
215 None
216 },
217 )
218 } else {
219 (None, None)
220 };
221 Some(SearchResult {
222 id,
223 score,
224 metadata,
225 vector,
226 })
227 })
228 .collect();
229
230 results.truncate(request.top_k);
232
233 Ok(QueryResponse {
234 results,
235 next_cursor: None,
236 has_more: Some(false),
237 search_time_ms: 0, })
239 }
240
241 async fn get_or_build_index(
243 &self,
244 namespace: &NamespaceId,
245 metric: DistanceMetric,
246 ) -> Result<Arc<HnswIndex>> {
247 {
249 let indices = self.ann_indices.read();
250 if let Some(index) = indices.get(namespace.as_str()) {
251 return Ok(Arc::clone(index));
252 }
253 }
254
255 tracing::info!(namespace = %namespace, "Building HNSW index for ANN acceleration");
257 let vectors = self.storage.get_all(namespace).await?;
258
259 let config = HnswConfig::default().with_distance_metric(metric);
260 let index = HnswIndex::with_config(config);
261
262 for v in &vectors {
263 index.insert(v.id.clone(), v.values.clone());
264 }
265
266 let index = Arc::new(index);
267
268 {
270 let mut indices = self.ann_indices.write();
271 indices.insert(namespace.clone(), Arc::clone(&index));
272 }
273
274 tracing::info!(
275 namespace = %namespace,
276 vectors = vectors.len(),
277 "HNSW index built and cached"
278 );
279
280 Ok(index)
281 }
282
283 pub fn invalidate_ann_index(&self, namespace: &NamespaceId) {
285 let mut indices = self.ann_indices.write();
286 if indices.remove(namespace.as_str()).is_some() {
287 tracing::debug!(namespace = %namespace, "HNSW index invalidated");
288 }
289 }
290
291 pub fn storage(&self) -> &Arc<S> {
293 &self.storage
294 }
295
296 #[cfg(test)]
298 pub fn new_with_threshold(storage: Arc<S>, ann_threshold: usize) -> Self {
299 Self {
300 storage,
301 ann_indices: RwLock::new(HashMap::new()),
302 ann_threshold,
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use common::{DistanceMetric, FilterCondition, FilterExpression, FilterValue, Vector};
311 use std::collections::HashMap;
312 use storage::InMemoryStorage;
313
314 async fn setup_engine() -> (SearchEngine<InMemoryStorage>, String) {
315 let storage = Arc::new(InMemoryStorage::new());
316 let engine = SearchEngine::new(storage.clone());
317 let namespace = "test".to_string();
318
319 storage.ensure_namespace(&namespace).await.unwrap();
320 storage
321 .upsert(
322 &namespace,
323 vec![
324 Vector {
325 id: "v1".to_string(),
326 values: vec![1.0, 0.0, 0.0],
327 metadata: None,
328 ttl_seconds: None,
329 expires_at: None,
330 },
331 Vector {
332 id: "v2".to_string(),
333 values: vec![0.0, 1.0, 0.0],
334 metadata: None,
335 ttl_seconds: None,
336 expires_at: None,
337 },
338 Vector {
339 id: "v3".to_string(),
340 values: vec![0.707, 0.707, 0.0],
341 metadata: None,
342 ttl_seconds: None,
343 expires_at: None,
344 },
345 ],
346 )
347 .await
348 .unwrap();
349
350 (engine, namespace)
351 }
352
353 #[tokio::test]
354 async fn test_search_basic() {
355 let (engine, namespace) = setup_engine().await;
356
357 let request = QueryRequest {
358 vector: vec![1.0, 0.0, 0.0],
359 top_k: 2,
360 distance_metric: DistanceMetric::Cosine,
361 include_metadata: true,
362 include_vectors: false,
363 filter: None,
364 cursor: None,
365 consistency: Default::default(),
366 staleness_config: None,
367 };
368
369 let response = engine.search(&namespace, &request).await.unwrap();
370
371 assert_eq!(response.results.len(), 2);
372 assert_eq!(response.results[0].id, "v1"); }
374
375 #[tokio::test]
376 async fn test_search_namespace_not_found() {
377 let storage = Arc::new(InMemoryStorage::new());
378 let engine = SearchEngine::new(storage);
379
380 let request = QueryRequest {
381 vector: vec![1.0, 0.0, 0.0],
382 top_k: 5,
383 distance_metric: DistanceMetric::Cosine,
384 include_metadata: true,
385 include_vectors: false,
386 filter: None,
387 cursor: None,
388 consistency: Default::default(),
389 staleness_config: None,
390 };
391
392 let result = engine.search(&"nonexistent".to_string(), &request).await;
393
394 assert!(matches!(result, Err(DakeraError::NamespaceNotFound(_))));
395 }
396
397 #[tokio::test]
398 async fn test_search_dimension_mismatch() {
399 let (engine, namespace) = setup_engine().await;
400
401 let request = QueryRequest {
402 vector: vec![1.0, 0.0], top_k: 5,
404 distance_metric: DistanceMetric::Cosine,
405 include_metadata: true,
406 include_vectors: false,
407 filter: None,
408 cursor: None,
409 consistency: Default::default(),
410 staleness_config: None,
411 };
412
413 let result = engine.search(&namespace, &request).await;
414
415 assert!(matches!(
416 result,
417 Err(DakeraError::DimensionMismatch {
418 expected: 3,
419 actual: 2
420 })
421 ));
422 }
423
424 #[tokio::test]
425 async fn test_search_empty_namespace() {
426 let storage = Arc::new(InMemoryStorage::new());
427 let engine = SearchEngine::new(storage.clone());
428 let namespace = "empty".to_string();
429
430 storage.ensure_namespace(&namespace).await.unwrap();
431
432 let request = QueryRequest {
433 vector: vec![1.0, 0.0, 0.0],
434 top_k: 5,
435 distance_metric: DistanceMetric::Cosine,
436 include_metadata: true,
437 include_vectors: false,
438 filter: None,
439 cursor: None,
440 consistency: Default::default(),
441 staleness_config: None,
442 };
443
444 let response = engine.search(&namespace, &request).await.unwrap();
445
446 assert!(response.results.is_empty());
447 }
448
449 #[tokio::test]
450 async fn test_search_with_filter() {
451 let storage = Arc::new(InMemoryStorage::new());
452 let engine = SearchEngine::new(storage.clone());
453 let namespace = "test".to_string();
454
455 storage.ensure_namespace(&namespace).await.unwrap();
456 storage
457 .upsert(
458 &namespace,
459 vec![
460 Vector {
461 id: "v1".to_string(),
462 values: vec![1.0, 0.0, 0.0],
463 metadata: Some(
464 serde_json::json!({"category": "electronics", "price": 100}),
465 ),
466 ttl_seconds: None,
467 expires_at: None,
468 },
469 Vector {
470 id: "v2".to_string(),
471 values: vec![0.9, 0.1, 0.0],
472 metadata: Some(serde_json::json!({"category": "books", "price": 20})),
473 ttl_seconds: None,
474 expires_at: None,
475 },
476 Vector {
477 id: "v3".to_string(),
478 values: vec![0.8, 0.2, 0.0],
479 metadata: Some(serde_json::json!({"category": "electronics", "price": 50})),
480 ttl_seconds: None,
481 expires_at: None,
482 },
483 ],
484 )
485 .await
486 .unwrap();
487
488 let mut field = HashMap::new();
490 field.insert(
491 "category".to_string(),
492 FilterCondition::Eq(FilterValue::String("electronics".to_string())),
493 );
494
495 let request = QueryRequest {
496 vector: vec![1.0, 0.0, 0.0],
497 top_k: 10,
498 distance_metric: DistanceMetric::Cosine,
499 include_metadata: true,
500 include_vectors: false,
501 filter: Some(FilterExpression::Field { field }),
502 cursor: None,
503 consistency: Default::default(),
504 staleness_config: None,
505 };
506
507 let response = engine.search(&namespace, &request).await.unwrap();
508
509 assert_eq!(response.results.len(), 2);
511 assert!(response
512 .results
513 .iter()
514 .all(|r| r.id == "v1" || r.id == "v3"));
515 }
516
517 #[tokio::test]
518 async fn test_search_with_numeric_filter() {
519 let storage = Arc::new(InMemoryStorage::new());
520 let engine = SearchEngine::new(storage.clone());
521 let namespace = "test".to_string();
522
523 storage.ensure_namespace(&namespace).await.unwrap();
524 storage
525 .upsert(
526 &namespace,
527 vec![
528 Vector {
529 id: "v1".to_string(),
530 values: vec![1.0, 0.0, 0.0],
531 metadata: Some(serde_json::json!({"price": 100})),
532 ttl_seconds: None,
533 expires_at: None,
534 },
535 Vector {
536 id: "v2".to_string(),
537 values: vec![0.9, 0.1, 0.0],
538 metadata: Some(serde_json::json!({"price": 20})),
539 ttl_seconds: None,
540 expires_at: None,
541 },
542 Vector {
543 id: "v3".to_string(),
544 values: vec![0.8, 0.2, 0.0],
545 metadata: Some(serde_json::json!({"price": 50})),
546 ttl_seconds: None,
547 expires_at: None,
548 },
549 ],
550 )
551 .await
552 .unwrap();
553
554 let mut field = HashMap::new();
556 field.insert(
557 "price".to_string(),
558 FilterCondition::Lt(FilterValue::Number(60.0)),
559 );
560
561 let request = QueryRequest {
562 vector: vec![1.0, 0.0, 0.0],
563 top_k: 10,
564 distance_metric: DistanceMetric::Cosine,
565 include_metadata: true,
566 include_vectors: false,
567 filter: Some(FilterExpression::Field { field }),
568 cursor: None,
569 consistency: Default::default(),
570 staleness_config: None,
571 };
572
573 let response = engine.search(&namespace, &request).await.unwrap();
574
575 assert_eq!(response.results.len(), 2);
577 assert!(response
578 .results
579 .iter()
580 .all(|r| r.id == "v2" || r.id == "v3"));
581 }
582
583 #[tokio::test]
587 async fn test_ann_search_with_filter() {
588 let storage = Arc::new(InMemoryStorage::new());
589 let engine = SearchEngine::new_with_threshold(storage.clone(), 2);
591 let namespace = "test_ann_filter".to_string();
592
593 storage.ensure_namespace(&namespace).await.unwrap();
594 storage
595 .upsert(
596 &namespace,
597 vec![
598 Vector {
599 id: "v1".to_string(),
600 values: vec![1.0, 0.0, 0.0],
601 metadata: Some(serde_json::json!({"category": "electronics"})),
602 ttl_seconds: None,
603 expires_at: None,
604 },
605 Vector {
606 id: "v2".to_string(),
607 values: vec![0.9, 0.1, 0.0],
608 metadata: Some(serde_json::json!({"category": "books"})),
609 ttl_seconds: None,
610 expires_at: None,
611 },
612 Vector {
613 id: "v3".to_string(),
614 values: vec![0.8, 0.2, 0.0],
615 metadata: Some(serde_json::json!({"category": "electronics"})),
616 ttl_seconds: None,
617 expires_at: None,
618 },
619 ],
620 )
621 .await
622 .unwrap();
623
624 let mut field = HashMap::new();
625 field.insert(
626 "category".to_string(),
627 FilterCondition::Eq(FilterValue::String("electronics".to_string())),
628 );
629
630 let request = QueryRequest {
631 vector: vec![1.0, 0.0, 0.0],
632 top_k: 10,
633 distance_metric: DistanceMetric::Cosine,
634 include_metadata: true,
635 include_vectors: false,
636 filter: Some(FilterExpression::Field { field }),
637 cursor: None,
638 consistency: Default::default(),
639 staleness_config: None,
640 };
641
642 let response = engine.search(&namespace, &request).await.unwrap();
643
644 assert_eq!(response.results.len(), 2);
646 assert!(response
647 .results
648 .iter()
649 .all(|r| r.id == "v1" || r.id == "v3"));
650 assert_eq!(response.results[0].id, "v1");
652 }
653}