use roaring::RoaringBitmap;
use selene_core::{CancellationChecker, DbString, VectorValue};
use crate::error::GraphError;
use crate::graph::SeleneGraph;
use super::{
ApproximateVectorSearchOptions, VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet,
VectorNodeSearchHit, VectorSearchError, approx_batch, rerank_ann_row_candidates,
turbo_quant_exact,
};
impl SeleneGraph {
pub fn approximate_vector_search_candidate_set_checked(
&self,
label: &DbString,
property: &DbString,
query: &VectorValue,
candidates: &VectorCandidateSet,
options: ApproximateVectorSearchOptions,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
if options.k == 0 || candidates.is_empty() {
return Ok(Vec::new());
}
let query_dimension = u32::try_from(query.dimension())
.map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
let Some(index) = self
.vector_index_for(label, property)
.filter(|index| index.dimension() == query_dimension)
else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
let Some(indexed_metric) = index.ann_metric() else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
if indexed_metric != options.metric {
return Err(VectorSearchError::ApproximateMetricMismatch {
indexed: indexed_metric,
requested: options.metric,
});
}
if !index.is_turbo_quant() {
return Err(VectorSearchError::ApproximateIndexMissing);
}
let allowed_rows = self.vector_candidate_rows(candidates, index.rows(), &checker)?;
if allowed_rows.is_empty() {
return Ok(Vec::new());
}
if turbo_quant_exact::covers_rows(&allowed_rows, options) {
return rerank_ann_row_candidates(
self,
property,
query,
options.metric,
options.k,
turbo_quant_exact::row_hits(&allowed_rows),
&checker,
);
}
let row_hits = index
.turbo_quant_candidates_in_rows(query, options.k, options.ef_search, &allowed_rows)
.ok_or(VectorSearchError::ApproximateIndexMissing)?
.map_err(GraphError::from)?;
rerank_ann_row_candidates(
self,
property,
query,
options.metric,
options.k,
row_hits,
&checker,
)
}
pub fn approximate_vector_search_candidate_sets_batch_checked(
&self,
label: &DbString,
property: &DbString,
queries: &[VectorValue],
candidate_sets: &[VectorCandidateSet],
options: ApproximateVectorSearchOptions,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
checker.check()?;
if queries.len() != candidate_sets.len() {
return Err(VectorSearchError::BatchLengthMismatch {
queries: queries.len(),
candidate_sets: candidate_sets.len(),
});
}
let Some(first_query) = queries.first() else {
return Ok(Vec::new());
};
if options.k == 0 {
return Ok(vec![Vec::new(); queries.len()]);
}
let query_dimension = u32::try_from(first_query.dimension())
.map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
let Some(index) = self
.vector_index_for(label, property)
.filter(|index| index.dimension() == query_dimension)
else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
let Some(indexed_metric) = index.ann_metric() else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
if indexed_metric != options.metric {
return Err(VectorSearchError::ApproximateMetricMismatch {
indexed: indexed_metric,
requested: options.metric,
});
}
if !index.is_turbo_quant() {
return Err(VectorSearchError::ApproximateIndexMissing);
}
for query in queries {
checker.check()?;
let dimension = u32::try_from(query.dimension())
.map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
if dimension != query_dimension {
return Err(VectorSearchError::ApproximateIndexMissing);
}
}
if let Some(first_candidate_set) = candidate_sets.first()
&& candidate_sets.iter().skip(1).all(|candidate_set| {
approx_batch::candidate_sets_match(first_candidate_set, candidate_set)
})
{
let allowed_rows =
self.vector_candidate_rows(first_candidate_set, index.rows(), &checker)?;
if turbo_quant_exact::covers_rows(&allowed_rows, options) {
let row_batches = vec![turbo_quant_exact::row_hits(&allowed_rows); queries.len()];
return approx_batch::rerank_ann_row_candidate_batches(
self,
property,
queries,
options.metric,
options.k,
row_batches,
&checker,
);
}
let row_batches = index
.turbo_quant_candidates_batch_in_shared_rows(
queries,
options.k,
options.ef_search,
&allowed_rows,
)
.ok_or(VectorSearchError::ApproximateIndexMissing)?
.map_err(GraphError::from)?;
return approx_batch::rerank_ann_row_candidate_batches(
self,
property,
queries,
options.metric,
options.k,
row_batches,
&checker,
);
}
let allowed_rows = candidate_sets
.iter()
.map(|candidates| self.vector_candidate_rows(candidates, index.rows(), &checker))
.collect::<Result<Vec<_>, _>>()?;
if allowed_rows
.iter()
.all(|rows| turbo_quant_exact::covers_rows(rows, options))
{
let row_batches = allowed_rows
.iter()
.map(turbo_quant_exact::row_hits)
.collect::<Vec<_>>();
return approx_batch::rerank_ann_row_candidate_batches(
self,
property,
queries,
options.metric,
options.k,
row_batches,
&checker,
);
}
let row_batches = index
.turbo_quant_candidates_batch_in_rows(
queries,
options.k,
options.ef_search,
&allowed_rows,
)
.ok_or(VectorSearchError::ApproximateIndexMissing)?
.map_err(GraphError::from)?;
approx_batch::rerank_ann_row_candidate_batches(
self,
property,
queries,
options.metric,
options.k,
row_batches,
&checker,
)
}
fn vector_candidate_rows(
&self,
candidates: &VectorCandidateSet,
index_rows: &RoaringBitmap,
checker: &CancellationChecker<'_>,
) -> Result<RoaringBitmap, VectorSearchError> {
let mut rows = RoaringBitmap::new();
let mut candidates_since_check = 0usize;
for node_id in candidates.as_nodes().iter().copied() {
candidates_since_check += 1;
if candidates_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
checker.note_nodes_scanned(candidates_since_check)?;
candidates_since_check = 0;
}
let Some(row) = self.row_for_node_id(node_id) else {
continue;
};
let raw_row = row.get();
if self.node_store.is_alive(raw_row) && index_rows.contains(raw_row) {
rows.insert(raw_row);
}
}
if candidates_since_check > 0 {
checker.note_nodes_scanned(candidates_since_check)?;
}
Ok(rows)
}
}