1use crate::types::{AppError, Document, Result, SearchResult};
44use async_trait::async_trait;
45use serde::{Deserialize, Serialize};
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(tag = "provider", rename_all = "lowercase")]
57pub enum VectorStoreProvider {
58 #[cfg(feature = "ares-vector")]
63 AresVector {
64 path: Option<String>,
66 },
67
68 #[cfg(feature = "lancedb")]
73 LanceDB {
74 path: String,
76 },
77
78 #[cfg(feature = "qdrant")]
82 Qdrant {
83 url: String,
85 api_key: Option<String>,
87 },
88
89 #[cfg(feature = "pgvector")]
93 PgVector {
94 connection_string: String,
96 },
97
98 #[cfg(feature = "chromadb")]
102 ChromaDB {
103 url: String,
105 },
106
107 #[cfg(feature = "pinecone")]
111 Pinecone {
112 api_key: String,
114 environment: String,
116 index_name: String,
118 },
119
120 InMemory,
124}
125
126impl VectorStoreProvider {
127 pub async fn create_store(&self) -> Result<Box<dyn VectorStore>> {
134 match self {
135 #[cfg(feature = "ares-vector")]
136 VectorStoreProvider::AresVector { path } => {
137 let store = super::ares_vector::AresVectorStore::new(path.clone()).await?;
138 Ok(Box::new(store))
139 }
140
141 #[cfg(feature = "lancedb")]
142 VectorStoreProvider::LanceDB { path } => {
143 let store = super::lancedb::LanceDBStore::new(path).await?;
144 Ok(Box::new(store))
145 }
146
147 #[cfg(feature = "qdrant")]
148 VectorStoreProvider::Qdrant { url, api_key } => {
149 let store =
150 super::qdrant::QdrantVectorStore::new(url.clone(), api_key.clone()).await?;
151 Ok(Box::new(store))
152 }
153
154 #[cfg(feature = "pgvector")]
155 VectorStoreProvider::PgVector { connection_string } => {
156 let store = super::pgvector::PgVectorStore::new(connection_string).await?;
157 Ok(Box::new(store))
158 }
159
160 #[cfg(feature = "chromadb")]
161 VectorStoreProvider::ChromaDB { url } => {
162 let store = super::chromadb::ChromaDBStore::new(url).await?;
163 Ok(Box::new(store))
164 }
165
166 #[cfg(feature = "pinecone")]
167 VectorStoreProvider::Pinecone {
168 api_key,
169 environment,
170 index_name,
171 } => {
172 let store =
173 super::pinecone::PineconeStore::new(api_key, environment, index_name).await?;
174 Ok(Box::new(store))
175 }
176
177 VectorStoreProvider::InMemory => {
178 let store = InMemoryVectorStore::new();
179 Ok(Box::new(store))
180 }
181
182 #[allow(unreachable_patterns)]
183 _ => Err(AppError::Configuration(
184 "Vector store provider not enabled. Check feature flags.".into(),
185 )),
186 }
187 }
188
189 pub fn from_env() -> Self {
200 #[cfg(feature = "ares-vector")]
201 if let Ok(path) = std::env::var("ARES_VECTOR_PATH") {
202 return VectorStoreProvider::AresVector {
203 path: Some(path),
204 };
205 }
206
207 #[cfg(feature = "lancedb")]
208 if let Ok(path) = std::env::var("LANCEDB_PATH") {
209 return VectorStoreProvider::LanceDB { path };
210 }
211
212 #[cfg(feature = "qdrant")]
213 if let Ok(url) = std::env::var("QDRANT_URL") {
214 let api_key = std::env::var("QDRANT_API_KEY").ok();
215 return VectorStoreProvider::Qdrant { url, api_key };
216 }
217
218 #[cfg(feature = "pgvector")]
219 if let Ok(connection_string) = std::env::var("PGVECTOR_URL") {
220 return VectorStoreProvider::PgVector { connection_string };
221 }
222
223 #[cfg(feature = "chromadb")]
224 if let Ok(url) = std::env::var("CHROMADB_URL") {
225 return VectorStoreProvider::ChromaDB { url };
226 }
227
228 #[cfg(feature = "pinecone")]
229 if let Ok(api_key) = std::env::var("PINECONE_API_KEY") {
230 let environment =
231 std::env::var("PINECONE_ENVIRONMENT").unwrap_or_else(|_| "us-east-1".into());
232 let index_name =
233 std::env::var("PINECONE_INDEX").unwrap_or_else(|_| "ares-documents".into());
234 return VectorStoreProvider::Pinecone {
235 api_key,
236 environment,
237 index_name,
238 };
239 }
240
241 #[cfg(feature = "ares-vector")]
243 return VectorStoreProvider::AresVector { path: None };
244
245 #[cfg(not(feature = "ares-vector"))]
246 VectorStoreProvider::InMemory
247 }
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct CollectionStats {
257 pub name: String,
259 pub document_count: usize,
261 pub dimensions: usize,
263 pub index_size_bytes: Option<u64>,
265 pub distance_metric: String,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct CollectionInfo {
272 pub name: String,
274 pub document_count: usize,
276 pub dimensions: usize,
278}
279
280#[async_trait]
298pub trait VectorStore: Send + Sync {
299 fn provider_name(&self) -> &'static str;
301
302 async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()>;
313
314 async fn delete_collection(&self, name: &str) -> Result<()>;
324
325 async fn list_collections(&self) -> Result<Vec<CollectionInfo>>;
327
328 async fn collection_exists(&self, name: &str) -> Result<bool>;
330
331 async fn collection_stats(&self, name: &str) -> Result<CollectionStats>;
333
334 async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize>;
349
350 async fn search(
363 &self,
364 collection: &str,
365 embedding: &[f32],
366 limit: usize,
367 threshold: f32,
368 ) -> Result<Vec<SearchResult>>;
369
370 async fn search_with_filters(
384 &self,
385 collection: &str,
386 embedding: &[f32],
387 limit: usize,
388 threshold: f32,
389 _filters: &[(String, String)],
390 ) -> Result<Vec<SearchResult>> {
391 self.search(collection, embedding, limit, threshold).await
394 }
395
396 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize>;
407
408 async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>>;
419
420 async fn count(&self, collection: &str) -> Result<usize> {
422 let stats = self.collection_stats(collection).await?;
423 Ok(stats.document_count)
424 }
425}
426
427use parking_lot::RwLock;
432use std::collections::HashMap;
433use std::sync::Arc;
434
435pub struct InMemoryVectorStore {
440 collections: Arc<RwLock<HashMap<String, InMemoryCollection>>>,
441}
442
443struct InMemoryCollection {
444 dimensions: usize,
445 documents: HashMap<String, Document>,
446}
447
448impl InMemoryVectorStore {
449 pub fn new() -> Self {
451 Self {
452 collections: Arc::new(RwLock::new(HashMap::new())),
453 }
454 }
455
456 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
458 if a.len() != b.len() {
459 return 0.0;
460 }
461
462 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
463 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
464 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
465
466 if norm_a == 0.0 || norm_b == 0.0 {
467 return 0.0;
468 }
469
470 dot_product / (norm_a * norm_b)
471 }
472}
473
474impl Default for InMemoryVectorStore {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480#[async_trait]
481impl VectorStore for InMemoryVectorStore {
482 fn provider_name(&self) -> &'static str {
483 "in-memory"
484 }
485
486 async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()> {
487 let mut collections = self.collections.write();
488 if collections.contains_key(name) {
489 return Err(AppError::InvalidInput(format!(
490 "Collection '{}' already exists",
491 name
492 )));
493 }
494 collections.insert(
495 name.to_string(),
496 InMemoryCollection {
497 dimensions,
498 documents: HashMap::new(),
499 },
500 );
501 Ok(())
502 }
503
504 async fn delete_collection(&self, name: &str) -> Result<()> {
505 let mut collections = self.collections.write();
506 collections
507 .remove(name)
508 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", name)))?;
509 Ok(())
510 }
511
512 async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
513 let collections = self.collections.read();
514 Ok(collections
515 .iter()
516 .map(|(name, col)| CollectionInfo {
517 name: name.clone(),
518 document_count: col.documents.len(),
519 dimensions: col.dimensions,
520 })
521 .collect())
522 }
523
524 async fn collection_exists(&self, name: &str) -> Result<bool> {
525 let collections = self.collections.read();
526 Ok(collections.contains_key(name))
527 }
528
529 async fn collection_stats(&self, name: &str) -> Result<CollectionStats> {
530 let collections = self.collections.read();
531 let col = collections
532 .get(name)
533 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", name)))?;
534
535 Ok(CollectionStats {
536 name: name.to_string(),
537 document_count: col.documents.len(),
538 dimensions: col.dimensions,
539 index_size_bytes: None,
540 distance_metric: "cosine".to_string(),
541 })
542 }
543
544 async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize> {
545 let mut collections = self.collections.write();
546 let col = collections
547 .get_mut(collection)
548 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
549
550 let mut count = 0;
551 for doc in documents {
552 if doc.embedding.is_none() {
553 return Err(AppError::InvalidInput(format!(
554 "Document '{}' is missing embedding",
555 doc.id
556 )));
557 }
558 col.documents.insert(doc.id.clone(), doc.clone());
559 count += 1;
560 }
561
562 Ok(count)
563 }
564
565 async fn search(
566 &self,
567 collection: &str,
568 embedding: &[f32],
569 limit: usize,
570 threshold: f32,
571 ) -> Result<Vec<SearchResult>> {
572 let collections = self.collections.read();
573 let col = collections
574 .get(collection)
575 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
576
577 let mut results: Vec<SearchResult> = col
578 .documents
579 .values()
580 .filter_map(|doc| {
581 let doc_embedding = doc.embedding.as_ref()?;
582 let score = Self::cosine_similarity(embedding, doc_embedding);
583 if score >= threshold {
584 Some(SearchResult {
585 document: Document {
586 id: doc.id.clone(),
587 content: doc.content.clone(),
588 metadata: doc.metadata.clone(),
589 embedding: None, },
591 score,
592 })
593 } else {
594 None
595 }
596 })
597 .collect();
598
599 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
601
602 results.truncate(limit);
604
605 Ok(results)
606 }
607
608 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
609 let mut collections = self.collections.write();
610 let col = collections
611 .get_mut(collection)
612 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
613
614 let mut count = 0;
615 for id in ids {
616 if col.documents.remove(id).is_some() {
617 count += 1;
618 }
619 }
620
621 Ok(count)
622 }
623
624 async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>> {
625 let collections = self.collections.read();
626 let col = collections
627 .get(collection)
628 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
629
630 Ok(col.documents.get(id).cloned())
631 }
632}
633
634#[cfg(test)]
639mod tests {
640 use super::*;
641 use crate::types::DocumentMetadata;
642 use chrono::Utc;
643
644 fn create_test_document(id: &str, content: &str, embedding: Vec<f32>) -> Document {
645 Document {
646 id: id.to_string(),
647 content: content.to_string(),
648 metadata: DocumentMetadata {
649 title: format!("Test Doc {}", id),
650 source: "test".to_string(),
651 created_at: Utc::now(),
652 tags: vec!["test".to_string()],
653 },
654 embedding: Some(embedding),
655 }
656 }
657
658 #[tokio::test]
659 async fn test_inmemory_create_collection() {
660 let store = InMemoryVectorStore::new();
661
662 store.create_collection("test", 384).await.unwrap();
663
664 assert!(store.collection_exists("test").await.unwrap());
665 }
666
667 #[tokio::test]
668 async fn test_inmemory_duplicate_collection_error() {
669 let store = InMemoryVectorStore::new();
670
671 store.create_collection("test", 384).await.unwrap();
672 let result = store.create_collection("test", 384).await;
673
674 assert!(result.is_err());
675 }
676
677 #[tokio::test]
678 async fn test_inmemory_upsert_and_search() {
679 let store = InMemoryVectorStore::new();
680 store.create_collection("test", 3).await.unwrap();
681
682 let doc1 = create_test_document("doc1", "Hello world", vec![1.0, 0.0, 0.0]);
683 let doc2 = create_test_document("doc2", "Goodbye world", vec![0.0, 1.0, 0.0]);
684 let doc3 = create_test_document("doc3", "Hello again", vec![0.9, 0.1, 0.0]);
685
686 store.upsert("test", &[doc1, doc2, doc3]).await.unwrap();
687
688 let results = store
690 .search("test", &[1.0, 0.0, 0.0], 10, 0.5)
691 .await
692 .unwrap();
693
694 assert_eq!(results.len(), 2); assert_eq!(results[0].document.id, "doc1"); assert_eq!(results[1].document.id, "doc3"); }
698
699 #[tokio::test]
700 async fn test_inmemory_delete() {
701 let store = InMemoryVectorStore::new();
702 store.create_collection("test", 3).await.unwrap();
703
704 let doc = create_test_document("doc1", "Test", vec![1.0, 0.0, 0.0]);
705 store.upsert("test", &[doc]).await.unwrap();
706
707 assert_eq!(store.count("test").await.unwrap(), 1);
708
709 let deleted = store
710 .delete("test", &["doc1".to_string()])
711 .await
712 .unwrap();
713 assert_eq!(deleted, 1);
714
715 assert_eq!(store.count("test").await.unwrap(), 0);
716 }
717
718 #[tokio::test]
719 async fn test_inmemory_get() {
720 let store = InMemoryVectorStore::new();
721 store.create_collection("test", 3).await.unwrap();
722
723 let doc = create_test_document("doc1", "Test content", vec![1.0, 0.0, 0.0]);
724 store.upsert("test", &[doc]).await.unwrap();
725
726 let retrieved = store.get("test", "doc1").await.unwrap();
727 assert!(retrieved.is_some());
728 assert_eq!(retrieved.unwrap().content, "Test content");
729
730 let not_found = store.get("test", "nonexistent").await.unwrap();
731 assert!(not_found.is_none());
732 }
733
734 #[tokio::test]
735 async fn test_inmemory_list_collections() {
736 let store = InMemoryVectorStore::new();
737
738 store.create_collection("col1", 384).await.unwrap();
739 store.create_collection("col2", 768).await.unwrap();
740
741 let collections = store.list_collections().await.unwrap();
742 assert_eq!(collections.len(), 2);
743 }
744
745 #[tokio::test]
746 async fn test_cosine_similarity() {
747 assert!((InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 0.001);
749
750 assert!(InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 0.001);
752
753 assert!((InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 0.001);
755 }
756}