use selene_core::{DbString, Value, VectorMetric};
use selene_graph::SeleneGraph;
use crate::procedure_registry::ProcedureError;
use super::vector_common::metric_arg;
pub(super) const ANN_METRIC_DEFAULT_DOC: &str =
"NULL (matching index metric, otherwise squared_euclidean)";
pub(super) const DEFAULT_HNSW_SEARCH_WIDTH: usize = 64;
pub(super) const DEFAULT_IVF_SEARCH_WIDTH: usize = 2;
pub(super) const DEFAULT_TURBO_QUANT_SEARCH_WIDTH: usize = 512;
pub(super) const SEARCH_WIDTH_DEFAULT_DOC: &str = "NULL (HNSW 64, IVF 2, TurboQuant 512)";
pub(super) fn optional_search_width_arg(
proc_name: &str,
value: &Value,
) -> Result<Option<usize>, ProcedureError> {
match value {
Value::Null => Ok(None),
Value::Int(value) if *value >= 0 => usize::try_from(*value)
.map(Some)
.map_err(|_| search_width_too_large(proc_name)),
Value::Uint(value) => usize::try_from(*value)
.map(Some)
.map_err(|_| search_width_too_large(proc_name)),
_ => Err(ProcedureError::InvalidArgument {
detail: format!("{proc_name} ef_search must be NULL or a non-negative INTEGER"),
}),
}
}
pub(super) fn optional_metric_arg(
proc_name: &'static str,
value: &Value,
) -> Result<Option<VectorMetric>, ProcedureError> {
match value {
Value::Null => Ok(None),
_ => metric_arg(proc_name, value).map(Some),
}
}
pub(super) fn default_metric(
graph: &SeleneGraph,
label: &DbString,
property: &DbString,
query_dimension: usize,
) -> VectorMetric {
let Ok(query_dimension) = u32::try_from(query_dimension) else {
return VectorMetric::SquaredEuclidean;
};
graph
.vector_index_for(label, property)
.filter(|index| index.dimension() == query_dimension)
.and_then(|index| index.ann_metric())
.unwrap_or(VectorMetric::SquaredEuclidean)
}
pub(super) fn default_search_width(
graph: &SeleneGraph,
label: &DbString,
property: &DbString,
query_dimension: usize,
metric: VectorMetric,
) -> usize {
let Ok(query_dimension) = u32::try_from(query_dimension) else {
return DEFAULT_HNSW_SEARCH_WIDTH;
};
let Some(index) = graph
.vector_index_for(label, property)
.filter(|index| index.dimension() == query_dimension)
else {
return DEFAULT_HNSW_SEARCH_WIDTH;
};
if index.ann_metric() == Some(metric) && index.is_ivf() {
DEFAULT_IVF_SEARCH_WIDTH
} else if index.ann_metric() == Some(metric) && index.is_turbo_quant() {
DEFAULT_TURBO_QUANT_SEARCH_WIDTH
} else {
DEFAULT_HNSW_SEARCH_WIDTH
}
}
fn search_width_too_large(proc_name: &str) -> ProcedureError {
ProcedureError::InvalidArgument {
detail: format!("{proc_name} ef_search is too large for this platform"),
}
}
#[cfg(test)]
mod tests {
use selene_core::{GraphId, VectorMetric, db_string};
use selene_graph::{SharedGraph, VectorIndexKind};
use super::{
DEFAULT_HNSW_SEARCH_WIDTH, DEFAULT_IVF_SEARCH_WIDTH, DEFAULT_TURBO_QUANT_SEARCH_WIDTH,
default_metric, default_search_width,
};
fn graph_with_index(kind: VectorIndexKind) -> SharedGraph {
graph_with_index_dimension(kind, 2)
}
fn graph_with_index_dimension(kind: VectorIndexKind, dimension: u32) -> SharedGraph {
let graph = SharedGraph::new(GraphId::new(431_001));
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let mut txn = graph.begin_write();
txn.mutator()
.create_vector_index(label, property, kind, dimension)
.expect("vector index creates");
txn.commit().expect("index creation commits");
graph
}
#[test]
fn default_search_width_selects_ivf_width_for_matching_ivf_index() {
let graph = graph_with_index(VectorIndexKind::IvfSquaredEuclidean);
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let snapshot = graph.read();
assert_eq!(
default_search_width(
&snapshot,
&label,
&property,
2,
VectorMetric::SquaredEuclidean
),
DEFAULT_IVF_SEARCH_WIDTH
);
}
#[test]
fn default_search_width_keeps_hnsw_width_for_matching_hnsw_index() {
let graph = graph_with_index(VectorIndexKind::HnswSquaredEuclidean);
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let snapshot = graph.read();
assert_eq!(
default_search_width(
&snapshot,
&label,
&property,
2,
VectorMetric::SquaredEuclidean
),
DEFAULT_HNSW_SEARCH_WIDTH
);
}
#[test]
fn default_search_width_keeps_hnsw_width_without_matching_ivf_index() {
let graph = graph_with_index(VectorIndexKind::IvfCosine);
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let snapshot = graph.read();
assert_eq!(
default_search_width(&snapshot, &label, &property, 3, VectorMetric::Cosine),
DEFAULT_HNSW_SEARCH_WIDTH
);
assert_eq!(
default_search_width(
&snapshot,
&label,
&property,
2,
VectorMetric::SquaredEuclidean
),
DEFAULT_HNSW_SEARCH_WIDTH
);
}
#[test]
fn default_search_width_selects_turbo_quant_width_for_matching_index() {
let graph = graph_with_index(VectorIndexKind::TurboQuantCosine);
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let snapshot = graph.read();
assert_eq!(
default_search_width(&snapshot, &label, &property, 2, VectorMetric::Cosine),
DEFAULT_TURBO_QUANT_SEARCH_WIDTH
);
}
#[test]
fn default_search_width_selects_turbo_quant_width_for_high_dimensions() {
let graph = graph_with_index_dimension(VectorIndexKind::TurboQuantCosine, 1536);
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let snapshot = graph.read();
assert_eq!(
default_search_width(&snapshot, &label, &property, 1536, VectorMetric::Cosine),
DEFAULT_TURBO_QUANT_SEARCH_WIDTH
);
}
#[test]
fn default_metric_selects_registered_ann_metric() {
let graph = graph_with_index(VectorIndexKind::TurboQuantCosine);
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let snapshot = graph.read();
assert_eq!(
default_metric(&snapshot, &label, &property, 2),
VectorMetric::Cosine
);
}
#[test]
fn default_metric_falls_back_without_matching_ann_index() {
let graph = graph_with_index(VectorIndexKind::Flat);
let label = db_string("VectorDoc").expect("label fits DB string cap");
let property = db_string("embedding").expect("property fits DB string cap");
let snapshot = graph.read();
assert_eq!(
default_metric(&snapshot, &label, &property, 2),
VectorMetric::SquaredEuclidean
);
assert_eq!(
default_metric(&snapshot, &label, &property, 3),
VectorMetric::SquaredEuclidean
);
}
}