use roaring::RoaringBitmap;
use selene_core::{
CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
VectorTopK, VectorValue,
};
use super::{
VECTOR_SEARCH_CANCEL_STRIDE, VECTOR_SEARCH_PARALLEL_CHUNK_ROWS, VectorNodeSearchHit,
VectorSearchError, should_parallelize_exact_scan, vector_node_hits,
};
use crate::error::GraphError;
use crate::graph::SeleneGraph;
use crate::parallel_scan::try_reduce_bitmap_chunks;
use crate::store::RowIndex;
impl SeleneGraph {
pub fn exact_vector_search_nodes_batch_checked(
&self,
label: &DbString,
property: &DbString,
queries: &[VectorValue],
metric: VectorMetric,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
checker.check()?;
let Some(first_query) = queries.first() else {
return Ok(Vec::new());
};
let first_dimension = first_query.dimension();
for query in &queries[1..] {
if query.dimension() != first_dimension {
return Err(GraphError::from(CoreError::VectorDimensionMismatch {
lhs: first_dimension,
rhs: query.dimension(),
})
.into());
}
}
if k == 0 {
return Ok(vec![Vec::new(); queries.len()]);
}
let Some(label_rows) = self.nodes_with_label(label) else {
return Ok(vec![Vec::new(); queries.len()]);
};
let query_dimension = u32::try_from(first_dimension).ok();
let vector_index = query_dimension.and_then(|dimension| {
self.vector_index_for(label, property)
.filter(|index| index.dimension() == dimension)
});
let rows = vector_index
.as_ref()
.map_or(label_rows, |index| index.rows());
let scorers: Result<Vec<_>, GraphError> = queries
.iter()
.map(|query| metric.bind_query(query).map_err(GraphError::from))
.collect();
let scorers = scorers?;
if should_parallelize_exact_scan(rows, k) {
return self
.exact_vector_search_batch_parallel(label, property, &scorers, k, rows, checker);
}
let top_ks =
self.exact_vector_search_batch_serial(label, property, &scorers, k, rows, checker)?;
Ok(top_ks.into_iter().map(vector_node_hits).collect())
}
fn exact_vector_search_batch_parallel(
&self,
label: &DbString,
property: &DbString,
scorers: &[VectorMetricQuery<'_>],
k: usize,
rows: &RoaringBitmap,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
let top_ks = try_reduce_bitmap_chunks(
rows,
VECTOR_SEARCH_PARALLEL_CHUNK_ROWS,
checker,
|| new_batch_top_ks(scorers.len(), k),
|chunk| self.exact_vector_search_batch_chunk(label, property, scorers, k, chunk),
merge_batch_top_ks,
)?;
Ok(top_ks.into_iter().map(vector_node_hits).collect())
}
fn exact_vector_search_batch_serial(
&self,
label: &DbString,
property: &DbString,
scorers: &[VectorMetricQuery<'_>],
k: usize,
rows: &RoaringBitmap,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
let mut top_ks = new_batch_top_ks(scorers.len(), k);
let mut rows_since_check = 0usize;
for raw_row in rows.iter() {
rows_since_check += 1;
if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
checker.check()?;
rows_since_check = 0;
}
self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
}
Ok(top_ks)
}
fn exact_vector_search_batch_chunk(
&self,
label: &DbString,
property: &DbString,
scorers: &[VectorMetricQuery<'_>],
k: usize,
rows: &[u32],
) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
let mut top_ks = new_batch_top_ks(scorers.len(), k);
for &raw_row in rows {
self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
}
Ok(top_ks)
}
fn push_batch_row(
&self,
label: &DbString,
property: &DbString,
scorers: &[VectorMetricQuery<'_>],
top_ks: &mut [VectorTopK<NodeId>],
raw_row: u32,
) -> Result<(), VectorSearchError> {
if !self.node_store.is_alive(raw_row) {
return Ok(());
}
let row = RowIndex::new(raw_row);
let node_id = self
.node_id_for_row(row)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"vector search row {raw_row} for {} has no node id",
label.as_str()
),
})?;
let properties = self
.node_store
.properties
.get(raw_row as usize)
.ok_or_else(|| GraphError::Inconsistent {
reason: format!(
"vector search row {raw_row} for {} has no property row",
label.as_str()
),
})?;
let Some(Value::Vector(vector)) = properties.get(property) else {
return Ok(());
};
for (scorer, top_k) in scorers.iter().zip(top_ks) {
let distance = scorer.distance(vector).map_err(GraphError::from)?;
top_k.push_distance(node_id, distance);
}
Ok(())
}
}
fn new_batch_top_ks(query_count: usize, k: usize) -> Vec<VectorTopK<NodeId>> {
(0..query_count).map(|_| VectorTopK::new(k)).collect()
}
fn merge_batch_top_ks(
mut lhs: Vec<VectorTopK<NodeId>>,
rhs: Vec<VectorTopK<NodeId>>,
) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
debug_assert_eq!(lhs.len(), rhs.len());
for (lhs_top_k, rhs_top_k) in lhs.iter_mut().zip(rhs) {
for hit in rhs_top_k.into_hits() {
lhs_top_k.push_distance(hit.key, hit.distance);
}
}
Ok(lhs)
}