Skip to main content

graphmind_sdk/
vector_ext.rs

1//! VectorClient — extension trait for vector search operations (EmbeddedClient only)
2//!
3//! Provides vector index creation, insertion, and k-NN search via the
4//! `VectorClient` trait. Only `EmbeddedClient` implements this trait
5//! since vector operations require direct in-process access to the graph store.
6
7use async_trait::async_trait;
8
9use graphmind::graph::NodeId;
10use graphmind::vector::DistanceMetric;
11
12use crate::embedded::EmbeddedClient;
13use crate::error::{GraphmindError, GraphmindResult};
14
15/// Extension trait for vector search operations.
16///
17/// Only implemented by `EmbeddedClient` since vector ops need direct store access.
18#[async_trait]
19pub trait VectorClient {
20    /// Create a vector index for a given label and property.
21    async fn create_vector_index(
22        &self,
23        label: &str,
24        property: &str,
25        dimensions: usize,
26        metric: DistanceMetric,
27    ) -> GraphmindResult<()>;
28
29    /// Add a vector to the index for a given node.
30    async fn add_vector(
31        &self,
32        label: &str,
33        property: &str,
34        node_id: NodeId,
35        vector: &[f32],
36    ) -> GraphmindResult<()>;
37
38    /// Search for the k nearest neighbors to a query vector.
39    async fn vector_search(
40        &self,
41        label: &str,
42        property: &str,
43        query_vec: &[f32],
44        k: usize,
45    ) -> GraphmindResult<Vec<(NodeId, f32)>>;
46}
47
48#[async_trait]
49impl VectorClient for EmbeddedClient {
50    async fn create_vector_index(
51        &self,
52        label: &str,
53        property: &str,
54        dimensions: usize,
55        metric: DistanceMetric,
56    ) -> GraphmindResult<()> {
57        let store = self.store.read().await;
58        store
59            .create_vector_index(label, property, dimensions, metric)
60            .map_err(|e| GraphmindError::VectorError(e.to_string()))
61    }
62
63    async fn add_vector(
64        &self,
65        label: &str,
66        property: &str,
67        node_id: NodeId,
68        vector: &[f32],
69    ) -> GraphmindResult<()> {
70        let store = self.store.read().await;
71        store
72            .vector_index
73            .add_vector(label, property, node_id, &vector.to_vec())
74            .map_err(|e| GraphmindError::VectorError(e.to_string()))
75    }
76
77    async fn vector_search(
78        &self,
79        label: &str,
80        property: &str,
81        query_vec: &[f32],
82        k: usize,
83    ) -> GraphmindResult<Vec<(NodeId, f32)>> {
84        let store = self.store.read().await;
85        store
86            .vector_search(label, property, query_vec, k)
87            .map_err(|e| GraphmindError::VectorError(e.to_string()))
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::{EmbeddedClient, GraphmindClient};
95
96    #[tokio::test]
97    async fn test_vector_index_create_and_search() {
98        let client = EmbeddedClient::new();
99
100        // Create index
101        client
102            .create_vector_index("Doc", "embedding", 4, DistanceMetric::Cosine)
103            .await
104            .unwrap();
105
106        // Create some nodes
107        client
108            .query("default", r#"CREATE (d:Doc {title: "Alpha"})"#)
109            .await
110            .unwrap();
111        client
112            .query("default", r#"CREATE (d:Doc {title: "Beta"})"#)
113            .await
114            .unwrap();
115
116        // Add vectors
117        let store = client.store().read().await;
118        let nodes: Vec<_> = store.all_nodes().iter().map(|n| n.id).collect();
119        drop(store);
120
121        client
122            .add_vector("Doc", "embedding", nodes[0], &[1.0, 0.0, 0.0, 0.0])
123            .await
124            .unwrap();
125        client
126            .add_vector("Doc", "embedding", nodes[1], &[0.0, 1.0, 0.0, 0.0])
127            .await
128            .unwrap();
129
130        // Search
131        let results = client
132            .vector_search("Doc", "embedding", &[1.0, 0.1, 0.0, 0.0], 2)
133            .await
134            .unwrap();
135        assert_eq!(results.len(), 2);
136        // First result should be closest to query
137        assert_eq!(results[0].0, nodes[0]);
138    }
139}