use std::sync::Arc;
use crate::error::Result;
use crate::vector::core::vector::Vector;
use crate::vector::reader::VectorIndexReader;
use crate::vector::search::searcher::VectorIndexSearcher;
use crate::vector::search::searcher::{VectorIndexQuery, VectorIndexQueryResults};
#[derive(Debug)]
pub struct FlatVectorSearcher {
index_reader: Arc<dyn VectorIndexReader>,
}
impl FlatVectorSearcher {
pub fn new(index_reader: Arc<dyn VectorIndexReader>) -> Result<Self> {
Ok(Self { index_reader })
}
}
impl VectorIndexSearcher for FlatVectorSearcher {
fn search(&self, request: &VectorIndexQuery) -> Result<VectorIndexQueryResults> {
use std::time::Instant;
let start = Instant::now();
let mut results = VectorIndexQueryResults::new();
let vector_count = self.index_reader.vector_count();
results.candidates_examined = vector_count;
let vector_ids = self.index_reader.vector_ids()?;
let filtered_vector_ids: Vec<(u64, String)> =
if let Some(ref field_name) = request.field_name {
vector_ids
.into_iter()
.filter(|(_, f)| f == field_name)
.collect()
} else {
vector_ids
};
let mut candidates: Vec<(u64, String, f32, f32, Vector)> =
Vec::with_capacity(filtered_vector_ids.len());
for (doc_id, field_name) in filtered_vector_ids {
if let Ok(Some(vector)) = self.index_reader.get_vector(doc_id, &field_name) {
let metric = self.index_reader.distance_metric();
let distance = metric.distance(&request.query.data, &vector.data)?;
let similarity = metric.distance_to_similarity(distance);
candidates.push((doc_id, field_name, similarity, distance, vector));
}
}
candidates
.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let top_k = request.params.top_k.min(candidates.len());
for (doc_id, field_name, similarity, distance, vector) in candidates.into_iter().take(top_k)
{
if similarity < request.params.min_similarity {
break;
}
let vector_output = if request.params.include_vectors {
Some(vector)
} else {
None
};
results
.results
.push(crate::vector::search::searcher::VectorIndexQueryResult {
doc_id,
field_name,
similarity,
distance,
vector: vector_output,
});
}
results.search_time_ms = start.elapsed().as_secs_f64() * 1000.0;
Ok(results)
}
fn count(&self, request: VectorIndexQuery) -> Result<u64> {
let vector_ids = self.index_reader.vector_ids()?;
if let Some(ref field_name) = request.field_name {
Ok(vector_ids.iter().filter(|(_, f)| f == field_name).count() as u64)
} else {
Ok(vector_ids.len() as u64)
}
}
}