use std::sync::Arc;
use crate::error::{LaurusError, Result};
use crate::vector::core::vector::Vector;
use crate::vector::reader::VectorIndexReader;
use crate::vector::search::searcher::VectorIndexSearcher;
use crate::vector::search::searcher::{VectorIndexQuery, VectorIndexQueryResults};
#[derive(Debug)]
pub struct IvfSearcher {
index_reader: Arc<dyn VectorIndexReader>,
n_probe: usize,
}
impl IvfSearcher {
pub fn new(index_reader: Arc<dyn VectorIndexReader>) -> Result<Self> {
let n_probe = 1;
Ok(Self {
index_reader,
n_probe,
})
}
pub fn set_n_probe(&mut self, n_probe: usize) {
self.n_probe = n_probe;
}
fn probe_clusters(
&self,
query: &Vector,
n_probe: usize,
field_name: Option<&str>,
) -> Result<Vec<(u64, String)>> {
use super::reader::IvfIndexReader;
if let Some(ivf_reader) = self.index_reader.as_any().downcast_ref::<IvfIndexReader>() {
let centroids = ivf_reader.centroids();
let distance_metric = self.index_reader.distance_metric();
if centroids.is_empty() {
return Ok(Vec::new());
}
let mut centroid_distances: Vec<(usize, f32)> = centroids
.iter()
.enumerate()
.map(|(i, centroid)| {
let dist = distance_metric
.distance(&query.data, ¢roid.data)
.unwrap_or(f32::MAX);
(i, dist)
})
.collect();
centroid_distances.sort_unstable_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
let mut result = Vec::new();
for &(cluster_idx, _) in centroid_distances.iter().take(n_probe) {
let cluster_vecs = ivf_reader.cluster_vectors(cluster_idx);
if let Some(field) = field_name {
result.extend(cluster_vecs.iter().filter(|(_, f)| f == field).cloned());
} else {
result.extend_from_slice(cluster_vecs);
}
}
Ok(result)
} else {
Err(LaurusError::InvalidOperation(
"IVF searcher requires an IvfIndexReader, but a different reader type was provided"
.to_string(),
))
}
}
}
impl VectorIndexSearcher for IvfSearcher {
fn search(&self, request: &VectorIndexQuery) -> Result<VectorIndexQueryResults> {
use crate::util::time::Timer;
let start = Timer::now();
let mut results = VectorIndexQueryResults::new();
let n_probe = self.n_probe.min(10);
let vector_ids =
self.probe_clusters(&request.query, n_probe, request.field_name.as_deref())?;
let metric = self.index_reader.distance_metric();
let mut candidates: Vec<(u64, String, f32, f32, Vector)> =
Vec::with_capacity(vector_ids.len());
for (doc_id, field_name) in &vector_ids {
if let Ok(Some(vector)) = self.index_reader.get_vector(*doc_id, field_name) {
let distance = metric.distance(&request.query.data, &vector.data)?;
let similarity = metric.distance_to_similarity(distance);
candidates.push((*doc_id, field_name.clone(), similarity, distance, vector));
}
}
candidates
.sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let candidates_len = candidates.len();
let top_k = request.params.top_k.min(candidates_len);
for (doc_id, field_name, similarity, distance, vector) in candidates.into_iter().take(top_k)
{
if similarity < request.params.min_similarity {
break;
}
let vector_output = if request.params.include_vectors {
Some(vector)
} else {
None
};
results
.results
.push(crate::vector::search::searcher::VectorIndexQueryResult {
doc_id,
field_name,
similarity,
distance,
vector: vector_output,
});
}
results.search_time_ms = start.elapsed().as_secs_f64() * 1000.0;
results.candidates_examined = candidates_len;
Ok(results)
}
fn count(&self, request: VectorIndexQuery) -> Result<u64> {
let vector_ids = self.index_reader.vector_ids()?;
if let Some(ref field_name) = request.field_name {
Ok(vector_ids.iter().filter(|(_, f)| f == field_name).count() as u64)
} else {
Ok(vector_ids.len() as u64)
}
}
}