graphmind_sdk/
vector_ext.rs1use async_trait::async_trait;
8
9use graphmind::graph::NodeId;
10use graphmind::vector::DistanceMetric;
11
12use crate::embedded::EmbeddedClient;
13use crate::error::{GraphmindError, GraphmindResult};
14
15#[async_trait]
19pub trait VectorClient {
20 async fn create_vector_index(
22 &self,
23 label: &str,
24 property: &str,
25 dimensions: usize,
26 metric: DistanceMetric,
27 ) -> GraphmindResult<()>;
28
29 async fn add_vector(
31 &self,
32 label: &str,
33 property: &str,
34 node_id: NodeId,
35 vector: &[f32],
36 ) -> GraphmindResult<()>;
37
38 async fn vector_search(
40 &self,
41 label: &str,
42 property: &str,
43 query_vec: &[f32],
44 k: usize,
45 ) -> GraphmindResult<Vec<(NodeId, f32)>>;
46}
47
48#[async_trait]
49impl VectorClient for EmbeddedClient {
50 async fn create_vector_index(
51 &self,
52 label: &str,
53 property: &str,
54 dimensions: usize,
55 metric: DistanceMetric,
56 ) -> GraphmindResult<()> {
57 let store = self.store.read().await;
58 store
59 .create_vector_index(label, property, dimensions, metric)
60 .map_err(|e| GraphmindError::VectorError(e.to_string()))
61 }
62
63 async fn add_vector(
64 &self,
65 label: &str,
66 property: &str,
67 node_id: NodeId,
68 vector: &[f32],
69 ) -> GraphmindResult<()> {
70 let store = self.store.read().await;
71 store
72 .vector_index
73 .add_vector(label, property, node_id, &vector.to_vec())
74 .map_err(|e| GraphmindError::VectorError(e.to_string()))
75 }
76
77 async fn vector_search(
78 &self,
79 label: &str,
80 property: &str,
81 query_vec: &[f32],
82 k: usize,
83 ) -> GraphmindResult<Vec<(NodeId, f32)>> {
84 let store = self.store.read().await;
85 store
86 .vector_search(label, property, query_vec, k)
87 .map_err(|e| GraphmindError::VectorError(e.to_string()))
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::{EmbeddedClient, GraphmindClient};
95
96 #[tokio::test]
97 async fn test_vector_index_create_and_search() {
98 let client = EmbeddedClient::new();
99
100 client
102 .create_vector_index("Doc", "embedding", 4, DistanceMetric::Cosine)
103 .await
104 .unwrap();
105
106 client
108 .query("default", r#"CREATE (d:Doc {title: "Alpha"})"#)
109 .await
110 .unwrap();
111 client
112 .query("default", r#"CREATE (d:Doc {title: "Beta"})"#)
113 .await
114 .unwrap();
115
116 let store = client.store().read().await;
118 let nodes: Vec<_> = store.all_nodes().iter().map(|n| n.id).collect();
119 drop(store);
120
121 client
122 .add_vector("Doc", "embedding", nodes[0], &[1.0, 0.0, 0.0, 0.0])
123 .await
124 .unwrap();
125 client
126 .add_vector("Doc", "embedding", nodes[1], &[0.0, 1.0, 0.0, 0.0])
127 .await
128 .unwrap();
129
130 let results = client
132 .vector_search("Doc", "embedding", &[1.0, 0.1, 0.0, 0.0], 2)
133 .await
134 .unwrap();
135 assert_eq!(results.len(), 2);
136 assert_eq!(results[0].0, nodes[0]);
138 }
139}