1use crate::types::{AppError, Document, Result, SearchResult};
23use async_trait::async_trait;
24use parking_lot::RwLock;
25use std::collections::HashMap;
26use std::path::PathBuf;
27use std::sync::Arc;
28
29use super::vectorstore::{CollectionInfo, CollectionStats, VectorStore};
30use ares_vector::{Config, DistanceMetric, VectorDb, VectorMetadata};
31
32pub struct AresVectorStore {
44 db: VectorDb,
46 path: Option<PathBuf>,
48 documents: Arc<RwLock<HashMap<String, HashMap<String, Document>>>>,
50}
51
52impl AresVectorStore {
53 pub async fn new(path: Option<String>) -> Result<Self> {
63 let path_buf = path.map(PathBuf::from);
64
65 let config = if let Some(ref p) = path_buf {
67 Config::persistent(p.to_string_lossy().to_string())
68 } else {
69 Config::memory()
70 };
71
72 let db = VectorDb::open(config).await.map_err(|e| {
74 AppError::Configuration(format!("Failed to initialize AresVector: {}", e))
75 })?;
76
77 let store = Self {
79 db,
80 path: path_buf,
81 documents: Arc::new(RwLock::new(HashMap::new())),
82 };
83
84 if let Some(ref path) = store.path {
86 store.load_documents(path).await?;
87 }
88
89 Ok(store)
90 }
91
92 async fn load_documents(&self, path: &PathBuf) -> Result<()> {
94 let docs_path = path.join("documents.json");
95 if docs_path.exists() {
96 let data = tokio::fs::read_to_string(&docs_path).await.map_err(|e| {
97 AppError::Configuration(format!("Failed to read documents file: {}", e))
98 })?;
99
100 let loaded: HashMap<String, HashMap<String, Document>> =
101 serde_json::from_str(&data).map_err(|e| {
102 AppError::Configuration(format!("Failed to parse documents file: {}", e))
103 })?;
104
105 let mut docs = self.documents.write();
106 *docs = loaded;
107 }
108 Ok(())
109 }
110
111 async fn save_documents(&self) -> Result<()> {
113 if let Some(ref path) = self.path {
114 let data = {
116 let docs = self.documents.read();
117 serde_json::to_string_pretty(&*docs).map_err(|e| {
118 AppError::Internal(format!("Failed to serialize documents: {}", e))
119 })?
120 };
121
122 tokio::fs::create_dir_all(path).await.map_err(|e| {
124 AppError::Internal(format!("Failed to create data directory: {}", e))
125 })?;
126
127 let docs_path = path.join("documents.json");
128 tokio::fs::write(&docs_path, data).await.map_err(|e| {
129 AppError::Internal(format!("Failed to write documents file: {}", e))
130 })?;
131 }
132 Ok(())
133 }
134
135}
136
137#[async_trait]
138impl VectorStore for AresVectorStore {
139 fn provider_name(&self) -> &'static str {
140 "ares-vector"
141 }
142
143 async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()> {
144 if self.db.list_collections().contains(&name.to_string()) {
146 return Err(AppError::Configuration(format!(
147 "Collection '{}' already exists",
148 name
149 )));
150 }
151
152 self.db.create_collection(name, dimensions, DistanceMetric::Cosine)
154 .await
155 .map_err(|e| AppError::Internal(format!("Failed to create collection: {}", e)))?;
156
157 {
159 let mut docs = self.documents.write();
160 docs.insert(name.to_string(), HashMap::new());
161 }
162
163 if self.path.is_some() {
165 self.save_documents().await?;
166 }
167
168 Ok(())
169 }
170
171 async fn delete_collection(&self, name: &str) -> Result<()> {
172 self.db.delete_collection(name)
173 .await
174 .map_err(|e| AppError::Internal(format!("Failed to delete collection: {}", e)))?;
175
176 {
178 let mut docs = self.documents.write();
179 docs.remove(name);
180 }
181
182 if self.path.is_some() {
184 self.save_documents().await?;
185 }
186
187 Ok(())
188 }
189
190 async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
191 let collections = self.db.list_collections();
192
193 let mut infos = Vec::with_capacity(collections.len());
194 for name in collections {
195 if let Ok(collection) = self.db.get_collection(&name) {
196 let stats = collection.stats();
197 infos.push(CollectionInfo {
198 name,
199 dimensions: stats.dimensions,
200 document_count: stats.vector_count,
201 });
202 }
203 }
204
205 Ok(infos)
206 }
207
208 async fn collection_exists(&self, name: &str) -> Result<bool> {
209 Ok(self.db.list_collections().contains(&name.to_string()))
210 }
211
212 async fn collection_stats(&self, name: &str) -> Result<CollectionStats> {
213 let collection = self.db.get_collection(name).map_err(|_| {
214 AppError::NotFound(format!("Collection '{}' not found", name))
215 })?;
216
217 let stats = collection.stats();
218
219 Ok(CollectionStats {
220 name: stats.name,
221 document_count: stats.vector_count,
222 dimensions: stats.dimensions,
223 index_size_bytes: Some(stats.memory_bytes as u64),
224 distance_metric: format!("{:?}", stats.metric),
225 })
226 }
227
228 async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize> {
229 if documents.is_empty() {
230 return Ok(0);
231 }
232
233 if !self.db.list_collections().contains(&collection.to_string()) {
235 return Err(AppError::NotFound(format!(
236 "Collection '{}' not found",
237 collection
238 )));
239 }
240
241 let mut upserted = 0;
242
243 for doc in documents {
244 let embedding = doc.embedding.as_ref().ok_or_else(|| {
245 AppError::Internal(format!("Document '{}' missing embedding", doc.id))
246 })?;
247
248 let meta = VectorMetadata::from_pairs([
250 ("title", ares_vector::types::MetadataValue::String(doc.metadata.title.clone())),
251 ("source", ares_vector::types::MetadataValue::String(doc.metadata.source.clone())),
252 ]);
253
254 self.db.insert(collection, &doc.id, embedding, Some(meta))
256 .await
257 .map_err(|e| AppError::Internal(format!("Failed to insert vector: {}", e)))?;
258
259 {
261 let mut docs = self.documents.write();
262 let collection_docs = docs.entry(collection.to_string()).or_default();
263 collection_docs.insert(doc.id.clone(), doc.clone());
264 }
265
266 upserted += 1;
267 }
268
269 if self.path.is_some() {
271 self.save_documents().await?;
272 }
273
274 Ok(upserted)
275 }
276
277 async fn search(
278 &self,
279 collection: &str,
280 embedding: &[f32],
281 limit: usize,
282 threshold: f32,
283 ) -> Result<Vec<SearchResult>> {
284 let vector_results = self.db
286 .search(collection, embedding, limit * 2) .await
288 .map_err(|e| AppError::Internal(format!("Search failed: {}", e)))?;
289
290 let docs = self.documents.read();
292 let collection_docs = docs.get(collection).ok_or_else(|| {
293 AppError::NotFound(format!("Collection '{}' not found", collection))
294 })?;
295
296 let mut results = Vec::with_capacity(limit);
297 for result in vector_results {
298 let similarity = result.score;
300
301 if similarity >= threshold {
302 if let Some(doc) = collection_docs.get(&result.id) {
303 results.push(SearchResult {
304 document: doc.clone(),
305 score: similarity,
306 });
307
308 if results.len() >= limit {
309 break;
310 }
311 }
312 }
313 }
314
315 results.sort_by(|a, b| {
317 b.score
318 .partial_cmp(&a.score)
319 .unwrap_or(std::cmp::Ordering::Equal)
320 });
321
322 Ok(results)
323 }
324
325 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
326 if ids.is_empty() {
327 return Ok(0);
328 }
329
330 let mut deleted = 0;
331
332 for id in ids {
333 match self.db.delete(collection, id).await {
334 Ok(true) => {
335 let mut docs = self.documents.write();
337 if let Some(collection_docs) = docs.get_mut(collection) {
338 if collection_docs.remove(id).is_some() {
339 deleted += 1;
340 }
341 }
342 }
343 _ => {}
344 }
345 }
346
347 if self.path.is_some() {
349 self.save_documents().await?;
350 }
351
352 Ok(deleted)
353 }
354
355 async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>> {
356 let docs = self.documents.read();
357
358 let collection_docs = docs.get(collection).ok_or_else(|| {
359 AppError::NotFound(format!("Collection '{}' not found", collection))
360 })?;
361
362 Ok(collection_docs.get(id).cloned())
363 }
364}
365
366impl Default for AresVectorStore {
367 fn default() -> Self {
368 let config = Config::memory();
371 let db = tokio::task::block_in_place(|| {
372 tokio::runtime::Handle::current().block_on(async {
373 VectorDb::open(config).await.expect("Failed to create in-memory VectorDb")
374 })
375 });
376
377 Self {
378 db,
379 path: None,
380 documents: Arc::new(RwLock::new(HashMap::new())),
381 }
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use crate::types::DocumentMetadata;
389 use chrono::Utc;
390
391 #[tokio::test]
392 async fn test_create_and_search() {
393 let store = AresVectorStore::new(None).await.unwrap();
394
395 store.create_collection("test", 3).await.unwrap();
397
398 let docs = vec![
400 Document {
401 id: "doc1".to_string(),
402 content: "Hello world".to_string(),
403 metadata: DocumentMetadata {
404 title: "Test 1".to_string(),
405 source: "test".to_string(),
406 created_at: Utc::now(),
407 tags: vec![],
408 },
409 embedding: Some(vec![1.0, 0.0, 0.0]),
410 },
411 Document {
412 id: "doc2".to_string(),
413 content: "Goodbye world".to_string(),
414 metadata: DocumentMetadata {
415 title: "Test 2".to_string(),
416 source: "test".to_string(),
417 created_at: Utc::now(),
418 tags: vec![],
419 },
420 embedding: Some(vec![0.0, 1.0, 0.0]),
421 },
422 ];
423
424 let count = store.upsert("test", &docs).await.unwrap();
426 assert_eq!(count, 2);
427
428 let query = vec![1.0, 0.1, 0.0]; let results = store.search("test", &query, 10, 0.0).await.unwrap();
431
432 assert!(!results.is_empty());
433 assert_eq!(results[0].document.id, "doc1");
434 }
435
436 #[tokio::test]
437 async fn test_collection_operations() {
438 let store = AresVectorStore::new(None).await.unwrap();
439
440 store.create_collection("col1", 128).await.unwrap();
442 store.create_collection("col2", 256).await.unwrap();
443
444 let collections = store.list_collections().await.unwrap();
446 assert_eq!(collections.len(), 2);
447
448 assert!(store.collection_exists("col1").await.unwrap());
450 assert!(!store.collection_exists("col3").await.unwrap());
451
452 store.delete_collection("col1").await.unwrap();
454 assert!(!store.collection_exists("col1").await.unwrap());
455 }
456}