1use crate::types::{AppError, Document, Result, SearchResult};
23use async_trait::async_trait;
24use parking_lot::RwLock;
25use std::collections::HashMap;
26use std::path::{Path, PathBuf};
27use std::sync::Arc;
28
29use super::vectorstore::{CollectionInfo, CollectionStats, VectorStore};
30use ares_vector::{Config, DistanceMetric, VectorDb, VectorMetadata};
31
32pub struct AresVectorStore {
44 db: VectorDb,
46 path: Option<PathBuf>,
48 documents: Arc<RwLock<HashMap<String, HashMap<String, Document>>>>,
50}
51
52impl AresVectorStore {
53 pub async fn new(path: Option<String>) -> Result<Self> {
63 let path_buf = path.map(PathBuf::from);
64
65 let config = if let Some(ref p) = path_buf {
67 Config::persistent(p.to_string_lossy().to_string())
68 } else {
69 Config::memory()
70 };
71
72 let db = VectorDb::open(config).await.map_err(|e| {
74 AppError::Configuration(format!("Failed to initialize AresVector: {}", e))
75 })?;
76
77 let store = Self {
79 db,
80 path: path_buf,
81 documents: Arc::new(RwLock::new(HashMap::new())),
82 };
83
84 if let Some(ref path) = store.path {
86 store.load_documents(path).await?;
87 }
88
89 Ok(store)
90 }
91
92 async fn load_documents(&self, path: &Path) -> Result<()> {
94 let docs_path = path.join("documents.json");
95 if docs_path.exists() {
96 let data = tokio::fs::read_to_string(&docs_path).await.map_err(|e| {
97 AppError::Configuration(format!("Failed to read documents file: {}", e))
98 })?;
99
100 let loaded: HashMap<String, HashMap<String, Document>> = serde_json::from_str(&data)
101 .map_err(|e| {
102 AppError::Configuration(format!("Failed to parse documents file: {}", e))
103 })?;
104
105 let mut docs = self.documents.write();
106 *docs = loaded;
107 }
108 Ok(())
109 }
110
111 async fn save_documents(&self) -> Result<()> {
113 if let Some(ref path) = self.path {
114 let data = {
116 let docs = self.documents.read();
117 serde_json::to_string_pretty(&*docs).map_err(|e| {
118 AppError::Internal(format!("Failed to serialize documents: {}", e))
119 })?
120 };
121
122 tokio::fs::create_dir_all(path).await.map_err(|e| {
124 AppError::Internal(format!("Failed to create data directory: {}", e))
125 })?;
126
127 let docs_path = path.join("documents.json");
128 tokio::fs::write(&docs_path, data).await.map_err(|e| {
129 AppError::Internal(format!("Failed to write documents file: {}", e))
130 })?;
131 }
132 Ok(())
133 }
134}
135
136#[async_trait]
137impl VectorStore for AresVectorStore {
138 fn provider_name(&self) -> &'static str {
139 "ares-vector"
140 }
141
142 async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()> {
143 if self.db.list_collections().contains(&name.to_string()) {
145 return Err(AppError::Configuration(format!(
146 "Collection '{}' already exists",
147 name
148 )));
149 }
150
151 self.db
153 .create_collection(name, dimensions, DistanceMetric::Cosine)
154 .await
155 .map_err(|e| AppError::Internal(format!("Failed to create collection: {}", e)))?;
156
157 {
159 let mut docs = self.documents.write();
160 docs.insert(name.to_string(), HashMap::new());
161 }
162
163 if self.path.is_some() {
165 self.save_documents().await?;
166 }
167
168 Ok(())
169 }
170
171 async fn delete_collection(&self, name: &str) -> Result<()> {
172 self.db
173 .delete_collection(name)
174 .await
175 .map_err(|e| AppError::Internal(format!("Failed to delete collection: {}", e)))?;
176
177 {
179 let mut docs = self.documents.write();
180 docs.remove(name);
181 }
182
183 if self.path.is_some() {
185 self.save_documents().await?;
186 }
187
188 Ok(())
189 }
190
191 async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
192 let collections = self.db.list_collections();
193
194 let mut infos = Vec::with_capacity(collections.len());
195 for name in collections {
196 if let Ok(collection) = self.db.get_collection(&name) {
197 let stats = collection.stats();
198 infos.push(CollectionInfo {
199 name,
200 dimensions: stats.dimensions,
201 document_count: stats.vector_count,
202 });
203 }
204 }
205
206 Ok(infos)
207 }
208
209 async fn collection_exists(&self, name: &str) -> Result<bool> {
210 Ok(self.db.list_collections().contains(&name.to_string()))
211 }
212
213 async fn collection_stats(&self, name: &str) -> Result<CollectionStats> {
214 let collection = self
215 .db
216 .get_collection(name)
217 .map_err(|_| AppError::NotFound(format!("Collection '{}' not found", name)))?;
218
219 let stats = collection.stats();
220
221 Ok(CollectionStats {
222 name: stats.name,
223 document_count: stats.vector_count,
224 dimensions: stats.dimensions,
225 index_size_bytes: Some(stats.memory_bytes as u64),
226 distance_metric: format!("{:?}", stats.metric),
227 })
228 }
229
230 async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize> {
231 if documents.is_empty() {
232 return Ok(0);
233 }
234
235 if !self.db.list_collections().contains(&collection.to_string()) {
237 return Err(AppError::NotFound(format!(
238 "Collection '{}' not found",
239 collection
240 )));
241 }
242
243 let mut upserted = 0;
244
245 for doc in documents {
246 let embedding = doc.embedding.as_ref().ok_or_else(|| {
247 AppError::Internal(format!("Document '{}' missing embedding", doc.id))
248 })?;
249
250 let meta = VectorMetadata::from_pairs([
252 (
253 "title",
254 ares_vector::types::MetadataValue::String(doc.metadata.title.clone()),
255 ),
256 (
257 "source",
258 ares_vector::types::MetadataValue::String(doc.metadata.source.clone()),
259 ),
260 ]);
261
262 self.db
264 .insert(collection, &doc.id, embedding, Some(meta))
265 .await
266 .map_err(|e| AppError::Internal(format!("Failed to insert vector: {}", e)))?;
267
268 {
270 let mut docs = self.documents.write();
271 let collection_docs = docs.entry(collection.to_string()).or_default();
272 collection_docs.insert(doc.id.clone(), doc.clone());
273 }
274
275 upserted += 1;
276 }
277
278 if self.path.is_some() {
280 self.save_documents().await?;
281 }
282
283 Ok(upserted)
284 }
285
286 async fn search(
287 &self,
288 collection: &str,
289 embedding: &[f32],
290 limit: usize,
291 threshold: f32,
292 ) -> Result<Vec<SearchResult>> {
293 let vector_results = self
295 .db
296 .search(collection, embedding, limit * 2) .await
298 .map_err(|e| AppError::Internal(format!("Search failed: {}", e)))?;
299
300 let docs = self.documents.read();
302 let collection_docs = docs
303 .get(collection)
304 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
305
306 let mut results = Vec::with_capacity(limit);
307 for result in vector_results {
308 let similarity = result.score;
310
311 if similarity >= threshold {
312 if let Some(doc) = collection_docs.get(&result.id) {
313 results.push(SearchResult {
314 document: doc.clone(),
315 score: similarity,
316 });
317
318 if results.len() >= limit {
319 break;
320 }
321 }
322 }
323 }
324
325 results.sort_by(|a, b| {
327 b.score
328 .partial_cmp(&a.score)
329 .unwrap_or(std::cmp::Ordering::Equal)
330 });
331
332 Ok(results)
333 }
334
335 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
336 if ids.is_empty() {
337 return Ok(0);
338 }
339
340 let mut deleted = 0;
341
342 for id in ids {
343 if let Ok(true) = self.db.delete(collection, id).await {
344 let mut docs = self.documents.write();
346 if let Some(collection_docs) = docs.get_mut(collection) {
347 if collection_docs.remove(id).is_some() {
348 deleted += 1;
349 }
350 }
351 }
352 }
353
354 if self.path.is_some() {
356 self.save_documents().await?;
357 }
358
359 Ok(deleted)
360 }
361
362 async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>> {
363 let docs = self.documents.read();
364
365 let collection_docs = docs
366 .get(collection)
367 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
368
369 Ok(collection_docs.get(id).cloned())
370 }
371}
372
373impl Default for AresVectorStore {
374 fn default() -> Self {
375 let config = Config::memory();
378 let db = tokio::task::block_in_place(|| {
379 tokio::runtime::Handle::current().block_on(async {
380 VectorDb::open(config)
381 .await
382 .expect("Failed to create in-memory VectorDb")
383 })
384 });
385
386 Self {
387 db,
388 path: None,
389 documents: Arc::new(RwLock::new(HashMap::new())),
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::types::DocumentMetadata;
398 use chrono::Utc;
399
400 #[tokio::test]
401 async fn test_create_and_search() {
402 let store = AresVectorStore::new(None).await.unwrap();
403
404 store.create_collection("test", 3).await.unwrap();
406
407 let docs = vec![
409 Document {
410 id: "doc1".to_string(),
411 content: "Hello world".to_string(),
412 metadata: DocumentMetadata {
413 title: "Test 1".to_string(),
414 source: "test".to_string(),
415 created_at: Utc::now(),
416 tags: vec![],
417 },
418 embedding: Some(vec![1.0, 0.0, 0.0]),
419 },
420 Document {
421 id: "doc2".to_string(),
422 content: "Goodbye world".to_string(),
423 metadata: DocumentMetadata {
424 title: "Test 2".to_string(),
425 source: "test".to_string(),
426 created_at: Utc::now(),
427 tags: vec![],
428 },
429 embedding: Some(vec![0.0, 1.0, 0.0]),
430 },
431 ];
432
433 let count = store.upsert("test", &docs).await.unwrap();
435 assert_eq!(count, 2);
436
437 let query = vec![1.0, 0.1, 0.0]; let results = store.search("test", &query, 10, 0.0).await.unwrap();
440
441 assert!(!results.is_empty());
442 assert_eq!(results[0].document.id, "doc1");
443 }
444
445 #[tokio::test]
446 async fn test_collection_operations() {
447 let store = AresVectorStore::new(None).await.unwrap();
448
449 store.create_collection("col1", 128).await.unwrap();
451 store.create_collection("col2", 256).await.unwrap();
452
453 let collections = store.list_collections().await.unwrap();
455 assert_eq!(collections.len(), 2);
456
457 assert!(store.collection_exists("col1").await.unwrap());
459 assert!(!store.collection_exists("col3").await.unwrap());
460
461 store.delete_collection("col1").await.unwrap();
463 assert!(!store.collection_exists("col1").await.unwrap());
464 }
465}