use selene_core::{
CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
VectorTopK, VectorValue,
};
use crate::error::{GraphError, GraphResult};
use crate::graph::SeleneGraph;
use crate::parallel_scan::{should_parallelize_scan, try_reduce_chunks};
use super::{
VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet, VectorNeighborDirection,
VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError, merge_top_k,
score_candidate_batch::{
candidate_sets_all_match, should_parallelize_candidate_batch_scoring,
should_parallelize_repeated_candidate_batch,
},
vector_node_hits,
};
#[cfg(not(test))]
const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 4096;
#[cfg(test)]
const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 8;
#[cfg(not(test))]
const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 1024;
#[cfg(test)]
const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 4;
impl SeleneGraph {
pub fn score_vector_nodes(
&self,
property: &DbString,
query: &VectorValue,
candidates: &[NodeId],
metric: VectorMetric,
k: usize,
) -> GraphResult<Vec<VectorNodeSearchHit>> {
self.score_vector_nodes_checked(
property,
query,
candidates,
metric,
k,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn score_vector_nodes_checked(
&self,
property: &DbString,
query: &VectorValue,
candidates: &[NodeId],
metric: VectorMetric,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
if k == 0 || candidates.is_empty() {
return Ok(Vec::new());
}
let candidates = VectorCandidateSet::from_nodes(candidates.iter().copied());
self.score_vector_candidate_set_after_initial_check(
property,
query,
&candidates,
metric,
k,
checker,
)
}
pub fn score_vector_candidate_set(
&self,
property: &DbString,
query: &VectorValue,
candidates: &VectorCandidateSet,
metric: VectorMetric,
k: usize,
) -> GraphResult<Vec<VectorNodeSearchHit>> {
self.score_vector_candidate_set_checked(
property,
query,
candidates,
metric,
k,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn score_vector_candidate_set_checked(
&self,
property: &DbString,
query: &VectorValue,
candidates: &VectorCandidateSet,
metric: VectorMetric,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
if k == 0 || candidates.is_empty() {
return Ok(Vec::new());
}
self.score_vector_candidate_set_after_initial_check(
property, query, candidates, metric, k, checker,
)
}
fn score_vector_candidate_set_after_initial_check(
&self,
property: &DbString,
query: &VectorValue,
candidates: &VectorCandidateSet,
metric: VectorMetric,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
let scorer = metric.bind_query(query).map_err(GraphError::from)?;
if should_parallelize_candidate_scoring(candidates.len(), k) {
return self
.score_vector_candidate_set_parallel(property, scorer, candidates, k, checker);
}
self.score_vector_candidate_set_serial(property, scorer, candidates, k, checker)
}
pub(super) fn score_vector_candidate_set_serial(
&self,
property: &DbString,
scorer: VectorMetricQuery<'_>,
candidates: &VectorCandidateSet,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
let mut top_k = VectorTopK::new(k);
for (offset, node_id) in candidates.as_nodes().iter().copied().enumerate() {
if offset % VECTOR_SEARCH_CANCEL_STRIDE == 0 {
checker.check()?;
}
let Some(properties) = self.node_properties(node_id) else {
continue;
};
let Some(Value::Vector(vector)) = properties.get(property) else {
continue;
};
let distance = scorer.distance(vector).map_err(GraphError::from)?;
top_k.push_distance(node_id, distance);
}
Ok(vector_node_hits(top_k))
}
fn score_vector_candidate_set_parallel(
&self,
property: &DbString,
scorer: VectorMetricQuery<'_>,
candidates: &VectorCandidateSet,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
let top_k = try_reduce_chunks(
candidates.as_nodes(),
VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES,
checker,
|| VectorTopK::new(k),
|chunk| self.score_vector_candidate_set_chunk(property, scorer, chunk, k),
merge_top_k,
)?;
Ok(vector_node_hits(top_k))
}
fn score_vector_candidate_set_chunk(
&self,
property: &DbString,
scorer: VectorMetricQuery<'_>,
candidates: &[NodeId],
k: usize,
) -> Result<VectorTopK<NodeId>, VectorSearchError> {
let mut top_k = VectorTopK::new(k);
for node_id in candidates.iter().copied() {
let Some(properties) = self.node_properties(node_id) else {
continue;
};
let Some(Value::Vector(vector)) = properties.get(property) else {
continue;
};
let distance = scorer.distance(vector).map_err(GraphError::from)?;
top_k.push_distance(node_id, distance);
}
Ok(top_k)
}
pub fn score_vector_nodes_batch<C>(
&self,
property: &DbString,
queries: &[VectorValue],
candidate_sets: &[C],
metric: VectorMetric,
k: usize,
) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>>
where
C: AsRef<[NodeId]>,
{
self.score_vector_nodes_batch_checked(
property,
queries,
candidate_sets,
metric,
k,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn score_vector_nodes_batch_checked<C>(
&self,
property: &DbString,
queries: &[VectorValue],
candidate_sets: &[C],
metric: VectorMetric,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError>
where
C: AsRef<[NodeId]>,
{
checker.check()?;
validate_batch_inputs(queries, candidate_sets.len())?;
if queries.is_empty() {
return Ok(Vec::new());
}
if k == 0 {
return Ok(vec![Vec::new(); queries.len()]);
}
let mut canonical_sets = Vec::with_capacity(candidate_sets.len());
for candidates in candidate_sets {
checker.check()?;
canonical_sets.push(VectorCandidateSet::from_nodes(
candidates.as_ref().iter().copied(),
));
}
self.score_vector_candidate_sets_batch_checked(
property,
queries,
&canonical_sets,
metric,
k,
checker,
)
}
pub fn score_vector_candidate_sets_batch(
&self,
property: &DbString,
queries: &[VectorValue],
candidate_sets: &[VectorCandidateSet],
metric: VectorMetric,
k: usize,
) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
self.score_vector_candidate_sets_batch_checked(
property,
queries,
candidate_sets,
metric,
k,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn score_vector_candidate_sets_batch_checked(
&self,
property: &DbString,
queries: &[VectorValue],
candidate_sets: &[VectorCandidateSet],
metric: VectorMetric,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
checker.check()?;
validate_batch_inputs(queries, candidate_sets.len())?;
if queries.is_empty() {
return Ok(Vec::new());
}
if k == 0 {
return Ok(vec![Vec::new(); queries.len()]);
}
let should_parallelize_batch =
should_parallelize_candidate_batch_scoring(candidate_sets, k);
if let Some(candidates) = candidate_sets.first()
&& should_parallelize_repeated_candidate_batch(queries.len(), candidates.len(), k)
&& candidate_sets_all_match(candidate_sets)
{
return self.score_repeated_vector_candidate_set_batch_parallel(
property, queries, candidates, metric, k, checker,
);
}
if should_parallelize_batch {
return self.score_vector_candidate_sets_batch_parallel(
property,
queries,
candidate_sets,
metric,
k,
checker,
);
}
if candidate_sets_all_match(candidate_sets) {
return self.score_repeated_vector_candidate_set_batch_serial(
property,
queries,
&candidate_sets[0],
metric,
k,
checker,
);
}
self.score_vector_candidate_sets_batch_grouped_serial(
property,
queries,
candidate_sets,
metric,
k,
checker,
)
}
pub fn score_vector_neighbors(
&self,
property: &DbString,
query: &VectorValue,
anchor: NodeId,
options: VectorNeighborSearchOptions<'_>,
) -> GraphResult<Vec<VectorNodeSearchHit>> {
self.score_vector_neighbors_checked(
property,
query,
anchor,
options,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn score_vector_neighbors_checked(
&self,
property: &DbString,
query: &VectorValue,
anchor: NodeId,
options: VectorNeighborSearchOptions<'_>,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
checker.check()?;
if options.k == 0 {
return Ok(Vec::new());
}
let candidates =
self.vector_neighbor_candidates(anchor, options.edge_label, options.direction);
self.score_vector_candidate_set_checked(
property,
query,
&candidates,
options.metric,
options.k,
checker,
)
}
pub fn score_vector_neighbors_batch(
&self,
property: &DbString,
queries: &[VectorValue],
anchors: &[NodeId],
options: VectorNeighborSearchOptions<'_>,
) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
self.score_vector_neighbors_batch_checked(
property,
queries,
anchors,
options,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn score_vector_neighbors_batch_checked(
&self,
property: &DbString,
queries: &[VectorValue],
anchors: &[NodeId],
options: VectorNeighborSearchOptions<'_>,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
checker.check()?;
validate_batch_inputs(queries, anchors.len())?;
if queries.is_empty() {
return Ok(Vec::new());
}
if options.k == 0 {
return Ok(vec![Vec::new(); queries.len()]);
}
let candidate_sets = self.vector_neighbor_candidate_sets_batch(
anchors,
options.edge_label,
options.direction,
options.k,
checker,
)?;
self.score_vector_candidate_sets_batch_checked(
property,
queries,
&candidate_sets,
options.metric,
options.k,
checker,
)
}
pub fn score_vector_expanded_candidate_sets_batch(
&self,
property: &DbString,
queries: &[VectorValue],
root_sets: &[VectorCandidateSet],
options: VectorNeighborSearchOptions<'_>,
) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
self.score_vector_expanded_candidate_sets_batch_checked(
property,
queries,
root_sets,
options,
CancellationChecker::disabled(),
)
.map_err(VectorSearchError::into_graph_error)
}
pub fn score_vector_expanded_candidate_sets_batch_checked(
&self,
property: &DbString,
queries: &[VectorValue],
root_sets: &[VectorCandidateSet],
options: VectorNeighborSearchOptions<'_>,
checker: CancellationChecker<'_>,
) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
checker.check()?;
validate_batch_inputs(queries, root_sets.len())?;
if queries.is_empty() {
return Ok(Vec::new());
}
if options.k == 0 {
return Ok(vec![Vec::new(); queries.len()]);
}
let expanded_sets = self.expand_vector_candidate_sets_batch(
root_sets,
options.edge_label,
options.direction,
options.k,
checker,
)?;
self.score_vector_candidate_sets_batch_checked(
property,
queries,
&expanded_sets,
options.metric,
options.k,
checker,
)
}
pub fn expand_vector_candidate_sets_batch_checked(
&self,
root_sets: &[VectorCandidateSet],
edge_label: &DbString,
direction: VectorNeighborDirection,
k: usize,
checker: CancellationChecker<'_>,
) -> Result<Vec<VectorCandidateSet>, VectorSearchError> {
self.expand_vector_candidate_sets_batch(root_sets, edge_label, direction, k, checker)
}
#[must_use]
pub fn vector_neighbor_candidates(
&self,
anchor: NodeId,
edge_label: &DbString,
direction: VectorNeighborDirection,
) -> VectorCandidateSet {
let mut candidates = Vec::new();
if matches!(
direction,
VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
) && let Some(entry) = self.outgoing_edges(anchor)
{
candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
}
if matches!(
direction,
VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
) && let Some(entry) = self.incoming_edges(anchor)
{
candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
}
VectorCandidateSet::from_nodes(candidates)
}
#[must_use]
pub fn expand_vector_candidate_set(
&self,
roots: &VectorCandidateSet,
edge_label: &DbString,
direction: VectorNeighborDirection,
) -> VectorCandidateSet {
self.expand_vector_candidate_set_checked(
roots,
edge_label,
direction,
CancellationChecker::disabled(),
)
.expect("disabled cancellation cannot fail")
}
pub fn expand_vector_candidate_set_checked(
&self,
roots: &VectorCandidateSet,
edge_label: &DbString,
direction: VectorNeighborDirection,
checker: CancellationChecker<'_>,
) -> Result<VectorCandidateSet, VectorSearchError> {
checker.check()?;
if roots.is_empty() {
return Ok(VectorCandidateSet::default());
}
let mut candidates = Vec::with_capacity(roots.len());
candidates.extend_from_slice(roots.as_nodes());
for (offset, root) in roots.as_nodes().iter().copied().enumerate() {
if offset % VECTOR_SEARCH_CANCEL_STRIDE == 0 {
checker.check()?;
}
if matches!(
direction,
VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
) && let Some(entry) = self.outgoing_edges(root)
{
candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
}
if matches!(
direction,
VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
) && let Some(entry) = self.incoming_edges(root)
{
candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
}
}
Ok(VectorCandidateSet::from_nodes(candidates))
}
}
fn should_parallelize_candidate_scoring(candidate_count: usize, k: usize) -> bool {
should_parallelize_scan(
candidate_count as u64,
k,
VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES as u64,
)
}
fn validate_batch_inputs(
queries: &[VectorValue],
candidate_set_count: usize,
) -> Result<(), VectorSearchError> {
if queries.len() != candidate_set_count {
return Err(VectorSearchError::BatchLengthMismatch {
queries: queries.len(),
candidate_sets: candidate_set_count,
});
}
let Some(first_query) = queries.first() else {
return Ok(());
};
let first_dimension = first_query.dimension();
for query in &queries[1..] {
if query.dimension() != first_dimension {
return Err(GraphError::from(CoreError::VectorDimensionMismatch {
lhs: first_dimension,
rhs: query.dimension(),
})
.into());
}
}
Ok(())
}