use super::filter::Filter;
use super::sort::QueryLimits;
use crate::storage::engine::distance::DistanceMetric;
use crate::storage::engine::vector_store::{SearchResult, VectorCollection, VectorId};
use crate::storage::schema::Value;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DenseVector {
values: Vec<f32>,
}
impl DenseVector {
pub fn new(values: Vec<f32>) -> Self {
Self { values }
}
pub fn as_slice(&self) -> &[f32] {
&self.values
}
}
impl From<Vec<f32>> for DenseVector {
fn from(values: Vec<f32>) -> Self {
Self { values }
}
}
#[derive(Debug, Clone)]
pub struct SimilarityQuery {
pub vector: DenseVector,
pub k: usize,
pub distance: DistanceMetric,
pub filter: Option<Filter>,
pub n_probes: Option<usize>,
pub distance_threshold: Option<f32>,
}
impl SimilarityQuery {
pub fn new(vector: DenseVector, k: usize) -> Self {
Self {
vector,
k,
distance: DistanceMetric::Cosine,
filter: None,
n_probes: None,
distance_threshold: None,
}
}
pub fn with_distance(mut self, distance: DistanceMetric) -> Self {
self.distance = distance;
self
}
pub fn with_filter(mut self, filter: Filter) -> Self {
self.filter = Some(filter);
self
}
pub fn with_probes(mut self, n_probes: usize) -> Self {
self.n_probes = Some(n_probes);
self
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.distance_threshold = Some(threshold);
self
}
}
#[derive(Debug, Clone)]
pub struct SimilarityResult {
pub id: VectorId,
pub distance: f32,
pub score: f32,
pub metadata: Option<HashMap<String, Value>>,
}
impl SimilarityResult {
pub fn new(id: VectorId, distance: f32) -> Self {
Self {
id,
distance,
score: 1.0 / (1.0 + distance), metadata: None,
}
}
pub fn with_metric(id: VectorId, distance: f32, metric: DistanceMetric) -> Self {
let score = match metric {
DistanceMetric::Cosine => 1.0 - distance, DistanceMetric::InnerProduct => -distance, DistanceMetric::L2 => 1.0 / (1.0 + distance),
};
Self {
id,
distance,
score: score.max(0.0),
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug, Clone)]
pub struct SimilarityResultSet {
pub results: Vec<SimilarityResult>,
pub dimension: usize,
pub distance: DistanceMetric,
pub vectors_searched: Option<usize>,
pub search_time_us: u64,
}
impl SimilarityResultSet {
pub fn empty(dimension: usize, distance: DistanceMetric) -> Self {
Self {
results: Vec::new(),
dimension,
distance,
vectors_searched: None,
search_time_us: 0,
}
}
pub fn from_results(
results: Vec<SearchResult>,
dimension: usize,
distance: DistanceMetric,
) -> Self {
let similarity_results = results
.into_iter()
.map(|r| SimilarityResult::with_metric(r.id, r.distance, distance))
.collect();
Self {
results: similarity_results,
dimension,
distance,
vectors_searched: None,
search_time_us: 0,
}
}
pub fn len(&self) -> usize {
self.results.len()
}
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
pub fn top_ids(&self, k: usize) -> Vec<VectorId> {
self.results.iter().take(k).map(|r| r.id).collect()
}
pub fn above_score(&self, threshold: f32) -> Vec<&SimilarityResult> {
self.results
.iter()
.filter(|r| r.score >= threshold)
.collect()
}
pub fn apply_limits(mut self, limits: QueryLimits) -> Self {
self.results = limits.apply(self.results);
self
}
}
pub trait VectorIndex: Send + Sync {
fn search(&self, query: &DenseVector, k: usize) -> Vec<SearchResult>;
fn search_with_params(
&self,
query: &DenseVector,
k: usize,
n_probes: Option<usize>,
) -> Vec<SearchResult>;
fn get(&self, id: VectorId) -> Option<DenseVector>;
fn dimension(&self) -> usize;
fn distance_metric(&self) -> DistanceMetric;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl VectorIndex for VectorCollection {
fn search(&self, query: &DenseVector, k: usize) -> Vec<SearchResult> {
VectorCollection::search(self, query.as_slice(), k)
}
fn search_with_params(
&self,
query: &DenseVector,
k: usize,
_n_probes: Option<usize>,
) -> Vec<SearchResult> {
VectorCollection::search(self, query.as_slice(), k)
}
fn get(&self, id: VectorId) -> Option<DenseVector> {
VectorCollection::get(self, id).map(|vec| DenseVector::new(vec.clone()))
}
fn dimension(&self) -> usize {
self.dimension
}
fn distance_metric(&self) -> DistanceMetric {
self.metric
}
fn len(&self) -> usize {
self.len()
}
}
pub fn execute_similarity_search(
index: &dyn VectorIndex,
query: &SimilarityQuery,
) -> SimilarityResultSet {
let start = std::time::Instant::now();
let results = if let Some(threshold) = query.distance_threshold {
let candidates = index.search_with_params(&query.vector, query.k * 10, query.n_probes);
candidates
.into_iter()
.filter(|r| r.distance <= threshold)
.take(query.k)
.collect()
} else {
index.search_with_params(&query.vector, query.k, query.n_probes)
};
let search_time = start.elapsed().as_micros() as u64;
let mut result_set =
SimilarityResultSet::from_results(results, index.dimension(), index.distance_metric());
result_set.search_time_us = search_time;
result_set.vectors_searched = Some(index.len());
result_set
}
pub fn execute_hybrid_search<F>(
index: &dyn VectorIndex,
query: &SimilarityQuery,
get_metadata: F,
filter_matches: impl Fn(VectorId, &Filter) -> bool,
) -> SimilarityResultSet
where
F: Fn(VectorId) -> Option<HashMap<String, Value>>,
{
let start = std::time::Instant::now();
let over_fetch = if query.filter.is_some() { 10 } else { 1 };
let candidates = index.search_with_params(&query.vector, query.k * over_fetch, query.n_probes);
let results: Vec<SimilarityResult> = candidates
.into_iter()
.filter(|r| {
if let Some(filter) = &query.filter {
filter_matches(r.id, filter)
} else {
true
}
})
.take(query.k)
.map(|r| {
let mut result =
SimilarityResult::with_metric(r.id, r.distance, index.distance_metric());
if let Some(meta) = get_metadata(r.id) {
result = result.with_metadata(meta);
}
result
})
.collect();
let search_time = start.elapsed().as_micros() as u64;
SimilarityResultSet {
results,
dimension: index.dimension(),
distance: index.distance_metric(),
vectors_searched: Some(index.len()),
search_time_us: search_time,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_index() -> VectorCollection {
let mut collection = VectorCollection::new("test", 3).with_metric(DistanceMetric::Cosine);
let _ = collection.insert(vec![1.0, 0.0, 0.0], None);
let _ = collection.insert(vec![0.0, 1.0, 0.0], None);
let _ = collection.insert(vec![0.0, 0.0, 1.0], None);
let _ = collection.insert(vec![0.7, 0.7, 0.0], None);
let _ = collection.insert(vec![0.5, 0.5, 0.7], None);
collection
}
#[test]
fn test_similarity_query_basic() {
let index = create_test_index();
let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 3);
let results = execute_similarity_search(&index, &query);
assert_eq!(results.len(), 3);
assert_eq!(results.results[0].id, 0); assert!(results.results[0].distance < 0.01);
}
#[test]
fn test_similarity_result_score() {
let result = SimilarityResult::with_metric(1, 0.0, DistanceMetric::Cosine);
assert!((result.score - 1.0).abs() < 0.01);
let result = SimilarityResult::with_metric(1, 1.0, DistanceMetric::Cosine);
assert!(result.score < 0.01);
}
#[test]
fn test_similarity_result_set_top_ids() {
let index = create_test_index();
let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 5);
let results = execute_similarity_search(&index, &query);
let top3 = results.top_ids(3);
assert_eq!(top3.len(), 3);
assert_eq!(top3[0], 0);
}
#[test]
fn test_similarity_threshold() {
let index = create_test_index();
let query =
SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 10).with_threshold(0.5);
let results = execute_similarity_search(&index, &query);
for result in &results.results {
assert!(result.distance <= 0.5);
}
}
#[test]
fn test_vector_index_trait() {
let index = create_test_index();
let index_ref: &dyn VectorIndex = &index;
assert_eq!(index_ref.dimension(), 3);
assert_eq!(index_ref.len(), 5);
assert!(!index_ref.is_empty());
let vec = index_ref.get(0).unwrap();
assert_eq!(vec.as_slice(), &[1.0, 0.0, 0.0]);
}
#[test]
fn test_above_score_filter() {
let results = SimilarityResultSet {
results: vec![
SimilarityResult::new(1, 0.1), SimilarityResult::new(2, 0.5), SimilarityResult::new(3, 2.0), ],
dimension: 3,
distance: DistanceMetric::L2,
vectors_searched: Some(100),
search_time_us: 100,
};
let above_05 = results.above_score(0.5);
assert_eq!(above_05.len(), 2); }
#[test]
fn test_similarity_query_builder() {
let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 10)
.with_distance(DistanceMetric::L2)
.with_probes(5)
.with_threshold(1.0);
assert_eq!(query.k, 10);
assert_eq!(query.distance, DistanceMetric::L2);
assert_eq!(query.n_probes, Some(5));
assert_eq!(query.distance_threshold, Some(1.0));
}
#[test]
fn test_hybrid_search_with_filter() {
let index = create_test_index();
let metadata: HashMap<VectorId, HashMap<String, Value>> = [
(
1,
[("category".to_string(), Value::text("A".to_string()))]
.into_iter()
.collect(),
),
(
2,
[("category".to_string(), Value::text("B".to_string()))]
.into_iter()
.collect(),
),
(
3,
[("category".to_string(), Value::text("A".to_string()))]
.into_iter()
.collect(),
),
(
4,
[("category".to_string(), Value::text("B".to_string()))]
.into_iter()
.collect(),
),
(
5,
[("category".to_string(), Value::text("A".to_string()))]
.into_iter()
.collect(),
),
]
.into_iter()
.collect();
let filter = Filter::eq("category", Value::text("A".to_string()));
let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 5)
.with_filter(filter.clone());
let results = execute_hybrid_search(
&index,
&query,
|id| metadata.get(&id).cloned(),
|id, filter| {
if let Some(meta) = metadata.get(&id) {
filter.evaluate(&|col| meta.get(col).cloned())
} else {
false
}
},
);
assert!(results.len() <= 3); for result in &results.results {
if let Some(meta) = &result.metadata {
assert_eq!(meta.get("category"), Some(&Value::text("A".to_string())));
}
}
}
#[test]
fn test_apply_limits() {
let results = SimilarityResultSet {
results: (0..10)
.map(|i| SimilarityResult::new(i, i as f32 * 0.1))
.collect(),
dimension: 3,
distance: DistanceMetric::L2,
vectors_searched: Some(100),
search_time_us: 100,
};
let limited = results.apply_limits(QueryLimits::none().offset(2).limit(3));
assert_eq!(limited.len(), 3);
assert_eq!(limited.results[0].id, 2);
}
}