use std::sync::Arc;
use crate::error::{LaurusError, Result};
use crate::vector::core::vector::Vector;
use crate::vector::reader::VectorIndexReader;
use crate::vector::search::searcher::VectorIndexSearcher;
use crate::vector::search::searcher::{VectorIndexQuery, VectorIndexQueryResults};
#[derive(Debug)]
pub struct IvfSearcher {
index_reader: Arc<dyn VectorIndexReader>,
n_probe: usize, }
impl IvfSearcher {
pub fn new(index_reader: Arc<dyn VectorIndexReader>) -> Result<Self> {
let n_probe = 1;
Ok(Self {
index_reader,
n_probe,
})
}
pub fn set_n_probe(&mut self, n_probe: usize) {
self.n_probe = n_probe;
}
fn find_nearest_centroids(&self, query: &Vector, n_probe: usize) -> Result<Vec<usize>> {
use super::reader::IvfIndexReader;
if let Some(ivf_reader) = self.index_reader.as_any().downcast_ref::<IvfIndexReader>() {
let centroids = ivf_reader.centroids();
let distance_metric = self.index_reader.distance_metric();
if centroids.is_empty() {
return Ok(Vec::new());
}
let mut centroid_distances: Vec<(usize, f32)> = centroids
.iter()
.enumerate()
.map(|(i, centroid)| {
let dist = distance_metric
.distance(&query.data, ¢roid.data)
.unwrap_or(f32::MAX);
(i, dist)
})
.collect();
centroid_distances
.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(centroid_distances
.into_iter()
.take(n_probe)
.map(|(i, _)| i)
.collect())
} else {
Err(LaurusError::InvalidOperation(
"IVF searcher requires an IvfIndexReader, but a different reader type was provided"
.to_string(),
))
}
}
}
impl VectorIndexSearcher for IvfSearcher {
fn search(&self, request: &VectorIndexQuery) -> Result<VectorIndexQueryResults> {
use std::time::Instant;
let start = Instant::now();
let mut results = VectorIndexQueryResults::new();
let n_probe = self.n_probe.min(10); let _nearest_centroid_indices = self.find_nearest_centroids(&request.query, n_probe)?;
results.candidates_examined = n_probe;
let vector_ids = self.index_reader.vector_ids()?;
let vector_ids: Vec<(u64, String)> = if let Some(ref field_name) = request.field_name {
vector_ids
.into_iter()
.filter(|(_, field)| field == field_name)
.collect()
} else {
vector_ids
};
let mut candidates: Vec<(u64, String, f32, f32, Vector)> =
Vec::with_capacity(vector_ids.len());
for (doc_id, field_name) in vector_ids {
if let Ok(Some(vector)) = self.index_reader.get_vector(doc_id, &field_name) {
let similarity = self
.index_reader
.distance_metric()
.similarity(&request.query.data, &vector.data)?;
let distance = self
.index_reader
.distance_metric()
.distance(&request.query.data, &vector.data)?;
candidates.push((doc_id, field_name, similarity, distance, vector));
}
}
candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let candidates_len = candidates.len();
let top_k = request.params.top_k.min(candidates_len);
for (doc_id, field_name, similarity, distance, vector) in candidates.into_iter().take(top_k)
{
if similarity < request.params.min_similarity {
break;
}
let vector_output = if request.params.include_vectors {
Some(vector)
} else {
None
};
results
.results
.push(crate::vector::search::searcher::VectorIndexQueryResult {
doc_id,
field_name,
similarity,
distance,
vector: vector_output,
});
}
results.search_time_ms = start.elapsed().as_secs_f64() * 1000.0;
results.candidates_examined = candidates_len;
Ok(results)
}
fn count(&self, request: VectorIndexQuery) -> Result<u64> {
let vector_ids = self.index_reader.vector_ids()?;
if let Some(ref field_name) = request.field_name {
Ok(vector_ids.iter().filter(|(_, f)| f == field_name).count() as u64)
} else {
Ok(vector_ids.len() as u64)
}
}
}