1use async_trait::async_trait;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34use std::sync::Arc;
35use tokio::sync::RwLock;
36
37use cortexai_core::errors::MemoryError;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct VectorDocument {
42 pub id: String,
44 pub content: String,
46 pub embedding: Vec<f32>,
48 pub metadata: HashMap<String, serde_json::Value>,
50 pub source_id: Option<String>,
52 pub chunk_index: Option<usize>,
54}
55
56impl VectorDocument {
57 pub fn new(id: impl Into<String>, content: impl Into<String>, embedding: Vec<f32>) -> Self {
58 Self {
59 id: id.into(),
60 content: content.into(),
61 embedding,
62 metadata: HashMap::new(),
63 source_id: None,
64 chunk_index: None,
65 }
66 }
67
68 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
69 self.metadata.insert(key.into(), value);
70 self
71 }
72
73 pub fn with_source(mut self, source_id: impl Into<String>, chunk_index: usize) -> Self {
74 self.source_id = Some(source_id.into());
75 self.chunk_index = Some(chunk_index);
76 self
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SearchResult {
83 pub document: VectorDocument,
85 pub score: f32,
87}
88
89#[async_trait]
91pub trait Embedder: Send + Sync {
92 async fn embed(&self, text: &str) -> Result<Vec<f32>, MemoryError>;
94
95 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, MemoryError> {
97 let mut results = Vec::with_capacity(texts.len());
98 for text in texts {
99 results.push(self.embed(text).await?);
100 }
101 Ok(results)
102 }
103
104 fn dimension(&self) -> usize;
106}
107
108#[async_trait]
110pub trait VectorStore: Send + Sync {
111 async fn insert(&self, doc: VectorDocument) -> Result<(), MemoryError>;
113
114 async fn insert_batch(&self, docs: Vec<VectorDocument>) -> Result<(), MemoryError> {
116 for doc in docs {
117 self.insert(doc).await?;
118 }
119 Ok(())
120 }
121
122 async fn search(
124 &self,
125 query_embedding: &[f32],
126 top_k: usize,
127 ) -> Result<Vec<SearchResult>, MemoryError>;
128
129 async fn get(&self, id: &str) -> Result<Option<VectorDocument>, MemoryError>;
131
132 async fn delete(&self, id: &str) -> Result<bool, MemoryError>;
134
135 async fn delete_by_source(&self, source_id: &str) -> Result<usize, MemoryError>;
137
138 async fn count(&self) -> Result<usize, MemoryError>;
140
141 async fn clear(&self) -> Result<(), MemoryError>;
143
144 fn name(&self) -> &'static str;
146}
147
148pub struct InMemoryVectorStore {
150 documents: Arc<RwLock<HashMap<String, VectorDocument>>>,
151 dimension: usize,
152}
153
154impl InMemoryVectorStore {
155 pub fn new(dimension: usize) -> Self {
156 Self {
157 documents: Arc::new(RwLock::new(HashMap::new())),
158 dimension,
159 }
160 }
161
162 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
164 if a.len() != b.len() {
165 return 0.0;
166 }
167
168 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
169 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
170 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
171
172 if norm_a == 0.0 || norm_b == 0.0 {
173 return 0.0;
174 }
175
176 dot_product / (norm_a * norm_b)
177 }
178}
179
180#[async_trait]
181impl VectorStore for InMemoryVectorStore {
182 async fn insert(&self, doc: VectorDocument) -> Result<(), MemoryError> {
183 if doc.embedding.len() != self.dimension {
184 return Err(MemoryError::StorageError(format!(
185 "Embedding dimension mismatch: expected {}, got {}",
186 self.dimension,
187 doc.embedding.len()
188 )));
189 }
190
191 let mut docs = self.documents.write().await;
192 docs.insert(doc.id.clone(), doc);
193 Ok(())
194 }
195
196 async fn search(
197 &self,
198 query_embedding: &[f32],
199 top_k: usize,
200 ) -> Result<Vec<SearchResult>, MemoryError> {
201 if query_embedding.len() != self.dimension {
202 return Err(MemoryError::StorageError(format!(
203 "Query embedding dimension mismatch: expected {}, got {}",
204 self.dimension,
205 query_embedding.len()
206 )));
207 }
208
209 let docs = self.documents.read().await;
210
211 let mut results: Vec<SearchResult> = docs
212 .values()
213 .map(|doc| SearchResult {
214 document: doc.clone(),
215 score: Self::cosine_similarity(query_embedding, &doc.embedding),
216 })
217 .collect();
218
219 results.sort_by(|a, b| {
221 b.score
222 .partial_cmp(&a.score)
223 .unwrap_or(std::cmp::Ordering::Equal)
224 });
225
226 results.truncate(top_k);
228
229 Ok(results)
230 }
231
232 async fn get(&self, id: &str) -> Result<Option<VectorDocument>, MemoryError> {
233 let docs = self.documents.read().await;
234 Ok(docs.get(id).cloned())
235 }
236
237 async fn delete(&self, id: &str) -> Result<bool, MemoryError> {
238 let mut docs = self.documents.write().await;
239 Ok(docs.remove(id).is_some())
240 }
241
242 async fn delete_by_source(&self, source_id: &str) -> Result<usize, MemoryError> {
243 let mut docs = self.documents.write().await;
244 let to_remove: Vec<_> = docs
245 .iter()
246 .filter(|(_, doc)| doc.source_id.as_deref() == Some(source_id))
247 .map(|(id, _)| id.clone())
248 .collect();
249
250 let count = to_remove.len();
251 for id in to_remove {
252 docs.remove(&id);
253 }
254
255 Ok(count)
256 }
257
258 async fn count(&self) -> Result<usize, MemoryError> {
259 let docs = self.documents.read().await;
260 Ok(docs.len())
261 }
262
263 async fn clear(&self) -> Result<(), MemoryError> {
264 let mut docs = self.documents.write().await;
265 docs.clear();
266 Ok(())
267 }
268
269 fn name(&self) -> &'static str {
270 "in-memory"
271 }
272}
273
274pub struct SledVectorStore {
276 db: sled::Db,
277 documents_tree: sled::Tree,
278 index_tree: sled::Tree,
279 dimension: usize,
280}
281
282impl SledVectorStore {
283 pub fn new<P: AsRef<std::path::Path>>(path: P, dimension: usize) -> Result<Self, MemoryError> {
285 let db = sled::open(path).map_err(|e| MemoryError::StorageError(e.to_string()))?;
286
287 let documents_tree = db
288 .open_tree("vector_documents")
289 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
290
291 let index_tree = db
292 .open_tree("vector_index")
293 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
294
295 Ok(Self {
296 db,
297 documents_tree,
298 index_tree,
299 dimension,
300 })
301 }
302
303 pub fn temporary(dimension: usize) -> Result<Self, MemoryError> {
305 let db = sled::Config::new()
306 .temporary(true)
307 .open()
308 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
309
310 let documents_tree = db
311 .open_tree("vector_documents")
312 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
313
314 let index_tree = db
315 .open_tree("vector_index")
316 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
317
318 Ok(Self {
319 db,
320 documents_tree,
321 index_tree,
322 dimension,
323 })
324 }
325
326 pub fn flush(&self) -> Result<(), MemoryError> {
328 self.db
329 .flush()
330 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
331 Ok(())
332 }
333
334 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
335 InMemoryVectorStore::cosine_similarity(a, b)
336 }
337}
338
339#[async_trait]
340impl VectorStore for SledVectorStore {
341 async fn insert(&self, doc: VectorDocument) -> Result<(), MemoryError> {
342 if doc.embedding.len() != self.dimension {
343 return Err(MemoryError::StorageError(format!(
344 "Embedding dimension mismatch: expected {}, got {}",
345 self.dimension,
346 doc.embedding.len()
347 )));
348 }
349
350 let doc_bytes =
351 serde_json::to_vec(&doc).map_err(|e| MemoryError::SerializationError(e.to_string()))?;
352
353 self.documents_tree
354 .insert(doc.id.as_bytes(), doc_bytes)
355 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
356
357 if let Some(source_id) = &doc.source_id {
359 let key = format!("source:{}:{}", source_id, doc.id);
360 self.index_tree
361 .insert(key.as_bytes(), doc.id.as_bytes())
362 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
363 }
364
365 Ok(())
366 }
367
368 async fn search(
369 &self,
370 query_embedding: &[f32],
371 top_k: usize,
372 ) -> Result<Vec<SearchResult>, MemoryError> {
373 if query_embedding.len() != self.dimension {
374 return Err(MemoryError::StorageError(format!(
375 "Query embedding dimension mismatch: expected {}, got {}",
376 self.dimension,
377 query_embedding.len()
378 )));
379 }
380
381 let mut results = Vec::new();
382
383 for item in self.documents_tree.iter() {
384 let (_, value) = item.map_err(|e| MemoryError::StorageError(e.to_string()))?;
385 let doc: VectorDocument = serde_json::from_slice(&value)
386 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
387
388 let score = Self::cosine_similarity(query_embedding, &doc.embedding);
389 results.push(SearchResult {
390 document: doc,
391 score,
392 });
393 }
394
395 results.sort_by(|a, b| {
397 b.score
398 .partial_cmp(&a.score)
399 .unwrap_or(std::cmp::Ordering::Equal)
400 });
401 results.truncate(top_k);
402
403 Ok(results)
404 }
405
406 async fn get(&self, id: &str) -> Result<Option<VectorDocument>, MemoryError> {
407 match self.documents_tree.get(id.as_bytes()) {
408 Ok(Some(bytes)) => {
409 let doc: VectorDocument = serde_json::from_slice(&bytes)
410 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
411 Ok(Some(doc))
412 }
413 Ok(None) => Ok(None),
414 Err(e) => Err(MemoryError::StorageError(e.to_string())),
415 }
416 }
417
418 async fn delete(&self, id: &str) -> Result<bool, MemoryError> {
419 if let Some(doc) = self.get(id).await? {
421 if let Some(source_id) = &doc.source_id {
422 let key = format!("source:{}:{}", source_id, id);
423 let _ = self.index_tree.remove(key.as_bytes());
424 }
425 }
426
427 let removed = self
428 .documents_tree
429 .remove(id.as_bytes())
430 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
431
432 Ok(removed.is_some())
433 }
434
435 async fn delete_by_source(&self, source_id: &str) -> Result<usize, MemoryError> {
436 let prefix = format!("source:{}:", source_id);
437 let mut ids_to_remove = Vec::new();
438
439 for item in self.index_tree.scan_prefix(prefix.as_bytes()) {
440 let (_, value) = item.map_err(|e| MemoryError::StorageError(e.to_string()))?;
441 let id = String::from_utf8(value.to_vec())
442 .map_err(|e| MemoryError::SerializationError(e.to_string()))?;
443 ids_to_remove.push(id);
444 }
445
446 let count = ids_to_remove.len();
447 for id in ids_to_remove {
448 self.delete(&id).await?;
449 }
450
451 Ok(count)
452 }
453
454 async fn count(&self) -> Result<usize, MemoryError> {
455 Ok(self.documents_tree.len())
456 }
457
458 async fn clear(&self) -> Result<(), MemoryError> {
459 self.documents_tree
460 .clear()
461 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
462 self.index_tree
463 .clear()
464 .map_err(|e| MemoryError::StorageError(e.to_string()))?;
465 Ok(())
466 }
467
468 fn name(&self) -> &'static str {
469 "sled"
470 }
471}
472
473#[derive(Debug, Clone)]
475pub enum ChunkingStrategy {
476 FixedSize { chunk_size: usize, overlap: usize },
478 Sentence { max_sentences: usize },
480 Paragraph,
482 None,
484}
485
486impl Default for ChunkingStrategy {
487 fn default() -> Self {
488 ChunkingStrategy::FixedSize {
489 chunk_size: 512,
490 overlap: 64,
491 }
492 }
493}
494
495pub struct TextChunker {
497 strategy: ChunkingStrategy,
498}
499
500impl TextChunker {
501 pub fn new(strategy: ChunkingStrategy) -> Self {
502 Self { strategy }
503 }
504
505 pub fn chunk(&self, text: &str) -> Vec<String> {
507 match &self.strategy {
508 ChunkingStrategy::FixedSize {
509 chunk_size,
510 overlap,
511 } => self.chunk_fixed_size(text, *chunk_size, *overlap),
512 ChunkingStrategy::Sentence { max_sentences } => {
513 self.chunk_by_sentences(text, *max_sentences)
514 }
515 ChunkingStrategy::Paragraph => self.chunk_by_paragraphs(text),
516 ChunkingStrategy::None => vec![text.to_string()],
517 }
518 }
519
520 fn chunk_fixed_size(&self, text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
521 let chars: Vec<char> = text.chars().collect();
522 let mut chunks = Vec::new();
523 let mut start = 0;
524
525 while start < chars.len() {
526 let end = (start + chunk_size).min(chars.len());
527 let chunk: String = chars[start..end].iter().collect();
528
529 if !chunk.trim().is_empty() {
530 chunks.push(chunk.trim().to_string());
531 }
532
533 if end >= chars.len() {
534 break;
535 }
536
537 start = if overlap < chunk_size {
538 start + chunk_size - overlap
539 } else {
540 start + chunk_size
541 };
542 }
543
544 chunks
545 }
546
547 fn chunk_by_sentences(&self, text: &str, max_sentences: usize) -> Vec<String> {
548 let sentences: Vec<&str> = text
549 .split(['.', '!', '?'])
550 .filter(|s| !s.trim().is_empty())
551 .collect();
552
553 sentences
554 .chunks(max_sentences)
555 .map(|chunk| {
556 chunk
557 .iter()
558 .map(|s| s.trim())
559 .collect::<Vec<_>>()
560 .join(". ")
561 + "."
562 })
563 .collect()
564 }
565
566 fn chunk_by_paragraphs(&self, text: &str) -> Vec<String> {
567 text.split("\n\n")
568 .filter(|p| !p.trim().is_empty())
569 .map(|p| p.trim().to_string())
570 .collect()
571 }
572}
573
574pub struct RAGPipeline<S: VectorStore, E: Embedder> {
576 store: Arc<S>,
577 embedder: Arc<E>,
578 chunker: TextChunker,
579 default_top_k: usize,
580}
581
582impl<S: VectorStore, E: Embedder> RAGPipeline<S, E> {
583 pub fn new(store: S, embedder: E) -> Self {
585 Self {
586 store: Arc::new(store),
587 embedder: Arc::new(embedder),
588 chunker: TextChunker::new(ChunkingStrategy::default()),
589 default_top_k: 5,
590 }
591 }
592
593 pub fn with_chunking(mut self, strategy: ChunkingStrategy) -> Self {
595 self.chunker = TextChunker::new(strategy);
596 self
597 }
598
599 pub fn with_top_k(mut self, top_k: usize) -> Self {
601 self.default_top_k = top_k;
602 self
603 }
604
605 pub async fn index_document(
607 &self,
608 doc_id: &str,
609 content: &str,
610 metadata: Option<HashMap<String, serde_json::Value>>,
611 ) -> Result<usize, MemoryError> {
612 self.store.delete_by_source(doc_id).await?;
614
615 let chunks = self.chunker.chunk(content);
617 let chunk_count = chunks.len();
618
619 let chunk_refs: Vec<&str> = chunks.iter().map(|s| s.as_str()).collect();
621 let embeddings = self.embedder.embed_batch(&chunk_refs).await?;
622
623 for (i, (chunk, embedding)) in chunks.into_iter().zip(embeddings).enumerate() {
625 let chunk_id = format!("{}:chunk:{}", doc_id, i);
626 let mut doc = VectorDocument::new(chunk_id, chunk, embedding).with_source(doc_id, i);
627
628 if let Some(ref meta) = metadata {
629 for (k, v) in meta {
630 doc = doc.with_metadata(k.clone(), v.clone());
631 }
632 }
633
634 self.store.insert(doc).await?;
635 }
636
637 tracing::info!(doc_id = doc_id, chunks = chunk_count, "Indexed document");
638
639 Ok(chunk_count)
640 }
641
642 pub async fn retrieve(
644 &self,
645 query: &str,
646 top_k: Option<usize>,
647 ) -> Result<Vec<SearchResult>, MemoryError> {
648 let k = top_k.unwrap_or(self.default_top_k);
649
650 let query_embedding = self.embedder.embed(query).await?;
652
653 let results = self.store.search(&query_embedding, k).await?;
655
656 tracing::debug!(query = query, results = results.len(), "Retrieved context");
657
658 Ok(results)
659 }
660
661 pub async fn retrieve_context(
663 &self,
664 query: &str,
665 top_k: Option<usize>,
666 ) -> Result<String, MemoryError> {
667 let results = self.retrieve(query, top_k).await?;
668
669 if results.is_empty() {
670 return Ok(String::new());
671 }
672
673 let context = results
674 .iter()
675 .enumerate()
676 .map(|(i, r)| {
677 format!(
678 "[Source {}] (relevance: {:.2})\n{}",
679 i + 1,
680 r.score,
681 r.document.content
682 )
683 })
684 .collect::<Vec<_>>()
685 .join("\n\n");
686
687 Ok(context)
688 }
689
690 pub async fn augment_prompt(
692 &self,
693 query: &str,
694 top_k: Option<usize>,
695 ) -> Result<String, MemoryError> {
696 let context = self.retrieve_context(query, top_k).await?;
697
698 if context.is_empty() {
699 return Ok(query.to_string());
700 }
701
702 Ok(format!(
703 "Use the following context to answer the question. If the context doesn't contain \
704 relevant information, say so and answer based on your knowledge.\n\n\
705 Context:\n{}\n\n\
706 Question: {}",
707 context, query
708 ))
709 }
710
711 pub async fn delete_document(&self, doc_id: &str) -> Result<usize, MemoryError> {
713 self.store.delete_by_source(doc_id).await
714 }
715
716 pub async fn document_count(&self) -> Result<usize, MemoryError> {
718 self.store.count().await
719 }
720
721 pub async fn clear(&self) -> Result<(), MemoryError> {
723 self.store.clear().await
724 }
725}
726
727pub struct SemanticMemory<S: VectorStore, E: Embedder> {
729 rag: RAGPipeline<S, E>,
730 agent_id: String,
731}
732
733impl<S: VectorStore, E: Embedder> SemanticMemory<S, E> {
734 pub fn new(store: S, embedder: E, agent_id: impl Into<String>) -> Self {
735 Self {
736 rag: RAGPipeline::new(store, embedder),
737 agent_id: agent_id.into(),
738 }
739 }
740
741 pub async fn remember(
743 &self,
744 content: &str,
745 tags: Option<Vec<String>>,
746 ) -> Result<(), MemoryError> {
747 let memory_id = format!("{}:memory:{}", self.agent_id, uuid::Uuid::new_v4());
748
749 let mut metadata = HashMap::new();
750 metadata.insert("agent_id".to_string(), serde_json::json!(self.agent_id));
751 metadata.insert(
752 "timestamp".to_string(),
753 serde_json::json!(chrono::Utc::now().to_rfc3339()),
754 );
755
756 if let Some(tags) = tags {
757 metadata.insert("tags".to_string(), serde_json::json!(tags));
758 }
759
760 self.rag
761 .index_document(&memory_id, content, Some(metadata))
762 .await?;
763 Ok(())
764 }
765
766 pub async fn recall(
768 &self,
769 query: &str,
770 top_k: usize,
771 ) -> Result<Vec<SearchResult>, MemoryError> {
772 self.rag.retrieve(query, Some(top_k)).await
773 }
774
775 pub async fn get_context(&self, query: &str, top_k: usize) -> Result<String, MemoryError> {
777 self.rag.retrieve_context(query, Some(top_k)).await
778 }
779
780 pub async fn forget(&self, query: &str, threshold: f32) -> Result<usize, MemoryError> {
782 let results = self.rag.retrieve(query, Some(100)).await?;
783
784 let mut deleted = 0;
785 for result in results {
786 if result.score >= threshold && self.rag.store.delete(&result.document.id).await? {
787 deleted += 1;
788 }
789 }
790
791 Ok(deleted)
792 }
793}
794
795pub struct LLMEmbedder {
797 backend: Arc<dyn cortexai_providers::LLMBackend>,
798 dimension: usize,
799}
800
801impl LLMEmbedder {
802 pub fn new(backend: Arc<dyn cortexai_providers::LLMBackend>, dimension: usize) -> Self {
809 Self { backend, dimension }
810 }
811
812 pub fn openai_small(backend: Arc<dyn cortexai_providers::LLMBackend>) -> Self {
814 Self::new(backend, 1536)
815 }
816
817 pub fn openai_large(backend: Arc<dyn cortexai_providers::LLMBackend>) -> Self {
819 Self::new(backend, 3072)
820 }
821}
822
823#[async_trait]
824impl Embedder for LLMEmbedder {
825 async fn embed(&self, text: &str) -> Result<Vec<f32>, MemoryError> {
826 self.backend
827 .embed(text)
828 .await
829 .map_err(|e| MemoryError::StorageError(format!("Embedding failed: {}", e)))
830 }
831
832 fn dimension(&self) -> usize {
833 self.dimension
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840
841 struct MockEmbedder {
843 dimension: usize,
844 }
845
846 impl MockEmbedder {
847 fn new(dimension: usize) -> Self {
848 Self { dimension }
849 }
850 }
851
852 #[async_trait]
853 impl Embedder for MockEmbedder {
854 async fn embed(&self, text: &str) -> Result<Vec<f32>, MemoryError> {
855 let hash = text.bytes().fold(0u64, |acc, b| acc.wrapping_add(b as u64));
857 let mut embedding = vec![0.0f32; self.dimension];
858
859 for (i, val) in embedding.iter_mut().enumerate() {
860 *val = ((hash.wrapping_add(i as u64) % 1000) as f32 / 1000.0) - 0.5;
861 }
862
863 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
865 if norm > 0.0 {
866 for val in &mut embedding {
867 *val /= norm;
868 }
869 }
870
871 Ok(embedding)
872 }
873
874 fn dimension(&self) -> usize {
875 self.dimension
876 }
877 }
878
879 #[tokio::test]
880 async fn test_in_memory_vector_store() {
881 let store = InMemoryVectorStore::new(4);
882
883 let doc1 = VectorDocument::new("doc1", "Hello world", vec![1.0, 0.0, 0.0, 0.0]);
885 let doc2 = VectorDocument::new("doc2", "Goodbye world", vec![0.0, 1.0, 0.0, 0.0]);
886 let doc3 = VectorDocument::new("doc3", "Hello there", vec![0.9, 0.1, 0.0, 0.0]);
887
888 store.insert(doc1).await.unwrap();
889 store.insert(doc2).await.unwrap();
890 store.insert(doc3).await.unwrap();
891
892 assert_eq!(store.count().await.unwrap(), 3);
893
894 let results = store.search(&[1.0, 0.0, 0.0, 0.0], 2).await.unwrap();
896 assert_eq!(results.len(), 2);
897 assert_eq!(results[0].document.id, "doc1"); assert_eq!(results[1].document.id, "doc3"); }
900
901 #[tokio::test]
902 async fn test_sled_vector_store() {
903 let store = SledVectorStore::temporary(4).unwrap();
904
905 let doc = VectorDocument::new("test", "Test content", vec![1.0, 0.0, 0.0, 0.0])
906 .with_metadata("key", serde_json::json!("value"));
907
908 store.insert(doc.clone()).await.unwrap();
909
910 let retrieved = store.get("test").await.unwrap().unwrap();
911 assert_eq!(retrieved.content, "Test content");
912 assert_eq!(
913 retrieved.metadata.get("key"),
914 Some(&serde_json::json!("value"))
915 );
916
917 store.delete("test").await.unwrap();
918 assert!(store.get("test").await.unwrap().is_none());
919 }
920
921 #[tokio::test]
922 async fn test_text_chunker_fixed_size() {
923 let chunker = TextChunker::new(ChunkingStrategy::FixedSize {
924 chunk_size: 10,
925 overlap: 2,
926 });
927
928 let text = "Hello world, this is a test of the chunking system.";
929 let chunks = chunker.chunk(text);
930
931 assert!(chunks.len() > 1);
932 assert!(chunks.iter().all(|c| c.len() <= 12)); }
934
935 #[tokio::test]
936 async fn test_text_chunker_sentences() {
937 let chunker = TextChunker::new(ChunkingStrategy::Sentence { max_sentences: 2 });
938
939 let text = "First sentence. Second sentence. Third sentence. Fourth sentence.";
940 let chunks = chunker.chunk(text);
941
942 assert_eq!(chunks.len(), 2);
943 }
944
945 #[tokio::test]
946 async fn test_rag_pipeline() {
947 let store = InMemoryVectorStore::new(64);
948 let embedder = MockEmbedder::new(64);
949 let rag = RAGPipeline::new(store, embedder)
950 .with_chunking(ChunkingStrategy::None)
951 .with_top_k(3);
952
953 rag.index_document("doc1", "Rust is a systems programming language.", None)
955 .await
956 .unwrap();
957 rag.index_document("doc2", "Python is great for data science.", None)
958 .await
959 .unwrap();
960 rag.index_document("doc3", "Rust has excellent memory safety.", None)
961 .await
962 .unwrap();
963
964 assert_eq!(rag.document_count().await.unwrap(), 3);
965
966 let results = rag.retrieve("Tell me about Rust", Some(2)).await.unwrap();
968 assert_eq!(results.len(), 2);
969
970 let context = rag
972 .retrieve_context("Rust programming", Some(2))
973 .await
974 .unwrap();
975 assert!(!context.is_empty());
976 }
977
978 #[tokio::test]
979 async fn test_semantic_memory() {
980 let store = InMemoryVectorStore::new(64);
981 let embedder = MockEmbedder::new(64);
982 let memory = SemanticMemory::new(store, embedder, "test-agent");
983
984 memory
986 .remember(
987 "The capital of France is Paris.",
988 Some(vec!["geography".to_string()]),
989 )
990 .await
991 .unwrap();
992 memory
993 .remember(
994 "Rust was created by Mozilla.",
995 Some(vec!["programming".to_string()]),
996 )
997 .await
998 .unwrap();
999
1000 let results = memory
1002 .recall("What is the capital of France?", 5)
1003 .await
1004 .unwrap();
1005 assert!(!results.is_empty());
1006 }
1007
1008 #[tokio::test]
1009 async fn test_delete_by_source() {
1010 let store = InMemoryVectorStore::new(4);
1011
1012 let doc1 = VectorDocument::new("chunk1", "Part 1", vec![1.0, 0.0, 0.0, 0.0])
1013 .with_source("doc1", 0);
1014 let doc2 = VectorDocument::new("chunk2", "Part 2", vec![0.0, 1.0, 0.0, 0.0])
1015 .with_source("doc1", 1);
1016 let doc3 =
1017 VectorDocument::new("other", "Other", vec![0.0, 0.0, 1.0, 0.0]).with_source("doc2", 0);
1018
1019 store.insert(doc1).await.unwrap();
1020 store.insert(doc2).await.unwrap();
1021 store.insert(doc3).await.unwrap();
1022
1023 assert_eq!(store.count().await.unwrap(), 3);
1024
1025 let deleted = store.delete_by_source("doc1").await.unwrap();
1026 assert_eq!(deleted, 2);
1027 assert_eq!(store.count().await.unwrap(), 1);
1028 }
1029}