nt_memory/agentdb/
vector_store.rs

1//! Vector store with semantic search capabilities
2
3use crate::Result;
4use nt_agentdb_client::{AgentDBClient, BatchDocument, CollectionConfig};
5use serde::{Serialize, Deserialize};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// Vector store for semantic search
10pub struct VectorStore {
11    client: Arc<AgentDBClient>,
12    collections: Arc<RwLock<std::collections::HashSet<String>>>,
13}
14
15impl VectorStore {
16    /// Create new vector store
17    pub async fn new(base_url: &str) -> anyhow::Result<Self> {
18        let client = AgentDBClient::new(base_url.to_string());
19
20        // Verify connection
21        client.health_check().await?;
22
23        Ok(Self {
24            client: Arc::new(client),
25            collections: Arc::new(RwLock::new(std::collections::HashSet::new())),
26        })
27    }
28
29    /// Ensure collection exists
30    pub async fn ensure_collection(
31        &self,
32        name: &str,
33        dimension: usize,
34    ) -> anyhow::Result<()> {
35        let mut collections = self.collections.write().await;
36
37        if collections.contains(name) {
38            return Ok(());
39        }
40
41        // Use the client module's CollectionConfig
42        use nt_agentdb_client::client::CollectionConfig;
43
44        let config = CollectionConfig {
45            name: name.to_string(),
46            dimension,
47            index_type: "hnsw".to_string(),
48            metadata_schema: None,
49        };
50
51        self.client.create_collection(config).await
52            .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))?;
53        collections.insert(name.to_string());
54
55        Ok(())
56    }
57
58    /// Insert vector with metadata
59    pub async fn insert<T: Serialize>(
60        &self,
61        collection: &str,
62        id: &str,
63        embedding: Vec<f32>,
64        metadata: Option<T>,
65    ) -> anyhow::Result<()> {
66        let id_bytes = id.as_bytes();
67
68        self.client
69            .insert(collection, id_bytes, &embedding, metadata.as_ref())
70            .await?;
71
72        Ok(())
73    }
74
75    /// Batch insert vectors
76    pub async fn batch_insert<T: Serialize>(
77        &self,
78        collection: &str,
79        documents: Vec<(String, Vec<f32>, Option<T>)>,
80    ) -> anyhow::Result<usize> {
81        // Use the client module's BatchDocument
82        use nt_agentdb_client::client::BatchDocument;
83
84        let batch: Vec<BatchDocument<T>> = documents
85            .into_iter()
86            .map(|(id, embedding, metadata)| BatchDocument {
87                id: id.into_bytes(),
88                embedding,
89                metadata,
90            })
91            .collect();
92
93        let response = self.client.batch_insert(collection, batch).await
94            .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))?;
95        Ok(response.inserted)
96    }
97
98    /// Search for similar vectors
99    pub async fn search(
100        &self,
101        collection: &str,
102        query_embedding: Vec<f32>,
103        top_k: usize,
104    ) -> Result<Vec<(String, f32)>> {
105        use nt_agentdb_client::VectorQuery;
106
107        let query = VectorQuery::new(
108            collection.to_string(),
109            query_embedding,
110            top_k,
111        );
112
113        let results: Vec<SearchResult> = self
114            .client
115            .vector_search(query)
116            .await
117            .map_err(|e| crate::MemoryError::VectorDB(e.to_string()))?;
118
119        Ok(results
120            .into_iter()
121            .map(|r| (r.id, r.score))
122            .collect())
123    }
124
125    /// Get vector by ID
126    pub async fn get<T: for<'de> Deserialize<'de>>(
127        &self,
128        collection: &str,
129        id: &str,
130    ) -> anyhow::Result<Option<T>> {
131        let id_bytes = id.as_bytes();
132        self.client.get(collection, id_bytes).await
133            .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))
134    }
135
136    /// Delete vector
137    pub async fn delete(&self, collection: &str, id: &str) -> anyhow::Result<()> {
138        let id_bytes = id.as_bytes();
139        self.client.delete(collection, id_bytes).await
140            .map_err(|e| anyhow::anyhow!("AgentDB error: {}", e))
141    }
142}
143
144/// Search result with metadata
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct SearchResult {
147    pub id: String,
148    pub score: f32,
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[tokio::test]
156    #[ignore] // Requires AgentDB server
157    async fn test_vector_store_operations() {
158        let store = VectorStore::new("http://localhost:3000")
159            .await
160            .unwrap();
161
162        // Create collection
163        store
164            .ensure_collection("test_collection", 384)
165            .await
166            .unwrap();
167
168        // Insert vector
169        let embedding = vec![0.1; 384];
170        store
171            .insert(
172                "test_collection",
173                "test_id",
174                embedding.clone(),
175                Some(serde_json::json!({"type": "test"})),
176            )
177            .await
178            .unwrap();
179
180        // Search
181        let results = store
182            .search("test_collection", embedding, 1)
183            .await
184            .unwrap();
185
186        assert_eq!(results.len(), 1);
187    }
188}