use super::{Operator, OperatorError, OperatorResult};
use crate::execution::DataChunk;
use crate::graph::GraphStore;
use crate::index::vector::DistanceMetric;
use grafeo_common::types::{LogicalType, NodeId, PropertyKey, Value};
use std::sync::Arc;
#[cfg(feature = "vector-index")]
use crate::index::vector::VectorIndexKind;
pub struct VectorScanOperator {
store: Arc<dyn GraphStore>,
#[cfg(feature = "vector-index")]
index: Option<Arc<VectorIndexKind>>,
query: Vec<f32>,
k: usize,
metric: DistanceMetric,
property: String,
label: Option<String>,
min_similarity: Option<f32>,
max_distance: Option<f32>,
ef: usize,
results: Vec<(NodeId, f32)>,
position: usize,
executed: bool,
chunk_capacity: usize,
uses_index: bool,
}
impl VectorScanOperator {
#[cfg(feature = "vector-index")]
#[must_use]
pub fn with_index(
store: Arc<dyn GraphStore>,
index: Arc<VectorIndexKind>,
query: Vec<f32>,
k: usize,
) -> Self {
Self {
store,
index: Some(index),
query,
k,
metric: DistanceMetric::Cosine,
property: String::new(),
label: None,
min_similarity: None,
max_distance: None,
ef: 64, results: Vec::new(),
position: 0,
executed: false,
chunk_capacity: 2048,
uses_index: true,
}
}
#[must_use]
pub fn with_property(mut self, property: impl Into<String>) -> Self {
self.property = property.into();
self
}
#[must_use]
pub fn brute_force(
store: Arc<dyn GraphStore>,
property: impl Into<String>,
query: Vec<f32>,
k: usize,
metric: DistanceMetric,
) -> Self {
Self {
store,
#[cfg(feature = "vector-index")]
index: None,
query,
k,
metric,
property: property.into(),
label: None,
min_similarity: None,
max_distance: None,
ef: 64,
results: Vec::new(),
position: 0,
executed: false,
chunk_capacity: 2048,
uses_index: false,
}
}
#[must_use]
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
#[must_use]
pub fn with_ef(mut self, ef: usize) -> Self {
self.ef = ef;
self
}
#[must_use]
pub fn with_min_similarity(mut self, threshold: f32) -> Self {
self.min_similarity = Some(threshold);
self
}
#[must_use]
pub fn with_max_distance(mut self, threshold: f32) -> Self {
self.max_distance = Some(threshold);
self
}
#[must_use]
pub fn with_chunk_capacity(mut self, capacity: usize) -> Self {
self.chunk_capacity = capacity;
self
}
fn execute_search(&mut self) {
if self.executed {
return;
}
self.executed = true;
#[cfg(feature = "vector-index")]
{
if let Some(ref index) = self.index {
let accessor = crate::index::vector::PropertyVectorAccessor::new(
&*self.store,
&*self.property,
);
self.results = index.search_with_ef(&self.query, self.k, self.ef, &accessor);
self.apply_filters();
return;
}
}
self.results = self.brute_force_search();
self.apply_filters();
}
fn brute_force_search(&self) -> Vec<(NodeId, f32)> {
use crate::index::vector::brute_force_knn;
let node_ids = match &self.label {
Some(label) => self.store.nodes_by_label(label),
None => self.store.node_ids(),
};
let vectors: Vec<(NodeId, Vec<f32>)> = node_ids
.into_iter()
.filter_map(|id| {
self.store
.get_node_property(id, &PropertyKey::new(&self.property))
.and_then(|v| {
if let Value::Vector(vec) = v {
Some((id, vec.to_vec()))
} else {
None
}
})
})
.collect();
let iter = vectors.iter().map(|(id, v)| (*id, v.as_slice()));
brute_force_knn(iter, &self.query, self.k, self.metric)
}
fn apply_filters(&mut self) {
if self.min_similarity.is_none() && self.max_distance.is_none() {
return;
}
self.results.retain(|(_, distance)| {
let passes_similarity = match self.min_similarity {
Some(threshold) if self.metric == DistanceMetric::Cosine => {
let similarity = 1.0 - distance;
similarity >= threshold
}
Some(_) => true, None => true,
};
let passes_distance = match self.max_distance {
Some(threshold) => *distance <= threshold,
None => true,
};
passes_similarity && passes_distance
});
}
}
impl Operator for VectorScanOperator {
fn next(&mut self) -> OperatorResult {
self.execute_search();
if self.position >= self.results.len() {
return Ok(None);
}
let schema = [LogicalType::Node, LogicalType::Float64];
let mut chunk = DataChunk::with_capacity(&schema, self.chunk_capacity);
let end = (self.position + self.chunk_capacity).min(self.results.len());
let count = end - self.position;
{
let node_col = chunk
.column_mut(0)
.ok_or_else(|| OperatorError::ColumnNotFound("node column".into()))?;
for i in self.position..end {
let (node_id, _) = self.results[i];
node_col.push_node_id(node_id);
}
}
{
let dist_col = chunk
.column_mut(1)
.ok_or_else(|| OperatorError::ColumnNotFound("distance column".into()))?;
for i in self.position..end {
let (_, distance) = self.results[i];
dist_col.push_float64(f64::from(distance));
}
}
chunk.set_count(count);
self.position = end;
Ok(Some(chunk))
}
fn reset(&mut self) {
self.position = 0;
self.results.clear();
self.executed = false;
}
fn name(&self) -> &'static str {
if self.uses_index {
"VectorScan(HNSW)"
} else {
"VectorScan(BruteForce)"
}
}
}
#[cfg(all(test, feature = "lpg"))]
mod tests {
use super::*;
use crate::graph::lpg::LpgStore;
#[test]
fn test_vector_scan_brute_force() {
let store = Arc::new(LpgStore::new().unwrap());
let n1 = store.create_node(&["Document"]);
let n2 = store.create_node(&["Document"]);
let n3 = store.create_node(&["Document"]);
store.set_node_property(n1, "embedding", Value::Vector(vec![0.1, 0.2, 0.3].into()));
store.set_node_property(n2, "embedding", Value::Vector(vec![0.5, 0.6, 0.7].into()));
store.set_node_property(n3, "embedding", Value::Vector(vec![0.9, 0.8, 0.7].into()));
let query = vec![0.1, 0.2, 0.35];
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"embedding",
query,
2, DistanceMetric::Euclidean,
)
.with_label("Document");
let chunk = scan.next().unwrap().unwrap();
assert_eq!(chunk.row_count(), 2);
let first_node = chunk.column(0).unwrap().get_node_id(0);
assert_eq!(first_node, Some(n1));
assert!(scan.next().unwrap().is_none());
}
#[test]
fn test_vector_scan_reset() {
let store = Arc::new(LpgStore::new().unwrap());
let n1 = store.create_node(&["Doc"]);
store.set_node_property(n1, "vec", Value::Vector(vec![0.1, 0.2].into()));
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"vec",
vec![0.1, 0.2],
10,
DistanceMetric::Cosine,
);
let chunk1 = scan.next().unwrap().unwrap();
assert_eq!(chunk1.row_count(), 1);
assert!(scan.next().unwrap().is_none());
scan.reset();
let chunk2 = scan.next().unwrap().unwrap();
assert_eq!(chunk2.row_count(), 1);
}
#[test]
fn test_vector_scan_with_distance_filter() {
let store = Arc::new(LpgStore::new().unwrap());
let n1 = store.create_node(&["Doc"]);
let n2 = store.create_node(&["Doc"]);
store.set_node_property(n1, "vec", Value::Vector(vec![0.1, 0.0].into()));
store.set_node_property(n2, "vec", Value::Vector(vec![10.0, 10.0].into()));
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"vec",
vec![0.0, 0.0],
10,
DistanceMetric::Euclidean,
)
.with_max_distance(1.0);
let chunk = scan.next().unwrap().unwrap();
assert_eq!(chunk.row_count(), 1);
let node_id = chunk.column(0).unwrap().get_node_id(0);
assert_eq!(node_id, Some(n1));
}
#[test]
fn test_vector_scan_empty_results() {
let store = Arc::new(LpgStore::new().unwrap());
store.create_node(&["Doc"]);
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"embedding",
vec![0.1, 0.2],
10,
DistanceMetric::Cosine,
);
let result = scan.next().unwrap();
assert!(result.is_none());
}
#[test]
fn test_vector_scan_name() {
let store = Arc::new(LpgStore::new().unwrap());
let brute_scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"vec",
vec![0.1],
10,
DistanceMetric::Cosine,
);
assert_eq!(brute_scan.name(), "VectorScan(BruteForce)");
}
#[test]
fn test_vector_scan_with_min_similarity() {
let store = Arc::new(LpgStore::new().unwrap());
let n1 = store.create_node(&["Doc"]);
let n2 = store.create_node(&["Doc"]);
store.set_node_property(n1, "vec", Value::Vector(vec![1.0, 0.0].into()));
store.set_node_property(n2, "vec", Value::Vector(vec![0.707, 0.707].into()));
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"vec",
vec![0.0, 1.0], 10,
DistanceMetric::Cosine,
)
.with_min_similarity(0.5);
let chunk = scan.next().unwrap().unwrap();
assert_eq!(chunk.row_count(), 1);
let node_id = chunk.column(0).unwrap().get_node_id(0);
assert_eq!(node_id, Some(n2));
}
#[test]
fn test_vector_scan_with_ef() {
let store = Arc::new(LpgStore::new().unwrap());
let n1 = store.create_node(&["Doc"]);
store.set_node_property(n1, "vec", Value::Vector(vec![0.1, 0.2].into()));
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"vec",
vec![0.1, 0.2],
10,
DistanceMetric::Cosine,
)
.with_ef(128);
let chunk = scan.next().unwrap().unwrap();
assert_eq!(chunk.row_count(), 1);
}
#[test]
fn test_vector_scan_with_chunk_capacity() {
let store = Arc::new(LpgStore::new().unwrap());
for i in 0..10 {
let node = store.create_node(&["Doc"]);
store.set_node_property(node, "vec", Value::Vector(vec![i as f32, 0.0].into()));
}
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"vec",
vec![0.0, 0.0],
10,
DistanceMetric::Euclidean,
)
.with_chunk_capacity(3);
let chunk1 = scan.next().unwrap().unwrap();
assert_eq!(chunk1.row_count(), 3);
let chunk2 = scan.next().unwrap().unwrap();
assert_eq!(chunk2.row_count(), 3);
let chunk3 = scan.next().unwrap().unwrap();
assert_eq!(chunk3.row_count(), 3);
let chunk4 = scan.next().unwrap().unwrap();
assert_eq!(chunk4.row_count(), 1);
assert!(scan.next().unwrap().is_none());
}
#[test]
fn test_vector_scan_no_label_filter() {
let store = Arc::new(LpgStore::new().unwrap());
let n1 = store.create_node(&["TypeA"]);
let n2 = store.create_node(&["TypeB"]);
store.set_node_property(n1, "vec", Value::Vector(vec![0.1, 0.2].into()));
store.set_node_property(n2, "vec", Value::Vector(vec![0.3, 0.4].into()));
let mut scan = VectorScanOperator::brute_force(
Arc::clone(&store) as Arc<dyn GraphStore>,
"vec",
vec![0.0, 0.0],
10,
DistanceMetric::Euclidean,
);
let chunk = scan.next().unwrap().unwrap();
assert_eq!(chunk.row_count(), 2);
}
#[cfg(feature = "vector-index")]
#[test]
fn test_vector_scan_with_hnsw_index() {
use crate::index::vector::{
HnswConfig, HnswIndex, PropertyVectorAccessor, VectorIndexKind,
};
let store = Arc::new(LpgStore::new().unwrap());
let n1 = store.create_node(&["Doc"]);
let n2 = store.create_node(&["Doc"]);
let n3 = store.create_node(&["Doc"]);
let v1 = vec![0.1f32, 0.2, 0.3];
let v2 = vec![0.5, 0.6, 0.7];
let v3 = vec![0.9, 0.8, 0.7];
store.set_node_property(n1, "vec", Value::Vector(v1.clone().into()));
store.set_node_property(n2, "vec", Value::Vector(v2.clone().into()));
store.set_node_property(n3, "vec", Value::Vector(v3.clone().into()));
let config = HnswConfig::new(3, DistanceMetric::Euclidean);
let hnsw = HnswIndex::new(config);
let accessor = PropertyVectorAccessor::new(&*store, "vec");
hnsw.insert(n1, &v1, &accessor);
hnsw.insert(n2, &v2, &accessor);
hnsw.insert(n3, &v3, &accessor);
let index = Arc::new(VectorIndexKind::Hnsw(hnsw));
let query = vec![0.1f32, 0.2, 0.35];
let mut scan = VectorScanOperator::with_index(
Arc::clone(&store) as Arc<dyn GraphStore>,
Arc::clone(&index),
query,
2,
)
.with_property("vec");
assert_eq!(scan.name(), "VectorScan(HNSW)");
let chunk = scan.next().unwrap().unwrap();
assert_eq!(chunk.row_count(), 2);
let first_node = chunk.column(0).unwrap().get_node_id(0);
assert_eq!(first_node, Some(n1));
}
}