use std::cmp::Ordering;
use roaring::RoaringBitmap;
use selene_core::{
CancellationChecker, DbString, NodeId, Value, VectorMetric, VectorMetricQuery, VectorTopK,
VectorValue,
};
use crate::error::{GraphError, GraphResult};
use crate::graph::SeleneGraph;
use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
#[cfg(test)]
use crate::shared::SharedGraph;
use crate::store::RowIndex;
use crate::vector_index::VectorIndexSearchHit;
#[path = "vector_search/types.rs"]
mod types;
pub use types::{
ApproximateVectorExpansionOptions, ApproximateVectorSearchOptions, VectorCandidateSet,
VectorNeighborDirection, VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError,
};
#[path = "vector_search/approx_batch.rs"]
mod approx_batch;
#[path = "vector_search/approx_turbo_quant.rs"]
mod approx_turbo_quant;
#[path = "vector_search/exact_batch.rs"]
mod exact_batch;
#[path = "vector_search/shared_wrappers.rs"]
mod shared_wrappers;
#[path = "vector_search/turbo_quant_exact.rs"]
mod turbo_quant_exact;
const VECTOR_SEARCH_CANCEL_STRIDE: usize = 1024;
const VECTOR_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
#[cfg(not(test))]
const VECTOR_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
#[cfg(test)]
const VECTOR_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
impl SeleneGraph {
pub fn exact_vector_search_nodes(
&self,
label: &DbString,
property: &DbString,
query: &VectorValue,
metric: VectorMetric,
k: usize,
) -> GraphResult<Vec<VectorNodeSearchHit>> {
self.exact_vector_search_nodes_checked(
label,
property,
query,
metric,
k,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn exact_vector_search_nodes_checked(
&self,
label: &DbString,
property: &DbString,
query: &VectorValue,
metric: VectorMetric,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
if k == 0 {
return Ok(Vec::new());
}
let Some(label_rows) = self.nodes_with_label(label) else {
return Ok(Vec::new());
};
let query_dimension = u32::try_from(query.dimension()).ok();
let vector_index = query_dimension.and_then(|dimension| {
self.vector_index_for(label, property)
.filter(|index| index.dimension() == dimension)
});
let rows = vector_index
.as_ref()
.map_or(label_rows, |index| index.rows());
let scorer = metric.bind_query(query).map_err(GraphError::from)?;
if should_parallelize_exact_scan(rows, k) {
return self.exact_vector_search_parallel(label, property, scorer, k, rows, checker);
}
let mut top_k = VectorTopK::new(k);
let mut rows_since_check = 0usize;
for raw_row in rows.iter() {
rows_since_check += 1;
if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
checker.check()?;
rows_since_check = 0;
}
if !self.node_store.is_alive(raw_row) {
continue;
}
let row = RowIndex::new(raw_row);
let node_id = self
.node_id_for_row(row)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"label index row {raw_row} for {} has no node id",
label.as_str()
),
})?;
let properties = self
.node_store
.properties
.get(raw_row as usize)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"label index row {raw_row} for {} has no property row",
label.as_str()
),
})?;
let Some(Value::Vector(vector)) = properties.get(property) else {
continue;
};
let distance = scorer.distance(vector).map_err(GraphError::from)?;
top_k.push_distance(node_id, distance);
}
Ok(top_k
.into_hits()
.into_iter()
.map(|hit| VectorNodeSearchHit {
node_id: hit.key,
distance: hit.distance,
})
.collect())
}
fn exact_vector_search_parallel(
&self,
label: &DbString,
property: &DbString,
scorer: VectorMetricQuery<'_>,
k: usize,
rows: &RoaringBitmap,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
let top_k = try_reduce_bitmap_chunks(
rows,
VECTOR_SEARCH_PARALLEL_CHUNK_ROWS,
checker,
|| VectorTopK::new(k),
|chunk| self.exact_vector_search_chunk(label, property, scorer, k, chunk),
merge_top_k,
)?;
Ok(vector_node_hits(top_k))
}
fn exact_vector_search_chunk(
&self,
label: &DbString,
property: &DbString,
scorer: VectorMetricQuery<'_>,
k: usize,
rows: &[u32],
) -> Result<VectorTopK<NodeId>, VectorSearchError> {
let mut top_k = VectorTopK::new(k);
for &raw_row in rows {
if !self.node_store.is_alive(raw_row) {
continue;
}
let row = RowIndex::new(raw_row);
let node_id = self
.node_id_for_row(row)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"vector search row {raw_row} for {} has no node id",
label.as_str()
),
})?;
let properties = self
.node_store
.properties
.get(raw_row as usize)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"vector search row {raw_row} for {} has no property row",
label.as_str()
),
})?;
let Some(Value::Vector(vector)) = properties.get(property) else {
continue;
};
let distance = scorer.distance(vector).map_err(GraphError::from)?;
top_k.push_distance(node_id, distance);
}
Ok(top_k)
}
pub fn approximate_vector_search_nodes_checked(
&self,
label: &DbString,
property: &DbString,
query: &VectorValue,
options: ApproximateVectorSearchOptions,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
let query_dimension = u32::try_from(query.dimension())
.map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
let Some(index) = self
.vector_index_for(label, property)
.filter(|index| index.dimension() == query_dimension)
else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
let Some(indexed_metric) = index.ann_metric() else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
if indexed_metric != options.metric {
return Err(VectorSearchError::ApproximateMetricMismatch {
indexed: indexed_metric,
requested: options.metric,
});
}
if index.is_turbo_quant() {
if turbo_quant_exact::covers_rows(index.rows(), options) {
return self.exact_vector_search_nodes_checked(
label,
property,
query,
options.metric,
options.k,
checker,
);
}
let row_hits = index
.turbo_quant_candidates(query, options.k, options.ef_search)
.ok_or(VectorSearchError::ApproximateIndexMissing)?
.map_err(GraphError::from)?;
return rerank_ann_row_candidates(
self,
property,
query,
options.metric,
options.k,
row_hits,
&checker,
);
}
let row_hits = index
.ann_search(query, options.k, options.ef_search)
.ok_or(VectorSearchError::ApproximateIndexMissing)?
.map_err(GraphError::from)?;
ann_row_hits_to_node_hits(self, label, row_hits, &checker)
}
pub fn approximate_vector_search_nodes_batch_checked(
&self,
label: &DbString,
property: &DbString,
queries: &[VectorValue],
options: ApproximateVectorSearchOptions,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
checker.check()?;
let Some(first_query) = queries.first() else {
return Ok(Vec::new());
};
let query_dimension = u32::try_from(first_query.dimension())
.map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
let Some(index) = self
.vector_index_for(label, property)
.filter(|index| index.dimension() == query_dimension)
else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
let Some(indexed_metric) = index.ann_metric() else {
return Err(VectorSearchError::ApproximateIndexMissing);
};
if indexed_metric != options.metric {
return Err(VectorSearchError::ApproximateMetricMismatch {
indexed: indexed_metric,
requested: options.metric,
});
}
if index.is_turbo_quant() {
for query in queries {
checker.check()?;
let dimension = u32::try_from(query.dimension())
.map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
if dimension != query_dimension {
return Err(VectorSearchError::ApproximateIndexMissing);
}
}
if turbo_quant_exact::covers_rows(index.rows(), options) {
return self.exact_vector_search_nodes_batch_checked(
label,
property,
queries,
options.metric,
options.k,
checker,
);
}
if !index.turbo_quant_prefers_fused_batch(queries.len()) {
let mut batch_hits = Vec::with_capacity(queries.len());
for query in queries {
checker.check()?;
let row_hits = index
.turbo_quant_candidates(query, options.k, options.ef_search)
.ok_or(VectorSearchError::ApproximateIndexMissing)?
.map_err(GraphError::from)?;
batch_hits.push(rerank_ann_row_candidates(
self,
property,
query,
options.metric,
options.k,
row_hits,
&checker,
)?);
}
return Ok(batch_hits);
}
let row_batches = index
.turbo_quant_candidates_batch(queries, options.k, options.ef_search)
.ok_or(VectorSearchError::ApproximateIndexMissing)?
.map_err(GraphError::from)?;
let mut batch_hits = Vec::with_capacity(queries.len());
for (query, row_hits) in queries.iter().zip(row_batches) {
batch_hits.push(rerank_ann_row_candidates(
self,
property,
query,
options.metric,
options.k,
row_hits,
&checker,
)?);
}
return Ok(batch_hits);
}
approx_batch::ann_index_batch_search(
self,
label,
&index,
queries,
options,
query_dimension,
checker,
)
}
pub fn approximate_vector_search_expanded_candidates_checked(
&self,
label: &DbString,
property: &DbString,
query: &VectorValue,
options: ApproximateVectorExpansionOptions<'_>,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
let root_hits = self.approximate_vector_search_nodes_checked(
label,
property,
query,
ApproximateVectorSearchOptions::new(options.metric, options.root_k, options.ef_search),
checker,
)?;
if options.k == 0 || root_hits.is_empty() {
return Ok(Vec::new());
}
let roots = VectorCandidateSet::from_search_hits(&root_hits);
let expanded = self.expand_vector_candidate_set_checked(
&roots,
options.edge_label,
options.direction,
checker,
)?;
self.score_vector_candidate_set_checked(
property,
query,
&expanded,
options.metric,
options.k,
checker,
)
}
pub fn approximate_vector_search_expanded_candidates_batch_checked(
&self,
label: &DbString,
property: &DbString,
queries: &[VectorValue],
options: ApproximateVectorExpansionOptions<'_>,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
checker.check()?;
let root_hits = self.approximate_vector_search_nodes_batch_checked(
label,
property,
queries,
ApproximateVectorSearchOptions::new(options.metric, options.root_k, options.ef_search),
checker,
)?;
if options.k == 0 {
return Ok(vec![Vec::new(); queries.len()]);
}
let root_sets = root_hits
.iter()
.map(VectorCandidateSet::from_search_hits)
.collect::<Vec<_>>();
self.score_vector_expanded_candidate_sets_batch_checked(
property,
queries,
&root_sets,
VectorNeighborSearchOptions::new(
options.edge_label,
options.direction,
options.metric,
options.k,
),
checker,
)
}
}
fn should_parallelize_exact_scan(rows: &RoaringBitmap, k: usize) -> bool {
should_parallelize_scan(rows.len(), k, VECTOR_SEARCH_PARALLEL_MIN_ROWS)
}
fn merge_top_k(
mut lhs: VectorTopK<NodeId>,
rhs: VectorTopK<NodeId>,
) -> Result<VectorTopK<NodeId>, VectorSearchError> {
for hit in rhs.into_hits() {
lhs.push_distance(hit.key, hit.distance);
}
Ok(lhs)
}
fn vector_node_hits(top_k: VectorTopK<NodeId>) -> Vec<VectorNodeSearchHit> {
top_k
.into_hits()
.into_iter()
.map(|hit| VectorNodeSearchHit {
node_id: hit.key,
distance: hit.distance,
})
.collect()
}
fn ann_row_hits_to_node_hits(
graph: &SeleneGraph,
label: &DbString,
row_hits: Vec<VectorIndexSearchHit>,
checker: &CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
let mut hits = Vec::with_capacity(row_hits.len());
let mut needs_sort = false;
for hit in row_hits {
checker.check()?;
if !graph.node_store.is_alive(hit.row) {
continue;
}
let row = RowIndex::new(hit.row);
let node_id = graph
.node_id_for_row(row)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"ANN vector index row {} for {} has no node id",
hit.row,
label.as_str()
),
})?;
let node_hit = VectorNodeSearchHit {
node_id,
distance: hit.distance,
};
needs_sort |= hits
.last()
.is_some_and(|previous| compare_node_search_hit(previous, &node_hit).is_gt());
hits.push(node_hit);
}
if needs_sort {
hits.sort_by(compare_node_search_hit);
}
Ok(hits)
}
fn rerank_ann_row_candidates(
graph: &SeleneGraph,
property: &DbString,
query: &VectorValue,
metric: VectorMetric,
k: usize,
row_hits: Vec<VectorIndexSearchHit>,
checker: &CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
let scorer = metric.bind_query(query).map_err(GraphError::from)?;
let mut top_k = VectorTopK::new(k);
for hit in row_hits {
checker.check()?;
if !graph.node_store.is_alive(hit.row) {
continue;
}
let row = RowIndex::new(hit.row);
let node_id = graph
.node_id_for_row(row)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!("ANN vector candidate row {} has no node id", hit.row),
})?;
let properties = graph
.node_store
.properties
.get(hit.row as usize)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!("ANN vector candidate row {} has no property row", hit.row),
})?;
let Some(Value::Vector(vector)) = properties.get(property) else {
continue;
};
let distance = scorer.distance(vector).map_err(GraphError::from)?;
top_k.push_distance(node_id, distance);
}
Ok(vector_node_hits(top_k))
}
fn compare_node_search_hit(lhs: &VectorNodeSearchHit, rhs: &VectorNodeSearchHit) -> Ordering {
lhs.distance
.total_cmp(&rhs.distance)
.then_with(|| lhs.node_id.cmp(&rhs.node_id))
}
#[cfg(test)]
#[path = "vector_search/ann_conversion_tests.rs"]
mod ann_conversion_tests;
#[cfg(test)]
#[path = "vector_search/ann_expansion_tests.rs"]
mod ann_expansion_tests;
#[cfg(test)]
#[path = "vector_search/batch_tests.rs"]
mod batch_tests;
#[cfg(test)]
#[path = "vector_search/recall_tests.rs"]
mod recall_tests;
#[path = "vector_search/score.rs"]
mod score;
#[path = "vector_search/score_candidate_batch.rs"]
mod score_candidate_batch;
#[path = "vector_search/score_expanded_batch.rs"]
mod score_expanded_batch;
#[path = "vector_search/score_neighbor_batch.rs"]
mod score_neighbor_batch;
#[path = "vector_search/score_shared.rs"]
mod score_shared;
#[cfg(test)]
#[path = "vector_search/score_tests.rs"]
mod score_tests;
#[cfg(test)]
#[path = "vector_search/tests.rs"]
mod tests;