1use std::collections::HashMap;
2use std::sync::Arc;
3
4use common::{
5 DakeraError, DistanceMetric, NamespaceId, PaginationCursor, QueryRequest, QueryResponse,
6 Result, SearchResult,
7};
8use parking_lot::RwLock;
9use storage::VectorStorage;
10
11use crate::filter::evaluate_filter;
12use crate::hnsw::{HnswConfig, HnswIndex};
13use crate::search::brute_force_search;
14
15const DEFAULT_ANN_THRESHOLD: usize = 1000;
17
18#[inline]
20fn distance_to_similarity(distance: f32, metric: DistanceMetric) -> f32 {
21 match metric {
22 DistanceMetric::Cosine => 1.0 - distance,
23 DistanceMetric::Euclidean => -distance,
24 DistanceMetric::DotProduct => -distance,
25 }
26}
27
28fn ann_threshold_from_env() -> usize {
30 std::env::var("DAKERA_ANN_THRESHOLD")
31 .ok()
32 .and_then(|v| v.parse().ok())
33 .unwrap_or(DEFAULT_ANN_THRESHOLD)
34}
35
36pub struct SearchEngine<S: VectorStorage + ?Sized> {
38 storage: Arc<S>,
39 ann_indices: RwLock<HashMap<String, Arc<HnswIndex>>>,
41 ann_threshold: usize,
43}
44
45impl<S: VectorStorage + ?Sized> SearchEngine<S> {
46 pub fn new(storage: Arc<S>) -> Self {
47 Self {
48 storage,
49 ann_indices: RwLock::new(HashMap::new()),
50 ann_threshold: ann_threshold_from_env(),
51 }
52 }
53
54 pub async fn search(
56 &self,
57 namespace: &NamespaceId,
58 request: &QueryRequest,
59 ) -> Result<QueryResponse> {
60 if !self.storage.namespace_exists(namespace).await? {
62 return Err(DakeraError::NamespaceNotFound(namespace.clone()));
63 }
64
65 if let Some(expected_dim) = self.storage.dimension(namespace).await? {
67 if request.vector.len() != expected_dim {
68 return Err(DakeraError::DimensionMismatch {
69 expected: expected_dim,
70 actual: request.vector.len(),
71 });
72 }
73 }
74
75 let use_ann =
80 request.filter.is_none() && request.cursor.is_none() && self.ann_threshold > 0;
81
82 if use_ann {
83 let count = self.storage.count(namespace).await?;
84 if count > self.ann_threshold {
85 return self.ann_search(namespace, request, count).await;
86 }
87 }
88
89 self.brute_force_path(namespace, request).await
91 }
92
93 async fn brute_force_path(
95 &self,
96 namespace: &NamespaceId,
97 request: &QueryRequest,
98 ) -> Result<QueryResponse> {
99 let vectors = self.storage.get_all(namespace).await?;
100
101 let filtered_vectors: Vec<_> = if let Some(ref filter) = request.filter {
102 vectors
103 .into_iter()
104 .filter(|v| evaluate_filter(filter, v.metadata.as_ref()))
105 .collect()
106 } else {
107 vectors
108 };
109
110 let cursor = request
111 .cursor
112 .as_ref()
113 .and_then(|c| PaginationCursor::decode(c));
114
115 tracing::debug!(
116 namespace = %namespace,
117 vector_count = filtered_vectors.len(),
118 top_k = request.top_k,
119 metric = ?request.distance_metric,
120 has_filter = request.filter.is_some(),
121 has_cursor = cursor.is_some(),
122 "Performing brute-force search"
123 );
124
125 let response = brute_force_search(
126 &request.vector,
127 &filtered_vectors,
128 request.top_k,
129 request.distance_metric,
130 request.include_metadata,
131 request.include_vectors,
132 cursor.as_ref(),
133 );
134
135 Ok(response)
136 }
137
138 async fn ann_search(
140 &self,
141 namespace: &NamespaceId,
142 request: &QueryRequest,
143 vector_count: usize,
144 ) -> Result<QueryResponse> {
145 let index = self
147 .get_or_build_index(namespace, request.distance_metric)
148 .await?;
149
150 tracing::debug!(
151 namespace = %namespace,
152 vector_count = vector_count,
153 top_k = request.top_k,
154 metric = ?request.distance_metric,
155 "Performing ANN search (HNSW)"
156 );
157
158 let hnsw_results = index.search(&request.vector, request.top_k);
160
161 let need_fetch = request.include_metadata || request.include_vectors;
163 let fetched = if need_fetch && !hnsw_results.is_empty() {
164 let ids: Vec<String> = hnsw_results.iter().map(|(id, _)| id.clone()).collect();
165 let vectors = self.storage.get(namespace, &ids).await?;
166 let map: HashMap<String, _> = vectors.into_iter().map(|v| (v.id.clone(), v)).collect();
167 Some(map)
168 } else {
169 None
170 };
171
172 let results: Vec<SearchResult> = hnsw_results
174 .into_iter()
175 .map(|(id, distance)| {
176 let score = distance_to_similarity(distance, request.distance_metric);
177 let (metadata, vector) = if let Some(ref map) = fetched {
178 if let Some(v) = map.get(&id) {
179 (
180 if request.include_metadata {
181 v.metadata.clone()
182 } else {
183 None
184 },
185 if request.include_vectors {
186 Some(v.values.clone())
187 } else {
188 None
189 },
190 )
191 } else {
192 (None, None)
193 }
194 } else {
195 (None, None)
196 };
197 SearchResult {
198 id,
199 score,
200 metadata,
201 vector,
202 }
203 })
204 .collect();
205
206 Ok(QueryResponse {
207 results,
208 next_cursor: None,
209 has_more: Some(false),
210 search_time_ms: 0, })
212 }
213
214 async fn get_or_build_index(
216 &self,
217 namespace: &NamespaceId,
218 metric: DistanceMetric,
219 ) -> Result<Arc<HnswIndex>> {
220 {
222 let indices = self.ann_indices.read();
223 if let Some(index) = indices.get(namespace.as_str()) {
224 return Ok(Arc::clone(index));
225 }
226 }
227
228 tracing::info!(namespace = %namespace, "Building HNSW index for ANN acceleration");
230 let vectors = self.storage.get_all(namespace).await?;
231
232 let config = HnswConfig::default().with_distance_metric(metric);
233 let index = HnswIndex::with_config(config);
234
235 for v in &vectors {
236 index.insert(v.id.clone(), v.values.clone());
237 }
238
239 let index = Arc::new(index);
240
241 {
243 let mut indices = self.ann_indices.write();
244 indices.insert(namespace.clone(), Arc::clone(&index));
245 }
246
247 tracing::info!(
248 namespace = %namespace,
249 vectors = vectors.len(),
250 "HNSW index built and cached"
251 );
252
253 Ok(index)
254 }
255
256 pub fn invalidate_ann_index(&self, namespace: &NamespaceId) {
258 let mut indices = self.ann_indices.write();
259 if indices.remove(namespace.as_str()).is_some() {
260 tracing::debug!(namespace = %namespace, "HNSW index invalidated");
261 }
262 }
263
264 pub fn storage(&self) -> &Arc<S> {
266 &self.storage
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use common::{DistanceMetric, FilterCondition, FilterExpression, FilterValue, Vector};
274 use std::collections::HashMap;
275 use storage::InMemoryStorage;
276
277 async fn setup_engine() -> (SearchEngine<InMemoryStorage>, String) {
278 let storage = Arc::new(InMemoryStorage::new());
279 let engine = SearchEngine::new(storage.clone());
280 let namespace = "test".to_string();
281
282 storage.ensure_namespace(&namespace).await.unwrap();
283 storage
284 .upsert(
285 &namespace,
286 vec![
287 Vector {
288 id: "v1".to_string(),
289 values: vec![1.0, 0.0, 0.0],
290 metadata: None,
291 ttl_seconds: None,
292 expires_at: None,
293 },
294 Vector {
295 id: "v2".to_string(),
296 values: vec![0.0, 1.0, 0.0],
297 metadata: None,
298 ttl_seconds: None,
299 expires_at: None,
300 },
301 Vector {
302 id: "v3".to_string(),
303 values: vec![0.707, 0.707, 0.0],
304 metadata: None,
305 ttl_seconds: None,
306 expires_at: None,
307 },
308 ],
309 )
310 .await
311 .unwrap();
312
313 (engine, namespace)
314 }
315
316 #[tokio::test]
317 async fn test_search_basic() {
318 let (engine, namespace) = setup_engine().await;
319
320 let request = QueryRequest {
321 vector: vec![1.0, 0.0, 0.0],
322 top_k: 2,
323 distance_metric: DistanceMetric::Cosine,
324 include_metadata: true,
325 include_vectors: false,
326 filter: None,
327 cursor: None,
328 consistency: Default::default(),
329 staleness_config: None,
330 };
331
332 let response = engine.search(&namespace, &request).await.unwrap();
333
334 assert_eq!(response.results.len(), 2);
335 assert_eq!(response.results[0].id, "v1"); }
337
338 #[tokio::test]
339 async fn test_search_namespace_not_found() {
340 let storage = Arc::new(InMemoryStorage::new());
341 let engine = SearchEngine::new(storage);
342
343 let request = QueryRequest {
344 vector: vec![1.0, 0.0, 0.0],
345 top_k: 5,
346 distance_metric: DistanceMetric::Cosine,
347 include_metadata: true,
348 include_vectors: false,
349 filter: None,
350 cursor: None,
351 consistency: Default::default(),
352 staleness_config: None,
353 };
354
355 let result = engine.search(&"nonexistent".to_string(), &request).await;
356
357 assert!(matches!(result, Err(DakeraError::NamespaceNotFound(_))));
358 }
359
360 #[tokio::test]
361 async fn test_search_dimension_mismatch() {
362 let (engine, namespace) = setup_engine().await;
363
364 let request = QueryRequest {
365 vector: vec![1.0, 0.0], top_k: 5,
367 distance_metric: DistanceMetric::Cosine,
368 include_metadata: true,
369 include_vectors: false,
370 filter: None,
371 cursor: None,
372 consistency: Default::default(),
373 staleness_config: None,
374 };
375
376 let result = engine.search(&namespace, &request).await;
377
378 assert!(matches!(
379 result,
380 Err(DakeraError::DimensionMismatch {
381 expected: 3,
382 actual: 2
383 })
384 ));
385 }
386
387 #[tokio::test]
388 async fn test_search_empty_namespace() {
389 let storage = Arc::new(InMemoryStorage::new());
390 let engine = SearchEngine::new(storage.clone());
391 let namespace = "empty".to_string();
392
393 storage.ensure_namespace(&namespace).await.unwrap();
394
395 let request = QueryRequest {
396 vector: vec![1.0, 0.0, 0.0],
397 top_k: 5,
398 distance_metric: DistanceMetric::Cosine,
399 include_metadata: true,
400 include_vectors: false,
401 filter: None,
402 cursor: None,
403 consistency: Default::default(),
404 staleness_config: None,
405 };
406
407 let response = engine.search(&namespace, &request).await.unwrap();
408
409 assert!(response.results.is_empty());
410 }
411
412 #[tokio::test]
413 async fn test_search_with_filter() {
414 let storage = Arc::new(InMemoryStorage::new());
415 let engine = SearchEngine::new(storage.clone());
416 let namespace = "test".to_string();
417
418 storage.ensure_namespace(&namespace).await.unwrap();
419 storage
420 .upsert(
421 &namespace,
422 vec![
423 Vector {
424 id: "v1".to_string(),
425 values: vec![1.0, 0.0, 0.0],
426 metadata: Some(
427 serde_json::json!({"category": "electronics", "price": 100}),
428 ),
429 ttl_seconds: None,
430 expires_at: None,
431 },
432 Vector {
433 id: "v2".to_string(),
434 values: vec![0.9, 0.1, 0.0],
435 metadata: Some(serde_json::json!({"category": "books", "price": 20})),
436 ttl_seconds: None,
437 expires_at: None,
438 },
439 Vector {
440 id: "v3".to_string(),
441 values: vec![0.8, 0.2, 0.0],
442 metadata: Some(serde_json::json!({"category": "electronics", "price": 50})),
443 ttl_seconds: None,
444 expires_at: None,
445 },
446 ],
447 )
448 .await
449 .unwrap();
450
451 let mut field = HashMap::new();
453 field.insert(
454 "category".to_string(),
455 FilterCondition::Eq(FilterValue::String("electronics".to_string())),
456 );
457
458 let request = QueryRequest {
459 vector: vec![1.0, 0.0, 0.0],
460 top_k: 10,
461 distance_metric: DistanceMetric::Cosine,
462 include_metadata: true,
463 include_vectors: false,
464 filter: Some(FilterExpression::Field { field }),
465 cursor: None,
466 consistency: Default::default(),
467 staleness_config: None,
468 };
469
470 let response = engine.search(&namespace, &request).await.unwrap();
471
472 assert_eq!(response.results.len(), 2);
474 assert!(response
475 .results
476 .iter()
477 .all(|r| r.id == "v1" || r.id == "v3"));
478 }
479
480 #[tokio::test]
481 async fn test_search_with_numeric_filter() {
482 let storage = Arc::new(InMemoryStorage::new());
483 let engine = SearchEngine::new(storage.clone());
484 let namespace = "test".to_string();
485
486 storage.ensure_namespace(&namespace).await.unwrap();
487 storage
488 .upsert(
489 &namespace,
490 vec![
491 Vector {
492 id: "v1".to_string(),
493 values: vec![1.0, 0.0, 0.0],
494 metadata: Some(serde_json::json!({"price": 100})),
495 ttl_seconds: None,
496 expires_at: None,
497 },
498 Vector {
499 id: "v2".to_string(),
500 values: vec![0.9, 0.1, 0.0],
501 metadata: Some(serde_json::json!({"price": 20})),
502 ttl_seconds: None,
503 expires_at: None,
504 },
505 Vector {
506 id: "v3".to_string(),
507 values: vec![0.8, 0.2, 0.0],
508 metadata: Some(serde_json::json!({"price": 50})),
509 ttl_seconds: None,
510 expires_at: None,
511 },
512 ],
513 )
514 .await
515 .unwrap();
516
517 let mut field = HashMap::new();
519 field.insert(
520 "price".to_string(),
521 FilterCondition::Lt(FilterValue::Number(60.0)),
522 );
523
524 let request = QueryRequest {
525 vector: vec![1.0, 0.0, 0.0],
526 top_k: 10,
527 distance_metric: DistanceMetric::Cosine,
528 include_metadata: true,
529 include_vectors: false,
530 filter: Some(FilterExpression::Field { field }),
531 cursor: None,
532 consistency: Default::default(),
533 staleness_config: None,
534 };
535
536 let response = engine.search(&namespace, &request).await.unwrap();
537
538 assert_eq!(response.results.len(), 2);
540 assert!(response
541 .results
542 .iter()
543 .all(|r| r.id == "v2" || r.id == "v3"));
544 }
545}