1#![allow(dead_code)]
15use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::path::Path;
21
22pub type MemoryId = String;
28
29pub type Embedding = Vec<f32>;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryEntry {
35 pub id: MemoryId,
37 pub content: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub embedding: Option<Embedding>,
42 pub memory_type: MemoryType,
44 pub source: MemorySource,
46 pub importance: f32,
48 pub access_count: u64,
50 pub last_accessed: DateTime<Utc>,
52 pub created_at: DateTime<Utc>,
54 pub expires_at: Option<DateTime<Utc>>,
56 pub agent_id: Option<String>,
58 pub session_id: Option<String>,
60 pub metadata: HashMap<String, serde_json::Value>,
62 pub tags: Vec<String>,
64}
65
66impl MemoryEntry {
67 pub fn new(content: impl Into<String>, memory_type: MemoryType, source: MemorySource) -> Self {
69 let now = Utc::now();
70 Self {
71 id: generate_memory_id(),
72 content: content.into(),
73 embedding: None,
74 memory_type,
75 source,
76 importance: 0.5,
77 access_count: 0,
78 last_accessed: now,
79 created_at: now,
80 expires_at: None,
81 agent_id: None,
82 session_id: None,
83 metadata: HashMap::new(),
84 tags: Vec::new(),
85 }
86 }
87
88 pub fn with_embedding(mut self, embedding: Embedding) -> Self {
90 self.embedding = Some(embedding);
91 self
92 }
93
94 pub fn with_importance(mut self, importance: f32) -> Self {
96 self.importance = importance.clamp(0.0, 1.0);
97 self
98 }
99
100 pub fn with_agent(mut self, agent_id: impl Into<String>) -> Self {
102 self.agent_id = Some(agent_id.into());
103 self
104 }
105
106 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
108 self.session_id = Some(session_id.into());
109 self
110 }
111
112 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
114 self.tags.push(tag.into());
115 self
116 }
117
118 pub fn expires_in(mut self, duration: chrono::Duration) -> Self {
120 self.expires_at = Some(Utc::now() + duration);
121 self
122 }
123
124 pub fn is_expired(&self) -> bool {
126 self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum MemoryType {
134 ShortTerm,
136 LongTerm,
138 Episodic,
140 Semantic,
142 Procedural,
144 Preference,
146 Cache,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "snake_case")]
153pub enum MemorySource {
154 Conversation {
156 session_id: String,
157 message_id: String,
158 },
159 Document { path: String, chunk_index: u32 },
161 UserInput,
163 AgentReasoning { agent_id: String },
165 ToolResult { tool_name: String },
167 WebPage { url: String },
169 Summary { source_ids: Vec<String> },
171 Custom { source_type: String },
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VectorStoreConfig {
182 pub embedding_model: EmbeddingModel,
184 pub embedding_dim: usize,
186 pub similarity_metric: SimilarityMetric,
188 pub max_entries: usize,
190 pub db_path: Option<String>,
192}
193
194impl Default for VectorStoreConfig {
195 fn default() -> Self {
196 Self {
197 embedding_model: EmbeddingModel::default(),
198 embedding_dim: 384,
199 similarity_metric: SimilarityMetric::Cosine,
200 max_entries: 100_000,
201 db_path: None,
202 }
203 }
204}
205
206#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum EmbeddingModel {
210 OpenAISmall,
212 OpenAILarge,
214 OpenAIAda,
216 #[default]
218 MiniLM,
219 MPNet,
221 Cohere,
223 GoogleGecko,
225 Voyage,
227 Ollama { model: String },
229 Custom { name: String, dim: usize },
231}
232
233impl EmbeddingModel {
234 pub fn dimension(&self) -> usize {
236 match self {
237 EmbeddingModel::OpenAISmall => 1536,
238 EmbeddingModel::OpenAILarge => 3072,
239 EmbeddingModel::OpenAIAda => 1536,
240 EmbeddingModel::MiniLM => 384,
241 EmbeddingModel::MPNet => 768,
242 EmbeddingModel::Cohere => 1024,
243 EmbeddingModel::GoogleGecko => 768,
244 EmbeddingModel::Voyage => 1024,
245 EmbeddingModel::Ollama { .. } => 4096, EmbeddingModel::Custom { dim, .. } => *dim,
247 }
248 }
249}
250
251#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum SimilarityMetric {
255 #[default]
256 Cosine,
257 Euclidean,
258 DotProduct,
259 Manhattan,
260}
261
262impl SimilarityMetric {
263 pub fn calculate(&self, a: &[f32], b: &[f32]) -> f32 {
265 assert_eq!(a.len(), b.len(), "Vector dimensions must match");
266
267 match self {
268 SimilarityMetric::Cosine => {
269 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
270 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
271 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
272 if norm_a == 0.0 || norm_b == 0.0 {
273 0.0
274 } else {
275 dot / (norm_a * norm_b)
276 }
277 }
278 SimilarityMetric::Euclidean => {
279 let dist: f32 = a
280 .iter()
281 .zip(b.iter())
282 .map(|(x, y)| (x - y).powi(2))
283 .sum::<f32>()
284 .sqrt();
285 1.0 / (1.0 + dist) }
287 SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
288 SimilarityMetric::Manhattan => {
289 let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
290 1.0 / (1.0 + dist)
291 }
292 }
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SearchResult {
299 pub entry: MemoryEntry,
301 pub score: f32,
303 pub rank: usize,
305}
306
307pub struct VectorStore {
309 config: VectorStoreConfig,
310 entries: Vec<MemoryEntry>,
311 db: Option<rusqlite::Connection>,
312}
313
314impl VectorStore {
315 pub fn new(config: VectorStoreConfig) -> Self {
317 Self {
318 config,
319 entries: Vec::new(),
320 db: None,
321 }
322 }
323
324 pub fn with_persistence(
326 config: VectorStoreConfig,
327 db_path: impl AsRef<Path>,
328 ) -> Result<Self, MemoryError> {
329 let db = rusqlite::Connection::open(db_path.as_ref())
330 .map_err(|e| MemoryError::Database(e.to_string()))?;
331
332 db.execute_batch(
334 r#"
335 CREATE TABLE IF NOT EXISTS memory_entries (
336 id TEXT PRIMARY KEY,
337 content TEXT NOT NULL,
338 embedding BLOB,
339 memory_type TEXT NOT NULL,
340 source TEXT NOT NULL,
341 importance REAL NOT NULL,
342 access_count INTEGER NOT NULL DEFAULT 0,
343 last_accessed TEXT NOT NULL,
344 created_at TEXT NOT NULL,
345 expires_at TEXT,
346 agent_id TEXT,
347 session_id TEXT,
348 metadata TEXT,
349 tags TEXT
350 );
351
352 CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(memory_type);
353 CREATE INDEX IF NOT EXISTS idx_agent_id ON memory_entries(agent_id);
354 CREATE INDEX IF NOT EXISTS idx_session_id ON memory_entries(session_id);
355 CREATE INDEX IF NOT EXISTS idx_created_at ON memory_entries(created_at);
356 CREATE INDEX IF NOT EXISTS idx_importance ON memory_entries(importance DESC);
357 "#,
358 )
359 .map_err(|e| MemoryError::Database(e.to_string()))?;
360
361 let mut store = Self {
362 config,
363 entries: Vec::new(),
364 db: Some(db),
365 };
366
367 store.load_from_db()?;
368 Ok(store)
369 }
370
371 fn load_from_db(&mut self) -> Result<(), MemoryError> {
373 if let Some(ref db) = self.db {
374 let mut stmt = db
375 .prepare(
376 "SELECT id, content, embedding, memory_type, source, importance,
377 access_count, last_accessed, created_at, expires_at,
378 agent_id, session_id, metadata, tags
379 FROM memory_entries
380 ORDER BY importance DESC, created_at DESC",
381 )
382 .map_err(|e| MemoryError::Database(e.to_string()))?;
383
384 let entries = stmt
385 .query_map([], |row| {
386 let embedding_blob: Option<Vec<u8>> = row.get(2)?;
387 let embedding = embedding_blob.map(|blob| {
388 blob.chunks(4)
389 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4])))
390 .collect()
391 });
392
393 Ok(MemoryEntry {
394 id: row.get(0)?,
395 content: row.get(1)?,
396 embedding,
397 memory_type: serde_json::from_str(&row.get::<_, String>(3)?)
398 .unwrap_or(MemoryType::LongTerm),
399 source: serde_json::from_str(&row.get::<_, String>(4)?)
400 .unwrap_or(MemorySource::UserInput),
401 importance: row.get(5)?,
402 access_count: row.get(6)?,
403 last_accessed: row
404 .get::<_, String>(7)?
405 .parse()
406 .unwrap_or_else(|_| Utc::now()),
407 created_at: row
408 .get::<_, String>(8)?
409 .parse()
410 .unwrap_or_else(|_| Utc::now()),
411 expires_at: row
412 .get::<_, Option<String>>(9)?
413 .and_then(|s| s.parse().ok()),
414 agent_id: row.get(10)?,
415 session_id: row.get(11)?,
416 metadata: row
417 .get::<_, Option<String>>(12)?
418 .and_then(|s| serde_json::from_str(&s).ok())
419 .unwrap_or_default(),
420 tags: row
421 .get::<_, Option<String>>(13)?
422 .and_then(|s| serde_json::from_str(&s).ok())
423 .unwrap_or_default(),
424 })
425 })
426 .map_err(|e| MemoryError::Database(e.to_string()))?;
427
428 self.entries = entries.filter_map(|e| e.ok()).collect();
429 }
430 Ok(())
431 }
432
433 pub fn add(&mut self, entry: MemoryEntry) -> Result<MemoryId, MemoryError> {
435 let id = entry.id.clone();
436
437 if let Some(ref db) = self.db {
439 let embedding_blob: Option<Vec<u8>> = entry
440 .embedding
441 .as_ref()
442 .map(|emb| emb.iter().flat_map(|f| f.to_le_bytes()).collect());
443
444 db.execute(
445 "INSERT OR REPLACE INTO memory_entries
446 (id, content, embedding, memory_type, source, importance,
447 access_count, last_accessed, created_at, expires_at,
448 agent_id, session_id, metadata, tags)
449 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)",
450 rusqlite::params![
451 entry.id,
452 entry.content,
453 embedding_blob,
454 serde_json::to_string(&entry.memory_type).unwrap_or_default(),
455 serde_json::to_string(&entry.source).unwrap_or_default(),
456 entry.importance,
457 entry.access_count,
458 entry.last_accessed.to_rfc3339(),
459 entry.created_at.to_rfc3339(),
460 entry.expires_at.map(|e| e.to_rfc3339()),
461 entry.agent_id,
462 entry.session_id,
463 serde_json::to_string(&entry.metadata).ok(),
464 serde_json::to_string(&entry.tags).ok(),
465 ],
466 )
467 .map_err(|e| MemoryError::Database(e.to_string()))?;
468 }
469
470 self.entries.push(entry);
471
472 if self.entries.len() > self.config.max_entries {
474 self.prune()?;
475 }
476
477 Ok(id)
478 }
479
480 pub fn search(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
482 let mut results: Vec<(usize, f32)> = self
483 .entries
484 .iter()
485 .enumerate()
486 .filter(|(_, e)| !e.is_expired() && e.embedding.is_some())
487 .map(|(i, e)| {
488 let score = self
489 .config
490 .similarity_metric
491 .calculate(query_embedding, e.embedding.as_ref().unwrap());
492 (i, score)
493 })
494 .collect();
495
496 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
498
499 results
501 .into_iter()
502 .take(limit)
503 .enumerate()
504 .map(|(rank, (idx, score))| {
505 self.entries[idx].access_count += 1;
506 self.entries[idx].last_accessed = Utc::now();
507
508 SearchResult {
509 entry: self.entries[idx].clone(),
510 score,
511 rank,
512 }
513 })
514 .collect()
515 }
516
517 pub fn search_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
519 self.entries
520 .iter()
521 .filter(|e| e.memory_type == memory_type && !e.is_expired())
522 .take(limit)
523 .collect()
524 }
525
526 pub fn search_by_tags(&self, tags: &[String], limit: usize) -> Vec<&MemoryEntry> {
528 self.entries
529 .iter()
530 .filter(|e| !e.is_expired() && tags.iter().any(|t| e.tags.contains(t)))
531 .take(limit)
532 .collect()
533 }
534
535 pub fn get(&self, id: &str) -> Option<&MemoryEntry> {
537 self.entries.iter().find(|e| e.id == id)
538 }
539
540 pub fn delete(&mut self, id: &str) -> Result<bool, MemoryError> {
542 if let Some(pos) = self.entries.iter().position(|e| e.id == id) {
543 self.entries.remove(pos);
544
545 if let Some(ref db) = self.db {
546 db.execute("DELETE FROM memory_entries WHERE id = ?1", [id])
547 .map_err(|e| MemoryError::Database(e.to_string()))?;
548 }
549
550 Ok(true)
551 } else {
552 Ok(false)
553 }
554 }
555
556 fn prune(&mut self) -> Result<(), MemoryError> {
558 self.entries.retain(|e| !e.is_expired());
560
561 if self.entries.len() > self.config.max_entries {
563 self.entries.sort_by(|a, b| {
564 b.importance
565 .partial_cmp(&a.importance)
566 .unwrap_or(std::cmp::Ordering::Equal)
567 });
568 self.entries.truncate(self.config.max_entries);
569 }
570
571 Ok(())
572 }
573
574 pub fn stats(&self) -> VectorStoreStats {
576 VectorStoreStats {
577 total_entries: self.entries.len(),
578 entries_by_type: self.entries.iter().fold(HashMap::new(), |mut acc, e| {
579 *acc.entry(format!("{:?}", e.memory_type)).or_insert(0) += 1;
580 acc
581 }),
582 total_access_count: self.entries.iter().map(|e| e.access_count).sum(),
583 avg_importance: if self.entries.is_empty() {
584 0.0
585 } else {
586 self.entries.iter().map(|e| e.importance).sum::<f32>() / self.entries.len() as f32
587 },
588 }
589 }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct VectorStoreStats {
595 pub total_entries: usize,
596 pub entries_by_type: HashMap<String, usize>,
597 pub total_access_count: u64,
598 pub avg_importance: f32,
599}
600
601#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct Document {
608 pub id: String,
610 pub title: String,
612 pub content: String,
614 pub doc_type: DocumentType,
616 pub source: String,
618 pub chunks: Vec<DocumentChunk>,
620 pub created_at: DateTime<Utc>,
622 pub updated_at: DateTime<Utc>,
624 pub metadata: HashMap<String, serde_json::Value>,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize)]
630#[serde(rename_all = "snake_case")]
631pub enum DocumentType {
632 Text,
633 Markdown,
634 Code { language: String },
635 Html,
636 Pdf,
637 Json,
638 Yaml,
639 Csv,
640 Custom { mime_type: String },
641}
642
643#[derive(Debug, Clone, Serialize, Deserialize)]
645pub struct DocumentChunk {
646 pub index: u32,
648 pub content: String,
650 pub start_pos: usize,
652 pub end_pos: usize,
654 pub embedding: Option<Embedding>,
656 pub token_count: u32,
658}
659
660#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ChunkingConfig {
663 pub chunk_size: usize,
665 pub chunk_overlap: usize,
667 pub strategy: ChunkingStrategy,
669}
670
671impl Default for ChunkingConfig {
672 fn default() -> Self {
673 Self {
674 chunk_size: 512,
675 chunk_overlap: 50,
676 strategy: ChunkingStrategy::Semantic,
677 }
678 }
679}
680
681#[derive(Debug, Clone, Default, Serialize, Deserialize)]
683#[serde(rename_all = "snake_case")]
684pub enum ChunkingStrategy {
685 FixedSize,
687 Sentence,
689 Paragraph,
691 #[default]
693 Semantic,
694 Code,
696}
697
698pub struct KnowledgeBase {
700 vector_store: VectorStore,
702 documents: HashMap<String, Document>,
704 chunking_config: ChunkingConfig,
706}
707
708impl KnowledgeBase {
709 pub fn new(vector_config: VectorStoreConfig) -> Self {
711 Self {
712 vector_store: VectorStore::new(vector_config),
713 documents: HashMap::new(),
714 chunking_config: ChunkingConfig::default(),
715 }
716 }
717
718 pub fn with_persistence(
720 vector_config: VectorStoreConfig,
721 db_path: impl AsRef<Path>,
722 ) -> Result<Self, MemoryError> {
723 Ok(Self {
724 vector_store: VectorStore::with_persistence(vector_config, db_path)?,
725 documents: HashMap::new(),
726 chunking_config: ChunkingConfig::default(),
727 })
728 }
729
730 pub fn add_document(&mut self, mut document: Document) -> Result<String, MemoryError> {
732 document.chunks = self.chunk_document(&document.content);
734
735 let doc_id = document.id.clone();
736
737 for chunk in &document.chunks {
739 if let Some(ref embedding) = chunk.embedding {
740 let entry = MemoryEntry::new(
741 &chunk.content,
742 MemoryType::Semantic,
743 MemorySource::Document {
744 path: document.source.clone(),
745 chunk_index: chunk.index,
746 },
747 )
748 .with_embedding(embedding.clone())
749 .with_tag(format!("doc:{}", doc_id));
750
751 self.vector_store.add(entry)?;
752 }
753 }
754
755 self.documents.insert(doc_id.clone(), document);
756 Ok(doc_id)
757 }
758
759 fn chunk_document(&self, content: &str) -> Vec<DocumentChunk> {
761 match self.chunking_config.strategy {
762 ChunkingStrategy::Semantic => self.semantic_chunk(content),
763 ChunkingStrategy::Paragraph => self.paragraph_chunk(content),
764 ChunkingStrategy::Sentence => self.sentence_chunk(content),
765 ChunkingStrategy::FixedSize => self.fixed_chunk(content),
766 ChunkingStrategy::Code => self.code_chunk(content),
767 }
768 }
769
770 fn semantic_chunk(&self, content: &str) -> Vec<DocumentChunk> {
771 let mut chunks = Vec::new();
773 let mut current_chunk = String::new();
774 let mut start_pos = 0;
775 let mut chunk_index = 0;
776
777 for para in content.split("\n\n") {
778 let para = para.trim();
779 if para.is_empty() {
780 continue;
781 }
782
783 let para_tokens = estimate_tokens(para);
784 let current_tokens = estimate_tokens(¤t_chunk);
785
786 if current_tokens + para_tokens > self.chunking_config.chunk_size
787 && !current_chunk.is_empty()
788 {
789 let end_pos = start_pos + current_chunk.len();
791 chunks.push(DocumentChunk {
792 index: chunk_index,
793 content: current_chunk.trim().to_string(),
794 start_pos,
795 end_pos,
796 embedding: None,
797 token_count: estimate_tokens(¤t_chunk) as u32,
798 });
799 chunk_index += 1;
800 start_pos = end_pos;
801 current_chunk = String::new();
802 }
803
804 if !current_chunk.is_empty() {
805 current_chunk.push_str("\n\n");
806 }
807 current_chunk.push_str(para);
808 }
809
810 if !current_chunk.is_empty() {
812 let end_pos = start_pos + current_chunk.len();
813 chunks.push(DocumentChunk {
814 index: chunk_index,
815 content: current_chunk.trim().to_string(),
816 start_pos,
817 end_pos,
818 embedding: None,
819 token_count: estimate_tokens(¤t_chunk) as u32,
820 });
821 }
822
823 chunks
824 }
825
826 fn paragraph_chunk(&self, content: &str) -> Vec<DocumentChunk> {
827 content
828 .split("\n\n")
829 .filter(|p| !p.trim().is_empty())
830 .enumerate()
831 .scan(0usize, |pos, (i, para)| {
832 let start = *pos;
833 *pos += para.len() + 2;
834 Some(DocumentChunk {
835 index: i as u32,
836 content: para.trim().to_string(),
837 start_pos: start,
838 end_pos: *pos,
839 embedding: None,
840 token_count: estimate_tokens(para) as u32,
841 })
842 })
843 .collect()
844 }
845
846 fn sentence_chunk(&self, content: &str) -> Vec<DocumentChunk> {
847 let sentences: Vec<&str> = content
849 .split(['.', '!', '?'])
850 .filter(|s| !s.trim().is_empty())
851 .collect();
852
853 let mut chunks = Vec::new();
854 let mut current = String::new();
855 let mut start = 0;
856 let mut idx = 0;
857
858 for sentence in sentences {
859 let sentence = sentence.trim();
860 if estimate_tokens(¤t) + estimate_tokens(sentence)
861 > self.chunking_config.chunk_size
862 && !current.is_empty() {
863 chunks.push(DocumentChunk {
864 index: idx,
865 content: current.clone(),
866 start_pos: start,
867 end_pos: start + current.len(),
868 embedding: None,
869 token_count: estimate_tokens(¤t) as u32,
870 });
871 idx += 1;
872 start += current.len();
873 current.clear();
874 }
875 if !current.is_empty() {
876 current.push(' ');
877 }
878 current.push_str(sentence);
879 current.push('.');
880 }
881
882 if !current.is_empty() {
883 chunks.push(DocumentChunk {
884 index: idx,
885 content: current.clone(),
886 start_pos: start,
887 end_pos: start + current.len(),
888 embedding: None,
889 token_count: estimate_tokens(¤t) as u32,
890 });
891 }
892
893 chunks
894 }
895
896 fn fixed_chunk(&self, content: &str) -> Vec<DocumentChunk> {
897 let chars_per_chunk = self.chunking_config.chunk_size * 4; content
899 .chars()
900 .collect::<Vec<_>>()
901 .chunks(chars_per_chunk)
902 .enumerate()
903 .map(|(i, chars)| {
904 let s: String = chars.iter().collect();
905 DocumentChunk {
906 index: i as u32,
907 content: s.clone(),
908 start_pos: i * chars_per_chunk,
909 end_pos: (i + 1) * chars_per_chunk,
910 embedding: None,
911 token_count: estimate_tokens(&s) as u32,
912 }
913 })
914 .collect()
915 }
916
917 fn code_chunk(&self, content: &str) -> Vec<DocumentChunk> {
918 let mut chunks = Vec::new();
920 let mut current = String::new();
921 let mut start = 0;
922 let mut idx = 0;
923
924 for line in content.lines() {
925 let is_boundary = line.starts_with("fn ")
926 || line.starts_with("pub fn ")
927 || line.starts_with("async fn ")
928 || line.starts_with("impl ")
929 || line.starts_with("struct ")
930 || line.starts_with("enum ")
931 || line.starts_with("trait ")
932 || line.starts_with("class ")
933 || line.starts_with("def ")
934 || line.starts_with("function ")
935 || line.starts_with("const ")
936 || line.starts_with("export ");
937
938 if is_boundary && !current.is_empty() {
939 chunks.push(DocumentChunk {
940 index: idx,
941 content: current.clone(),
942 start_pos: start,
943 end_pos: start + current.len(),
944 embedding: None,
945 token_count: estimate_tokens(¤t) as u32,
946 });
947 idx += 1;
948 start += current.len();
949 current.clear();
950 }
951
952 current.push_str(line);
953 current.push('\n');
954 }
955
956 if !current.is_empty() {
957 chunks.push(DocumentChunk {
958 index: idx,
959 content: current.clone(),
960 start_pos: start,
961 end_pos: start + current.len(),
962 embedding: None,
963 token_count: estimate_tokens(¤t) as u32,
964 });
965 }
966
967 chunks
968 }
969
970 pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
972 self.vector_store.search(query_embedding, limit)
973 }
974
975 pub fn get_document(&self, id: &str) -> Option<&Document> {
977 self.documents.get(id)
978 }
979
980 pub fn list_documents(&self) -> Vec<&Document> {
982 self.documents.values().collect()
983 }
984
985 pub fn delete_document(&mut self, id: &str) -> bool {
987 self.documents.remove(id).is_some()
988 }
989}
990
991#[derive(Debug, Clone)]
997pub struct ContextWindow {
998 pub max_tokens: usize,
1000 pub reserved_for_response: usize,
1002 segments: Vec<ContextSegment>,
1004}
1005
1006#[derive(Debug, Clone, Serialize, Deserialize)]
1008pub struct ContextSegment {
1009 pub segment_type: ContextSegmentType,
1011 pub content: String,
1013 pub tokens: usize,
1015 pub priority: u32,
1017 pub required: bool,
1019}
1020
1021#[derive(Debug, Clone, Serialize, Deserialize)]
1023#[serde(rename_all = "snake_case")]
1024pub enum ContextSegmentType {
1025 SystemPrompt,
1026 UserPreferences,
1027 ConversationHistory,
1028 RetrievedContext,
1029 ToolResults,
1030 CurrentQuery,
1031 Custom { name: String },
1032}
1033
1034impl ContextWindow {
1035 pub fn new(max_tokens: usize) -> Self {
1037 Self {
1038 max_tokens,
1039 reserved_for_response: max_tokens / 4, segments: Vec::new(),
1041 }
1042 }
1043
1044 pub fn add_segment(&mut self, segment: ContextSegment) {
1046 self.segments.push(segment);
1047 }
1048
1049 pub fn build(&mut self) -> String {
1051 let available = self.max_tokens - self.reserved_for_response;
1052
1053 self.segments
1055 .sort_by(|a, b| match (a.required, b.required) {
1056 (true, false) => std::cmp::Ordering::Less,
1057 (false, true) => std::cmp::Ordering::Greater,
1058 _ => b.priority.cmp(&a.priority),
1059 });
1060
1061 let mut total_tokens = 0;
1062 let mut result = Vec::new();
1063
1064 for segment in &self.segments {
1065 if total_tokens + segment.tokens <= available {
1066 result.push(segment.content.clone());
1067 total_tokens += segment.tokens;
1068 } else if segment.required {
1069 let remaining = available.saturating_sub(total_tokens);
1071 if remaining > 0 {
1072 let truncated = truncate_to_tokens(&segment.content, remaining);
1073 result.push(truncated);
1074 break;
1075 }
1076 }
1077 }
1078
1079 result.join("\n\n")
1080 }
1081
1082 pub fn token_usage(&self) -> (usize, usize) {
1084 let used: usize = self.segments.iter().map(|s| s.tokens).sum();
1085 (used, self.max_tokens - self.reserved_for_response)
1086 }
1087}
1088
1089#[derive(Debug, Clone, Serialize, Deserialize)]
1095pub struct CacheEntry<T> {
1096 pub key: String,
1097 pub value: T,
1098 pub created_at: DateTime<Utc>,
1099 pub expires_at: Option<DateTime<Utc>>,
1100 pub access_count: u64,
1101}
1102
1103impl<T> CacheEntry<T> {
1104 pub fn is_expired(&self) -> bool {
1105 self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
1106 }
1107}
1108
1109pub struct AgentCache<T> {
1111 entries: HashMap<String, CacheEntry<T>>,
1112 max_size: usize,
1113}
1114
1115impl<T: Clone> AgentCache<T> {
1116 pub fn new(max_size: usize) -> Self {
1118 Self {
1119 entries: HashMap::new(),
1120 max_size,
1121 }
1122 }
1123
1124 pub fn get(&mut self, key: &str) -> Option<T> {
1126 if let Some(entry) = self.entries.get_mut(key) {
1127 if entry.is_expired() {
1128 self.entries.remove(key);
1129 return None;
1130 }
1131 entry.access_count += 1;
1132 Some(entry.value.clone())
1133 } else {
1134 None
1135 }
1136 }
1137
1138 pub fn set(&mut self, key: impl Into<String>, value: T, ttl: Option<chrono::Duration>) {
1140 let key = key.into();
1141 let now = Utc::now();
1142
1143 self.entries.insert(
1144 key.clone(),
1145 CacheEntry {
1146 key,
1147 value,
1148 created_at: now,
1149 expires_at: ttl.map(|d| now + d),
1150 access_count: 0,
1151 },
1152 );
1153
1154 if self.entries.len() > self.max_size {
1156 self.evict_lru();
1157 }
1158 }
1159
1160 pub fn remove(&mut self, key: &str) -> Option<T> {
1162 self.entries.remove(key).map(|e| e.value)
1163 }
1164
1165 pub fn clear(&mut self) {
1167 self.entries.clear();
1168 }
1169
1170 fn evict_lru(&mut self) {
1172 self.entries.retain(|_, v| !v.is_expired());
1174
1175 if self.entries.len() > self.max_size {
1177 let mut entries: Vec<_> = self
1179 .entries
1180 .iter()
1181 .map(|(k, v)| (k.clone(), v.access_count))
1182 .collect();
1183 entries.sort_by_key(|(_, count)| *count);
1184
1185 let to_remove = self.entries.len() - self.max_size;
1186 let keys_to_remove: Vec<String> = entries
1187 .into_iter()
1188 .take(to_remove)
1189 .map(|(k, _)| k)
1190 .collect();
1191
1192 for key in keys_to_remove {
1193 self.entries.remove(&key);
1194 }
1195 }
1196 }
1197}
1198
1199#[derive(Debug, Clone, Serialize, Deserialize)]
1205pub struct MemoryConfig {
1206 pub vector_store: VectorStoreConfig,
1208 pub chunking: ChunkingConfig,
1210 pub context_window_tokens: usize,
1212 pub cache_size: usize,
1214 pub db_path: Option<String>,
1216 pub auto_summarize: bool,
1218 pub summarize_threshold: usize,
1220}
1221
1222impl Default for MemoryConfig {
1223 fn default() -> Self {
1224 Self {
1225 vector_store: VectorStoreConfig::default(),
1226 chunking: ChunkingConfig::default(),
1227 context_window_tokens: 8192,
1228 cache_size: 1000,
1229 db_path: None,
1230 auto_summarize: true,
1231 summarize_threshold: 20,
1232 }
1233 }
1234}
1235
1236pub struct MemoryManager {
1238 config: MemoryConfig,
1239 vector_store: VectorStore,
1240 knowledge_base: KnowledgeBase,
1241 cache: AgentCache<String>,
1242}
1243
1244impl MemoryManager {
1245 pub fn new(config: MemoryConfig) -> Result<Self, MemoryError> {
1247 let vector_store = if let Some(ref path) = config.db_path {
1248 VectorStore::with_persistence(config.vector_store.clone(), path)?
1249 } else {
1250 VectorStore::new(config.vector_store.clone())
1251 };
1252
1253 let knowledge_base = if let Some(ref path) = config.db_path {
1254 let kb_path = format!("{}_kb", path);
1255 KnowledgeBase::with_persistence(config.vector_store.clone(), kb_path)?
1256 } else {
1257 KnowledgeBase::new(config.vector_store.clone())
1258 };
1259
1260 Ok(Self {
1261 config: config.clone(),
1262 vector_store,
1263 knowledge_base,
1264 cache: AgentCache::new(config.cache_size),
1265 })
1266 }
1267
1268 pub fn remember(
1270 &mut self,
1271 content: impl Into<String>,
1272 memory_type: MemoryType,
1273 source: MemorySource,
1274 ) -> Result<MemoryId, MemoryError> {
1275 let entry = MemoryEntry::new(content, memory_type, source);
1276 self.vector_store.add(entry)
1277 }
1278
1279 pub fn remember_with_embedding(
1281 &mut self,
1282 content: impl Into<String>,
1283 embedding: Embedding,
1284 memory_type: MemoryType,
1285 source: MemorySource,
1286 ) -> Result<MemoryId, MemoryError> {
1287 let entry = MemoryEntry::new(content, memory_type, source).with_embedding(embedding);
1288 self.vector_store.add(entry)
1289 }
1290
1291 pub fn recall(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1293 self.vector_store.search(query_embedding, limit)
1294 }
1295
1296 pub fn recall_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
1298 self.vector_store.search_by_type(memory_type, limit)
1299 }
1300
1301 pub fn add_document(&mut self, document: Document) -> Result<String, MemoryError> {
1303 self.knowledge_base.add_document(document)
1304 }
1305
1306 pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1308 self.knowledge_base.retrieve(query_embedding, limit)
1309 }
1310
1311 pub fn build_context(
1313 &mut self,
1314 query_embedding: &Embedding,
1315 system_prompt: &str,
1316 conversation: &[String],
1317 ) -> String {
1318 let mut context = ContextWindow::new(self.config.context_window_tokens);
1319
1320 context.add_segment(ContextSegment {
1322 segment_type: ContextSegmentType::SystemPrompt,
1323 content: system_prompt.to_string(),
1324 tokens: estimate_tokens(system_prompt),
1325 priority: 100,
1326 required: true,
1327 });
1328
1329 let retrieved = self.recall(query_embedding, 5);
1331 if !retrieved.is_empty() {
1332 let retrieved_text: String = retrieved
1333 .iter()
1334 .map(|r| format!("- {}", r.entry.content))
1335 .collect::<Vec<_>>()
1336 .join("\n");
1337
1338 context.add_segment(ContextSegment {
1339 segment_type: ContextSegmentType::RetrievedContext,
1340 content: format!("Relevant context:\n{}", retrieved_text),
1341 tokens: estimate_tokens(&retrieved_text) + 20,
1342 priority: 80,
1343 required: false,
1344 });
1345 }
1346
1347 let conv_text = conversation.join("\n");
1349 context.add_segment(ContextSegment {
1350 segment_type: ContextSegmentType::ConversationHistory,
1351 content: conv_text.clone(),
1352 tokens: estimate_tokens(&conv_text),
1353 priority: 90,
1354 required: false,
1355 });
1356
1357 context.build()
1358 }
1359
1360 pub fn cache_result(
1362 &mut self,
1363 key: impl Into<String>,
1364 value: String,
1365 ttl: Option<chrono::Duration>,
1366 ) {
1367 self.cache.set(key, value, ttl);
1368 }
1369
1370 pub fn get_cached(&mut self, key: &str) -> Option<String> {
1372 self.cache.get(key)
1373 }
1374
1375 pub fn stats(&self) -> MemoryStats {
1377 MemoryStats {
1378 vector_store: self.vector_store.stats(),
1379 document_count: self.knowledge_base.list_documents().len(),
1380 }
1381 }
1382}
1383
1384#[derive(Debug, Clone, Serialize, Deserialize)]
1386pub struct MemoryStats {
1387 pub vector_store: VectorStoreStats,
1388 pub document_count: usize,
1389}
1390
1391#[derive(Debug, Clone)]
1397pub enum MemoryError {
1398 Database(String),
1400 Embedding(String),
1402 NotFound(String),
1404 InvalidInput(String),
1406 Io(String),
1408}
1409
1410impl std::fmt::Display for MemoryError {
1411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1412 match self {
1413 MemoryError::Database(e) => write!(f, "Database error: {}", e),
1414 MemoryError::Embedding(e) => write!(f, "Embedding error: {}", e),
1415 MemoryError::NotFound(e) => write!(f, "Not found: {}", e),
1416 MemoryError::InvalidInput(e) => write!(f, "Invalid input: {}", e),
1417 MemoryError::Io(e) => write!(f, "IO error: {}", e),
1418 }
1419 }
1420}
1421
1422impl std::error::Error for MemoryError {}
1423
1424fn generate_memory_id() -> String {
1429 use std::time::{SystemTime, UNIX_EPOCH};
1430 let timestamp = SystemTime::now()
1431 .duration_since(UNIX_EPOCH)
1432 .unwrap()
1433 .as_nanos();
1434 format!("mem_{:x}", timestamp)
1435}
1436
1437fn estimate_tokens(text: &str) -> usize {
1439 (text.len() as f32 / 4.0).ceil() as usize
1440}
1441
1442fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
1444 let max_chars = max_tokens * 4;
1445 if text.len() <= max_chars {
1446 text.to_string()
1447 } else {
1448 format!("{}...", &text[..max_chars.min(text.len())])
1449 }
1450}
1451
1452#[async_trait::async_trait]
1458pub trait EmbeddingProvider: Send + Sync {
1459 async fn embed(&self, text: &str) -> Result<Embedding, MemoryError>;
1461
1462 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError>;
1464
1465 fn dimension(&self) -> usize;
1467}
1468
1469pub struct OpenAIEmbedding {
1471 #[allow(dead_code)]
1472 api_key: String,
1473 model: String,
1474}
1475
1476impl OpenAIEmbedding {
1477 pub fn new(api_key: impl Into<String>) -> Self {
1478 Self {
1479 api_key: api_key.into(),
1480 model: "text-embedding-3-small".to_string(),
1481 }
1482 }
1483
1484 pub fn with_model(mut self, model: impl Into<String>) -> Self {
1485 self.model = model.into();
1486 self
1487 }
1488}
1489
1490#[async_trait::async_trait]
1491impl EmbeddingProvider for OpenAIEmbedding {
1492 async fn embed(&self, _text: &str) -> Result<Embedding, MemoryError> {
1493 Ok(vec![0.0; 1536])
1496 }
1497
1498 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError> {
1499 let mut results = Vec::new();
1500 for text in texts {
1501 results.push(self.embed(text).await?);
1502 }
1503 Ok(results)
1504 }
1505
1506 fn dimension(&self) -> usize {
1507 match self.model.as_str() {
1508 "text-embedding-3-large" => 3072,
1509 _ => 1536,
1510 }
1511 }
1512}
1513
1514#[cfg(test)]
1519mod tests {
1520 use super::*;
1521
1522 #[test]
1523 fn test_memory_entry_creation() {
1524 let entry = MemoryEntry::new(
1525 "Test content",
1526 MemoryType::LongTerm,
1527 MemorySource::UserInput,
1528 );
1529 assert!(!entry.id.is_empty());
1530 assert_eq!(entry.content, "Test content");
1531 assert_eq!(entry.memory_type, MemoryType::LongTerm);
1532 }
1533
1534 #[test]
1535 fn test_similarity_metrics() {
1536 let a = vec![1.0, 0.0, 0.0];
1537 let b = vec![1.0, 0.0, 0.0];
1538 let c = vec![0.0, 1.0, 0.0];
1539
1540 assert!((SimilarityMetric::Cosine.calculate(&a, &b) - 1.0).abs() < 0.001);
1541 assert!((SimilarityMetric::Cosine.calculate(&a, &c) - 0.0).abs() < 0.001);
1542 }
1543
1544 #[test]
1545 fn test_vector_store() {
1546 let config = VectorStoreConfig::default();
1547 let mut store = VectorStore::new(config);
1548
1549 let entry = MemoryEntry::new("Test", MemoryType::ShortTerm, MemorySource::UserInput)
1550 .with_embedding(vec![1.0, 0.0, 0.0]);
1551
1552 let id = store.add(entry).unwrap();
1553 assert!(!id.is_empty());
1554 assert!(store.get(&id).is_some());
1555 }
1556
1557 #[test]
1558 fn test_context_window() {
1559 let mut ctx = ContextWindow::new(1000);
1560
1561 ctx.add_segment(ContextSegment {
1562 segment_type: ContextSegmentType::SystemPrompt,
1563 content: "You are helpful".to_string(),
1564 tokens: 10,
1565 priority: 100,
1566 required: true,
1567 });
1568
1569 let result = ctx.build();
1570 assert!(result.contains("You are helpful"));
1571 }
1572
1573 #[test]
1574 fn test_cache() {
1575 let mut cache: AgentCache<String> = AgentCache::new(10);
1576
1577 cache.set("key1", "value1".to_string(), None);
1578 assert_eq!(cache.get("key1"), Some("value1".to_string()));
1579 assert_eq!(cache.get("key2"), None);
1580 }
1581
1582 #[test]
1583 fn test_estimate_tokens() {
1584 assert_eq!(estimate_tokens("hello"), 2); assert_eq!(estimate_tokens("hello world"), 3); }
1587}