chasm_cli/agency/
memory.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Memory and RAG (Retrieval-Augmented Generation) System
4//!
5//! Provides persistent context, knowledge retrieval, and semantic search for agents.
6//!
7//! ## Features
8//!
9//! - **Vector Store**: Semantic similarity search using embeddings
10//! - **Memory Types**: Short-term, long-term, episodic, and semantic memory
11//! - **Knowledge Base**: Structured document storage with chunking
12//! - **Context Window**: Smart context management for LLM prompts
13
14#![allow(dead_code)]
15//! - **Caching**: Frequently accessed information caching
16
17use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::path::Path;
21
22// =============================================================================
23// Core Types
24// =============================================================================
25
26/// Unique identifier for memory entries
27pub type MemoryId = String;
28
29/// Vector embedding (typically 384-1536 dimensions depending on model)
30pub type Embedding = Vec<f32>;
31
32/// Memory entry representing a piece of stored knowledge
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryEntry {
35    /// Unique identifier
36    pub id: MemoryId,
37    /// The content/text of this memory
38    pub content: String,
39    /// Vector embedding for similarity search
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub embedding: Option<Embedding>,
42    /// Memory type classification
43    pub memory_type: MemoryType,
44    /// Source of this memory (conversation, document, user input, etc.)
45    pub source: MemorySource,
46    /// Importance score (0.0 - 1.0)
47    pub importance: f32,
48    /// Access count for LRU caching
49    pub access_count: u64,
50    /// Last accessed timestamp
51    pub last_accessed: DateTime<Utc>,
52    /// Creation timestamp
53    pub created_at: DateTime<Utc>,
54    /// Optional expiration
55    pub expires_at: Option<DateTime<Utc>>,
56    /// Associated agent ID
57    pub agent_id: Option<String>,
58    /// Associated session ID
59    pub session_id: Option<String>,
60    /// Custom metadata
61    pub metadata: HashMap<String, serde_json::Value>,
62    /// Tags for filtering
63    pub tags: Vec<String>,
64}
65
66impl MemoryEntry {
67    /// Create a new memory entry
68    pub fn new(content: impl Into<String>, memory_type: MemoryType, source: MemorySource) -> Self {
69        let now = Utc::now();
70        Self {
71            id: generate_memory_id(),
72            content: content.into(),
73            embedding: None,
74            memory_type,
75            source,
76            importance: 0.5,
77            access_count: 0,
78            last_accessed: now,
79            created_at: now,
80            expires_at: None,
81            agent_id: None,
82            session_id: None,
83            metadata: HashMap::new(),
84            tags: Vec::new(),
85        }
86    }
87
88    /// Set the embedding
89    pub fn with_embedding(mut self, embedding: Embedding) -> Self {
90        self.embedding = Some(embedding);
91        self
92    }
93
94    /// Set importance score
95    pub fn with_importance(mut self, importance: f32) -> Self {
96        self.importance = importance.clamp(0.0, 1.0);
97        self
98    }
99
100    /// Set agent ID
101    pub fn with_agent(mut self, agent_id: impl Into<String>) -> Self {
102        self.agent_id = Some(agent_id.into());
103        self
104    }
105
106    /// Set session ID
107    pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
108        self.session_id = Some(session_id.into());
109        self
110    }
111
112    /// Add a tag
113    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
114        self.tags.push(tag.into());
115        self
116    }
117
118    /// Set expiration
119    pub fn expires_in(mut self, duration: chrono::Duration) -> Self {
120        self.expires_at = Some(Utc::now() + duration);
121        self
122    }
123
124    /// Check if expired
125    pub fn is_expired(&self) -> bool {
126        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
127    }
128}
129
130/// Types of memory
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum MemoryType {
134    /// Short-term working memory (current conversation context)
135    ShortTerm,
136    /// Long-term persistent memory (facts, preferences, learned info)
137    LongTerm,
138    /// Episodic memory (specific events and experiences)
139    Episodic,
140    /// Semantic memory (concepts, relationships, general knowledge)
141    Semantic,
142    /// Procedural memory (how to do things, workflows)
143    Procedural,
144    /// User preferences and settings
145    Preference,
146    /// Cached computation results
147    Cache,
148}
149
150/// Source of memory entry
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "snake_case")]
153pub enum MemorySource {
154    /// From a conversation message
155    Conversation {
156        session_id: String,
157        message_id: String,
158    },
159    /// From a document/file
160    Document { path: String, chunk_index: u32 },
161    /// Direct user input/instruction
162    UserInput,
163    /// Agent reasoning/reflection
164    AgentReasoning { agent_id: String },
165    /// External API or tool result
166    ToolResult { tool_name: String },
167    /// Web page or URL
168    WebPage { url: String },
169    /// System-generated summary
170    Summary { source_ids: Vec<String> },
171    /// Custom source
172    Custom { source_type: String },
173}
174
175// =============================================================================
176// Vector Store
177// =============================================================================
178
179/// Configuration for the vector store
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VectorStoreConfig {
182    /// Embedding model to use
183    pub embedding_model: EmbeddingModel,
184    /// Dimension of embeddings
185    pub embedding_dim: usize,
186    /// Similarity metric
187    pub similarity_metric: SimilarityMetric,
188    /// Maximum entries before pruning
189    pub max_entries: usize,
190    /// Database path
191    pub db_path: Option<String>,
192}
193
194impl Default for VectorStoreConfig {
195    fn default() -> Self {
196        Self {
197            embedding_model: EmbeddingModel::default(),
198            embedding_dim: 384,
199            similarity_metric: SimilarityMetric::Cosine,
200            max_entries: 100_000,
201            db_path: None,
202        }
203    }
204}
205
206/// Embedding model options
207#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum EmbeddingModel {
210    /// OpenAI text-embedding-3-small (1536 dims)
211    OpenAISmall,
212    /// OpenAI text-embedding-3-large (3072 dims)
213    OpenAILarge,
214    /// OpenAI text-embedding-ada-002 (1536 dims)
215    OpenAIAda,
216    /// Sentence Transformers all-MiniLM-L6-v2 (384 dims)
217    #[default]
218    MiniLM,
219    /// Sentence Transformers all-mpnet-base-v2 (768 dims)
220    MPNet,
221    /// Cohere embed-english-v3.0 (1024 dims)
222    Cohere,
223    /// Google text-embedding-004 (768 dims)
224    GoogleGecko,
225    /// Voyage AI voyage-2 (1024 dims)
226    Voyage,
227    /// Local model via Ollama
228    Ollama { model: String },
229    /// Custom model
230    Custom { name: String, dim: usize },
231}
232
233impl EmbeddingModel {
234    /// Get the dimension for this model
235    pub fn dimension(&self) -> usize {
236        match self {
237            EmbeddingModel::OpenAISmall => 1536,
238            EmbeddingModel::OpenAILarge => 3072,
239            EmbeddingModel::OpenAIAda => 1536,
240            EmbeddingModel::MiniLM => 384,
241            EmbeddingModel::MPNet => 768,
242            EmbeddingModel::Cohere => 1024,
243            EmbeddingModel::GoogleGecko => 768,
244            EmbeddingModel::Voyage => 1024,
245            EmbeddingModel::Ollama { .. } => 4096, // Typical for Ollama models
246            EmbeddingModel::Custom { dim, .. } => *dim,
247        }
248    }
249}
250
251/// Similarity metrics for vector search
252#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum SimilarityMetric {
255    #[default]
256    Cosine,
257    Euclidean,
258    DotProduct,
259    Manhattan,
260}
261
262impl SimilarityMetric {
263    /// Calculate similarity between two vectors
264    pub fn calculate(&self, a: &[f32], b: &[f32]) -> f32 {
265        assert_eq!(a.len(), b.len(), "Vector dimensions must match");
266
267        match self {
268            SimilarityMetric::Cosine => {
269                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
270                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
271                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
272                if norm_a == 0.0 || norm_b == 0.0 {
273                    0.0
274                } else {
275                    dot / (norm_a * norm_b)
276                }
277            }
278            SimilarityMetric::Euclidean => {
279                let dist: f32 = a
280                    .iter()
281                    .zip(b.iter())
282                    .map(|(x, y)| (x - y).powi(2))
283                    .sum::<f32>()
284                    .sqrt();
285                1.0 / (1.0 + dist) // Convert distance to similarity
286            }
287            SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
288            SimilarityMetric::Manhattan => {
289                let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
290                1.0 / (1.0 + dist)
291            }
292        }
293    }
294}
295
296/// Search result from vector store
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SearchResult {
299    /// The memory entry
300    pub entry: MemoryEntry,
301    /// Similarity score (0.0 - 1.0)
302    pub score: f32,
303    /// Rank in results
304    pub rank: usize,
305}
306
307/// Vector store for semantic search
308pub struct VectorStore {
309    config: VectorStoreConfig,
310    entries: Vec<MemoryEntry>,
311    db: Option<rusqlite::Connection>,
312}
313
314impl VectorStore {
315    /// Create a new in-memory vector store
316    pub fn new(config: VectorStoreConfig) -> Self {
317        Self {
318            config,
319            entries: Vec::new(),
320            db: None,
321        }
322    }
323
324    /// Create a vector store with SQLite persistence
325    pub fn with_persistence(
326        config: VectorStoreConfig,
327        db_path: impl AsRef<Path>,
328    ) -> Result<Self, MemoryError> {
329        let db = rusqlite::Connection::open(db_path.as_ref())
330            .map_err(|e| MemoryError::Database(e.to_string()))?;
331
332        // Initialize schema
333        db.execute_batch(
334            r#"
335            CREATE TABLE IF NOT EXISTS memory_entries (
336                id TEXT PRIMARY KEY,
337                content TEXT NOT NULL,
338                embedding BLOB,
339                memory_type TEXT NOT NULL,
340                source TEXT NOT NULL,
341                importance REAL NOT NULL,
342                access_count INTEGER NOT NULL DEFAULT 0,
343                last_accessed TEXT NOT NULL,
344                created_at TEXT NOT NULL,
345                expires_at TEXT,
346                agent_id TEXT,
347                session_id TEXT,
348                metadata TEXT,
349                tags TEXT
350            );
351            
352            CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(memory_type);
353            CREATE INDEX IF NOT EXISTS idx_agent_id ON memory_entries(agent_id);
354            CREATE INDEX IF NOT EXISTS idx_session_id ON memory_entries(session_id);
355            CREATE INDEX IF NOT EXISTS idx_created_at ON memory_entries(created_at);
356            CREATE INDEX IF NOT EXISTS idx_importance ON memory_entries(importance DESC);
357        "#,
358        )
359        .map_err(|e| MemoryError::Database(e.to_string()))?;
360
361        let mut store = Self {
362            config,
363            entries: Vec::new(),
364            db: Some(db),
365        };
366
367        store.load_from_db()?;
368        Ok(store)
369    }
370
371    /// Load entries from database
372    fn load_from_db(&mut self) -> Result<(), MemoryError> {
373        if let Some(ref db) = self.db {
374            let mut stmt = db
375                .prepare(
376                    "SELECT id, content, embedding, memory_type, source, importance, 
377                        access_count, last_accessed, created_at, expires_at, 
378                        agent_id, session_id, metadata, tags 
379                 FROM memory_entries 
380                 ORDER BY importance DESC, created_at DESC",
381                )
382                .map_err(|e| MemoryError::Database(e.to_string()))?;
383
384            let entries = stmt
385                .query_map([], |row| {
386                    let embedding_blob: Option<Vec<u8>> = row.get(2)?;
387                    let embedding = embedding_blob.map(|blob| {
388                        blob.chunks(4)
389                            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4])))
390                            .collect()
391                    });
392
393                    Ok(MemoryEntry {
394                        id: row.get(0)?,
395                        content: row.get(1)?,
396                        embedding,
397                        memory_type: serde_json::from_str(&row.get::<_, String>(3)?)
398                            .unwrap_or(MemoryType::LongTerm),
399                        source: serde_json::from_str(&row.get::<_, String>(4)?)
400                            .unwrap_or(MemorySource::UserInput),
401                        importance: row.get(5)?,
402                        access_count: row.get(6)?,
403                        last_accessed: row
404                            .get::<_, String>(7)?
405                            .parse()
406                            .unwrap_or_else(|_| Utc::now()),
407                        created_at: row
408                            .get::<_, String>(8)?
409                            .parse()
410                            .unwrap_or_else(|_| Utc::now()),
411                        expires_at: row
412                            .get::<_, Option<String>>(9)?
413                            .and_then(|s| s.parse().ok()),
414                        agent_id: row.get(10)?,
415                        session_id: row.get(11)?,
416                        metadata: row
417                            .get::<_, Option<String>>(12)?
418                            .and_then(|s| serde_json::from_str(&s).ok())
419                            .unwrap_or_default(),
420                        tags: row
421                            .get::<_, Option<String>>(13)?
422                            .and_then(|s| serde_json::from_str(&s).ok())
423                            .unwrap_or_default(),
424                    })
425                })
426                .map_err(|e| MemoryError::Database(e.to_string()))?;
427
428            self.entries = entries.filter_map(|e| e.ok()).collect();
429        }
430        Ok(())
431    }
432
433    /// Add a memory entry
434    pub fn add(&mut self, entry: MemoryEntry) -> Result<MemoryId, MemoryError> {
435        let id = entry.id.clone();
436
437        // Persist to database if available
438        if let Some(ref db) = self.db {
439            let embedding_blob: Option<Vec<u8>> = entry
440                .embedding
441                .as_ref()
442                .map(|emb| emb.iter().flat_map(|f| f.to_le_bytes()).collect());
443
444            db.execute(
445                "INSERT OR REPLACE INTO memory_entries 
446                 (id, content, embedding, memory_type, source, importance, 
447                  access_count, last_accessed, created_at, expires_at, 
448                  agent_id, session_id, metadata, tags)
449                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)",
450                rusqlite::params![
451                    entry.id,
452                    entry.content,
453                    embedding_blob,
454                    serde_json::to_string(&entry.memory_type).unwrap_or_default(),
455                    serde_json::to_string(&entry.source).unwrap_or_default(),
456                    entry.importance,
457                    entry.access_count,
458                    entry.last_accessed.to_rfc3339(),
459                    entry.created_at.to_rfc3339(),
460                    entry.expires_at.map(|e| e.to_rfc3339()),
461                    entry.agent_id,
462                    entry.session_id,
463                    serde_json::to_string(&entry.metadata).ok(),
464                    serde_json::to_string(&entry.tags).ok(),
465                ],
466            )
467            .map_err(|e| MemoryError::Database(e.to_string()))?;
468        }
469
470        self.entries.push(entry);
471
472        // Prune if needed
473        if self.entries.len() > self.config.max_entries {
474            self.prune()?;
475        }
476
477        Ok(id)
478    }
479
480    /// Search for similar entries
481    pub fn search(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
482        let mut results: Vec<(usize, f32)> = self
483            .entries
484            .iter()
485            .enumerate()
486            .filter(|(_, e)| !e.is_expired() && e.embedding.is_some())
487            .map(|(i, e)| {
488                let score = self
489                    .config
490                    .similarity_metric
491                    .calculate(query_embedding, e.embedding.as_ref().unwrap());
492                (i, score)
493            })
494            .collect();
495
496        // Sort by score descending
497        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
498
499        // Take top results and update access counts
500        results
501            .into_iter()
502            .take(limit)
503            .enumerate()
504            .map(|(rank, (idx, score))| {
505                self.entries[idx].access_count += 1;
506                self.entries[idx].last_accessed = Utc::now();
507
508                SearchResult {
509                    entry: self.entries[idx].clone(),
510                    score,
511                    rank,
512                }
513            })
514            .collect()
515    }
516
517    /// Search by memory type
518    pub fn search_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
519        self.entries
520            .iter()
521            .filter(|e| e.memory_type == memory_type && !e.is_expired())
522            .take(limit)
523            .collect()
524    }
525
526    /// Search by tags
527    pub fn search_by_tags(&self, tags: &[String], limit: usize) -> Vec<&MemoryEntry> {
528        self.entries
529            .iter()
530            .filter(|e| !e.is_expired() && tags.iter().any(|t| e.tags.contains(t)))
531            .take(limit)
532            .collect()
533    }
534
535    /// Get entry by ID
536    pub fn get(&self, id: &str) -> Option<&MemoryEntry> {
537        self.entries.iter().find(|e| e.id == id)
538    }
539
540    /// Delete entry
541    pub fn delete(&mut self, id: &str) -> Result<bool, MemoryError> {
542        if let Some(pos) = self.entries.iter().position(|e| e.id == id) {
543            self.entries.remove(pos);
544
545            if let Some(ref db) = self.db {
546                db.execute("DELETE FROM memory_entries WHERE id = ?1", [id])
547                    .map_err(|e| MemoryError::Database(e.to_string()))?;
548            }
549
550            Ok(true)
551        } else {
552            Ok(false)
553        }
554    }
555
556    /// Prune old/low-importance entries
557    fn prune(&mut self) -> Result<(), MemoryError> {
558        // Remove expired entries
559        self.entries.retain(|e| !e.is_expired());
560
561        // If still over limit, remove lowest importance entries
562        if self.entries.len() > self.config.max_entries {
563            self.entries.sort_by(|a, b| {
564                b.importance
565                    .partial_cmp(&a.importance)
566                    .unwrap_or(std::cmp::Ordering::Equal)
567            });
568            self.entries.truncate(self.config.max_entries);
569        }
570
571        Ok(())
572    }
573
574    /// Get statistics
575    pub fn stats(&self) -> VectorStoreStats {
576        VectorStoreStats {
577            total_entries: self.entries.len(),
578            entries_by_type: self.entries.iter().fold(HashMap::new(), |mut acc, e| {
579                *acc.entry(format!("{:?}", e.memory_type)).or_insert(0) += 1;
580                acc
581            }),
582            total_access_count: self.entries.iter().map(|e| e.access_count).sum(),
583            avg_importance: if self.entries.is_empty() {
584                0.0
585            } else {
586                self.entries.iter().map(|e| e.importance).sum::<f32>() / self.entries.len() as f32
587            },
588        }
589    }
590}
591
592/// Vector store statistics
593#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct VectorStoreStats {
595    pub total_entries: usize,
596    pub entries_by_type: HashMap<String, usize>,
597    pub total_access_count: u64,
598    pub avg_importance: f32,
599}
600
601// =============================================================================
602// Knowledge Base / RAG
603// =============================================================================
604
605/// Document for the knowledge base
606#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct Document {
608    /// Unique identifier
609    pub id: String,
610    /// Document title
611    pub title: String,
612    /// Full content
613    pub content: String,
614    /// Document type
615    pub doc_type: DocumentType,
616    /// Source URL or path
617    pub source: String,
618    /// Chunked content for embedding
619    pub chunks: Vec<DocumentChunk>,
620    /// Creation timestamp
621    pub created_at: DateTime<Utc>,
622    /// Last updated
623    pub updated_at: DateTime<Utc>,
624    /// Metadata
625    pub metadata: HashMap<String, serde_json::Value>,
626}
627
628/// Document types
629#[derive(Debug, Clone, Serialize, Deserialize)]
630#[serde(rename_all = "snake_case")]
631pub enum DocumentType {
632    Text,
633    Markdown,
634    Code { language: String },
635    Html,
636    Pdf,
637    Json,
638    Yaml,
639    Csv,
640    Custom { mime_type: String },
641}
642
643/// A chunk of a document
644#[derive(Debug, Clone, Serialize, Deserialize)]
645pub struct DocumentChunk {
646    /// Chunk index
647    pub index: u32,
648    /// Chunk content
649    pub content: String,
650    /// Start position in original document
651    pub start_pos: usize,
652    /// End position in original document
653    pub end_pos: usize,
654    /// Vector embedding
655    pub embedding: Option<Embedding>,
656    /// Token count estimate
657    pub token_count: u32,
658}
659
660/// Configuration for document chunking
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ChunkingConfig {
663    /// Target chunk size in tokens
664    pub chunk_size: usize,
665    /// Overlap between chunks in tokens
666    pub chunk_overlap: usize,
667    /// Chunking strategy
668    pub strategy: ChunkingStrategy,
669}
670
671impl Default for ChunkingConfig {
672    fn default() -> Self {
673        Self {
674            chunk_size: 512,
675            chunk_overlap: 50,
676            strategy: ChunkingStrategy::Semantic,
677        }
678    }
679}
680
681/// Chunking strategies
682#[derive(Debug, Clone, Default, Serialize, Deserialize)]
683#[serde(rename_all = "snake_case")]
684pub enum ChunkingStrategy {
685    /// Fixed-size character chunks
686    FixedSize,
687    /// Split on sentences
688    Sentence,
689    /// Split on paragraphs
690    Paragraph,
691    /// Semantic chunking (respects structure)
692    #[default]
693    Semantic,
694    /// Code-aware chunking
695    Code,
696}
697
698/// Knowledge base for RAG
699pub struct KnowledgeBase {
700    /// Vector store for semantic search
701    vector_store: VectorStore,
702    /// Documents
703    documents: HashMap<String, Document>,
704    /// Chunking configuration
705    chunking_config: ChunkingConfig,
706}
707
708impl KnowledgeBase {
709    /// Create a new knowledge base
710    pub fn new(vector_config: VectorStoreConfig) -> Self {
711        Self {
712            vector_store: VectorStore::new(vector_config),
713            documents: HashMap::new(),
714            chunking_config: ChunkingConfig::default(),
715        }
716    }
717
718    /// Create with persistence
719    pub fn with_persistence(
720        vector_config: VectorStoreConfig,
721        db_path: impl AsRef<Path>,
722    ) -> Result<Self, MemoryError> {
723        Ok(Self {
724            vector_store: VectorStore::with_persistence(vector_config, db_path)?,
725            documents: HashMap::new(),
726            chunking_config: ChunkingConfig::default(),
727        })
728    }
729
730    /// Add a document
731    pub fn add_document(&mut self, mut document: Document) -> Result<String, MemoryError> {
732        // Chunk the document
733        document.chunks = self.chunk_document(&document.content);
734
735        let doc_id = document.id.clone();
736
737        // Add chunks to vector store
738        for chunk in &document.chunks {
739            if let Some(ref embedding) = chunk.embedding {
740                let entry = MemoryEntry::new(
741                    &chunk.content,
742                    MemoryType::Semantic,
743                    MemorySource::Document {
744                        path: document.source.clone(),
745                        chunk_index: chunk.index,
746                    },
747                )
748                .with_embedding(embedding.clone())
749                .with_tag(format!("doc:{}", doc_id));
750
751                self.vector_store.add(entry)?;
752            }
753        }
754
755        self.documents.insert(doc_id.clone(), document);
756        Ok(doc_id)
757    }
758
759    /// Chunk a document
760    fn chunk_document(&self, content: &str) -> Vec<DocumentChunk> {
761        match self.chunking_config.strategy {
762            ChunkingStrategy::Semantic => self.semantic_chunk(content),
763            ChunkingStrategy::Paragraph => self.paragraph_chunk(content),
764            ChunkingStrategy::Sentence => self.sentence_chunk(content),
765            ChunkingStrategy::FixedSize => self.fixed_chunk(content),
766            ChunkingStrategy::Code => self.code_chunk(content),
767        }
768    }
769
770    fn semantic_chunk(&self, content: &str) -> Vec<DocumentChunk> {
771        // Split on double newlines (paragraphs) but respect size limits
772        let mut chunks = Vec::new();
773        let mut current_chunk = String::new();
774        let mut start_pos = 0;
775        let mut chunk_index = 0;
776
777        for para in content.split("\n\n") {
778            let para = para.trim();
779            if para.is_empty() {
780                continue;
781            }
782
783            let para_tokens = estimate_tokens(para);
784            let current_tokens = estimate_tokens(&current_chunk);
785
786            if current_tokens + para_tokens > self.chunking_config.chunk_size
787                && !current_chunk.is_empty()
788            {
789                // Save current chunk
790                let end_pos = start_pos + current_chunk.len();
791                chunks.push(DocumentChunk {
792                    index: chunk_index,
793                    content: current_chunk.trim().to_string(),
794                    start_pos,
795                    end_pos,
796                    embedding: None,
797                    token_count: estimate_tokens(&current_chunk) as u32,
798                });
799                chunk_index += 1;
800                start_pos = end_pos;
801                current_chunk = String::new();
802            }
803
804            if !current_chunk.is_empty() {
805                current_chunk.push_str("\n\n");
806            }
807            current_chunk.push_str(para);
808        }
809
810        // Add remaining content
811        if !current_chunk.is_empty() {
812            let end_pos = start_pos + current_chunk.len();
813            chunks.push(DocumentChunk {
814                index: chunk_index,
815                content: current_chunk.trim().to_string(),
816                start_pos,
817                end_pos,
818                embedding: None,
819                token_count: estimate_tokens(&current_chunk) as u32,
820            });
821        }
822
823        chunks
824    }
825
826    fn paragraph_chunk(&self, content: &str) -> Vec<DocumentChunk> {
827        content
828            .split("\n\n")
829            .filter(|p| !p.trim().is_empty())
830            .enumerate()
831            .scan(0usize, |pos, (i, para)| {
832                let start = *pos;
833                *pos += para.len() + 2;
834                Some(DocumentChunk {
835                    index: i as u32,
836                    content: para.trim().to_string(),
837                    start_pos: start,
838                    end_pos: *pos,
839                    embedding: None,
840                    token_count: estimate_tokens(para) as u32,
841                })
842            })
843            .collect()
844    }
845
846    fn sentence_chunk(&self, content: &str) -> Vec<DocumentChunk> {
847        // Simple sentence splitting (could be improved with NLP)
848        let sentences: Vec<&str> = content
849            .split(['.', '!', '?'])
850            .filter(|s| !s.trim().is_empty())
851            .collect();
852
853        let mut chunks = Vec::new();
854        let mut current = String::new();
855        let mut start = 0;
856        let mut idx = 0;
857
858        for sentence in sentences {
859            let sentence = sentence.trim();
860            if estimate_tokens(&current) + estimate_tokens(sentence)
861                > self.chunking_config.chunk_size
862                && !current.is_empty()
863            {
864                chunks.push(DocumentChunk {
865                    index: idx,
866                    content: current.clone(),
867                    start_pos: start,
868                    end_pos: start + current.len(),
869                    embedding: None,
870                    token_count: estimate_tokens(&current) as u32,
871                });
872                idx += 1;
873                start += current.len();
874                current.clear();
875            }
876            if !current.is_empty() {
877                current.push(' ');
878            }
879            current.push_str(sentence);
880            current.push('.');
881        }
882
883        if !current.is_empty() {
884            chunks.push(DocumentChunk {
885                index: idx,
886                content: current.clone(),
887                start_pos: start,
888                end_pos: start + current.len(),
889                embedding: None,
890                token_count: estimate_tokens(&current) as u32,
891            });
892        }
893
894        chunks
895    }
896
897    fn fixed_chunk(&self, content: &str) -> Vec<DocumentChunk> {
898        let chars_per_chunk = self.chunking_config.chunk_size * 4; // Rough estimate
899        content
900            .chars()
901            .collect::<Vec<_>>()
902            .chunks(chars_per_chunk)
903            .enumerate()
904            .map(|(i, chars)| {
905                let s: String = chars.iter().collect();
906                DocumentChunk {
907                    index: i as u32,
908                    content: s.clone(),
909                    start_pos: i * chars_per_chunk,
910                    end_pos: (i + 1) * chars_per_chunk,
911                    embedding: None,
912                    token_count: estimate_tokens(&s) as u32,
913                }
914            })
915            .collect()
916    }
917
918    fn code_chunk(&self, content: &str) -> Vec<DocumentChunk> {
919        // Split on function/class definitions (simple heuristic)
920        let mut chunks = Vec::new();
921        let mut current = String::new();
922        let mut start = 0;
923        let mut idx = 0;
924
925        for line in content.lines() {
926            let is_boundary = line.starts_with("fn ")
927                || line.starts_with("pub fn ")
928                || line.starts_with("async fn ")
929                || line.starts_with("impl ")
930                || line.starts_with("struct ")
931                || line.starts_with("enum ")
932                || line.starts_with("trait ")
933                || line.starts_with("class ")
934                || line.starts_with("def ")
935                || line.starts_with("function ")
936                || line.starts_with("const ")
937                || line.starts_with("export ");
938
939            if is_boundary && !current.is_empty() {
940                chunks.push(DocumentChunk {
941                    index: idx,
942                    content: current.clone(),
943                    start_pos: start,
944                    end_pos: start + current.len(),
945                    embedding: None,
946                    token_count: estimate_tokens(&current) as u32,
947                });
948                idx += 1;
949                start += current.len();
950                current.clear();
951            }
952
953            current.push_str(line);
954            current.push('\n');
955        }
956
957        if !current.is_empty() {
958            chunks.push(DocumentChunk {
959                index: idx,
960                content: current.clone(),
961                start_pos: start,
962                end_pos: start + current.len(),
963                embedding: None,
964                token_count: estimate_tokens(&current) as u32,
965            });
966        }
967
968        chunks
969    }
970
971    /// Retrieve relevant context for a query
972    pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
973        self.vector_store.search(query_embedding, limit)
974    }
975
976    /// Get document by ID
977    pub fn get_document(&self, id: &str) -> Option<&Document> {
978        self.documents.get(id)
979    }
980
981    /// List all documents
982    pub fn list_documents(&self) -> Vec<&Document> {
983        self.documents.values().collect()
984    }
985
986    /// Delete document
987    pub fn delete_document(&mut self, id: &str) -> bool {
988        self.documents.remove(id).is_some()
989    }
990}
991
992// =============================================================================
993// Context Window Manager
994// =============================================================================
995
996/// Manages context for LLM prompts
997#[derive(Debug, Clone)]
998pub struct ContextWindow {
999    /// Maximum tokens for context
1000    pub max_tokens: usize,
1001    /// Reserved tokens for response
1002    pub reserved_for_response: usize,
1003    /// Context segments
1004    segments: Vec<ContextSegment>,
1005}
1006
1007/// A segment of context
1008#[derive(Debug, Clone, Serialize, Deserialize)]
1009pub struct ContextSegment {
1010    /// Segment type
1011    pub segment_type: ContextSegmentType,
1012    /// Content
1013    pub content: String,
1014    /// Token count
1015    pub tokens: usize,
1016    /// Priority (higher = more important)
1017    pub priority: u32,
1018    /// Whether this segment is required
1019    pub required: bool,
1020}
1021
1022/// Types of context segments
1023#[derive(Debug, Clone, Serialize, Deserialize)]
1024#[serde(rename_all = "snake_case")]
1025pub enum ContextSegmentType {
1026    SystemPrompt,
1027    UserPreferences,
1028    ConversationHistory,
1029    RetrievedContext,
1030    ToolResults,
1031    CurrentQuery,
1032    Custom { name: String },
1033}
1034
1035impl ContextWindow {
1036    /// Create a new context window
1037    pub fn new(max_tokens: usize) -> Self {
1038        Self {
1039            max_tokens,
1040            reserved_for_response: max_tokens / 4, // Reserve 25% for response
1041            segments: Vec::new(),
1042        }
1043    }
1044
1045    /// Add a context segment
1046    pub fn add_segment(&mut self, segment: ContextSegment) {
1047        self.segments.push(segment);
1048    }
1049
1050    /// Build the final context, respecting token limits
1051    pub fn build(&mut self) -> String {
1052        let available = self.max_tokens - self.reserved_for_response;
1053
1054        // Sort by priority (required first, then by priority)
1055        self.segments
1056            .sort_by(|a, b| match (a.required, b.required) {
1057                (true, false) => std::cmp::Ordering::Less,
1058                (false, true) => std::cmp::Ordering::Greater,
1059                _ => b.priority.cmp(&a.priority),
1060            });
1061
1062        let mut total_tokens = 0;
1063        let mut result = Vec::new();
1064
1065        for segment in &self.segments {
1066            if total_tokens + segment.tokens <= available {
1067                result.push(segment.content.clone());
1068                total_tokens += segment.tokens;
1069            } else if segment.required {
1070                // Truncate if required
1071                let remaining = available.saturating_sub(total_tokens);
1072                if remaining > 0 {
1073                    let truncated = truncate_to_tokens(&segment.content, remaining);
1074                    result.push(truncated);
1075                    break;
1076                }
1077            }
1078        }
1079
1080        result.join("\n\n")
1081    }
1082
1083    /// Get current token usage
1084    pub fn token_usage(&self) -> (usize, usize) {
1085        let used: usize = self.segments.iter().map(|s| s.tokens).sum();
1086        (used, self.max_tokens - self.reserved_for_response)
1087    }
1088}
1089
1090// =============================================================================
1091// Cache
1092// =============================================================================
1093
1094/// Cache entry
1095#[derive(Debug, Clone, Serialize, Deserialize)]
1096pub struct CacheEntry<T> {
1097    pub key: String,
1098    pub value: T,
1099    pub created_at: DateTime<Utc>,
1100    pub expires_at: Option<DateTime<Utc>>,
1101    pub access_count: u64,
1102}
1103
1104impl<T> CacheEntry<T> {
1105    pub fn is_expired(&self) -> bool {
1106        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
1107    }
1108}
1109
1110/// Simple LRU cache for agent computations
1111pub struct AgentCache<T> {
1112    entries: HashMap<String, CacheEntry<T>>,
1113    max_size: usize,
1114}
1115
1116impl<T: Clone> AgentCache<T> {
1117    /// Create a new cache
1118    pub fn new(max_size: usize) -> Self {
1119        Self {
1120            entries: HashMap::new(),
1121            max_size,
1122        }
1123    }
1124
1125    /// Get a value from cache
1126    pub fn get(&mut self, key: &str) -> Option<T> {
1127        if let Some(entry) = self.entries.get_mut(key) {
1128            if entry.is_expired() {
1129                self.entries.remove(key);
1130                return None;
1131            }
1132            entry.access_count += 1;
1133            Some(entry.value.clone())
1134        } else {
1135            None
1136        }
1137    }
1138
1139    /// Set a value in cache
1140    pub fn set(&mut self, key: impl Into<String>, value: T, ttl: Option<chrono::Duration>) {
1141        let key = key.into();
1142        let now = Utc::now();
1143
1144        self.entries.insert(
1145            key.clone(),
1146            CacheEntry {
1147                key,
1148                value,
1149                created_at: now,
1150                expires_at: ttl.map(|d| now + d),
1151                access_count: 0,
1152            },
1153        );
1154
1155        // Evict if over size
1156        if self.entries.len() > self.max_size {
1157            self.evict_lru();
1158        }
1159    }
1160
1161    /// Remove a value
1162    pub fn remove(&mut self, key: &str) -> Option<T> {
1163        self.entries.remove(key).map(|e| e.value)
1164    }
1165
1166    /// Clear the cache
1167    pub fn clear(&mut self) {
1168        self.entries.clear();
1169    }
1170
1171    /// Evict least recently used entries
1172    fn evict_lru(&mut self) {
1173        // Remove expired first
1174        self.entries.retain(|_, v| !v.is_expired());
1175
1176        // If still over, remove least accessed
1177        if self.entries.len() > self.max_size {
1178            // Collect keys to remove (sorted by access count)
1179            let mut entries: Vec<_> = self
1180                .entries
1181                .iter()
1182                .map(|(k, v)| (k.clone(), v.access_count))
1183                .collect();
1184            entries.sort_by_key(|(_, count)| *count);
1185
1186            let to_remove = self.entries.len() - self.max_size;
1187            let keys_to_remove: Vec<String> = entries
1188                .into_iter()
1189                .take(to_remove)
1190                .map(|(k, _)| k)
1191                .collect();
1192
1193            for key in keys_to_remove {
1194                self.entries.remove(&key);
1195            }
1196        }
1197    }
1198}
1199
1200// =============================================================================
1201// Memory Manager (Unified Interface)
1202// =============================================================================
1203
1204/// Configuration for the memory manager
1205#[derive(Debug, Clone, Serialize, Deserialize)]
1206pub struct MemoryConfig {
1207    /// Vector store configuration
1208    pub vector_store: VectorStoreConfig,
1209    /// Chunking configuration
1210    pub chunking: ChunkingConfig,
1211    /// Context window size
1212    pub context_window_tokens: usize,
1213    /// Cache size
1214    pub cache_size: usize,
1215    /// Database path (None for in-memory)
1216    pub db_path: Option<String>,
1217    /// Auto-summarize long conversations
1218    pub auto_summarize: bool,
1219    /// Summarize after this many messages
1220    pub summarize_threshold: usize,
1221}
1222
1223impl Default for MemoryConfig {
1224    fn default() -> Self {
1225        Self {
1226            vector_store: VectorStoreConfig::default(),
1227            chunking: ChunkingConfig::default(),
1228            context_window_tokens: 8192,
1229            cache_size: 1000,
1230            db_path: None,
1231            auto_summarize: true,
1232            summarize_threshold: 20,
1233        }
1234    }
1235}
1236
1237/// Unified memory manager for agents
1238pub struct MemoryManager {
1239    config: MemoryConfig,
1240    vector_store: VectorStore,
1241    knowledge_base: KnowledgeBase,
1242    cache: AgentCache<String>,
1243}
1244
1245impl MemoryManager {
1246    /// Create a new memory manager
1247    pub fn new(config: MemoryConfig) -> Result<Self, MemoryError> {
1248        let vector_store = if let Some(ref path) = config.db_path {
1249            VectorStore::with_persistence(config.vector_store.clone(), path)?
1250        } else {
1251            VectorStore::new(config.vector_store.clone())
1252        };
1253
1254        let knowledge_base = if let Some(ref path) = config.db_path {
1255            let kb_path = format!("{}_kb", path);
1256            KnowledgeBase::with_persistence(config.vector_store.clone(), kb_path)?
1257        } else {
1258            KnowledgeBase::new(config.vector_store.clone())
1259        };
1260
1261        Ok(Self {
1262            config: config.clone(),
1263            vector_store,
1264            knowledge_base,
1265            cache: AgentCache::new(config.cache_size),
1266        })
1267    }
1268
1269    /// Store a memory
1270    pub fn remember(
1271        &mut self,
1272        content: impl Into<String>,
1273        memory_type: MemoryType,
1274        source: MemorySource,
1275    ) -> Result<MemoryId, MemoryError> {
1276        let entry = MemoryEntry::new(content, memory_type, source);
1277        self.vector_store.add(entry)
1278    }
1279
1280    /// Store a memory with embedding
1281    pub fn remember_with_embedding(
1282        &mut self,
1283        content: impl Into<String>,
1284        embedding: Embedding,
1285        memory_type: MemoryType,
1286        source: MemorySource,
1287    ) -> Result<MemoryId, MemoryError> {
1288        let entry = MemoryEntry::new(content, memory_type, source).with_embedding(embedding);
1289        self.vector_store.add(entry)
1290    }
1291
1292    /// Recall memories similar to a query
1293    pub fn recall(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1294        self.vector_store.search(query_embedding, limit)
1295    }
1296
1297    /// Recall by type
1298    pub fn recall_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
1299        self.vector_store.search_by_type(memory_type, limit)
1300    }
1301
1302    /// Add document to knowledge base
1303    pub fn add_document(&mut self, document: Document) -> Result<String, MemoryError> {
1304        self.knowledge_base.add_document(document)
1305    }
1306
1307    /// Retrieve from knowledge base
1308    pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1309        self.knowledge_base.retrieve(query_embedding, limit)
1310    }
1311
1312    /// Build context for a prompt
1313    pub fn build_context(
1314        &mut self,
1315        query_embedding: &Embedding,
1316        system_prompt: &str,
1317        conversation: &[String],
1318    ) -> String {
1319        let mut context = ContextWindow::new(self.config.context_window_tokens);
1320
1321        // System prompt (required)
1322        context.add_segment(ContextSegment {
1323            segment_type: ContextSegmentType::SystemPrompt,
1324            content: system_prompt.to_string(),
1325            tokens: estimate_tokens(system_prompt),
1326            priority: 100,
1327            required: true,
1328        });
1329
1330        // Retrieved context
1331        let retrieved = self.recall(query_embedding, 5);
1332        if !retrieved.is_empty() {
1333            let retrieved_text: String = retrieved
1334                .iter()
1335                .map(|r| format!("- {}", r.entry.content))
1336                .collect::<Vec<_>>()
1337                .join("\n");
1338
1339            context.add_segment(ContextSegment {
1340                segment_type: ContextSegmentType::RetrievedContext,
1341                content: format!("Relevant context:\n{}", retrieved_text),
1342                tokens: estimate_tokens(&retrieved_text) + 20,
1343                priority: 80,
1344                required: false,
1345            });
1346        }
1347
1348        // Conversation history
1349        let conv_text = conversation.join("\n");
1350        context.add_segment(ContextSegment {
1351            segment_type: ContextSegmentType::ConversationHistory,
1352            content: conv_text.clone(),
1353            tokens: estimate_tokens(&conv_text),
1354            priority: 90,
1355            required: false,
1356        });
1357
1358        context.build()
1359    }
1360
1361    /// Cache a computation result
1362    pub fn cache_result(
1363        &mut self,
1364        key: impl Into<String>,
1365        value: String,
1366        ttl: Option<chrono::Duration>,
1367    ) {
1368        self.cache.set(key, value, ttl);
1369    }
1370
1371    /// Get cached result
1372    pub fn get_cached(&mut self, key: &str) -> Option<String> {
1373        self.cache.get(key)
1374    }
1375
1376    /// Get statistics
1377    pub fn stats(&self) -> MemoryStats {
1378        MemoryStats {
1379            vector_store: self.vector_store.stats(),
1380            document_count: self.knowledge_base.list_documents().len(),
1381        }
1382    }
1383}
1384
1385/// Memory statistics
1386#[derive(Debug, Clone, Serialize, Deserialize)]
1387pub struct MemoryStats {
1388    pub vector_store: VectorStoreStats,
1389    pub document_count: usize,
1390}
1391
1392// =============================================================================
1393// Error Types
1394// =============================================================================
1395
1396/// Memory system errors
1397#[derive(Debug, Clone)]
1398pub enum MemoryError {
1399    /// Database error
1400    Database(String),
1401    /// Embedding error
1402    Embedding(String),
1403    /// Not found
1404    NotFound(String),
1405    /// Invalid input
1406    InvalidInput(String),
1407    /// IO error
1408    Io(String),
1409}
1410
1411impl std::fmt::Display for MemoryError {
1412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1413        match self {
1414            MemoryError::Database(e) => write!(f, "Database error: {}", e),
1415            MemoryError::Embedding(e) => write!(f, "Embedding error: {}", e),
1416            MemoryError::NotFound(e) => write!(f, "Not found: {}", e),
1417            MemoryError::InvalidInput(e) => write!(f, "Invalid input: {}", e),
1418            MemoryError::Io(e) => write!(f, "IO error: {}", e),
1419        }
1420    }
1421}
1422
1423impl std::error::Error for MemoryError {}
1424
1425// =============================================================================
1426// Utility Functions
1427// =============================================================================
1428
1429fn generate_memory_id() -> String {
1430    use std::time::{SystemTime, UNIX_EPOCH};
1431    let timestamp = SystemTime::now()
1432        .duration_since(UNIX_EPOCH)
1433        .unwrap()
1434        .as_nanos();
1435    format!("mem_{:x}", timestamp)
1436}
1437
1438/// Estimate token count (rough approximation: 4 chars per token)
1439fn estimate_tokens(text: &str) -> usize {
1440    (text.len() as f32 / 4.0).ceil() as usize
1441}
1442
1443/// Truncate text to approximate token count
1444fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
1445    let max_chars = max_tokens * 4;
1446    if text.len() <= max_chars {
1447        text.to_string()
1448    } else {
1449        format!("{}...", &text[..max_chars.min(text.len())])
1450    }
1451}
1452
1453// =============================================================================
1454// Embedding Provider Trait
1455// =============================================================================
1456
1457/// Trait for embedding providers
1458#[async_trait::async_trait]
1459pub trait EmbeddingProvider: Send + Sync {
1460    /// Generate embedding for text
1461    async fn embed(&self, text: &str) -> Result<Embedding, MemoryError>;
1462
1463    /// Generate embeddings for multiple texts
1464    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError>;
1465
1466    /// Get the embedding dimension
1467    fn dimension(&self) -> usize;
1468}
1469
1470/// OpenAI embedding provider
1471pub struct OpenAIEmbedding {
1472    #[allow(dead_code)]
1473    api_key: String,
1474    model: String,
1475}
1476
1477impl OpenAIEmbedding {
1478    pub fn new(api_key: impl Into<String>) -> Self {
1479        Self {
1480            api_key: api_key.into(),
1481            model: "text-embedding-3-small".to_string(),
1482        }
1483    }
1484
1485    pub fn with_model(mut self, model: impl Into<String>) -> Self {
1486        self.model = model.into();
1487        self
1488    }
1489}
1490
1491#[async_trait::async_trait]
1492impl EmbeddingProvider for OpenAIEmbedding {
1493    async fn embed(&self, _text: &str) -> Result<Embedding, MemoryError> {
1494        // Implementation would call OpenAI API
1495        // For now, return a placeholder
1496        Ok(vec![0.0; 1536])
1497    }
1498
1499    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError> {
1500        let mut results = Vec::new();
1501        for text in texts {
1502            results.push(self.embed(text).await?);
1503        }
1504        Ok(results)
1505    }
1506
1507    fn dimension(&self) -> usize {
1508        match self.model.as_str() {
1509            "text-embedding-3-large" => 3072,
1510            _ => 1536,
1511        }
1512    }
1513}
1514
1515// =============================================================================
1516// Tests
1517// =============================================================================
1518
1519#[cfg(test)]
1520mod tests {
1521    use super::*;
1522
1523    #[test]
1524    fn test_memory_entry_creation() {
1525        let entry = MemoryEntry::new(
1526            "Test content",
1527            MemoryType::LongTerm,
1528            MemorySource::UserInput,
1529        );
1530        assert!(!entry.id.is_empty());
1531        assert_eq!(entry.content, "Test content");
1532        assert_eq!(entry.memory_type, MemoryType::LongTerm);
1533    }
1534
1535    #[test]
1536    fn test_similarity_metrics() {
1537        let a = vec![1.0, 0.0, 0.0];
1538        let b = vec![1.0, 0.0, 0.0];
1539        let c = vec![0.0, 1.0, 0.0];
1540
1541        assert!((SimilarityMetric::Cosine.calculate(&a, &b) - 1.0).abs() < 0.001);
1542        assert!((SimilarityMetric::Cosine.calculate(&a, &c) - 0.0).abs() < 0.001);
1543    }
1544
1545    #[test]
1546    fn test_vector_store() {
1547        let config = VectorStoreConfig::default();
1548        let mut store = VectorStore::new(config);
1549
1550        let entry = MemoryEntry::new("Test", MemoryType::ShortTerm, MemorySource::UserInput)
1551            .with_embedding(vec![1.0, 0.0, 0.0]);
1552
1553        let id = store.add(entry).unwrap();
1554        assert!(!id.is_empty());
1555        assert!(store.get(&id).is_some());
1556    }
1557
1558    #[test]
1559    fn test_context_window() {
1560        let mut ctx = ContextWindow::new(1000);
1561
1562        ctx.add_segment(ContextSegment {
1563            segment_type: ContextSegmentType::SystemPrompt,
1564            content: "You are helpful".to_string(),
1565            tokens: 10,
1566            priority: 100,
1567            required: true,
1568        });
1569
1570        let result = ctx.build();
1571        assert!(result.contains("You are helpful"));
1572    }
1573
1574    #[test]
1575    fn test_cache() {
1576        let mut cache: AgentCache<String> = AgentCache::new(10);
1577
1578        cache.set("key1", "value1".to_string(), None);
1579        assert_eq!(cache.get("key1"), Some("value1".to_string()));
1580        assert_eq!(cache.get("key2"), None);
1581    }
1582
1583    #[test]
1584    fn test_estimate_tokens() {
1585        assert_eq!(estimate_tokens("hello"), 2); // 5 chars / 4 = 1.25 -> 2
1586        assert_eq!(estimate_tokens("hello world"), 3); // 11 chars / 4 = 2.75 -> 3
1587    }
1588}