use std::collections::HashMap;
use std::sync::Arc;
use common::{
DakeraError, DistanceMetric, NamespaceId, PaginationCursor, QueryRequest, QueryResponse,
Result, SearchResult,
};
use parking_lot::RwLock;
use storage::VectorStorage;
use crate::filter::evaluate_filter;
use crate::hnsw::{HnswConfig, HnswIndex};
use crate::search::brute_force_search;
const DEFAULT_ANN_THRESHOLD: usize = 1000;
const ANN_FILTER_OVERFETCH_FACTOR: usize = 4;
#[inline]
fn distance_to_similarity(distance: f32, metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Cosine => 1.0 - distance,
DistanceMetric::Euclidean => -distance,
DistanceMetric::DotProduct => -distance,
}
}
fn ann_threshold_from_env() -> usize {
std::env::var("DAKERA_ANN_THRESHOLD")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(DEFAULT_ANN_THRESHOLD)
}
pub struct SearchEngine<S: VectorStorage + ?Sized> {
storage: Arc<S>,
ann_indices: RwLock<HashMap<String, Arc<HnswIndex>>>,
ann_threshold: usize,
}
impl<S: VectorStorage + ?Sized> SearchEngine<S> {
pub fn new(storage: Arc<S>) -> Self {
Self {
storage,
ann_indices: RwLock::new(HashMap::new()),
ann_threshold: ann_threshold_from_env(),
}
}
pub async fn search(
&self,
namespace: &NamespaceId,
request: &QueryRequest,
) -> Result<QueryResponse> {
if !self.storage.namespace_exists(namespace).await? {
return Err(DakeraError::NamespaceNotFound(namespace.clone()));
}
if let Some(expected_dim) = self.storage.dimension(namespace).await? {
if request.vector.len() != expected_dim {
return Err(DakeraError::DimensionMismatch {
expected: expected_dim,
actual: request.vector.len(),
});
}
}
let use_ann = request.cursor.is_none() && self.ann_threshold > 0;
if use_ann {
let count = self.storage.count(namespace).await?;
if count > self.ann_threshold {
return self.ann_search(namespace, request, count).await;
}
}
self.brute_force_path(namespace, request).await
}
async fn brute_force_path(
&self,
namespace: &NamespaceId,
request: &QueryRequest,
) -> Result<QueryResponse> {
let vectors = self.storage.get_all(namespace).await?;
let filtered_vectors: Vec<_> = if let Some(ref filter) = request.filter {
vectors
.into_iter()
.filter(|v| evaluate_filter(filter, v.metadata.as_ref()))
.collect()
} else {
vectors
};
let cursor = request
.cursor
.as_ref()
.and_then(|c| PaginationCursor::decode(c));
tracing::debug!(
namespace = %namespace,
vector_count = filtered_vectors.len(),
top_k = request.top_k,
metric = ?request.distance_metric,
has_filter = request.filter.is_some(),
has_cursor = cursor.is_some(),
"Performing brute-force search"
);
let response = brute_force_search(
&request.vector,
&filtered_vectors,
request.top_k,
request.distance_metric,
request.include_metadata,
request.include_vectors,
cursor.as_ref(),
);
Ok(response)
}
async fn ann_search(
&self,
namespace: &NamespaceId,
request: &QueryRequest,
vector_count: usize,
) -> Result<QueryResponse> {
let index = self
.get_or_build_index(namespace, request.distance_metric)
.await?;
let has_filter = request.filter.is_some();
let hnsw_top_k = if has_filter {
request.top_k.saturating_mul(ANN_FILTER_OVERFETCH_FACTOR)
} else {
request.top_k
};
tracing::debug!(
namespace = %namespace,
vector_count = vector_count,
top_k = request.top_k,
hnsw_top_k,
has_filter,
metric = ?request.distance_metric,
"Performing ANN search (HNSW)"
);
let hnsw_results = index.search(&request.vector, hnsw_top_k);
let need_fetch = request.include_metadata || request.include_vectors || has_filter;
let fetched = if need_fetch && !hnsw_results.is_empty() {
let ids: Vec<String> = hnsw_results.iter().map(|(id, _)| id.clone()).collect();
let vectors = self.storage.get(namespace, &ids).await?;
let map: HashMap<String, _> = vectors.into_iter().map(|v| (v.id.clone(), v)).collect();
Some(map)
} else {
None
};
let mut results: Vec<SearchResult> = hnsw_results
.into_iter()
.filter_map(|(id, distance)| {
let score = distance_to_similarity(distance, request.distance_metric);
let entry = fetched.as_ref().and_then(|map| map.get(&id));
if let Some(ref filter) = request.filter {
let metadata = entry.and_then(|v| v.metadata.as_ref());
if !evaluate_filter(filter, metadata) {
return None;
}
}
let (metadata, vector) = if let Some(v) = entry {
(
if request.include_metadata {
v.metadata.clone()
} else {
None
},
if request.include_vectors {
Some(v.values.clone())
} else {
None
},
)
} else {
(None, None)
};
Some(SearchResult {
id,
score,
metadata,
vector,
})
})
.collect();
results.truncate(request.top_k);
Ok(QueryResponse {
results,
next_cursor: None,
has_more: Some(false),
search_time_ms: 0, })
}
async fn get_or_build_index(
&self,
namespace: &NamespaceId,
metric: DistanceMetric,
) -> Result<Arc<HnswIndex>> {
{
let indices = self.ann_indices.read();
if let Some(index) = indices.get(namespace.as_str()) {
return Ok(Arc::clone(index));
}
}
tracing::info!(namespace = %namespace, "Building HNSW index for ANN acceleration");
let vectors = self.storage.get_all(namespace).await?;
let config = HnswConfig::default().with_distance_metric(metric);
let index = HnswIndex::with_config(config);
for v in &vectors {
index.insert(v.id.clone(), v.values.clone());
}
let index = Arc::new(index);
{
let mut indices = self.ann_indices.write();
indices.insert(namespace.clone(), Arc::clone(&index));
}
tracing::info!(
namespace = %namespace,
vectors = vectors.len(),
"HNSW index built and cached"
);
Ok(index)
}
pub fn invalidate_ann_index(&self, namespace: &NamespaceId) {
let mut indices = self.ann_indices.write();
if indices.remove(namespace.as_str()).is_some() {
tracing::debug!(namespace = %namespace, "HNSW index invalidated");
}
}
pub fn storage(&self) -> &Arc<S> {
&self.storage
}
#[cfg(test)]
pub fn new_with_threshold(storage: Arc<S>, ann_threshold: usize) -> Self {
Self {
storage,
ann_indices: RwLock::new(HashMap::new()),
ann_threshold,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::{DistanceMetric, FilterCondition, FilterExpression, FilterValue, Vector};
use std::collections::HashMap;
use storage::InMemoryStorage;
async fn setup_engine() -> (SearchEngine<InMemoryStorage>, String) {
let storage = Arc::new(InMemoryStorage::new());
let engine = SearchEngine::new(storage.clone());
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![
Vector {
id: "v1".to_string(),
values: vec![1.0, 0.0, 0.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v2".to_string(),
values: vec![0.0, 1.0, 0.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v3".to_string(),
values: vec![0.707, 0.707, 0.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
(engine, namespace)
}
#[tokio::test]
async fn test_search_basic() {
let (engine, namespace) = setup_engine().await;
let request = QueryRequest {
vector: vec![1.0, 0.0, 0.0],
top_k: 2,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: None,
cursor: None,
consistency: Default::default(),
staleness_config: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].id, "v1"); }
#[tokio::test]
async fn test_search_namespace_not_found() {
let storage = Arc::new(InMemoryStorage::new());
let engine = SearchEngine::new(storage);
let request = QueryRequest {
vector: vec![1.0, 0.0, 0.0],
top_k: 5,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: None,
cursor: None,
consistency: Default::default(),
staleness_config: None,
};
let result = engine.search(&"nonexistent".to_string(), &request).await;
assert!(matches!(result, Err(DakeraError::NamespaceNotFound(_))));
}
#[tokio::test]
async fn test_search_dimension_mismatch() {
let (engine, namespace) = setup_engine().await;
let request = QueryRequest {
vector: vec![1.0, 0.0], top_k: 5,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: None,
cursor: None,
consistency: Default::default(),
staleness_config: None,
};
let result = engine.search(&namespace, &request).await;
assert!(matches!(
result,
Err(DakeraError::DimensionMismatch {
expected: 3,
actual: 2
})
));
}
#[tokio::test]
async fn test_search_empty_namespace() {
let storage = Arc::new(InMemoryStorage::new());
let engine = SearchEngine::new(storage.clone());
let namespace = "empty".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
let request = QueryRequest {
vector: vec![1.0, 0.0, 0.0],
top_k: 5,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: None,
cursor: None,
consistency: Default::default(),
staleness_config: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
assert!(response.results.is_empty());
}
#[tokio::test]
async fn test_search_with_filter() {
let storage = Arc::new(InMemoryStorage::new());
let engine = SearchEngine::new(storage.clone());
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![
Vector {
id: "v1".to_string(),
values: vec![1.0, 0.0, 0.0],
metadata: Some(
serde_json::json!({"category": "electronics", "price": 100}),
),
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v2".to_string(),
values: vec![0.9, 0.1, 0.0],
metadata: Some(serde_json::json!({"category": "books", "price": 20})),
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v3".to_string(),
values: vec![0.8, 0.2, 0.0],
metadata: Some(serde_json::json!({"category": "electronics", "price": 50})),
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
let mut field = HashMap::new();
field.insert(
"category".to_string(),
FilterCondition::Eq(FilterValue::String("electronics".to_string())),
);
let request = QueryRequest {
vector: vec![1.0, 0.0, 0.0],
top_k: 10,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: Some(FilterExpression::Field { field }),
cursor: None,
consistency: Default::default(),
staleness_config: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
assert_eq!(response.results.len(), 2);
assert!(response
.results
.iter()
.all(|r| r.id == "v1" || r.id == "v3"));
}
#[tokio::test]
async fn test_search_with_numeric_filter() {
let storage = Arc::new(InMemoryStorage::new());
let engine = SearchEngine::new(storage.clone());
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![
Vector {
id: "v1".to_string(),
values: vec![1.0, 0.0, 0.0],
metadata: Some(serde_json::json!({"price": 100})),
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v2".to_string(),
values: vec![0.9, 0.1, 0.0],
metadata: Some(serde_json::json!({"price": 20})),
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v3".to_string(),
values: vec![0.8, 0.2, 0.0],
metadata: Some(serde_json::json!({"price": 50})),
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
let mut field = HashMap::new();
field.insert(
"price".to_string(),
FilterCondition::Lt(FilterValue::Number(60.0)),
);
let request = QueryRequest {
vector: vec![1.0, 0.0, 0.0],
top_k: 10,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: Some(FilterExpression::Field { field }),
cursor: None,
consistency: Default::default(),
staleness_config: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
assert_eq!(response.results.len(), 2);
assert!(response
.results
.iter()
.all(|r| r.id == "v2" || r.id == "v3"));
}
#[tokio::test]
async fn test_ann_search_with_filter() {
let storage = Arc::new(InMemoryStorage::new());
let engine = SearchEngine::new_with_threshold(storage.clone(), 2);
let namespace = "test_ann_filter".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![
Vector {
id: "v1".to_string(),
values: vec![1.0, 0.0, 0.0],
metadata: Some(serde_json::json!({"category": "electronics"})),
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v2".to_string(),
values: vec![0.9, 0.1, 0.0],
metadata: Some(serde_json::json!({"category": "books"})),
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v3".to_string(),
values: vec![0.8, 0.2, 0.0],
metadata: Some(serde_json::json!({"category": "electronics"})),
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
let mut field = HashMap::new();
field.insert(
"category".to_string(),
FilterCondition::Eq(FilterValue::String("electronics".to_string())),
);
let request = QueryRequest {
vector: vec![1.0, 0.0, 0.0],
top_k: 10,
distance_metric: DistanceMetric::Cosine,
include_metadata: true,
include_vectors: false,
filter: Some(FilterExpression::Field { field }),
cursor: None,
consistency: Default::default(),
staleness_config: None,
};
let response = engine.search(&namespace, &request).await.unwrap();
assert_eq!(response.results.len(), 2);
assert!(response
.results
.iter()
.all(|r| r.id == "v1" || r.id == "v3"));
assert_eq!(response.results[0].id, "v1");
}
}