use std::sync::Arc;
use grafeo_common::types::{NodeId, PropertyKey, Value};
use crate::graph::GraphStore;
pub trait VectorAccessor: Send + Sync {
fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>>;
}
pub struct PropertyVectorAccessor<'a> {
store: &'a dyn GraphStore,
property: PropertyKey,
}
impl<'a> PropertyVectorAccessor<'a> {
#[must_use]
pub fn new(store: &'a dyn GraphStore, property: impl Into<PropertyKey>) -> Self {
Self {
store,
property: property.into(),
}
}
}
impl VectorAccessor for PropertyVectorAccessor<'_> {
fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
match self.store.get_node_property(id, &self.property) {
Some(Value::Vector(v)) => Some(v),
_ => None,
}
}
}
pub struct SpillableVectorAccessor<'a> {
store: &'a dyn GraphStore,
property: PropertyKey,
spill_storage: Arc<dyn super::storage::VectorStorage>,
}
impl<'a> SpillableVectorAccessor<'a> {
#[must_use]
pub fn new(
store: &'a dyn GraphStore,
property: impl Into<PropertyKey>,
spill_storage: Arc<dyn super::storage::VectorStorage>,
) -> Self {
Self {
store,
property: property.into(),
spill_storage,
}
}
}
impl VectorAccessor for SpillableVectorAccessor<'_> {
fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
if let Some(v) = self.spill_storage.get(id) {
return Some(v);
}
match self.store.get_node_property(id, &self.property) {
Some(Value::Vector(v)) => Some(v),
_ => None,
}
}
}
#[non_exhaustive]
pub enum VectorAccessorKind<'a> {
Property(PropertyVectorAccessor<'a>),
Spilled(SpillableVectorAccessor<'a>),
}
impl VectorAccessor for VectorAccessorKind<'_> {
fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
match self {
Self::Property(a) => a.get_vector(id),
Self::Spilled(a) => a.get_vector(id),
}
}
}
impl<F> VectorAccessor for F
where
F: Fn(NodeId) -> Option<Arc<[f32]>> + Send + Sync,
{
fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
self(id)
}
}
#[cfg(all(test, feature = "lpg"))]
mod tests {
use super::*;
use crate::graph::lpg::LpgStore;
#[test]
fn test_closure_accessor() {
let vectors: std::collections::HashMap<NodeId, Arc<[f32]>> = [
(NodeId::new(1), Arc::from(vec![1.0_f32, 0.0, 0.0])),
(NodeId::new(2), Arc::from(vec![0.0_f32, 1.0, 0.0])),
]
.into_iter()
.collect();
let accessor = move |id: NodeId| -> Option<Arc<[f32]>> { vectors.get(&id).cloned() };
assert!(accessor.get_vector(NodeId::new(1)).is_some());
assert_eq!(accessor.get_vector(NodeId::new(1)).unwrap().len(), 3);
assert!(accessor.get_vector(NodeId::new(3)).is_none());
}
#[test]
fn test_property_vector_accessor() {
let store = LpgStore::new().unwrap();
let id = store.create_node(&["Test"]);
let vec_data: Arc<[f32]> = vec![1.0, 2.0, 3.0].into();
store.set_node_property(id, "embedding", Value::Vector(vec_data.clone()));
let accessor = PropertyVectorAccessor::new(&store, "embedding");
let result = accessor.get_vector(id);
assert!(result.is_some());
assert_eq!(result.unwrap().as_ref(), vec_data.as_ref());
assert!(accessor.get_vector(NodeId::new(999)).is_none());
store.set_node_property(id, "name", Value::from("hello"));
let name_accessor = PropertyVectorAccessor::new(&store, "name");
assert!(name_accessor.get_vector(id).is_none());
}
}
#[cfg(all(test, feature = "lpg", feature = "vector-index"))]
mod spill_tests {
use super::*;
use crate::graph::lpg::LpgStore;
use crate::index::vector::storage::{RamStorage, VectorStorage};
#[test]
fn spill_accessor_returns_vector_from_spill_storage() {
let store = LpgStore::new().unwrap();
let alix_id = store.create_node(&["Person"]);
let spill_vec: Vec<f32> = vec![0.1, 0.2, 0.3];
let spill = Arc::new(RamStorage::new(3));
spill.insert(alix_id, &spill_vec).unwrap();
let accessor = SpillableVectorAccessor::new(&store as &dyn GraphStore, "embedding", spill);
let result = accessor.get_vector(alix_id);
assert!(result.is_some());
assert_eq!(result.unwrap().as_ref(), spill_vec.as_slice());
}
#[test]
fn spill_accessor_falls_back_to_property_store() {
let store = LpgStore::new().unwrap();
let gus_id = store.create_node(&["Person"]);
let prop_vec: Arc<[f32]> = vec![0.4, 0.5, 0.6].into();
store.set_node_property(gus_id, "embedding", Value::Vector(prop_vec.clone()));
let spill: Arc<dyn VectorStorage> = Arc::new(RamStorage::new(3));
let accessor = SpillableVectorAccessor::new(&store as &dyn GraphStore, "embedding", spill);
let result = accessor.get_vector(gus_id);
assert!(result.is_some());
assert_eq!(result.unwrap().as_ref(), prop_vec.as_ref());
}
#[test]
fn spill_accessor_prefers_spill_over_property_store() {
let store = LpgStore::new().unwrap();
let vincent_id = store.create_node(&["Person"]);
let prop_vec: Arc<[f32]> = vec![1.0, 0.0, 0.0].into();
store.set_node_property(vincent_id, "embedding", Value::Vector(prop_vec));
let spill_vec: Vec<f32> = vec![0.0, 1.0, 0.0];
let spill = Arc::new(RamStorage::new(3));
spill.insert(vincent_id, &spill_vec).unwrap();
let accessor = SpillableVectorAccessor::new(&store as &dyn GraphStore, "embedding", spill);
let result = accessor.get_vector(vincent_id).unwrap();
assert_eq!(result.as_ref(), spill_vec.as_slice());
}
#[test]
fn spill_accessor_returns_none_when_missing() {
let store = LpgStore::new().unwrap();
let spill: Arc<dyn VectorStorage> = Arc::new(RamStorage::new(3));
let accessor = SpillableVectorAccessor::new(&store as &dyn GraphStore, "embedding", spill);
assert!(accessor.get_vector(NodeId::new(999)).is_none());
}
#[test]
fn accessor_kind_property_dispatches() {
let store = LpgStore::new().unwrap();
let jules_id = store.create_node(&["Person"]);
let vec_data: Arc<[f32]> = vec![0.7, 0.8, 0.9].into();
store.set_node_property(jules_id, "embedding", Value::Vector(vec_data.clone()));
let accessor =
VectorAccessorKind::Property(PropertyVectorAccessor::new(&store, "embedding"));
let result = accessor.get_vector(jules_id);
assert!(result.is_some());
assert_eq!(result.unwrap().as_ref(), vec_data.as_ref());
assert!(accessor.get_vector(NodeId::new(999)).is_none());
}
#[test]
fn accessor_kind_spilled_dispatches() {
let store = LpgStore::new().unwrap();
let mia_id = store.create_node(&["Person"]);
let spill_vec: Vec<f32> = vec![0.3, 0.6, 0.9];
let spill = Arc::new(RamStorage::new(3));
spill.insert(mia_id, &spill_vec).unwrap();
let accessor = VectorAccessorKind::Spilled(SpillableVectorAccessor::new(
&store as &dyn GraphStore,
"embedding",
spill,
));
let result = accessor.get_vector(mia_id);
assert!(result.is_some());
assert_eq!(result.unwrap().as_ref(), spill_vec.as_slice());
}
#[test]
fn accessor_kind_spilled_uses_fallback() {
let store = LpgStore::new().unwrap();
let butch_id = store.create_node(&["Person"]);
let prop_vec: Arc<[f32]> = vec![0.2, 0.4, 0.6].into();
store.set_node_property(butch_id, "embedding", Value::Vector(prop_vec.clone()));
let spill: Arc<dyn VectorStorage> = Arc::new(RamStorage::new(3));
let accessor = VectorAccessorKind::Spilled(SpillableVectorAccessor::new(
&store as &dyn GraphStore,
"embedding",
spill,
));
let result = accessor.get_vector(butch_id);
assert!(result.is_some());
assert_eq!(result.unwrap().as_ref(), prop_vec.as_ref());
}
}