use crate::types::{AppError, Document, Result, SearchResult};
use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use super::vectorstore::{CollectionInfo, CollectionStats, VectorStore};
use ares_vector::{Config, DistanceMetric, VectorDb, VectorMetadata};
pub struct AresVectorStore {
db: VectorDb,
path: Option<PathBuf>,
documents: Arc<RwLock<HashMap<String, HashMap<String, Document>>>>,
}
impl AresVectorStore {
pub async fn new(path: Option<String>) -> Result<Self> {
let path_buf = path.map(PathBuf::from);
let config = if let Some(ref p) = path_buf {
Config::persistent(p.to_string_lossy().to_string())
} else {
Config::memory()
};
let db = VectorDb::open(config).await.map_err(|e| {
AppError::Configuration(format!("Failed to initialize AresVector: {}", e))
})?;
let store = Self {
db,
path: path_buf,
documents: Arc::new(RwLock::new(HashMap::new())),
};
if let Some(ref path) = store.path {
store.load_documents(path).await?;
}
Ok(store)
}
async fn load_documents(&self, path: &Path) -> Result<()> {
let docs_path = path.join("documents.json");
if docs_path.exists() {
let data = tokio::fs::read_to_string(&docs_path).await.map_err(|e| {
AppError::Configuration(format!("Failed to read documents file: {}", e))
})?;
let loaded: HashMap<String, HashMap<String, Document>> = serde_json::from_str(&data)
.map_err(|e| {
AppError::Configuration(format!("Failed to parse documents file: {}", e))
})?;
let mut docs = self.documents.write();
*docs = loaded;
}
Ok(())
}
async fn save_documents(&self) -> Result<()> {
if let Some(ref path) = self.path {
let data = {
let docs = self.documents.read();
serde_json::to_string_pretty(&*docs).map_err(|e| {
AppError::Internal(format!("Failed to serialize documents: {}", e))
})?
};
tokio::fs::create_dir_all(path).await.map_err(|e| {
AppError::Internal(format!("Failed to create data directory: {}", e))
})?;
let docs_path = path.join("documents.json");
tokio::fs::write(&docs_path, data).await.map_err(|e| {
AppError::Internal(format!("Failed to write documents file: {}", e))
})?;
}
Ok(())
}
}
#[async_trait]
impl VectorStore for AresVectorStore {
fn provider_name(&self) -> &'static str {
"ares-vector"
}
async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()> {
if self.db.list_collections().contains(&name.to_string()) {
return Err(AppError::Configuration(format!(
"Collection '{}' already exists",
name
)));
}
self.db
.create_collection(name, dimensions, DistanceMetric::Cosine)
.await
.map_err(|e| AppError::Internal(format!("Failed to create collection: {}", e)))?;
{
let mut docs = self.documents.write();
docs.insert(name.to_string(), HashMap::new());
}
if self.path.is_some() {
self.save_documents().await?;
}
Ok(())
}
async fn delete_collection(&self, name: &str) -> Result<()> {
self.db
.delete_collection(name)
.await
.map_err(|e| AppError::Internal(format!("Failed to delete collection: {}", e)))?;
{
let mut docs = self.documents.write();
docs.remove(name);
}
if self.path.is_some() {
self.save_documents().await?;
}
Ok(())
}
async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
let collections = self.db.list_collections();
let mut infos = Vec::with_capacity(collections.len());
for name in collections {
if let Ok(collection) = self.db.get_collection(&name) {
let stats = collection.stats();
infos.push(CollectionInfo {
name,
dimensions: stats.dimensions,
document_count: stats.vector_count,
});
}
}
Ok(infos)
}
async fn collection_exists(&self, name: &str) -> Result<bool> {
Ok(self.db.list_collections().contains(&name.to_string()))
}
async fn collection_stats(&self, name: &str) -> Result<CollectionStats> {
let collection = self
.db
.get_collection(name)
.map_err(|_| AppError::NotFound(format!("Collection '{}' not found", name)))?;
let stats = collection.stats();
Ok(CollectionStats {
name: stats.name,
document_count: stats.vector_count,
dimensions: stats.dimensions,
index_size_bytes: Some(stats.memory_bytes as u64),
distance_metric: format!("{:?}", stats.metric),
})
}
async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize> {
if documents.is_empty() {
return Ok(0);
}
if !self.db.list_collections().contains(&collection.to_string()) {
return Err(AppError::NotFound(format!(
"Collection '{}' not found",
collection
)));
}
let mut upserted = 0;
for doc in documents {
let embedding = doc.embedding.as_ref().ok_or_else(|| {
AppError::Internal(format!("Document '{}' missing embedding", doc.id))
})?;
let meta = VectorMetadata::from_pairs([
(
"title",
ares_vector::types::MetadataValue::String(doc.metadata.title.clone()),
),
(
"source",
ares_vector::types::MetadataValue::String(doc.metadata.source.clone()),
),
]);
self.db
.insert(collection, &doc.id, embedding, Some(meta))
.await
.map_err(|e| AppError::Internal(format!("Failed to insert vector: {}", e)))?;
{
let mut docs = self.documents.write();
let collection_docs = docs.entry(collection.to_string()).or_default();
collection_docs.insert(doc.id.clone(), doc.clone());
}
upserted += 1;
}
if self.path.is_some() {
self.save_documents().await?;
}
Ok(upserted)
}
async fn search(
&self,
collection: &str,
embedding: &[f32],
limit: usize,
threshold: f32,
) -> Result<Vec<SearchResult>> {
let vector_results = self
.db
.search(collection, embedding, limit * 2) .await
.map_err(|e| AppError::Internal(format!("Search failed: {}", e)))?;
let docs = self.documents.read();
let collection_docs = docs
.get(collection)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
let mut results = Vec::with_capacity(limit);
for result in vector_results {
let similarity = result.score;
if similarity >= threshold {
if let Some(doc) = collection_docs.get(&result.id) {
results.push(SearchResult {
document: doc.clone(),
score: similarity,
});
if results.len() >= limit {
break;
}
}
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
if ids.is_empty() {
return Ok(0);
}
let mut deleted = 0;
for id in ids {
if let Ok(true) = self.db.delete(collection, id).await {
let mut docs = self.documents.write();
if let Some(collection_docs) = docs.get_mut(collection) {
if collection_docs.remove(id).is_some() {
deleted += 1;
}
}
}
}
if self.path.is_some() {
self.save_documents().await?;
}
Ok(deleted)
}
async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>> {
let docs = self.documents.read();
let collection_docs = docs
.get(collection)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
Ok(collection_docs.get(id).cloned())
}
}
impl Default for AresVectorStore {
fn default() -> Self {
let config = Config::memory();
let db = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
VectorDb::open(config)
.await
.expect("Failed to create in-memory VectorDb")
})
});
Self {
db,
path: None,
documents: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::DocumentMetadata;
use chrono::Utc;
#[tokio::test]
async fn test_create_and_search() {
let store = AresVectorStore::new(None).await.unwrap();
store.create_collection("test", 3).await.unwrap();
let docs = vec![
Document {
id: "doc1".to_string(),
content: "Hello world".to_string(),
metadata: DocumentMetadata {
title: "Test 1".to_string(),
source: "test".to_string(),
created_at: Utc::now(),
tags: vec![],
},
embedding: Some(vec![1.0, 0.0, 0.0]),
},
Document {
id: "doc2".to_string(),
content: "Goodbye world".to_string(),
metadata: DocumentMetadata {
title: "Test 2".to_string(),
source: "test".to_string(),
created_at: Utc::now(),
tags: vec![],
},
embedding: Some(vec![0.0, 1.0, 0.0]),
},
];
let count = store.upsert("test", &docs).await.unwrap();
assert_eq!(count, 2);
let query = vec![1.0, 0.1, 0.0]; let results = store.search("test", &query, 10, 0.0).await.unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].document.id, "doc1");
}
#[tokio::test]
async fn test_collection_operations() {
let store = AresVectorStore::new(None).await.unwrap();
store.create_collection("col1", 128).await.unwrap();
store.create_collection("col2", 256).await.unwrap();
let collections = store.list_collections().await.unwrap();
assert_eq!(collections.len(), 2);
assert!(store.collection_exists("col1").await.unwrap());
assert!(!store.collection_exists("col3").await.unwrap());
store.delete_collection("col1").await.unwrap();
assert!(!store.collection_exists("col1").await.unwrap());
}
}