Skip to main content

cortexai_agents/
vector_store.rs

1//! Vector Store and RAG (Retrieval-Augmented Generation) System
2//!
3//! Provides semantic memory for agents through vector embeddings and similarity search.
4//!
5//! # Features
6//! - **VectorStore trait**: Pluggable backends for vector storage
7//! - **InMemoryVectorStore**: Fast in-memory store with HNSW-like indexing
8//! - **SledVectorStore**: Persistent vector storage using Sled
9//! - **RAGPipeline**: Complete retrieval-augmented generation pipeline
10//! - **Document chunking**: Automatic text splitting for large documents
11//!
12//! # Example
13//! ```ignore
14//! use cortexai_agents::vector_store::*;
15//!
16//! // Create vector store
17//! let store = InMemoryVectorStore::new(1536); // OpenAI embedding dimension
18//!
19//! // Create RAG pipeline
20//! let rag = RAGPipeline::new(store, embedder)
21//!     .with_chunk_size(512)
22//!     .with_top_k(5);
23//!
24//! // Index documents
25//! rag.index_document("doc1", "Long document text...").await?;
26//!
27//! // Query with context
28//! let context = rag.retrieve("What is the main topic?", 5).await?;
29//! ```
30
31use async_trait::async_trait;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::sync::Arc;
35use tokio::sync::RwLock;
36
37use cortexai_core::errors::MemoryError;
38
39/// A document chunk with its embedding
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct VectorDocument {
42    /// Unique document ID
43    pub id: String,
44    /// Original text content
45    pub content: String,
46    /// Vector embedding
47    pub embedding: Vec<f32>,
48    /// Document metadata
49    pub metadata: HashMap<String, serde_json::Value>,
50    /// Source document ID (for chunks)
51    pub source_id: Option<String>,
52    /// Chunk index within source document
53    pub chunk_index: Option<usize>,
54}
55
56impl VectorDocument {
57    pub fn new(id: impl Into<String>, content: impl Into<String>, embedding: Vec<f32>) -> Self {
58        Self {
59            id: id.into(),
60            content: content.into(),
61            embedding,
62            metadata: HashMap::new(),
63            source_id: None,
64            chunk_index: None,
65        }
66    }
67
68    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
69        self.metadata.insert(key.into(), value);
70        self
71    }
72
73    pub fn with_source(mut self, source_id: impl Into<String>, chunk_index: usize) -> Self {
74        self.source_id = Some(source_id.into());
75        self.chunk_index = Some(chunk_index);
76        self
77    }
78}
79
80/// Search result with similarity score
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SearchResult {
83    /// The matched document
84    pub document: VectorDocument,
85    /// Similarity score (0.0 - 1.0, higher is more similar)
86    pub score: f32,
87}
88
89/// Trait for vector embedding generation
90#[async_trait]
91pub trait Embedder: Send + Sync {
92    /// Generate embedding for text
93    async fn embed(&self, text: &str) -> Result<Vec<f32>, MemoryError>;
94
95    /// Generate embeddings for multiple texts (batch)
96    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, MemoryError> {
97        let mut results = Vec::with_capacity(texts.len());
98        for text in texts {
99            results.push(self.embed(text).await?);
100        }
101        Ok(results)
102    }
103
104    /// Get embedding dimension
105    fn dimension(&self) -> usize;
106}
107
108/// Trait for vector storage backends
109#[async_trait]
110pub trait VectorStore: Send + Sync {
111    /// Insert a document into the store
112    async fn insert(&self, doc: VectorDocument) -> Result<(), MemoryError>;
113
114    /// Insert multiple documents
115    async fn insert_batch(&self, docs: Vec<VectorDocument>) -> Result<(), MemoryError> {
116        for doc in docs {
117            self.insert(doc).await?;
118        }
119        Ok(())
120    }
121
122    /// Search for similar documents
123    async fn search(
124        &self,
125        query_embedding: &[f32],
126        top_k: usize,
127    ) -> Result<Vec<SearchResult>, MemoryError>;
128
129    /// Get document by ID
130    async fn get(&self, id: &str) -> Result<Option<VectorDocument>, MemoryError>;
131
132    /// Delete document by ID
133    async fn delete(&self, id: &str) -> Result<bool, MemoryError>;
134
135    /// Delete all documents with a given source_id
136    async fn delete_by_source(&self, source_id: &str) -> Result<usize, MemoryError>;
137
138    /// Get total document count
139    async fn count(&self) -> Result<usize, MemoryError>;
140
141    /// Clear all documents
142    async fn clear(&self) -> Result<(), MemoryError>;
143
144    /// Get store name for logging
145    fn name(&self) -> &'static str;
146}
147
148/// In-memory vector store with brute-force similarity search
149pub struct InMemoryVectorStore {
150    documents: Arc<RwLock<HashMap<String, VectorDocument>>>,
151    dimension: usize,
152}
153
154impl InMemoryVectorStore {
155    pub fn new(dimension: usize) -> Self {
156        Self {
157            documents: Arc::new(RwLock::new(HashMap::new())),
158            dimension,
159        }
160    }
161
162    /// Compute cosine similarity between two vectors
163    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
164        if a.len() != b.len() {
165            return 0.0;
166        }
167
168        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
169        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
170        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
171
172        if norm_a == 0.0 || norm_b == 0.0 {
173            return 0.0;
174        }
175
176        dot_product / (norm_a * norm_b)
177    }
178}
179
180#[async_trait]
181impl VectorStore for InMemoryVectorStore {
182    async fn insert(&self, doc: VectorDocument) -> Result<(), MemoryError> {
183        if doc.embedding.len() != self.dimension {
184            return Err(MemoryError::StorageError(format!(
185                "Embedding dimension mismatch: expected {}, got {}",
186                self.dimension,
187                doc.embedding.len()
188            )));
189        }
190
191        let mut docs = self.documents.write().await;
192        docs.insert(doc.id.clone(), doc);
193        Ok(())
194    }
195
196    async fn search(
197        &self,
198        query_embedding: &[f32],
199        top_k: usize,
200    ) -> Result<Vec<SearchResult>, MemoryError> {
201        if query_embedding.len() != self.dimension {
202            return Err(MemoryError::StorageError(format!(
203                "Query embedding dimension mismatch: expected {}, got {}",
204                self.dimension,
205                query_embedding.len()
206            )));
207        }
208
209        let docs = self.documents.read().await;
210
211        let mut results: Vec<SearchResult> = docs
212            .values()
213            .map(|doc| SearchResult {
214                document: doc.clone(),
215                score: Self::cosine_similarity(query_embedding, &doc.embedding),
216            })
217            .collect();
218
219        // Sort by score descending
220        results.sort_by(|a, b| {
221            b.score
222                .partial_cmp(&a.score)
223                .unwrap_or(std::cmp::Ordering::Equal)
224        });
225
226        // Take top_k
227        results.truncate(top_k);
228
229        Ok(results)
230    }
231
232    async fn get(&self, id: &str) -> Result<Option<VectorDocument>, MemoryError> {
233        let docs = self.documents.read().await;
234        Ok(docs.get(id).cloned())
235    }
236
237    async fn delete(&self, id: &str) -> Result<bool, MemoryError> {
238        let mut docs = self.documents.write().await;
239        Ok(docs.remove(id).is_some())
240    }
241
242    async fn delete_by_source(&self, source_id: &str) -> Result<usize, MemoryError> {
243        let mut docs = self.documents.write().await;
244        let to_remove: Vec<_> = docs
245            .iter()
246            .filter(|(_, doc)| doc.source_id.as_deref() == Some(source_id))
247            .map(|(id, _)| id.clone())
248            .collect();
249
250        let count = to_remove.len();
251        for id in to_remove {
252            docs.remove(&id);
253        }
254
255        Ok(count)
256    }
257
258    async fn count(&self) -> Result<usize, MemoryError> {
259        let docs = self.documents.read().await;
260        Ok(docs.len())
261    }
262
263    async fn clear(&self) -> Result<(), MemoryError> {
264        let mut docs = self.documents.write().await;
265        docs.clear();
266        Ok(())
267    }
268
269    fn name(&self) -> &'static str {
270        "in-memory"
271    }
272}
273
274/// Sled-based persistent vector store
275pub struct SledVectorStore {
276    db: sled::Db,
277    documents_tree: sled::Tree,
278    index_tree: sled::Tree,
279    dimension: usize,
280}
281
282impl SledVectorStore {
283    /// Create a new Sled vector store at the given path
284    pub fn new<P: AsRef<std::path::Path>>(path: P, dimension: usize) -> Result<Self, MemoryError> {
285        let db = sled::open(path).map_err(|e| MemoryError::StorageError(e.to_string()))?;
286
287        let documents_tree = db
288            .open_tree("vector_documents")
289            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
290
291        let index_tree = db
292            .open_tree("vector_index")
293            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
294
295        Ok(Self {
296            db,
297            documents_tree,
298            index_tree,
299            dimension,
300        })
301    }
302
303    /// Create a temporary in-memory Sled store (for testing)
304    pub fn temporary(dimension: usize) -> Result<Self, MemoryError> {
305        let db = sled::Config::new()
306            .temporary(true)
307            .open()
308            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
309
310        let documents_tree = db
311            .open_tree("vector_documents")
312            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
313
314        let index_tree = db
315            .open_tree("vector_index")
316            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
317
318        Ok(Self {
319            db,
320            documents_tree,
321            index_tree,
322            dimension,
323        })
324    }
325
326    /// Flush to disk
327    pub fn flush(&self) -> Result<(), MemoryError> {
328        self.db
329            .flush()
330            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
331        Ok(())
332    }
333
334    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
335        InMemoryVectorStore::cosine_similarity(a, b)
336    }
337}
338
339#[async_trait]
340impl VectorStore for SledVectorStore {
341    async fn insert(&self, doc: VectorDocument) -> Result<(), MemoryError> {
342        if doc.embedding.len() != self.dimension {
343            return Err(MemoryError::StorageError(format!(
344                "Embedding dimension mismatch: expected {}, got {}",
345                self.dimension,
346                doc.embedding.len()
347            )));
348        }
349
350        let doc_bytes =
351            serde_json::to_vec(&doc).map_err(|e| MemoryError::SerializationError(e.to_string()))?;
352
353        self.documents_tree
354            .insert(doc.id.as_bytes(), doc_bytes)
355            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
356
357        // Index by source_id if present
358        if let Some(source_id) = &doc.source_id {
359            let key = format!("source:{}:{}", source_id, doc.id);
360            self.index_tree
361                .insert(key.as_bytes(), doc.id.as_bytes())
362                .map_err(|e| MemoryError::StorageError(e.to_string()))?;
363        }
364
365        Ok(())
366    }
367
368    async fn search(
369        &self,
370        query_embedding: &[f32],
371        top_k: usize,
372    ) -> Result<Vec<SearchResult>, MemoryError> {
373        if query_embedding.len() != self.dimension {
374            return Err(MemoryError::StorageError(format!(
375                "Query embedding dimension mismatch: expected {}, got {}",
376                self.dimension,
377                query_embedding.len()
378            )));
379        }
380
381        let mut results = Vec::new();
382
383        for item in self.documents_tree.iter() {
384            let (_, value) = item.map_err(|e| MemoryError::StorageError(e.to_string()))?;
385            let doc: VectorDocument = serde_json::from_slice(&value)
386                .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
387
388            let score = Self::cosine_similarity(query_embedding, &doc.embedding);
389            results.push(SearchResult {
390                document: doc,
391                score,
392            });
393        }
394
395        // Sort by score descending
396        results.sort_by(|a, b| {
397            b.score
398                .partial_cmp(&a.score)
399                .unwrap_or(std::cmp::Ordering::Equal)
400        });
401        results.truncate(top_k);
402
403        Ok(results)
404    }
405
406    async fn get(&self, id: &str) -> Result<Option<VectorDocument>, MemoryError> {
407        match self.documents_tree.get(id.as_bytes()) {
408            Ok(Some(bytes)) => {
409                let doc: VectorDocument = serde_json::from_slice(&bytes)
410                    .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
411                Ok(Some(doc))
412            }
413            Ok(None) => Ok(None),
414            Err(e) => Err(MemoryError::StorageError(e.to_string())),
415        }
416    }
417
418    async fn delete(&self, id: &str) -> Result<bool, MemoryError> {
419        // Get doc first to remove from index
420        if let Some(doc) = self.get(id).await? {
421            if let Some(source_id) = &doc.source_id {
422                let key = format!("source:{}:{}", source_id, id);
423                let _ = self.index_tree.remove(key.as_bytes());
424            }
425        }
426
427        let removed = self
428            .documents_tree
429            .remove(id.as_bytes())
430            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
431
432        Ok(removed.is_some())
433    }
434
435    async fn delete_by_source(&self, source_id: &str) -> Result<usize, MemoryError> {
436        let prefix = format!("source:{}:", source_id);
437        let mut ids_to_remove = Vec::new();
438
439        for item in self.index_tree.scan_prefix(prefix.as_bytes()) {
440            let (_, value) = item.map_err(|e| MemoryError::StorageError(e.to_string()))?;
441            let id = String::from_utf8(value.to_vec())
442                .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
443            ids_to_remove.push(id);
444        }
445
446        let count = ids_to_remove.len();
447        for id in ids_to_remove {
448            self.delete(&id).await?;
449        }
450
451        Ok(count)
452    }
453
454    async fn count(&self) -> Result<usize, MemoryError> {
455        Ok(self.documents_tree.len())
456    }
457
458    async fn clear(&self) -> Result<(), MemoryError> {
459        self.documents_tree
460            .clear()
461            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
462        self.index_tree
463            .clear()
464            .map_err(|e| MemoryError::StorageError(e.to_string()))?;
465        Ok(())
466    }
467
468    fn name(&self) -> &'static str {
469        "sled"
470    }
471}
472
473/// Text chunking strategy
474#[derive(Debug, Clone)]
475pub enum ChunkingStrategy {
476    /// Fixed size chunks with overlap
477    FixedSize { chunk_size: usize, overlap: usize },
478    /// Split by sentences
479    Sentence { max_sentences: usize },
480    /// Split by paragraphs
481    Paragraph,
482    /// No chunking - use full text
483    None,
484}
485
486impl Default for ChunkingStrategy {
487    fn default() -> Self {
488        ChunkingStrategy::FixedSize {
489            chunk_size: 512,
490            overlap: 64,
491        }
492    }
493}
494
495/// Text chunker for splitting documents
496pub struct TextChunker {
497    strategy: ChunkingStrategy,
498}
499
500impl TextChunker {
501    pub fn new(strategy: ChunkingStrategy) -> Self {
502        Self { strategy }
503    }
504
505    /// Split text into chunks
506    pub fn chunk(&self, text: &str) -> Vec<String> {
507        match &self.strategy {
508            ChunkingStrategy::FixedSize {
509                chunk_size,
510                overlap,
511            } => self.chunk_fixed_size(text, *chunk_size, *overlap),
512            ChunkingStrategy::Sentence { max_sentences } => {
513                self.chunk_by_sentences(text, *max_sentences)
514            }
515            ChunkingStrategy::Paragraph => self.chunk_by_paragraphs(text),
516            ChunkingStrategy::None => vec![text.to_string()],
517        }
518    }
519
520    fn chunk_fixed_size(&self, text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
521        let chars: Vec<char> = text.chars().collect();
522        let mut chunks = Vec::new();
523        let mut start = 0;
524
525        while start < chars.len() {
526            let end = (start + chunk_size).min(chars.len());
527            let chunk: String = chars[start..end].iter().collect();
528
529            if !chunk.trim().is_empty() {
530                chunks.push(chunk.trim().to_string());
531            }
532
533            if end >= chars.len() {
534                break;
535            }
536
537            start = if overlap < chunk_size {
538                start + chunk_size - overlap
539            } else {
540                start + chunk_size
541            };
542        }
543
544        chunks
545    }
546
547    fn chunk_by_sentences(&self, text: &str, max_sentences: usize) -> Vec<String> {
548        let sentences: Vec<&str> = text
549            .split(['.', '!', '?'])
550            .filter(|s| !s.trim().is_empty())
551            .collect();
552
553        sentences
554            .chunks(max_sentences)
555            .map(|chunk| {
556                chunk
557                    .iter()
558                    .map(|s| s.trim())
559                    .collect::<Vec<_>>()
560                    .join(". ")
561                    + "."
562            })
563            .collect()
564    }
565
566    fn chunk_by_paragraphs(&self, text: &str) -> Vec<String> {
567        text.split("\n\n")
568            .filter(|p| !p.trim().is_empty())
569            .map(|p| p.trim().to_string())
570            .collect()
571    }
572}
573
574/// RAG (Retrieval-Augmented Generation) Pipeline
575pub struct RAGPipeline<S: VectorStore, E: Embedder> {
576    store: Arc<S>,
577    embedder: Arc<E>,
578    chunker: TextChunker,
579    default_top_k: usize,
580}
581
582impl<S: VectorStore, E: Embedder> RAGPipeline<S, E> {
583    /// Create a new RAG pipeline
584    pub fn new(store: S, embedder: E) -> Self {
585        Self {
586            store: Arc::new(store),
587            embedder: Arc::new(embedder),
588            chunker: TextChunker::new(ChunkingStrategy::default()),
589            default_top_k: 5,
590        }
591    }
592
593    /// Set chunking strategy
594    pub fn with_chunking(mut self, strategy: ChunkingStrategy) -> Self {
595        self.chunker = TextChunker::new(strategy);
596        self
597    }
598
599    /// Set default top_k for retrieval
600    pub fn with_top_k(mut self, top_k: usize) -> Self {
601        self.default_top_k = top_k;
602        self
603    }
604
605    /// Index a document (with automatic chunking)
606    pub async fn index_document(
607        &self,
608        doc_id: &str,
609        content: &str,
610        metadata: Option<HashMap<String, serde_json::Value>>,
611    ) -> Result<usize, MemoryError> {
612        // First, remove any existing chunks for this document
613        self.store.delete_by_source(doc_id).await?;
614
615        // Chunk the content
616        let chunks = self.chunker.chunk(content);
617        let chunk_count = chunks.len();
618
619        // Generate embeddings for all chunks
620        let chunk_refs: Vec<&str> = chunks.iter().map(|s| s.as_str()).collect();
621        let embeddings = self.embedder.embed_batch(&chunk_refs).await?;
622
623        // Create and insert documents
624        for (i, (chunk, embedding)) in chunks.into_iter().zip(embeddings).enumerate() {
625            let chunk_id = format!("{}:chunk:{}", doc_id, i);
626            let mut doc = VectorDocument::new(chunk_id, chunk, embedding).with_source(doc_id, i);
627
628            if let Some(ref meta) = metadata {
629                for (k, v) in meta {
630                    doc = doc.with_metadata(k.clone(), v.clone());
631                }
632            }
633
634            self.store.insert(doc).await?;
635        }
636
637        tracing::info!(doc_id = doc_id, chunks = chunk_count, "Indexed document");
638
639        Ok(chunk_count)
640    }
641
642    /// Retrieve relevant context for a query
643    pub async fn retrieve(
644        &self,
645        query: &str,
646        top_k: Option<usize>,
647    ) -> Result<Vec<SearchResult>, MemoryError> {
648        let k = top_k.unwrap_or(self.default_top_k);
649
650        // Generate query embedding
651        let query_embedding = self.embedder.embed(query).await?;
652
653        // Search for similar documents
654        let results = self.store.search(&query_embedding, k).await?;
655
656        tracing::debug!(query = query, results = results.len(), "Retrieved context");
657
658        Ok(results)
659    }
660
661    /// Retrieve and format context as a string for LLM prompt
662    pub async fn retrieve_context(
663        &self,
664        query: &str,
665        top_k: Option<usize>,
666    ) -> Result<String, MemoryError> {
667        let results = self.retrieve(query, top_k).await?;
668
669        if results.is_empty() {
670            return Ok(String::new());
671        }
672
673        let context = results
674            .iter()
675            .enumerate()
676            .map(|(i, r)| {
677                format!(
678                    "[Source {}] (relevance: {:.2})\n{}",
679                    i + 1,
680                    r.score,
681                    r.document.content
682                )
683            })
684            .collect::<Vec<_>>()
685            .join("\n\n");
686
687        Ok(context)
688    }
689
690    /// Build an augmented prompt with retrieved context
691    pub async fn augment_prompt(
692        &self,
693        query: &str,
694        top_k: Option<usize>,
695    ) -> Result<String, MemoryError> {
696        let context = self.retrieve_context(query, top_k).await?;
697
698        if context.is_empty() {
699            return Ok(query.to_string());
700        }
701
702        Ok(format!(
703            "Use the following context to answer the question. If the context doesn't contain \
704            relevant information, say so and answer based on your knowledge.\n\n\
705            Context:\n{}\n\n\
706            Question: {}",
707            context, query
708        ))
709    }
710
711    /// Delete a document and all its chunks
712    pub async fn delete_document(&self, doc_id: &str) -> Result<usize, MemoryError> {
713        self.store.delete_by_source(doc_id).await
714    }
715
716    /// Get document count
717    pub async fn document_count(&self) -> Result<usize, MemoryError> {
718        self.store.count().await
719    }
720
721    /// Clear all documents
722    pub async fn clear(&self) -> Result<(), MemoryError> {
723        self.store.clear().await
724    }
725}
726
727/// Semantic memory for agents using vector store
728pub struct SemanticMemory<S: VectorStore, E: Embedder> {
729    rag: RAGPipeline<S, E>,
730    agent_id: String,
731}
732
733impl<S: VectorStore, E: Embedder> SemanticMemory<S, E> {
734    pub fn new(store: S, embedder: E, agent_id: impl Into<String>) -> Self {
735        Self {
736            rag: RAGPipeline::new(store, embedder),
737            agent_id: agent_id.into(),
738        }
739    }
740
741    /// Remember a piece of information
742    pub async fn remember(
743        &self,
744        content: &str,
745        tags: Option<Vec<String>>,
746    ) -> Result<(), MemoryError> {
747        let memory_id = format!("{}:memory:{}", self.agent_id, uuid::Uuid::new_v4());
748
749        let mut metadata = HashMap::new();
750        metadata.insert("agent_id".to_string(), serde_json::json!(self.agent_id));
751        metadata.insert(
752            "timestamp".to_string(),
753            serde_json::json!(chrono::Utc::now().to_rfc3339()),
754        );
755
756        if let Some(tags) = tags {
757            metadata.insert("tags".to_string(), serde_json::json!(tags));
758        }
759
760        self.rag
761            .index_document(&memory_id, content, Some(metadata))
762            .await?;
763        Ok(())
764    }
765
766    /// Recall relevant memories for a query
767    pub async fn recall(
768        &self,
769        query: &str,
770        top_k: usize,
771    ) -> Result<Vec<SearchResult>, MemoryError> {
772        self.rag.retrieve(query, Some(top_k)).await
773    }
774
775    /// Get formatted context for a query
776    pub async fn get_context(&self, query: &str, top_k: usize) -> Result<String, MemoryError> {
777        self.rag.retrieve_context(query, Some(top_k)).await
778    }
779
780    /// Forget memories matching a query (by deleting similar documents)
781    pub async fn forget(&self, query: &str, threshold: f32) -> Result<usize, MemoryError> {
782        let results = self.rag.retrieve(query, Some(100)).await?;
783
784        let mut deleted = 0;
785        for result in results {
786            if result.score >= threshold && self.rag.store.delete(&result.document.id).await? {
787                deleted += 1;
788            }
789        }
790
791        Ok(deleted)
792    }
793}
794
795/// Embedder wrapper that uses an LLMBackend for embeddings
796pub struct LLMEmbedder {
797    backend: Arc<dyn cortexai_providers::LLMBackend>,
798    dimension: usize,
799}
800
801impl LLMEmbedder {
802    /// Create a new LLM-based embedder
803    ///
804    /// Common dimensions:
805    /// - OpenAI text-embedding-3-small: 1536
806    /// - OpenAI text-embedding-3-large: 3072
807    /// - OpenAI text-embedding-ada-002: 1536
808    pub fn new(backend: Arc<dyn cortexai_providers::LLMBackend>, dimension: usize) -> Self {
809        Self { backend, dimension }
810    }
811
812    /// Create embedder for OpenAI text-embedding-3-small (1536 dimensions)
813    pub fn openai_small(backend: Arc<dyn cortexai_providers::LLMBackend>) -> Self {
814        Self::new(backend, 1536)
815    }
816
817    /// Create embedder for OpenAI text-embedding-3-large (3072 dimensions)
818    pub fn openai_large(backend: Arc<dyn cortexai_providers::LLMBackend>) -> Self {
819        Self::new(backend, 3072)
820    }
821}
822
823#[async_trait]
824impl Embedder for LLMEmbedder {
825    async fn embed(&self, text: &str) -> Result<Vec<f32>, MemoryError> {
826        self.backend
827            .embed(text)
828            .await
829            .map_err(|e| MemoryError::StorageError(format!("Embedding failed: {}", e)))
830    }
831
832    fn dimension(&self) -> usize {
833        self.dimension
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840
841    // Mock embedder for testing
842    struct MockEmbedder {
843        dimension: usize,
844    }
845
846    impl MockEmbedder {
847        fn new(dimension: usize) -> Self {
848            Self { dimension }
849        }
850    }
851
852    #[async_trait]
853    impl Embedder for MockEmbedder {
854        async fn embed(&self, text: &str) -> Result<Vec<f32>, MemoryError> {
855            // Simple deterministic embedding based on text hash
856            let hash = text.bytes().fold(0u64, |acc, b| acc.wrapping_add(b as u64));
857            let mut embedding = vec![0.0f32; self.dimension];
858
859            for (i, val) in embedding.iter_mut().enumerate() {
860                *val = ((hash.wrapping_add(i as u64) % 1000) as f32 / 1000.0) - 0.5;
861            }
862
863            // Normalize
864            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
865            if norm > 0.0 {
866                for val in &mut embedding {
867                    *val /= norm;
868                }
869            }
870
871            Ok(embedding)
872        }
873
874        fn dimension(&self) -> usize {
875            self.dimension
876        }
877    }
878
879    #[tokio::test]
880    async fn test_in_memory_vector_store() {
881        let store = InMemoryVectorStore::new(4);
882
883        // Insert documents
884        let doc1 = VectorDocument::new("doc1", "Hello world", vec![1.0, 0.0, 0.0, 0.0]);
885        let doc2 = VectorDocument::new("doc2", "Goodbye world", vec![0.0, 1.0, 0.0, 0.0]);
886        let doc3 = VectorDocument::new("doc3", "Hello there", vec![0.9, 0.1, 0.0, 0.0]);
887
888        store.insert(doc1).await.unwrap();
889        store.insert(doc2).await.unwrap();
890        store.insert(doc3).await.unwrap();
891
892        assert_eq!(store.count().await.unwrap(), 3);
893
894        // Search
895        let results = store.search(&[1.0, 0.0, 0.0, 0.0], 2).await.unwrap();
896        assert_eq!(results.len(), 2);
897        assert_eq!(results[0].document.id, "doc1"); // Most similar
898        assert_eq!(results[1].document.id, "doc3"); // Second most similar
899    }
900
901    #[tokio::test]
902    async fn test_sled_vector_store() {
903        let store = SledVectorStore::temporary(4).unwrap();
904
905        let doc = VectorDocument::new("test", "Test content", vec![1.0, 0.0, 0.0, 0.0])
906            .with_metadata("key", serde_json::json!("value"));
907
908        store.insert(doc.clone()).await.unwrap();
909
910        let retrieved = store.get("test").await.unwrap().unwrap();
911        assert_eq!(retrieved.content, "Test content");
912        assert_eq!(
913            retrieved.metadata.get("key"),
914            Some(&serde_json::json!("value"))
915        );
916
917        store.delete("test").await.unwrap();
918        assert!(store.get("test").await.unwrap().is_none());
919    }
920
921    #[tokio::test]
922    async fn test_text_chunker_fixed_size() {
923        let chunker = TextChunker::new(ChunkingStrategy::FixedSize {
924            chunk_size: 10,
925            overlap: 2,
926        });
927
928        let text = "Hello world, this is a test of the chunking system.";
929        let chunks = chunker.chunk(text);
930
931        assert!(chunks.len() > 1);
932        assert!(chunks.iter().all(|c| c.len() <= 12)); // Allow some variance
933    }
934
935    #[tokio::test]
936    async fn test_text_chunker_sentences() {
937        let chunker = TextChunker::new(ChunkingStrategy::Sentence { max_sentences: 2 });
938
939        let text = "First sentence. Second sentence. Third sentence. Fourth sentence.";
940        let chunks = chunker.chunk(text);
941
942        assert_eq!(chunks.len(), 2);
943    }
944
945    #[tokio::test]
946    async fn test_rag_pipeline() {
947        let store = InMemoryVectorStore::new(64);
948        let embedder = MockEmbedder::new(64);
949        let rag = RAGPipeline::new(store, embedder)
950            .with_chunking(ChunkingStrategy::None)
951            .with_top_k(3);
952
953        // Index documents
954        rag.index_document("doc1", "Rust is a systems programming language.", None)
955            .await
956            .unwrap();
957        rag.index_document("doc2", "Python is great for data science.", None)
958            .await
959            .unwrap();
960        rag.index_document("doc3", "Rust has excellent memory safety.", None)
961            .await
962            .unwrap();
963
964        assert_eq!(rag.document_count().await.unwrap(), 3);
965
966        // Retrieve
967        let results = rag.retrieve("Tell me about Rust", Some(2)).await.unwrap();
968        assert_eq!(results.len(), 2);
969
970        // Get context
971        let context = rag
972            .retrieve_context("Rust programming", Some(2))
973            .await
974            .unwrap();
975        assert!(!context.is_empty());
976    }
977
978    #[tokio::test]
979    async fn test_semantic_memory() {
980        let store = InMemoryVectorStore::new(64);
981        let embedder = MockEmbedder::new(64);
982        let memory = SemanticMemory::new(store, embedder, "test-agent");
983
984        // Remember things
985        memory
986            .remember(
987                "The capital of France is Paris.",
988                Some(vec!["geography".to_string()]),
989            )
990            .await
991            .unwrap();
992        memory
993            .remember(
994                "Rust was created by Mozilla.",
995                Some(vec!["programming".to_string()]),
996            )
997            .await
998            .unwrap();
999
1000        // Recall
1001        let results = memory
1002            .recall("What is the capital of France?", 5)
1003            .await
1004            .unwrap();
1005        assert!(!results.is_empty());
1006    }
1007
1008    #[tokio::test]
1009    async fn test_delete_by_source() {
1010        let store = InMemoryVectorStore::new(4);
1011
1012        let doc1 = VectorDocument::new("chunk1", "Part 1", vec![1.0, 0.0, 0.0, 0.0])
1013            .with_source("doc1", 0);
1014        let doc2 = VectorDocument::new("chunk2", "Part 2", vec![0.0, 1.0, 0.0, 0.0])
1015            .with_source("doc1", 1);
1016        let doc3 =
1017            VectorDocument::new("other", "Other", vec![0.0, 0.0, 1.0, 0.0]).with_source("doc2", 0);
1018
1019        store.insert(doc1).await.unwrap();
1020        store.insert(doc2).await.unwrap();
1021        store.insert(doc3).await.unwrap();
1022
1023        assert_eq!(store.count().await.unwrap(), 3);
1024
1025        let deleted = store.delete_by_source("doc1").await.unwrap();
1026        assert_eq!(deleted, 2);
1027        assert_eq!(store.count().await.unwrap(), 1);
1028    }
1029}