grafeo_core/index/vector/
accessor.rs1use std::sync::Arc;
24
25use grafeo_common::types::{NodeId, PropertyKey, Value};
26
27use crate::graph::GraphStore;
28
29pub trait VectorAccessor: Send + Sync {
34 fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>>;
36}
37
38pub struct PropertyVectorAccessor<'a> {
44 store: &'a dyn GraphStore,
45 property: PropertyKey,
46}
47
48impl<'a> PropertyVectorAccessor<'a> {
49 #[must_use]
51 pub fn new(store: &'a dyn GraphStore, property: impl Into<PropertyKey>) -> Self {
52 Self {
53 store,
54 property: property.into(),
55 }
56 }
57}
58
59impl VectorAccessor for PropertyVectorAccessor<'_> {
60 fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
61 match self.store.get_node_property(id, &self.property) {
62 Some(Value::Vector(v)) => Some(v),
63 _ => None,
64 }
65 }
66}
67
68impl<F> VectorAccessor for F
70where
71 F: Fn(NodeId) -> Option<Arc<[f32]>> + Send + Sync,
72{
73 fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
74 self(id)
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::graph::lpg::LpgStore;
82
83 #[test]
84 fn test_closure_accessor() {
85 let vectors: std::collections::HashMap<NodeId, Arc<[f32]>> = [
86 (NodeId::new(1), Arc::from(vec![1.0_f32, 0.0, 0.0])),
87 (NodeId::new(2), Arc::from(vec![0.0_f32, 1.0, 0.0])),
88 ]
89 .into_iter()
90 .collect();
91
92 let accessor = move |id: NodeId| -> Option<Arc<[f32]>> { vectors.get(&id).cloned() };
93
94 assert!(accessor.get_vector(NodeId::new(1)).is_some());
95 assert_eq!(accessor.get_vector(NodeId::new(1)).unwrap().len(), 3);
96 assert!(accessor.get_vector(NodeId::new(3)).is_none());
97 }
98
99 #[test]
100 fn test_property_vector_accessor() {
101 let store = LpgStore::new().unwrap();
102 let id = store.create_node(&["Test"]);
103 let vec_data: Arc<[f32]> = vec![1.0, 2.0, 3.0].into();
104 store.set_node_property(id, "embedding", Value::Vector(vec_data.clone()));
105
106 let accessor = PropertyVectorAccessor::new(&store, "embedding");
107 let result = accessor.get_vector(id);
108 assert!(result.is_some());
109 assert_eq!(result.unwrap().as_ref(), vec_data.as_ref());
110
111 assert!(accessor.get_vector(NodeId::new(999)).is_none());
113
114 store.set_node_property(id, "name", Value::from("hello"));
116 let name_accessor = PropertyVectorAccessor::new(&store, "name");
117 assert!(name_accessor.get_vector(id).is_none());
118 }
119}