grafeo_core/index/vector/
accessor.rs1use std::sync::Arc;
25
26use grafeo_common::types::{NodeId, PropertyKey, Value};
27
28use crate::graph::lpg::LpgStore;
29
30pub trait VectorAccessor: Send + Sync {
35 fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>>;
37}
38
39pub struct PropertyVectorAccessor<'a> {
45 store: &'a LpgStore,
46 property: PropertyKey,
47}
48
49impl<'a> PropertyVectorAccessor<'a> {
50 #[must_use]
52 pub fn new(store: &'a LpgStore, property: impl Into<PropertyKey>) -> Self {
53 Self {
54 store,
55 property: property.into(),
56 }
57 }
58}
59
60impl VectorAccessor for PropertyVectorAccessor<'_> {
61 fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
62 match self.store.get_node_property(id, &self.property) {
63 Some(Value::Vector(v)) => Some(v),
64 _ => None,
65 }
66 }
67}
68
69impl<F> VectorAccessor for F
71where
72 F: Fn(NodeId) -> Option<Arc<[f32]>> + Send + Sync,
73{
74 fn get_vector(&self, id: NodeId) -> Option<Arc<[f32]>> {
75 self(id)
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn test_closure_accessor() {
85 let vectors: std::collections::HashMap<NodeId, Arc<[f32]>> = [
86 (NodeId::new(1), Arc::from(vec![1.0_f32, 0.0, 0.0])),
87 (NodeId::new(2), Arc::from(vec![0.0_f32, 1.0, 0.0])),
88 ]
89 .into_iter()
90 .collect();
91
92 let accessor = move |id: NodeId| -> Option<Arc<[f32]>> { vectors.get(&id).cloned() };
93
94 assert!(accessor.get_vector(NodeId::new(1)).is_some());
95 assert_eq!(accessor.get_vector(NodeId::new(1)).unwrap().len(), 3);
96 assert!(accessor.get_vector(NodeId::new(3)).is_none());
97 }
98
99 #[test]
100 fn test_property_vector_accessor() {
101 let store = LpgStore::new();
102 let id = store.create_node(&["Test"]);
103 let vec_data: Arc<[f32]> = vec![1.0, 2.0, 3.0].into();
104 store.set_node_property(id, "embedding", Value::Vector(vec_data.clone()));
105
106 let accessor = PropertyVectorAccessor::new(&store, "embedding");
107 let result = accessor.get_vector(id);
108 assert!(result.is_some());
109 assert_eq!(result.unwrap().as_ref(), vec_data.as_ref());
110
111 assert!(accessor.get_vector(NodeId::new(999)).is_none());
113
114 store.set_node_property(id, "name", Value::from("hello"));
116 let name_accessor = PropertyVectorAccessor::new(&store, "name");
117 assert!(name_accessor.get_vector(id).is_none());
118 }
119}