1#![allow(dead_code)]
15use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::path::Path;
21
22pub type MemoryId = String;
28
29pub type Embedding = Vec<f32>;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryEntry {
35 pub id: MemoryId,
37 pub content: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub embedding: Option<Embedding>,
42 pub memory_type: MemoryType,
44 pub source: MemorySource,
46 pub importance: f32,
48 pub access_count: u64,
50 pub last_accessed: DateTime<Utc>,
52 pub created_at: DateTime<Utc>,
54 pub expires_at: Option<DateTime<Utc>>,
56 pub agent_id: Option<String>,
58 pub session_id: Option<String>,
60 pub metadata: HashMap<String, serde_json::Value>,
62 pub tags: Vec<String>,
64}
65
66impl MemoryEntry {
67 pub fn new(content: impl Into<String>, memory_type: MemoryType, source: MemorySource) -> Self {
69 let now = Utc::now();
70 Self {
71 id: generate_memory_id(),
72 content: content.into(),
73 embedding: None,
74 memory_type,
75 source,
76 importance: 0.5,
77 access_count: 0,
78 last_accessed: now,
79 created_at: now,
80 expires_at: None,
81 agent_id: None,
82 session_id: None,
83 metadata: HashMap::new(),
84 tags: Vec::new(),
85 }
86 }
87
88 pub fn with_embedding(mut self, embedding: Embedding) -> Self {
90 self.embedding = Some(embedding);
91 self
92 }
93
94 pub fn with_importance(mut self, importance: f32) -> Self {
96 self.importance = importance.clamp(0.0, 1.0);
97 self
98 }
99
100 pub fn with_agent(mut self, agent_id: impl Into<String>) -> Self {
102 self.agent_id = Some(agent_id.into());
103 self
104 }
105
106 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
108 self.session_id = Some(session_id.into());
109 self
110 }
111
112 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
114 self.tags.push(tag.into());
115 self
116 }
117
118 pub fn expires_in(mut self, duration: chrono::Duration) -> Self {
120 self.expires_at = Some(Utc::now() + duration);
121 self
122 }
123
124 pub fn is_expired(&self) -> bool {
126 self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum MemoryType {
134 ShortTerm,
136 LongTerm,
138 Episodic,
140 Semantic,
142 Procedural,
144 Preference,
146 Cache,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "snake_case")]
153pub enum MemorySource {
154 Conversation {
156 session_id: String,
157 message_id: String,
158 },
159 Document { path: String, chunk_index: u32 },
161 UserInput,
163 AgentReasoning { agent_id: String },
165 ToolResult { tool_name: String },
167 WebPage { url: String },
169 Summary { source_ids: Vec<String> },
171 Custom { source_type: String },
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VectorStoreConfig {
182 pub embedding_model: EmbeddingModel,
184 pub embedding_dim: usize,
186 pub similarity_metric: SimilarityMetric,
188 pub max_entries: usize,
190 pub db_path: Option<String>,
192}
193
194impl Default for VectorStoreConfig {
195 fn default() -> Self {
196 Self {
197 embedding_model: EmbeddingModel::default(),
198 embedding_dim: 384,
199 similarity_metric: SimilarityMetric::Cosine,
200 max_entries: 100_000,
201 db_path: None,
202 }
203 }
204}
205
206#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum EmbeddingModel {
210 OpenAISmall,
212 OpenAILarge,
214 OpenAIAda,
216 #[default]
218 MiniLM,
219 MPNet,
221 Cohere,
223 GoogleGecko,
225 Voyage,
227 Ollama { model: String },
229 Custom { name: String, dim: usize },
231}
232
233impl EmbeddingModel {
234 pub fn dimension(&self) -> usize {
236 match self {
237 EmbeddingModel::OpenAISmall => 1536,
238 EmbeddingModel::OpenAILarge => 3072,
239 EmbeddingModel::OpenAIAda => 1536,
240 EmbeddingModel::MiniLM => 384,
241 EmbeddingModel::MPNet => 768,
242 EmbeddingModel::Cohere => 1024,
243 EmbeddingModel::GoogleGecko => 768,
244 EmbeddingModel::Voyage => 1024,
245 EmbeddingModel::Ollama { .. } => 4096, EmbeddingModel::Custom { dim, .. } => *dim,
247 }
248 }
249}
250
251#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum SimilarityMetric {
255 #[default]
256 Cosine,
257 Euclidean,
258 DotProduct,
259 Manhattan,
260}
261
262impl SimilarityMetric {
263 pub fn calculate(&self, a: &[f32], b: &[f32]) -> f32 {
265 assert_eq!(a.len(), b.len(), "Vector dimensions must match");
266
267 match self {
268 SimilarityMetric::Cosine => {
269 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
270 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
271 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
272 if norm_a == 0.0 || norm_b == 0.0 {
273 0.0
274 } else {
275 dot / (norm_a * norm_b)
276 }
277 }
278 SimilarityMetric::Euclidean => {
279 let dist: f32 = a
280 .iter()
281 .zip(b.iter())
282 .map(|(x, y)| (x - y).powi(2))
283 .sum::<f32>()
284 .sqrt();
285 1.0 / (1.0 + dist) }
287 SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
288 SimilarityMetric::Manhattan => {
289 let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
290 1.0 / (1.0 + dist)
291 }
292 }
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SearchResult {
299 pub entry: MemoryEntry,
301 pub score: f32,
303 pub rank: usize,
305}
306
307pub struct VectorStore {
309 config: VectorStoreConfig,
310 entries: Vec<MemoryEntry>,
311 db: Option<rusqlite::Connection>,
312}
313
314impl VectorStore {
315 pub fn new(config: VectorStoreConfig) -> Self {
317 Self {
318 config,
319 entries: Vec::new(),
320 db: None,
321 }
322 }
323
324 pub fn with_persistence(
326 config: VectorStoreConfig,
327 db_path: impl AsRef<Path>,
328 ) -> Result<Self, MemoryError> {
329 let db = rusqlite::Connection::open(db_path.as_ref())
330 .map_err(|e| MemoryError::Database(e.to_string()))?;
331
332 db.execute_batch(
334 r#"
335 CREATE TABLE IF NOT EXISTS memory_entries (
336 id TEXT PRIMARY KEY,
337 content TEXT NOT NULL,
338 embedding BLOB,
339 memory_type TEXT NOT NULL,
340 source TEXT NOT NULL,
341 importance REAL NOT NULL,
342 access_count INTEGER NOT NULL DEFAULT 0,
343 last_accessed TEXT NOT NULL,
344 created_at TEXT NOT NULL,
345 expires_at TEXT,
346 agent_id TEXT,
347 session_id TEXT,
348 metadata TEXT,
349 tags TEXT
350 );
351
352 CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(memory_type);
353 CREATE INDEX IF NOT EXISTS idx_agent_id ON memory_entries(agent_id);
354 CREATE INDEX IF NOT EXISTS idx_session_id ON memory_entries(session_id);
355 CREATE INDEX IF NOT EXISTS idx_created_at ON memory_entries(created_at);
356 CREATE INDEX IF NOT EXISTS idx_importance ON memory_entries(importance DESC);
357 "#,
358 )
359 .map_err(|e| MemoryError::Database(e.to_string()))?;
360
361 let mut store = Self {
362 config,
363 entries: Vec::new(),
364 db: Some(db),
365 };
366
367 store.load_from_db()?;
368 Ok(store)
369 }
370
371 fn load_from_db(&mut self) -> Result<(), MemoryError> {
373 if let Some(ref db) = self.db {
374 let mut stmt = db
375 .prepare(
376 "SELECT id, content, embedding, memory_type, source, importance,
377 access_count, last_accessed, created_at, expires_at,
378 agent_id, session_id, metadata, tags
379 FROM memory_entries
380 ORDER BY importance DESC, created_at DESC",
381 )
382 .map_err(|e| MemoryError::Database(e.to_string()))?;
383
384 let entries = stmt
385 .query_map([], |row| {
386 let embedding_blob: Option<Vec<u8>> = row.get(2)?;
387 let embedding = embedding_blob.map(|blob| {
388 blob.chunks(4)
389 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4])))
390 .collect()
391 });
392
393 Ok(MemoryEntry {
394 id: row.get(0)?,
395 content: row.get(1)?,
396 embedding,
397 memory_type: serde_json::from_str(&row.get::<_, String>(3)?)
398 .unwrap_or(MemoryType::LongTerm),
399 source: serde_json::from_str(&row.get::<_, String>(4)?)
400 .unwrap_or(MemorySource::UserInput),
401 importance: row.get(5)?,
402 access_count: row.get(6)?,
403 last_accessed: row
404 .get::<_, String>(7)?
405 .parse()
406 .unwrap_or_else(|_| Utc::now()),
407 created_at: row
408 .get::<_, String>(8)?
409 .parse()
410 .unwrap_or_else(|_| Utc::now()),
411 expires_at: row
412 .get::<_, Option<String>>(9)?
413 .and_then(|s| s.parse().ok()),
414 agent_id: row.get(10)?,
415 session_id: row.get(11)?,
416 metadata: row
417 .get::<_, Option<String>>(12)?
418 .and_then(|s| serde_json::from_str(&s).ok())
419 .unwrap_or_default(),
420 tags: row
421 .get::<_, Option<String>>(13)?
422 .and_then(|s| serde_json::from_str(&s).ok())
423 .unwrap_or_default(),
424 })
425 })
426 .map_err(|e| MemoryError::Database(e.to_string()))?;
427
428 self.entries = entries.filter_map(|e| e.ok()).collect();
429 }
430 Ok(())
431 }
432
433 pub fn add(&mut self, entry: MemoryEntry) -> Result<MemoryId, MemoryError> {
435 let id = entry.id.clone();
436
437 if let Some(ref db) = self.db {
439 let embedding_blob: Option<Vec<u8>> = entry
440 .embedding
441 .as_ref()
442 .map(|emb| emb.iter().flat_map(|f| f.to_le_bytes()).collect());
443
444 db.execute(
445 "INSERT OR REPLACE INTO memory_entries
446 (id, content, embedding, memory_type, source, importance,
447 access_count, last_accessed, created_at, expires_at,
448 agent_id, session_id, metadata, tags)
449 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)",
450 rusqlite::params![
451 entry.id,
452 entry.content,
453 embedding_blob,
454 serde_json::to_string(&entry.memory_type).unwrap_or_default(),
455 serde_json::to_string(&entry.source).unwrap_or_default(),
456 entry.importance,
457 entry.access_count,
458 entry.last_accessed.to_rfc3339(),
459 entry.created_at.to_rfc3339(),
460 entry.expires_at.map(|e| e.to_rfc3339()),
461 entry.agent_id,
462 entry.session_id,
463 serde_json::to_string(&entry.metadata).ok(),
464 serde_json::to_string(&entry.tags).ok(),
465 ],
466 )
467 .map_err(|e| MemoryError::Database(e.to_string()))?;
468 }
469
470 self.entries.push(entry);
471
472 if self.entries.len() > self.config.max_entries {
474 self.prune()?;
475 }
476
477 Ok(id)
478 }
479
480 pub fn search(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
482 let mut results: Vec<(usize, f32)> = self
483 .entries
484 .iter()
485 .enumerate()
486 .filter(|(_, e)| !e.is_expired() && e.embedding.is_some())
487 .map(|(i, e)| {
488 let score = self
489 .config
490 .similarity_metric
491 .calculate(query_embedding, e.embedding.as_ref().unwrap());
492 (i, score)
493 })
494 .collect();
495
496 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
498
499 results
501 .into_iter()
502 .take(limit)
503 .enumerate()
504 .map(|(rank, (idx, score))| {
505 self.entries[idx].access_count += 1;
506 self.entries[idx].last_accessed = Utc::now();
507
508 SearchResult {
509 entry: self.entries[idx].clone(),
510 score,
511 rank,
512 }
513 })
514 .collect()
515 }
516
517 pub fn search_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
519 self.entries
520 .iter()
521 .filter(|e| e.memory_type == memory_type && !e.is_expired())
522 .take(limit)
523 .collect()
524 }
525
526 pub fn search_by_tags(&self, tags: &[String], limit: usize) -> Vec<&MemoryEntry> {
528 self.entries
529 .iter()
530 .filter(|e| !e.is_expired() && tags.iter().any(|t| e.tags.contains(t)))
531 .take(limit)
532 .collect()
533 }
534
535 pub fn get(&self, id: &str) -> Option<&MemoryEntry> {
537 self.entries.iter().find(|e| e.id == id)
538 }
539
540 pub fn delete(&mut self, id: &str) -> Result<bool, MemoryError> {
542 if let Some(pos) = self.entries.iter().position(|e| e.id == id) {
543 self.entries.remove(pos);
544
545 if let Some(ref db) = self.db {
546 db.execute("DELETE FROM memory_entries WHERE id = ?1", [id])
547 .map_err(|e| MemoryError::Database(e.to_string()))?;
548 }
549
550 Ok(true)
551 } else {
552 Ok(false)
553 }
554 }
555
556 fn prune(&mut self) -> Result<(), MemoryError> {
558 self.entries.retain(|e| !e.is_expired());
560
561 if self.entries.len() > self.config.max_entries {
563 self.entries.sort_by(|a, b| {
564 b.importance
565 .partial_cmp(&a.importance)
566 .unwrap_or(std::cmp::Ordering::Equal)
567 });
568 self.entries.truncate(self.config.max_entries);
569 }
570
571 Ok(())
572 }
573
574 pub fn stats(&self) -> VectorStoreStats {
576 VectorStoreStats {
577 total_entries: self.entries.len(),
578 entries_by_type: self.entries.iter().fold(HashMap::new(), |mut acc, e| {
579 *acc.entry(format!("{:?}", e.memory_type)).or_insert(0) += 1;
580 acc
581 }),
582 total_access_count: self.entries.iter().map(|e| e.access_count).sum(),
583 avg_importance: if self.entries.is_empty() {
584 0.0
585 } else {
586 self.entries.iter().map(|e| e.importance).sum::<f32>() / self.entries.len() as f32
587 },
588 }
589 }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct VectorStoreStats {
595 pub total_entries: usize,
596 pub entries_by_type: HashMap<String, usize>,
597 pub total_access_count: u64,
598 pub avg_importance: f32,
599}
600
601#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct Document {
608 pub id: String,
610 pub title: String,
612 pub content: String,
614 pub doc_type: DocumentType,
616 pub source: String,
618 pub chunks: Vec<DocumentChunk>,
620 pub created_at: DateTime<Utc>,
622 pub updated_at: DateTime<Utc>,
624 pub metadata: HashMap<String, serde_json::Value>,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize)]
630#[serde(rename_all = "snake_case")]
631pub enum DocumentType {
632 Text,
633 Markdown,
634 Code { language: String },
635 Html,
636 Pdf,
637 Json,
638 Yaml,
639 Csv,
640 Custom { mime_type: String },
641}
642
643#[derive(Debug, Clone, Serialize, Deserialize)]
645pub struct DocumentChunk {
646 pub index: u32,
648 pub content: String,
650 pub start_pos: usize,
652 pub end_pos: usize,
654 pub embedding: Option<Embedding>,
656 pub token_count: u32,
658}
659
660#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ChunkingConfig {
663 pub chunk_size: usize,
665 pub chunk_overlap: usize,
667 pub strategy: ChunkingStrategy,
669}
670
671impl Default for ChunkingConfig {
672 fn default() -> Self {
673 Self {
674 chunk_size: 512,
675 chunk_overlap: 50,
676 strategy: ChunkingStrategy::Semantic,
677 }
678 }
679}
680
681#[derive(Debug, Clone, Default, Serialize, Deserialize)]
683#[serde(rename_all = "snake_case")]
684pub enum ChunkingStrategy {
685 FixedSize,
687 Sentence,
689 Paragraph,
691 #[default]
693 Semantic,
694 Code,
696}
697
698pub struct KnowledgeBase {
700 vector_store: VectorStore,
702 documents: HashMap<String, Document>,
704 chunking_config: ChunkingConfig,
706}
707
708impl KnowledgeBase {
709 pub fn new(vector_config: VectorStoreConfig) -> Self {
711 Self {
712 vector_store: VectorStore::new(vector_config),
713 documents: HashMap::new(),
714 chunking_config: ChunkingConfig::default(),
715 }
716 }
717
718 pub fn with_persistence(
720 vector_config: VectorStoreConfig,
721 db_path: impl AsRef<Path>,
722 ) -> Result<Self, MemoryError> {
723 Ok(Self {
724 vector_store: VectorStore::with_persistence(vector_config, db_path)?,
725 documents: HashMap::new(),
726 chunking_config: ChunkingConfig::default(),
727 })
728 }
729
730 pub fn add_document(&mut self, mut document: Document) -> Result<String, MemoryError> {
732 document.chunks = self.chunk_document(&document.content);
734
735 let doc_id = document.id.clone();
736
737 for chunk in &document.chunks {
739 if let Some(ref embedding) = chunk.embedding {
740 let entry = MemoryEntry::new(
741 &chunk.content,
742 MemoryType::Semantic,
743 MemorySource::Document {
744 path: document.source.clone(),
745 chunk_index: chunk.index,
746 },
747 )
748 .with_embedding(embedding.clone())
749 .with_tag(format!("doc:{}", doc_id));
750
751 self.vector_store.add(entry)?;
752 }
753 }
754
755 self.documents.insert(doc_id.clone(), document);
756 Ok(doc_id)
757 }
758
759 fn chunk_document(&self, content: &str) -> Vec<DocumentChunk> {
761 match self.chunking_config.strategy {
762 ChunkingStrategy::Semantic => self.semantic_chunk(content),
763 ChunkingStrategy::Paragraph => self.paragraph_chunk(content),
764 ChunkingStrategy::Sentence => self.sentence_chunk(content),
765 ChunkingStrategy::FixedSize => self.fixed_chunk(content),
766 ChunkingStrategy::Code => self.code_chunk(content),
767 }
768 }
769
770 fn semantic_chunk(&self, content: &str) -> Vec<DocumentChunk> {
771 let mut chunks = Vec::new();
773 let mut current_chunk = String::new();
774 let mut start_pos = 0;
775 let mut chunk_index = 0;
776
777 for para in content.split("\n\n") {
778 let para = para.trim();
779 if para.is_empty() {
780 continue;
781 }
782
783 let para_tokens = estimate_tokens(para);
784 let current_tokens = estimate_tokens(¤t_chunk);
785
786 if current_tokens + para_tokens > self.chunking_config.chunk_size
787 && !current_chunk.is_empty()
788 {
789 let end_pos = start_pos + current_chunk.len();
791 chunks.push(DocumentChunk {
792 index: chunk_index,
793 content: current_chunk.trim().to_string(),
794 start_pos,
795 end_pos,
796 embedding: None,
797 token_count: estimate_tokens(¤t_chunk) as u32,
798 });
799 chunk_index += 1;
800 start_pos = end_pos;
801 current_chunk = String::new();
802 }
803
804 if !current_chunk.is_empty() {
805 current_chunk.push_str("\n\n");
806 }
807 current_chunk.push_str(para);
808 }
809
810 if !current_chunk.is_empty() {
812 let end_pos = start_pos + current_chunk.len();
813 chunks.push(DocumentChunk {
814 index: chunk_index,
815 content: current_chunk.trim().to_string(),
816 start_pos,
817 end_pos,
818 embedding: None,
819 token_count: estimate_tokens(¤t_chunk) as u32,
820 });
821 }
822
823 chunks
824 }
825
826 fn paragraph_chunk(&self, content: &str) -> Vec<DocumentChunk> {
827 content
828 .split("\n\n")
829 .filter(|p| !p.trim().is_empty())
830 .enumerate()
831 .scan(0usize, |pos, (i, para)| {
832 let start = *pos;
833 *pos += para.len() + 2;
834 Some(DocumentChunk {
835 index: i as u32,
836 content: para.trim().to_string(),
837 start_pos: start,
838 end_pos: *pos,
839 embedding: None,
840 token_count: estimate_tokens(para) as u32,
841 })
842 })
843 .collect()
844 }
845
846 fn sentence_chunk(&self, content: &str) -> Vec<DocumentChunk> {
847 let sentences: Vec<&str> = content
849 .split(|c| c == '.' || c == '!' || c == '?')
850 .filter(|s| !s.trim().is_empty())
851 .collect();
852
853 let mut chunks = Vec::new();
854 let mut current = String::new();
855 let mut start = 0;
856 let mut idx = 0;
857
858 for sentence in sentences {
859 let sentence = sentence.trim();
860 if estimate_tokens(¤t) + estimate_tokens(sentence)
861 > self.chunking_config.chunk_size
862 {
863 if !current.is_empty() {
864 chunks.push(DocumentChunk {
865 index: idx,
866 content: current.clone(),
867 start_pos: start,
868 end_pos: start + current.len(),
869 embedding: None,
870 token_count: estimate_tokens(¤t) as u32,
871 });
872 idx += 1;
873 start += current.len();
874 current.clear();
875 }
876 }
877 if !current.is_empty() {
878 current.push(' ');
879 }
880 current.push_str(sentence);
881 current.push('.');
882 }
883
884 if !current.is_empty() {
885 chunks.push(DocumentChunk {
886 index: idx,
887 content: current.clone(),
888 start_pos: start,
889 end_pos: start + current.len(),
890 embedding: None,
891 token_count: estimate_tokens(¤t) as u32,
892 });
893 }
894
895 chunks
896 }
897
898 fn fixed_chunk(&self, content: &str) -> Vec<DocumentChunk> {
899 let chars_per_chunk = self.chunking_config.chunk_size * 4; content
901 .chars()
902 .collect::<Vec<_>>()
903 .chunks(chars_per_chunk)
904 .enumerate()
905 .map(|(i, chars)| {
906 let s: String = chars.iter().collect();
907 DocumentChunk {
908 index: i as u32,
909 content: s.clone(),
910 start_pos: i * chars_per_chunk,
911 end_pos: (i + 1) * chars_per_chunk,
912 embedding: None,
913 token_count: estimate_tokens(&s) as u32,
914 }
915 })
916 .collect()
917 }
918
919 fn code_chunk(&self, content: &str) -> Vec<DocumentChunk> {
920 let mut chunks = Vec::new();
922 let mut current = String::new();
923 let mut start = 0;
924 let mut idx = 0;
925
926 for line in content.lines() {
927 let is_boundary = line.starts_with("fn ")
928 || line.starts_with("pub fn ")
929 || line.starts_with("async fn ")
930 || line.starts_with("impl ")
931 || line.starts_with("struct ")
932 || line.starts_with("enum ")
933 || line.starts_with("trait ")
934 || line.starts_with("class ")
935 || line.starts_with("def ")
936 || line.starts_with("function ")
937 || line.starts_with("const ")
938 || line.starts_with("export ");
939
940 if is_boundary && !current.is_empty() {
941 chunks.push(DocumentChunk {
942 index: idx,
943 content: current.clone(),
944 start_pos: start,
945 end_pos: start + current.len(),
946 embedding: None,
947 token_count: estimate_tokens(¤t) as u32,
948 });
949 idx += 1;
950 start += current.len();
951 current.clear();
952 }
953
954 current.push_str(line);
955 current.push('\n');
956 }
957
958 if !current.is_empty() {
959 chunks.push(DocumentChunk {
960 index: idx,
961 content: current.clone(),
962 start_pos: start,
963 end_pos: start + current.len(),
964 embedding: None,
965 token_count: estimate_tokens(¤t) as u32,
966 });
967 }
968
969 chunks
970 }
971
972 pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
974 self.vector_store.search(query_embedding, limit)
975 }
976
977 pub fn get_document(&self, id: &str) -> Option<&Document> {
979 self.documents.get(id)
980 }
981
982 pub fn list_documents(&self) -> Vec<&Document> {
984 self.documents.values().collect()
985 }
986
987 pub fn delete_document(&mut self, id: &str) -> bool {
989 self.documents.remove(id).is_some()
990 }
991}
992
993#[derive(Debug, Clone)]
999pub struct ContextWindow {
1000 pub max_tokens: usize,
1002 pub reserved_for_response: usize,
1004 segments: Vec<ContextSegment>,
1006}
1007
1008#[derive(Debug, Clone, Serialize, Deserialize)]
1010pub struct ContextSegment {
1011 pub segment_type: ContextSegmentType,
1013 pub content: String,
1015 pub tokens: usize,
1017 pub priority: u32,
1019 pub required: bool,
1021}
1022
1023#[derive(Debug, Clone, Serialize, Deserialize)]
1025#[serde(rename_all = "snake_case")]
1026pub enum ContextSegmentType {
1027 SystemPrompt,
1028 UserPreferences,
1029 ConversationHistory,
1030 RetrievedContext,
1031 ToolResults,
1032 CurrentQuery,
1033 Custom { name: String },
1034}
1035
1036impl ContextWindow {
1037 pub fn new(max_tokens: usize) -> Self {
1039 Self {
1040 max_tokens,
1041 reserved_for_response: max_tokens / 4, segments: Vec::new(),
1043 }
1044 }
1045
1046 pub fn add_segment(&mut self, segment: ContextSegment) {
1048 self.segments.push(segment);
1049 }
1050
1051 pub fn build(&mut self) -> String {
1053 let available = self.max_tokens - self.reserved_for_response;
1054
1055 self.segments
1057 .sort_by(|a, b| match (a.required, b.required) {
1058 (true, false) => std::cmp::Ordering::Less,
1059 (false, true) => std::cmp::Ordering::Greater,
1060 _ => b.priority.cmp(&a.priority),
1061 });
1062
1063 let mut total_tokens = 0;
1064 let mut result = Vec::new();
1065
1066 for segment in &self.segments {
1067 if total_tokens + segment.tokens <= available {
1068 result.push(segment.content.clone());
1069 total_tokens += segment.tokens;
1070 } else if segment.required {
1071 let remaining = available.saturating_sub(total_tokens);
1073 if remaining > 0 {
1074 let truncated = truncate_to_tokens(&segment.content, remaining);
1075 result.push(truncated);
1076 break;
1077 }
1078 }
1079 }
1080
1081 result.join("\n\n")
1082 }
1083
1084 pub fn token_usage(&self) -> (usize, usize) {
1086 let used: usize = self.segments.iter().map(|s| s.tokens).sum();
1087 (used, self.max_tokens - self.reserved_for_response)
1088 }
1089}
1090
1091#[derive(Debug, Clone, Serialize, Deserialize)]
1097pub struct CacheEntry<T> {
1098 pub key: String,
1099 pub value: T,
1100 pub created_at: DateTime<Utc>,
1101 pub expires_at: Option<DateTime<Utc>>,
1102 pub access_count: u64,
1103}
1104
1105impl<T> CacheEntry<T> {
1106 pub fn is_expired(&self) -> bool {
1107 self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
1108 }
1109}
1110
1111pub struct AgentCache<T> {
1113 entries: HashMap<String, CacheEntry<T>>,
1114 max_size: usize,
1115}
1116
1117impl<T: Clone> AgentCache<T> {
1118 pub fn new(max_size: usize) -> Self {
1120 Self {
1121 entries: HashMap::new(),
1122 max_size,
1123 }
1124 }
1125
1126 pub fn get(&mut self, key: &str) -> Option<T> {
1128 if let Some(entry) = self.entries.get_mut(key) {
1129 if entry.is_expired() {
1130 self.entries.remove(key);
1131 return None;
1132 }
1133 entry.access_count += 1;
1134 Some(entry.value.clone())
1135 } else {
1136 None
1137 }
1138 }
1139
1140 pub fn set(&mut self, key: impl Into<String>, value: T, ttl: Option<chrono::Duration>) {
1142 let key = key.into();
1143 let now = Utc::now();
1144
1145 self.entries.insert(
1146 key.clone(),
1147 CacheEntry {
1148 key,
1149 value,
1150 created_at: now,
1151 expires_at: ttl.map(|d| now + d),
1152 access_count: 0,
1153 },
1154 );
1155
1156 if self.entries.len() > self.max_size {
1158 self.evict_lru();
1159 }
1160 }
1161
1162 pub fn remove(&mut self, key: &str) -> Option<T> {
1164 self.entries.remove(key).map(|e| e.value)
1165 }
1166
1167 pub fn clear(&mut self) {
1169 self.entries.clear();
1170 }
1171
1172 fn evict_lru(&mut self) {
1174 self.entries.retain(|_, v| !v.is_expired());
1176
1177 if self.entries.len() > self.max_size {
1179 let mut entries: Vec<_> = self
1181 .entries
1182 .iter()
1183 .map(|(k, v)| (k.clone(), v.access_count))
1184 .collect();
1185 entries.sort_by_key(|(_, count)| *count);
1186
1187 let to_remove = self.entries.len() - self.max_size;
1188 let keys_to_remove: Vec<String> = entries
1189 .into_iter()
1190 .take(to_remove)
1191 .map(|(k, _)| k)
1192 .collect();
1193
1194 for key in keys_to_remove {
1195 self.entries.remove(&key);
1196 }
1197 }
1198 }
1199}
1200
1201#[derive(Debug, Clone, Serialize, Deserialize)]
1207pub struct MemoryConfig {
1208 pub vector_store: VectorStoreConfig,
1210 pub chunking: ChunkingConfig,
1212 pub context_window_tokens: usize,
1214 pub cache_size: usize,
1216 pub db_path: Option<String>,
1218 pub auto_summarize: bool,
1220 pub summarize_threshold: usize,
1222}
1223
1224impl Default for MemoryConfig {
1225 fn default() -> Self {
1226 Self {
1227 vector_store: VectorStoreConfig::default(),
1228 chunking: ChunkingConfig::default(),
1229 context_window_tokens: 8192,
1230 cache_size: 1000,
1231 db_path: None,
1232 auto_summarize: true,
1233 summarize_threshold: 20,
1234 }
1235 }
1236}
1237
1238pub struct MemoryManager {
1240 config: MemoryConfig,
1241 vector_store: VectorStore,
1242 knowledge_base: KnowledgeBase,
1243 cache: AgentCache<String>,
1244}
1245
1246impl MemoryManager {
1247 pub fn new(config: MemoryConfig) -> Result<Self, MemoryError> {
1249 let vector_store = if let Some(ref path) = config.db_path {
1250 VectorStore::with_persistence(config.vector_store.clone(), path)?
1251 } else {
1252 VectorStore::new(config.vector_store.clone())
1253 };
1254
1255 let knowledge_base = if let Some(ref path) = config.db_path {
1256 let kb_path = format!("{}_kb", path);
1257 KnowledgeBase::with_persistence(config.vector_store.clone(), kb_path)?
1258 } else {
1259 KnowledgeBase::new(config.vector_store.clone())
1260 };
1261
1262 Ok(Self {
1263 config: config.clone(),
1264 vector_store,
1265 knowledge_base,
1266 cache: AgentCache::new(config.cache_size),
1267 })
1268 }
1269
1270 pub fn remember(
1272 &mut self,
1273 content: impl Into<String>,
1274 memory_type: MemoryType,
1275 source: MemorySource,
1276 ) -> Result<MemoryId, MemoryError> {
1277 let entry = MemoryEntry::new(content, memory_type, source);
1278 self.vector_store.add(entry)
1279 }
1280
1281 pub fn remember_with_embedding(
1283 &mut self,
1284 content: impl Into<String>,
1285 embedding: Embedding,
1286 memory_type: MemoryType,
1287 source: MemorySource,
1288 ) -> Result<MemoryId, MemoryError> {
1289 let entry = MemoryEntry::new(content, memory_type, source).with_embedding(embedding);
1290 self.vector_store.add(entry)
1291 }
1292
1293 pub fn recall(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1295 self.vector_store.search(query_embedding, limit)
1296 }
1297
1298 pub fn recall_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
1300 self.vector_store.search_by_type(memory_type, limit)
1301 }
1302
1303 pub fn add_document(&mut self, document: Document) -> Result<String, MemoryError> {
1305 self.knowledge_base.add_document(document)
1306 }
1307
1308 pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1310 self.knowledge_base.retrieve(query_embedding, limit)
1311 }
1312
1313 pub fn build_context(
1315 &mut self,
1316 query_embedding: &Embedding,
1317 system_prompt: &str,
1318 conversation: &[String],
1319 ) -> String {
1320 let mut context = ContextWindow::new(self.config.context_window_tokens);
1321
1322 context.add_segment(ContextSegment {
1324 segment_type: ContextSegmentType::SystemPrompt,
1325 content: system_prompt.to_string(),
1326 tokens: estimate_tokens(system_prompt),
1327 priority: 100,
1328 required: true,
1329 });
1330
1331 let retrieved = self.recall(query_embedding, 5);
1333 if !retrieved.is_empty() {
1334 let retrieved_text: String = retrieved
1335 .iter()
1336 .map(|r| format!("- {}", r.entry.content))
1337 .collect::<Vec<_>>()
1338 .join("\n");
1339
1340 context.add_segment(ContextSegment {
1341 segment_type: ContextSegmentType::RetrievedContext,
1342 content: format!("Relevant context:\n{}", retrieved_text),
1343 tokens: estimate_tokens(&retrieved_text) + 20,
1344 priority: 80,
1345 required: false,
1346 });
1347 }
1348
1349 let conv_text = conversation.join("\n");
1351 context.add_segment(ContextSegment {
1352 segment_type: ContextSegmentType::ConversationHistory,
1353 content: conv_text.clone(),
1354 tokens: estimate_tokens(&conv_text),
1355 priority: 90,
1356 required: false,
1357 });
1358
1359 context.build()
1360 }
1361
1362 pub fn cache_result(
1364 &mut self,
1365 key: impl Into<String>,
1366 value: String,
1367 ttl: Option<chrono::Duration>,
1368 ) {
1369 self.cache.set(key, value, ttl);
1370 }
1371
1372 pub fn get_cached(&mut self, key: &str) -> Option<String> {
1374 self.cache.get(key)
1375 }
1376
1377 pub fn stats(&self) -> MemoryStats {
1379 MemoryStats {
1380 vector_store: self.vector_store.stats(),
1381 document_count: self.knowledge_base.list_documents().len(),
1382 }
1383 }
1384}
1385
1386#[derive(Debug, Clone, Serialize, Deserialize)]
1388pub struct MemoryStats {
1389 pub vector_store: VectorStoreStats,
1390 pub document_count: usize,
1391}
1392
1393#[derive(Debug, Clone)]
1399pub enum MemoryError {
1400 Database(String),
1402 Embedding(String),
1404 NotFound(String),
1406 InvalidInput(String),
1408 Io(String),
1410}
1411
1412impl std::fmt::Display for MemoryError {
1413 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1414 match self {
1415 MemoryError::Database(e) => write!(f, "Database error: {}", e),
1416 MemoryError::Embedding(e) => write!(f, "Embedding error: {}", e),
1417 MemoryError::NotFound(e) => write!(f, "Not found: {}", e),
1418 MemoryError::InvalidInput(e) => write!(f, "Invalid input: {}", e),
1419 MemoryError::Io(e) => write!(f, "IO error: {}", e),
1420 }
1421 }
1422}
1423
1424impl std::error::Error for MemoryError {}
1425
1426fn generate_memory_id() -> String {
1431 use std::time::{SystemTime, UNIX_EPOCH};
1432 let timestamp = SystemTime::now()
1433 .duration_since(UNIX_EPOCH)
1434 .unwrap()
1435 .as_nanos();
1436 format!("mem_{:x}", timestamp)
1437}
1438
1439fn estimate_tokens(text: &str) -> usize {
1441 (text.len() as f32 / 4.0).ceil() as usize
1442}
1443
1444fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
1446 let max_chars = max_tokens * 4;
1447 if text.len() <= max_chars {
1448 text.to_string()
1449 } else {
1450 format!("{}...", &text[..max_chars.min(text.len())])
1451 }
1452}
1453
1454#[async_trait::async_trait]
1460pub trait EmbeddingProvider: Send + Sync {
1461 async fn embed(&self, text: &str) -> Result<Embedding, MemoryError>;
1463
1464 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError>;
1466
1467 fn dimension(&self) -> usize;
1469}
1470
1471pub struct OpenAIEmbedding {
1473 #[allow(dead_code)]
1474 api_key: String,
1475 model: String,
1476}
1477
1478impl OpenAIEmbedding {
1479 pub fn new(api_key: impl Into<String>) -> Self {
1480 Self {
1481 api_key: api_key.into(),
1482 model: "text-embedding-3-small".to_string(),
1483 }
1484 }
1485
1486 pub fn with_model(mut self, model: impl Into<String>) -> Self {
1487 self.model = model.into();
1488 self
1489 }
1490}
1491
1492#[async_trait::async_trait]
1493impl EmbeddingProvider for OpenAIEmbedding {
1494 async fn embed(&self, _text: &str) -> Result<Embedding, MemoryError> {
1495 Ok(vec![0.0; 1536])
1498 }
1499
1500 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError> {
1501 let mut results = Vec::new();
1502 for text in texts {
1503 results.push(self.embed(text).await?);
1504 }
1505 Ok(results)
1506 }
1507
1508 fn dimension(&self) -> usize {
1509 match self.model.as_str() {
1510 "text-embedding-3-large" => 3072,
1511 _ => 1536,
1512 }
1513 }
1514}
1515
1516#[cfg(test)]
1521mod tests {
1522 use super::*;
1523
1524 #[test]
1525 fn test_memory_entry_creation() {
1526 let entry = MemoryEntry::new(
1527 "Test content",
1528 MemoryType::LongTerm,
1529 MemorySource::UserInput,
1530 );
1531 assert!(!entry.id.is_empty());
1532 assert_eq!(entry.content, "Test content");
1533 assert_eq!(entry.memory_type, MemoryType::LongTerm);
1534 }
1535
1536 #[test]
1537 fn test_similarity_metrics() {
1538 let a = vec![1.0, 0.0, 0.0];
1539 let b = vec![1.0, 0.0, 0.0];
1540 let c = vec![0.0, 1.0, 0.0];
1541
1542 assert!((SimilarityMetric::Cosine.calculate(&a, &b) - 1.0).abs() < 0.001);
1543 assert!((SimilarityMetric::Cosine.calculate(&a, &c) - 0.0).abs() < 0.001);
1544 }
1545
1546 #[test]
1547 fn test_vector_store() {
1548 let config = VectorStoreConfig::default();
1549 let mut store = VectorStore::new(config);
1550
1551 let entry = MemoryEntry::new("Test", MemoryType::ShortTerm, MemorySource::UserInput)
1552 .with_embedding(vec![1.0, 0.0, 0.0]);
1553
1554 let id = store.add(entry).unwrap();
1555 assert!(!id.is_empty());
1556 assert!(store.get(&id).is_some());
1557 }
1558
1559 #[test]
1560 fn test_context_window() {
1561 let mut ctx = ContextWindow::new(1000);
1562
1563 ctx.add_segment(ContextSegment {
1564 segment_type: ContextSegmentType::SystemPrompt,
1565 content: "You are helpful".to_string(),
1566 tokens: 10,
1567 priority: 100,
1568 required: true,
1569 });
1570
1571 let result = ctx.build();
1572 assert!(result.contains("You are helpful"));
1573 }
1574
1575 #[test]
1576 fn test_cache() {
1577 let mut cache: AgentCache<String> = AgentCache::new(10);
1578
1579 cache.set("key1", "value1".to_string(), None);
1580 assert_eq!(cache.get("key1"), Some("value1".to_string()));
1581 assert_eq!(cache.get("key2"), None);
1582 }
1583
1584 #[test]
1585 fn test_estimate_tokens() {
1586 assert_eq!(estimate_tokens("hello"), 2); assert_eq!(estimate_tokens("hello world"), 3); }
1589}