use selene_core::Value;
use selene_graph::{ApproximateVectorSearchOptions, VectorCandidateSet, VectorNeighborDirection};
use super::meta::{StaticOutputColumn, StaticParameter};
use super::vector_candidate_state_common::{
CandidateStateOperation, candidate_state_error, operation_arg,
};
use super::vector_common::{
BatchMismatch, approximate_vector_search_error, cardinality_arg, expansion_direction_arg,
invalid_arg, query_arg, string_arg, vector_search_error,
};
use super::vector_search_ann_defaults::{
ANN_METRIC_DEFAULT_DOC, SEARCH_WIDTH_DEFAULT_DOC, default_metric, default_search_width,
optional_metric_arg, optional_search_width_arg,
};
use crate::procedure_registry::ProcedureError;
use crate::{
GqlType, GraphContext, ProcedureDefaultValue, ProcedureOutputColumn, ProcedureParameter,
ProcedureResult,
};
const PROC_NAME: &str = "selene.vector_search_candidate_state_expanded_ann";
static VECTOR_SEARCH_OUTPUTS: [StaticOutputColumn; 2] = [
StaticOutputColumn::new("node_id", GqlType::NodeRef)
.with_description("Scored composed-candidate node id."),
StaticOutputColumn::new("distance", GqlType::Float64)
.with_description("Lower-is-better distance."),
];
pub(super) fn signature() -> Vec<ProcedureParameter> {
[
StaticParameter::new("label", GqlType::String, false)
.with_description("ANN root node label."),
StaticParameter::new("property", GqlType::String, false).with_description("Property name."),
StaticParameter::new("query", GqlType::Vector, false).with_description("Query vector."),
StaticParameter::new("state_name", GqlType::String, false)
.with_description("Maintained candidate-state name."),
StaticParameter::new("root_k", GqlType::Integer, false)
.with_description("Maximum ANN root count before graph expansion."),
StaticParameter::new("edge_label", GqlType::String, false)
.with_description("Edge label used to expand ANN root candidates."),
StaticParameter::new("k", GqlType::Integer, false)
.with_description("Maximum final result count."),
StaticParameter::new("operation", GqlType::String, false)
.with_description("Candidate-set algebra operation.")
.with_default_doc("intersection")
.with_default(ProcedureDefaultValue::String("intersection")),
StaticParameter::new("direction", GqlType::String, false)
.with_description("Expansion direction: outgoing, incoming, or both.")
.with_default_doc("outgoing")
.with_default(ProcedureDefaultValue::String("outgoing")),
StaticParameter::new("metric", GqlType::String, true)
.with_description(
"Distance metric; NULL uses the matching index metric when available.",
)
.with_default_doc(ANN_METRIC_DEFAULT_DOC)
.with_default(ProcedureDefaultValue::Null),
StaticParameter::new("ef_search", GqlType::Integer, true)
.with_description("ANN search-width hint; NULL uses the index-kind default.")
.with_default_doc(SEARCH_WIDTH_DEFAULT_DOC)
.with_default(ProcedureDefaultValue::Null),
]
.into_iter()
.map(StaticParameter::into_parameter)
.collect()
}
pub(super) fn output_columns() -> Vec<ProcedureOutputColumn> {
VECTOR_SEARCH_OUTPUTS
.iter()
.cloned()
.map(StaticOutputColumn::into_output_column)
.collect()
}
pub(super) fn execute(
ctx: &GraphContext<'_>,
args: &[Value],
) -> Result<ProcedureResult, ProcedureError> {
if !(7..=11).contains(&args.len()) {
return Err(invalid_arg(format!(
"{PROC_NAME} expects 7 to 11 arguments"
)));
}
let label = string_arg(PROC_NAME, &args[0], "label")?;
let property = string_arg(PROC_NAME, &args[1], "property")?;
let query = query_arg(PROC_NAME, &args[2])?;
let state_name = string_arg(PROC_NAME, &args[3], "state_name")?;
let root_k = cardinality_arg(PROC_NAME, &args[4], "root_k")?;
let edge_label = string_arg(PROC_NAME, &args[5], "edge_label")?;
let k = cardinality_arg(PROC_NAME, &args[6], "k")?;
let operation = args
.get(7)
.map(|arg| operation_arg(PROC_NAME, arg))
.transpose()?
.unwrap_or(CandidateStateOperation::Intersection);
let direction = args
.get(8)
.map(|arg| expansion_direction_arg(PROC_NAME, arg))
.transpose()?
.unwrap_or(VectorNeighborDirection::Outgoing);
let metric = args
.get(9)
.map(|arg| optional_metric_arg(PROC_NAME, arg))
.transpose()?
.flatten()
.unwrap_or_else(|| default_metric(ctx.snapshot(), &label, &property, query.dimension()));
let ef_search = args
.get(10)
.map(|value| optional_search_width_arg(PROC_NAME, value))
.transpose()?
.flatten()
.unwrap_or_else(|| {
default_search_width(ctx.snapshot(), &label, &property, query.dimension(), metric)
});
let state = ctx
.vector_candidate_set(&state_name)
.map_err(|error| candidate_state_error(PROC_NAME, error))?
.ok_or_else(|| {
invalid_arg(format!(
"{PROC_NAME} unknown maintained candidate-state set '{}'",
state_name.as_str()
))
})?;
let root_hits = ctx
.snapshot()
.approximate_vector_search_nodes_checked(
&label,
&property,
&query,
ApproximateVectorSearchOptions::new(metric, root_k, ef_search),
ctx.cancellation_checker(),
)
.map_err(|error| {
approximate_vector_search_error(
PROC_NAME,
error,
"ANN candidate-state expanded vector search",
BatchMismatch::Internal(
"ANN candidate-state expanded vector search received batched-only error",
),
)
})?;
let roots = VectorCandidateSet::from_search_hits(&root_hits);
let expanded = ctx
.snapshot()
.expand_vector_candidate_set_checked(
&roots,
&edge_label,
direction,
ctx.cancellation_checker(),
)
.map_err(|error| {
vector_search_error(
error,
"ANN candidate-state expanded vector search",
BatchMismatch::Internal(
"ANN candidate-state expanded vector search received batched-only error",
),
"ANN candidate-state expanded vector search",
)
})?;
let candidates = operation.compose(&state, &expanded);
let hits = ctx
.snapshot()
.score_vector_candidate_set_checked(
&property,
&query,
&candidates,
metric,
k,
ctx.cancellation_checker(),
)
.map_err(|error| {
vector_search_error(
error,
"ANN candidate-state expanded vector search",
BatchMismatch::Internal(
"ANN candidate-state expanded vector search received batched-only error",
),
"ANN candidate-state expanded vector search",
)
})?;
let rows = hits
.into_iter()
.map(|hit| vec![Value::NodeRef(hit.node_id), Value::Float(hit.distance)])
.collect();
Ok(ProcedureResult { rows })
}