nt_memory/agentdb/
vector_store.rs1use crate::Result;
4use nt_agentdb_client::{AgentDBClient, BatchDocument, CollectionConfig};
5use serde::{Serialize, Deserialize};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9pub struct VectorStore {
11 client: Arc<AgentDBClient>,
12 collections: Arc<RwLock<std::collections::HashSet<String>>>,
13}
14
15impl VectorStore {
16 pub async fn new(base_url: &str) -> anyhow::Result<Self> {
18 let client = AgentDBClient::new(base_url.to_string());
19
20 client.health_check().await?;
22
23 Ok(Self {
24 client: Arc::new(client),
25 collections: Arc::new(RwLock::new(std::collections::HashSet::new())),
26 })
27 }
28
29 pub async fn ensure_collection(
31 &self,
32 name: &str,
33 dimension: usize,
34 ) -> anyhow::Result<()> {
35 let mut collections = self.collections.write().await;
36
37 if collections.contains(name) {
38 return Ok(());
39 }
40
41 use nt_agentdb_client::client::CollectionConfig;
43
44 let config = CollectionConfig {
45 name: name.to_string(),
46 dimension,
47 index_type: "hnsw".to_string(),
48 metadata_schema: None,
49 };
50
51 self.client.create_collection(config).await
52 .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))?;
53 collections.insert(name.to_string());
54
55 Ok(())
56 }
57
58 pub async fn insert<T: Serialize>(
60 &self,
61 collection: &str,
62 id: &str,
63 embedding: Vec<f32>,
64 metadata: Option<T>,
65 ) -> anyhow::Result<()> {
66 let id_bytes = id.as_bytes();
67
68 self.client
69 .insert(collection, id_bytes, &embedding, metadata.as_ref())
70 .await?;
71
72 Ok(())
73 }
74
75 pub async fn batch_insert<T: Serialize>(
77 &self,
78 collection: &str,
79 documents: Vec<(String, Vec<f32>, Option<T>)>,
80 ) -> anyhow::Result<usize> {
81 use nt_agentdb_client::client::BatchDocument;
83
84 let batch: Vec<BatchDocument<T>> = documents
85 .into_iter()
86 .map(|(id, embedding, metadata)| BatchDocument {
87 id: id.into_bytes(),
88 embedding,
89 metadata,
90 })
91 .collect();
92
93 let response = self.client.batch_insert(collection, batch).await
94 .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))?;
95 Ok(response.inserted)
96 }
97
98 pub async fn search(
100 &self,
101 collection: &str,
102 query_embedding: Vec<f32>,
103 top_k: usize,
104 ) -> Result<Vec<(String, f32)>> {
105 use nt_agentdb_client::VectorQuery;
106
107 let query = VectorQuery::new(
108 collection.to_string(),
109 query_embedding,
110 top_k,
111 );
112
113 let results: Vec<SearchResult> = self
114 .client
115 .vector_search(query)
116 .await
117 .map_err(|e| crate::MemoryError::VectorDB(e.to_string()))?;
118
119 Ok(results
120 .into_iter()
121 .map(|r| (r.id, r.score))
122 .collect())
123 }
124
125 pub async fn get<T: for<'de> Deserialize<'de>>(
127 &self,
128 collection: &str,
129 id: &str,
130 ) -> anyhow::Result<Option<T>> {
131 let id_bytes = id.as_bytes();
132 self.client.get(collection, id_bytes).await
133 .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))
134 }
135
136 pub async fn delete(&self, collection: &str, id: &str) -> anyhow::Result<()> {
138 let id_bytes = id.as_bytes();
139 self.client.delete(collection, id_bytes).await
140 .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct SearchResult {
147 pub id: String,
148 pub score: f32,
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[tokio::test]
156 #[ignore] async fn test_vector_store_operations() {
158 let store = VectorStore::new("http://localhost:3000")
159 .await
160 .unwrap();
161
162 store
164 .ensure_collection("test_collection", 384)
165 .await
166 .unwrap();
167
168 let embedding = vec![0.1; 384];
170 store
171 .insert(
172 "test_collection",
173 "test_id",
174 embedding.clone(),
175 Some(serde_json::json!({"type": "test"})),
176 )
177 .await
178 .unwrap();
179
180 let results = store
182 .search("test_collection", embedding, 1)
183 .await
184 .unwrap();
185
186 assert_eq!(results.len(), 1);
187 }
188}