use crate::types::{AppError, Document, Result, SearchResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "provider", rename_all = "lowercase")]
pub enum VectorStoreProvider {
#[cfg(feature = "ares-vector")]
AresVector {
path: Option<String>,
},
#[cfg(feature = "lancedb")]
LanceDB {
path: String,
},
#[cfg(feature = "qdrant")]
Qdrant {
url: String,
api_key: Option<String>,
},
#[cfg(feature = "pgvector")]
PgVector {
connection_string: String,
},
#[cfg(feature = "chromadb")]
ChromaDB {
url: String,
},
#[cfg(feature = "pinecone")]
Pinecone {
api_key: String,
environment: String,
index_name: String,
},
InMemory,
}
impl VectorStoreProvider {
pub async fn create_store(&self) -> Result<Box<dyn VectorStore>> {
match self {
#[cfg(feature = "ares-vector")]
VectorStoreProvider::AresVector { path } => {
let store = super::ares_vector::AresVectorStore::new(path.clone()).await?;
Ok(Box::new(store))
}
#[cfg(feature = "lancedb")]
VectorStoreProvider::LanceDB { path } => {
let store = super::lancedb::LanceDBStore::new(path).await?;
Ok(Box::new(store))
}
#[cfg(feature = "qdrant")]
VectorStoreProvider::Qdrant { url, api_key } => {
let store =
super::qdrant::QdrantVectorStore::new(url.clone(), api_key.clone()).await?;
Ok(Box::new(store))
}
#[cfg(feature = "pgvector")]
VectorStoreProvider::PgVector { connection_string } => {
let store = super::pgvector::PgVectorStore::new(connection_string).await?;
Ok(Box::new(store))
}
#[cfg(feature = "chromadb")]
VectorStoreProvider::ChromaDB { url } => {
let store = super::chromadb::ChromaDBStore::new(url).await?;
Ok(Box::new(store))
}
#[cfg(feature = "pinecone")]
VectorStoreProvider::Pinecone {
api_key,
environment,
index_name,
} => {
let store =
super::pinecone::PineconeStore::new(api_key, environment, index_name).await?;
Ok(Box::new(store))
}
VectorStoreProvider::InMemory => {
let store = InMemoryVectorStore::new();
Ok(Box::new(store))
}
#[allow(unreachable_patterns)]
_ => Err(AppError::Configuration(
"Vector store provider not enabled. Check feature flags.".into(),
)),
}
}
pub fn from_env() -> Self {
#[cfg(feature = "ares-vector")]
if let Ok(path) = std::env::var("ARES_VECTOR_PATH") {
return VectorStoreProvider::AresVector { path: Some(path) };
}
#[cfg(feature = "lancedb")]
if let Ok(path) = std::env::var("LANCEDB_PATH") {
return VectorStoreProvider::LanceDB { path };
}
#[cfg(feature = "qdrant")]
if let Ok(url) = std::env::var("QDRANT_URL") {
let api_key = std::env::var("QDRANT_API_KEY").ok();
return VectorStoreProvider::Qdrant { url, api_key };
}
#[cfg(feature = "pgvector")]
if let Ok(connection_string) = std::env::var("PGVECTOR_URL") {
return VectorStoreProvider::PgVector { connection_string };
}
#[cfg(feature = "chromadb")]
if let Ok(url) = std::env::var("CHROMADB_URL") {
return VectorStoreProvider::ChromaDB { url };
}
#[cfg(feature = "pinecone")]
if let Ok(api_key) = std::env::var("PINECONE_API_KEY") {
let environment =
std::env::var("PINECONE_ENVIRONMENT").unwrap_or_else(|_| "us-east-1".into());
let index_name =
std::env::var("PINECONE_INDEX").unwrap_or_else(|_| "ares-documents".into());
return VectorStoreProvider::Pinecone {
api_key,
environment,
index_name,
};
}
#[cfg(feature = "ares-vector")]
return VectorStoreProvider::AresVector { path: None };
#[cfg(not(feature = "ares-vector"))]
VectorStoreProvider::InMemory
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CollectionStats {
pub name: String,
pub document_count: usize,
pub dimensions: usize,
pub index_size_bytes: Option<u64>,
pub distance_metric: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CollectionInfo {
pub name: String,
pub document_count: usize,
pub dimensions: usize,
}
#[async_trait]
pub trait VectorStore: Send + Sync {
fn provider_name(&self) -> &'static str;
async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()>;
async fn delete_collection(&self, name: &str) -> Result<()>;
async fn list_collections(&self) -> Result<Vec<CollectionInfo>>;
async fn collection_exists(&self, name: &str) -> Result<bool>;
async fn collection_stats(&self, name: &str) -> Result<CollectionStats>;
async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize>;
async fn search(
&self,
collection: &str,
embedding: &[f32],
limit: usize,
threshold: f32,
) -> Result<Vec<SearchResult>>;
async fn search_with_filters(
&self,
collection: &str,
embedding: &[f32],
limit: usize,
threshold: f32,
_filters: &[(String, String)],
) -> Result<Vec<SearchResult>> {
self.search(collection, embedding, limit, threshold).await
}
async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize>;
async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>>;
async fn count(&self, collection: &str) -> Result<usize> {
let stats = self.collection_stats(collection).await?;
Ok(stats.document_count)
}
}
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
pub struct InMemoryVectorStore {
collections: Arc<RwLock<HashMap<String, InMemoryCollection>>>,
}
struct InMemoryCollection {
dimensions: usize,
documents: HashMap<String, Document>,
}
impl InMemoryVectorStore {
pub fn new() -> Self {
Self {
collections: Arc::new(RwLock::new(HashMap::new())),
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
}
impl Default for InMemoryVectorStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl VectorStore for InMemoryVectorStore {
fn provider_name(&self) -> &'static str {
"in-memory"
}
async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()> {
let mut collections = self.collections.write();
if collections.contains_key(name) {
return Err(AppError::InvalidInput(format!(
"Collection '{}' already exists",
name
)));
}
collections.insert(
name.to_string(),
InMemoryCollection {
dimensions,
documents: HashMap::new(),
},
);
Ok(())
}
async fn delete_collection(&self, name: &str) -> Result<()> {
let mut collections = self.collections.write();
collections
.remove(name)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", name)))?;
Ok(())
}
async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
let collections = self.collections.read();
Ok(collections
.iter()
.map(|(name, col)| CollectionInfo {
name: name.clone(),
document_count: col.documents.len(),
dimensions: col.dimensions,
})
.collect())
}
async fn collection_exists(&self, name: &str) -> Result<bool> {
let collections = self.collections.read();
Ok(collections.contains_key(name))
}
async fn collection_stats(&self, name: &str) -> Result<CollectionStats> {
let collections = self.collections.read();
let col = collections
.get(name)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", name)))?;
Ok(CollectionStats {
name: name.to_string(),
document_count: col.documents.len(),
dimensions: col.dimensions,
index_size_bytes: None,
distance_metric: "cosine".to_string(),
})
}
async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize> {
let mut collections = self.collections.write();
let col = collections
.get_mut(collection)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
let mut count = 0;
for doc in documents {
if doc.embedding.is_none() {
return Err(AppError::InvalidInput(format!(
"Document '{}' is missing embedding",
doc.id
)));
}
col.documents.insert(doc.id.clone(), doc.clone());
count += 1;
}
Ok(count)
}
async fn search(
&self,
collection: &str,
embedding: &[f32],
limit: usize,
threshold: f32,
) -> Result<Vec<SearchResult>> {
let collections = self.collections.read();
let col = collections
.get(collection)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
let mut results: Vec<SearchResult> = col
.documents
.values()
.filter_map(|doc| {
let doc_embedding = doc.embedding.as_ref()?;
let score = Self::cosine_similarity(embedding, doc_embedding);
if score >= threshold {
Some(SearchResult {
document: Document {
id: doc.id.clone(),
content: doc.content.clone(),
metadata: doc.metadata.clone(),
embedding: None, },
score,
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
Ok(results)
}
async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
let mut collections = self.collections.write();
let col = collections
.get_mut(collection)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
let mut count = 0;
for id in ids {
if col.documents.remove(id).is_some() {
count += 1;
}
}
Ok(count)
}
async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>> {
let collections = self.collections.read();
let col = collections
.get(collection)
.ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
Ok(col.documents.get(id).cloned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::DocumentMetadata;
use chrono::Utc;
fn create_test_document(id: &str, content: &str, embedding: Vec<f32>) -> Document {
Document {
id: id.to_string(),
content: content.to_string(),
metadata: DocumentMetadata {
title: format!("Test Doc {}", id),
source: "test".to_string(),
created_at: Utc::now(),
tags: vec!["test".to_string()],
},
embedding: Some(embedding),
}
}
#[tokio::test]
async fn test_inmemory_create_collection() {
let store = InMemoryVectorStore::new();
store.create_collection("test", 384).await.unwrap();
assert!(store.collection_exists("test").await.unwrap());
}
#[tokio::test]
async fn test_inmemory_duplicate_collection_error() {
let store = InMemoryVectorStore::new();
store.create_collection("test", 384).await.unwrap();
let result = store.create_collection("test", 384).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_inmemory_upsert_and_search() {
let store = InMemoryVectorStore::new();
store.create_collection("test", 3).await.unwrap();
let doc1 = create_test_document("doc1", "Hello world", vec![1.0, 0.0, 0.0]);
let doc2 = create_test_document("doc2", "Goodbye world", vec![0.0, 1.0, 0.0]);
let doc3 = create_test_document("doc3", "Hello again", vec![0.9, 0.1, 0.0]);
store.upsert("test", &[doc1, doc2, doc3]).await.unwrap();
let results = store
.search("test", &[1.0, 0.0, 0.0], 10, 0.5)
.await
.unwrap();
assert_eq!(results.len(), 2); assert_eq!(results[0].document.id, "doc1"); assert_eq!(results[1].document.id, "doc3"); }
#[tokio::test]
async fn test_inmemory_delete() {
let store = InMemoryVectorStore::new();
store.create_collection("test", 3).await.unwrap();
let doc = create_test_document("doc1", "Test", vec![1.0, 0.0, 0.0]);
store.upsert("test", &[doc]).await.unwrap();
assert_eq!(store.count("test").await.unwrap(), 1);
let deleted = store.delete("test", &["doc1".to_string()]).await.unwrap();
assert_eq!(deleted, 1);
assert_eq!(store.count("test").await.unwrap(), 0);
}
#[tokio::test]
async fn test_inmemory_get() {
let store = InMemoryVectorStore::new();
store.create_collection("test", 3).await.unwrap();
let doc = create_test_document("doc1", "Test content", vec![1.0, 0.0, 0.0]);
store.upsert("test", &[doc]).await.unwrap();
let retrieved = store.get("test", "doc1").await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "Test content");
let not_found = store.get("test", "nonexistent").await.unwrap();
assert!(not_found.is_none());
}
#[tokio::test]
async fn test_inmemory_list_collections() {
let store = InMemoryVectorStore::new();
store.create_collection("col1", 384).await.unwrap();
store.create_collection("col2", 768).await.unwrap();
let collections = store.list_collections().await.unwrap();
assert_eq!(collections.len(), 2);
}
#[tokio::test]
async fn test_cosine_similarity() {
assert!(
(InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 0.001
);
assert!(InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 0.001);
assert!(
(InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 0.001
);
}
}