1use crate::types::{AppError, Document, Result, SearchResult};
44use async_trait::async_trait;
45use serde::{Deserialize, Serialize};
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(tag = "provider", rename_all = "lowercase")]
57pub enum VectorStoreProvider {
58 #[cfg(feature = "ares-vector")]
63 AresVector {
64 path: Option<String>,
66 },
67
68 #[cfg(feature = "lancedb")]
73 LanceDB {
74 path: String,
76 },
77
78 #[cfg(feature = "qdrant")]
82 Qdrant {
83 url: String,
85 api_key: Option<String>,
87 },
88
89 #[cfg(feature = "pgvector")]
93 PgVector {
94 connection_string: String,
96 },
97
98 #[cfg(feature = "chromadb")]
102 ChromaDB {
103 url: String,
105 },
106
107 #[cfg(feature = "pinecone")]
111 Pinecone {
112 api_key: String,
114 environment: String,
116 index_name: String,
118 },
119
120 InMemory,
124}
125
126impl VectorStoreProvider {
127 pub async fn create_store(&self) -> Result<Box<dyn VectorStore>> {
134 match self {
135 #[cfg(feature = "ares-vector")]
136 VectorStoreProvider::AresVector { path } => {
137 let store = super::ares_vector::AresVectorStore::new(path.clone()).await?;
138 Ok(Box::new(store))
139 }
140
141 #[cfg(feature = "lancedb")]
142 VectorStoreProvider::LanceDB { path } => {
143 let store = super::lancedb::LanceDBStore::new(path).await?;
144 Ok(Box::new(store))
145 }
146
147 #[cfg(feature = "qdrant")]
148 VectorStoreProvider::Qdrant { url, api_key } => {
149 let store =
150 super::qdrant::QdrantVectorStore::new(url.clone(), api_key.clone()).await?;
151 Ok(Box::new(store))
152 }
153
154 #[cfg(feature = "pgvector")]
155 VectorStoreProvider::PgVector { connection_string } => {
156 let store = super::pgvector::PgVectorStore::new(connection_string).await?;
157 Ok(Box::new(store))
158 }
159
160 #[cfg(feature = "chromadb")]
161 VectorStoreProvider::ChromaDB { url } => {
162 let store = super::chromadb::ChromaDBStore::new(url).await?;
163 Ok(Box::new(store))
164 }
165
166 #[cfg(feature = "pinecone")]
167 VectorStoreProvider::Pinecone {
168 api_key,
169 environment,
170 index_name,
171 } => {
172 let store =
173 super::pinecone::PineconeStore::new(api_key, environment, index_name).await?;
174 Ok(Box::new(store))
175 }
176
177 VectorStoreProvider::InMemory => {
178 let store = InMemoryVectorStore::new();
179 Ok(Box::new(store))
180 }
181
182 #[allow(unreachable_patterns)]
183 _ => Err(AppError::Configuration(
184 "Vector store provider not enabled. Check feature flags.".into(),
185 )),
186 }
187 }
188
189 pub fn from_env() -> Self {
200 #[cfg(feature = "ares-vector")]
201 if let Ok(path) = std::env::var("ARES_VECTOR_PATH") {
202 return VectorStoreProvider::AresVector { path: Some(path) };
203 }
204
205 #[cfg(feature = "lancedb")]
206 if let Ok(path) = std::env::var("LANCEDB_PATH") {
207 return VectorStoreProvider::LanceDB { path };
208 }
209
210 #[cfg(feature = "qdrant")]
211 if let Ok(url) = std::env::var("QDRANT_URL") {
212 let api_key = std::env::var("QDRANT_API_KEY").ok();
213 return VectorStoreProvider::Qdrant { url, api_key };
214 }
215
216 #[cfg(feature = "pgvector")]
217 if let Ok(connection_string) = std::env::var("PGVECTOR_URL") {
218 return VectorStoreProvider::PgVector { connection_string };
219 }
220
221 #[cfg(feature = "chromadb")]
222 if let Ok(url) = std::env::var("CHROMADB_URL") {
223 return VectorStoreProvider::ChromaDB { url };
224 }
225
226 #[cfg(feature = "pinecone")]
227 if let Ok(api_key) = std::env::var("PINECONE_API_KEY") {
228 let environment =
229 std::env::var("PINECONE_ENVIRONMENT").unwrap_or_else(|_| "us-east-1".into());
230 let index_name =
231 std::env::var("PINECONE_INDEX").unwrap_or_else(|_| "ares-documents".into());
232 return VectorStoreProvider::Pinecone {
233 api_key,
234 environment,
235 index_name,
236 };
237 }
238
239 #[cfg(feature = "ares-vector")]
241 return VectorStoreProvider::AresVector { path: None };
242
243 #[cfg(not(feature = "ares-vector"))]
244 VectorStoreProvider::InMemory
245 }
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct CollectionStats {
255 pub name: String,
257 pub document_count: usize,
259 pub dimensions: usize,
261 pub index_size_bytes: Option<u64>,
263 pub distance_metric: String,
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct CollectionInfo {
270 pub name: String,
272 pub document_count: usize,
274 pub dimensions: usize,
276}
277
278#[async_trait]
296pub trait VectorStore: Send + Sync {
297 fn provider_name(&self) -> &'static str;
299
300 async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()>;
311
312 async fn delete_collection(&self, name: &str) -> Result<()>;
322
323 async fn list_collections(&self) -> Result<Vec<CollectionInfo>>;
325
326 async fn collection_exists(&self, name: &str) -> Result<bool>;
328
329 async fn collection_stats(&self, name: &str) -> Result<CollectionStats>;
331
332 async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize>;
347
348 async fn search(
361 &self,
362 collection: &str,
363 embedding: &[f32],
364 limit: usize,
365 threshold: f32,
366 ) -> Result<Vec<SearchResult>>;
367
368 async fn search_with_filters(
382 &self,
383 collection: &str,
384 embedding: &[f32],
385 limit: usize,
386 threshold: f32,
387 _filters: &[(String, String)],
388 ) -> Result<Vec<SearchResult>> {
389 self.search(collection, embedding, limit, threshold).await
392 }
393
394 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize>;
405
406 async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>>;
417
418 async fn count(&self, collection: &str) -> Result<usize> {
420 let stats = self.collection_stats(collection).await?;
421 Ok(stats.document_count)
422 }
423}
424
425use parking_lot::RwLock;
430use std::collections::HashMap;
431use std::sync::Arc;
432
433pub struct InMemoryVectorStore {
438 collections: Arc<RwLock<HashMap<String, InMemoryCollection>>>,
439}
440
441struct InMemoryCollection {
442 dimensions: usize,
443 documents: HashMap<String, Document>,
444}
445
446impl InMemoryVectorStore {
447 pub fn new() -> Self {
449 Self {
450 collections: Arc::new(RwLock::new(HashMap::new())),
451 }
452 }
453
454 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
456 if a.len() != b.len() {
457 return 0.0;
458 }
459
460 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
461 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
462 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
463
464 if norm_a == 0.0 || norm_b == 0.0 {
465 return 0.0;
466 }
467
468 dot_product / (norm_a * norm_b)
469 }
470}
471
472impl Default for InMemoryVectorStore {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478#[async_trait]
479impl VectorStore for InMemoryVectorStore {
480 fn provider_name(&self) -> &'static str {
481 "in-memory"
482 }
483
484 async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()> {
485 let mut collections = self.collections.write();
486 if collections.contains_key(name) {
487 return Err(AppError::InvalidInput(format!(
488 "Collection '{}' already exists",
489 name
490 )));
491 }
492 collections.insert(
493 name.to_string(),
494 InMemoryCollection {
495 dimensions,
496 documents: HashMap::new(),
497 },
498 );
499 Ok(())
500 }
501
502 async fn delete_collection(&self, name: &str) -> Result<()> {
503 let mut collections = self.collections.write();
504 collections
505 .remove(name)
506 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", name)))?;
507 Ok(())
508 }
509
510 async fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
511 let collections = self.collections.read();
512 Ok(collections
513 .iter()
514 .map(|(name, col)| CollectionInfo {
515 name: name.clone(),
516 document_count: col.documents.len(),
517 dimensions: col.dimensions,
518 })
519 .collect())
520 }
521
522 async fn collection_exists(&self, name: &str) -> Result<bool> {
523 let collections = self.collections.read();
524 Ok(collections.contains_key(name))
525 }
526
527 async fn collection_stats(&self, name: &str) -> Result<CollectionStats> {
528 let collections = self.collections.read();
529 let col = collections
530 .get(name)
531 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", name)))?;
532
533 Ok(CollectionStats {
534 name: name.to_string(),
535 document_count: col.documents.len(),
536 dimensions: col.dimensions,
537 index_size_bytes: None,
538 distance_metric: "cosine".to_string(),
539 })
540 }
541
542 async fn upsert(&self, collection: &str, documents: &[Document]) -> Result<usize> {
543 let mut collections = self.collections.write();
544 let col = collections
545 .get_mut(collection)
546 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
547
548 let mut count = 0;
549 for doc in documents {
550 if doc.embedding.is_none() {
551 return Err(AppError::InvalidInput(format!(
552 "Document '{}' is missing embedding",
553 doc.id
554 )));
555 }
556 col.documents.insert(doc.id.clone(), doc.clone());
557 count += 1;
558 }
559
560 Ok(count)
561 }
562
563 async fn search(
564 &self,
565 collection: &str,
566 embedding: &[f32],
567 limit: usize,
568 threshold: f32,
569 ) -> Result<Vec<SearchResult>> {
570 let collections = self.collections.read();
571 let col = collections
572 .get(collection)
573 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
574
575 let mut results: Vec<SearchResult> = col
576 .documents
577 .values()
578 .filter_map(|doc| {
579 let doc_embedding = doc.embedding.as_ref()?;
580 let score = Self::cosine_similarity(embedding, doc_embedding);
581 if score >= threshold {
582 Some(SearchResult {
583 document: Document {
584 id: doc.id.clone(),
585 content: doc.content.clone(),
586 metadata: doc.metadata.clone(),
587 embedding: None, },
589 score,
590 })
591 } else {
592 None
593 }
594 })
595 .collect();
596
597 results.sort_by(|a, b| {
599 b.score
600 .partial_cmp(&a.score)
601 .unwrap_or(std::cmp::Ordering::Equal)
602 });
603
604 results.truncate(limit);
606
607 Ok(results)
608 }
609
610 async fn delete(&self, collection: &str, ids: &[String]) -> Result<usize> {
611 let mut collections = self.collections.write();
612 let col = collections
613 .get_mut(collection)
614 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
615
616 let mut count = 0;
617 for id in ids {
618 if col.documents.remove(id).is_some() {
619 count += 1;
620 }
621 }
622
623 Ok(count)
624 }
625
626 async fn get(&self, collection: &str, id: &str) -> Result<Option<Document>> {
627 let collections = self.collections.read();
628 let col = collections
629 .get(collection)
630 .ok_or_else(|| AppError::NotFound(format!("Collection '{}' not found", collection)))?;
631
632 Ok(col.documents.get(id).cloned())
633 }
634}
635
636#[cfg(test)]
641mod tests {
642 use super::*;
643 use crate::types::DocumentMetadata;
644 use chrono::Utc;
645
646 fn create_test_document(id: &str, content: &str, embedding: Vec<f32>) -> Document {
647 Document {
648 id: id.to_string(),
649 content: content.to_string(),
650 metadata: DocumentMetadata {
651 title: format!("Test Doc {}", id),
652 source: "test".to_string(),
653 created_at: Utc::now(),
654 tags: vec!["test".to_string()],
655 },
656 embedding: Some(embedding),
657 }
658 }
659
660 #[tokio::test]
661 async fn test_inmemory_create_collection() {
662 let store = InMemoryVectorStore::new();
663
664 store.create_collection("test", 384).await.unwrap();
665
666 assert!(store.collection_exists("test").await.unwrap());
667 }
668
669 #[tokio::test]
670 async fn test_inmemory_duplicate_collection_error() {
671 let store = InMemoryVectorStore::new();
672
673 store.create_collection("test", 384).await.unwrap();
674 let result = store.create_collection("test", 384).await;
675
676 assert!(result.is_err());
677 }
678
679 #[tokio::test]
680 async fn test_inmemory_upsert_and_search() {
681 let store = InMemoryVectorStore::new();
682 store.create_collection("test", 3).await.unwrap();
683
684 let doc1 = create_test_document("doc1", "Hello world", vec![1.0, 0.0, 0.0]);
685 let doc2 = create_test_document("doc2", "Goodbye world", vec![0.0, 1.0, 0.0]);
686 let doc3 = create_test_document("doc3", "Hello again", vec![0.9, 0.1, 0.0]);
687
688 store.upsert("test", &[doc1, doc2, doc3]).await.unwrap();
689
690 let results = store
692 .search("test", &[1.0, 0.0, 0.0], 10, 0.5)
693 .await
694 .unwrap();
695
696 assert_eq!(results.len(), 2); assert_eq!(results[0].document.id, "doc1"); assert_eq!(results[1].document.id, "doc3"); }
700
701 #[tokio::test]
702 async fn test_inmemory_delete() {
703 let store = InMemoryVectorStore::new();
704 store.create_collection("test", 3).await.unwrap();
705
706 let doc = create_test_document("doc1", "Test", vec![1.0, 0.0, 0.0]);
707 store.upsert("test", &[doc]).await.unwrap();
708
709 assert_eq!(store.count("test").await.unwrap(), 1);
710
711 let deleted = store.delete("test", &["doc1".to_string()]).await.unwrap();
712 assert_eq!(deleted, 1);
713
714 assert_eq!(store.count("test").await.unwrap(), 0);
715 }
716
717 #[tokio::test]
718 async fn test_inmemory_get() {
719 let store = InMemoryVectorStore::new();
720 store.create_collection("test", 3).await.unwrap();
721
722 let doc = create_test_document("doc1", "Test content", vec![1.0, 0.0, 0.0]);
723 store.upsert("test", &[doc]).await.unwrap();
724
725 let retrieved = store.get("test", "doc1").await.unwrap();
726 assert!(retrieved.is_some());
727 assert_eq!(retrieved.unwrap().content, "Test content");
728
729 let not_found = store.get("test", "nonexistent").await.unwrap();
730 assert!(not_found.is_none());
731 }
732
733 #[tokio::test]
734 async fn test_inmemory_list_collections() {
735 let store = InMemoryVectorStore::new();
736
737 store.create_collection("col1", 384).await.unwrap();
738 store.create_collection("col2", 768).await.unwrap();
739
740 let collections = store.list_collections().await.unwrap();
741 assert_eq!(collections.len(), 2);
742 }
743
744 #[tokio::test]
745 async fn test_cosine_similarity() {
746 assert!(
748 (InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 0.001
749 );
750
751 assert!(InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 0.001);
753
754 assert!(
756 (InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 0.001
757 );
758 }
759}