chasm_cli/agency/
memory.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Memory and RAG (Retrieval-Augmented Generation) System
4//!
5//! Provides persistent context, knowledge retrieval, and semantic search for agents.
6//!
7//! ## Features
8//!
9//! - **Vector Store**: Semantic similarity search using embeddings
10//! - **Memory Types**: Short-term, long-term, episodic, and semantic memory
11//! - **Knowledge Base**: Structured document storage with chunking
12//! - **Context Window**: Smart context management for LLM prompts
13
14#![allow(dead_code)]
15//! - **Caching**: Frequently accessed information caching
16
17use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::path::Path;
21
22// =============================================================================
23// Core Types
24// =============================================================================
25
26/// Unique identifier for memory entries
27pub type MemoryId = String;
28
29/// Vector embedding (typically 384-1536 dimensions depending on model)
30pub type Embedding = Vec<f32>;
31
32/// Memory entry representing a piece of stored knowledge
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryEntry {
35    /// Unique identifier
36    pub id: MemoryId,
37    /// The content/text of this memory
38    pub content: String,
39    /// Vector embedding for similarity search
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub embedding: Option<Embedding>,
42    /// Memory type classification
43    pub memory_type: MemoryType,
44    /// Source of this memory (conversation, document, user input, etc.)
45    pub source: MemorySource,
46    /// Importance score (0.0 - 1.0)
47    pub importance: f32,
48    /// Access count for LRU caching
49    pub access_count: u64,
50    /// Last accessed timestamp
51    pub last_accessed: DateTime<Utc>,
52    /// Creation timestamp
53    pub created_at: DateTime<Utc>,
54    /// Optional expiration
55    pub expires_at: Option<DateTime<Utc>>,
56    /// Associated agent ID
57    pub agent_id: Option<String>,
58    /// Associated session ID
59    pub session_id: Option<String>,
60    /// Custom metadata
61    pub metadata: HashMap<String, serde_json::Value>,
62    /// Tags for filtering
63    pub tags: Vec<String>,
64}
65
66impl MemoryEntry {
67    /// Create a new memory entry
68    pub fn new(content: impl Into<String>, memory_type: MemoryType, source: MemorySource) -> Self {
69        let now = Utc::now();
70        Self {
71            id: generate_memory_id(),
72            content: content.into(),
73            embedding: None,
74            memory_type,
75            source,
76            importance: 0.5,
77            access_count: 0,
78            last_accessed: now,
79            created_at: now,
80            expires_at: None,
81            agent_id: None,
82            session_id: None,
83            metadata: HashMap::new(),
84            tags: Vec::new(),
85        }
86    }
87
88    /// Set the embedding
89    pub fn with_embedding(mut self, embedding: Embedding) -> Self {
90        self.embedding = Some(embedding);
91        self
92    }
93
94    /// Set importance score
95    pub fn with_importance(mut self, importance: f32) -> Self {
96        self.importance = importance.clamp(0.0, 1.0);
97        self
98    }
99
100    /// Set agent ID
101    pub fn with_agent(mut self, agent_id: impl Into<String>) -> Self {
102        self.agent_id = Some(agent_id.into());
103        self
104    }
105
106    /// Set session ID
107    pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
108        self.session_id = Some(session_id.into());
109        self
110    }
111
112    /// Add a tag
113    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
114        self.tags.push(tag.into());
115        self
116    }
117
118    /// Set expiration
119    pub fn expires_in(mut self, duration: chrono::Duration) -> Self {
120        self.expires_at = Some(Utc::now() + duration);
121        self
122    }
123
124    /// Check if expired
125    pub fn is_expired(&self) -> bool {
126        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
127    }
128}
129
130/// Types of memory
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum MemoryType {
134    /// Short-term working memory (current conversation context)
135    ShortTerm,
136    /// Long-term persistent memory (facts, preferences, learned info)
137    LongTerm,
138    /// Episodic memory (specific events and experiences)
139    Episodic,
140    /// Semantic memory (concepts, relationships, general knowledge)
141    Semantic,
142    /// Procedural memory (how to do things, workflows)
143    Procedural,
144    /// User preferences and settings
145    Preference,
146    /// Cached computation results
147    Cache,
148}
149
150/// Source of memory entry
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "snake_case")]
153pub enum MemorySource {
154    /// From a conversation message
155    Conversation {
156        session_id: String,
157        message_id: String,
158    },
159    /// From a document/file
160    Document { path: String, chunk_index: u32 },
161    /// Direct user input/instruction
162    UserInput,
163    /// Agent reasoning/reflection
164    AgentReasoning { agent_id: String },
165    /// External API or tool result
166    ToolResult { tool_name: String },
167    /// Web page or URL
168    WebPage { url: String },
169    /// System-generated summary
170    Summary { source_ids: Vec<String> },
171    /// Custom source
172    Custom { source_type: String },
173}
174
175// =============================================================================
176// Vector Store
177// =============================================================================
178
179/// Configuration for the vector store
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VectorStoreConfig {
182    /// Embedding model to use
183    pub embedding_model: EmbeddingModel,
184    /// Dimension of embeddings
185    pub embedding_dim: usize,
186    /// Similarity metric
187    pub similarity_metric: SimilarityMetric,
188    /// Maximum entries before pruning
189    pub max_entries: usize,
190    /// Database path
191    pub db_path: Option<String>,
192}
193
194impl Default for VectorStoreConfig {
195    fn default() -> Self {
196        Self {
197            embedding_model: EmbeddingModel::default(),
198            embedding_dim: 384,
199            similarity_metric: SimilarityMetric::Cosine,
200            max_entries: 100_000,
201            db_path: None,
202        }
203    }
204}
205
206/// Embedding model options
207#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum EmbeddingModel {
210    /// OpenAI text-embedding-3-small (1536 dims)
211    OpenAISmall,
212    /// OpenAI text-embedding-3-large (3072 dims)
213    OpenAILarge,
214    /// OpenAI text-embedding-ada-002 (1536 dims)
215    OpenAIAda,
216    /// Sentence Transformers all-MiniLM-L6-v2 (384 dims)
217    #[default]
218    MiniLM,
219    /// Sentence Transformers all-mpnet-base-v2 (768 dims)
220    MPNet,
221    /// Cohere embed-english-v3.0 (1024 dims)
222    Cohere,
223    /// Google text-embedding-004 (768 dims)
224    GoogleGecko,
225    /// Voyage AI voyage-2 (1024 dims)
226    Voyage,
227    /// Local model via Ollama
228    Ollama { model: String },
229    /// Custom model
230    Custom { name: String, dim: usize },
231}
232
233impl EmbeddingModel {
234    /// Get the dimension for this model
235    pub fn dimension(&self) -> usize {
236        match self {
237            EmbeddingModel::OpenAISmall => 1536,
238            EmbeddingModel::OpenAILarge => 3072,
239            EmbeddingModel::OpenAIAda => 1536,
240            EmbeddingModel::MiniLM => 384,
241            EmbeddingModel::MPNet => 768,
242            EmbeddingModel::Cohere => 1024,
243            EmbeddingModel::GoogleGecko => 768,
244            EmbeddingModel::Voyage => 1024,
245            EmbeddingModel::Ollama { .. } => 4096, // Typical for Ollama models
246            EmbeddingModel::Custom { dim, .. } => *dim,
247        }
248    }
249}
250
251/// Similarity metrics for vector search
252#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum SimilarityMetric {
255    #[default]
256    Cosine,
257    Euclidean,
258    DotProduct,
259    Manhattan,
260}
261
262impl SimilarityMetric {
263    /// Calculate similarity between two vectors
264    pub fn calculate(&self, a: &[f32], b: &[f32]) -> f32 {
265        assert_eq!(a.len(), b.len(), "Vector dimensions must match");
266
267        match self {
268            SimilarityMetric::Cosine => {
269                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
270                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
271                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
272                if norm_a == 0.0 || norm_b == 0.0 {
273                    0.0
274                } else {
275                    dot / (norm_a * norm_b)
276                }
277            }
278            SimilarityMetric::Euclidean => {
279                let dist: f32 = a
280                    .iter()
281                    .zip(b.iter())
282                    .map(|(x, y)| (x - y).powi(2))
283                    .sum::<f32>()
284                    .sqrt();
285                1.0 / (1.0 + dist) // Convert distance to similarity
286            }
287            SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
288            SimilarityMetric::Manhattan => {
289                let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
290                1.0 / (1.0 + dist)
291            }
292        }
293    }
294}
295
296/// Search result from vector store
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SearchResult {
299    /// The memory entry
300    pub entry: MemoryEntry,
301    /// Similarity score (0.0 - 1.0)
302    pub score: f32,
303    /// Rank in results
304    pub rank: usize,
305}
306
307/// Vector store for semantic search
308pub struct VectorStore {
309    config: VectorStoreConfig,
310    entries: Vec<MemoryEntry>,
311    db: Option<rusqlite::Connection>,
312}
313
314impl VectorStore {
315    /// Create a new in-memory vector store
316    pub fn new(config: VectorStoreConfig) -> Self {
317        Self {
318            config,
319            entries: Vec::new(),
320            db: None,
321        }
322    }
323
324    /// Create a vector store with SQLite persistence
325    pub fn with_persistence(
326        config: VectorStoreConfig,
327        db_path: impl AsRef<Path>,
328    ) -> Result<Self, MemoryError> {
329        let db = rusqlite::Connection::open(db_path.as_ref())
330            .map_err(|e| MemoryError::Database(e.to_string()))?;
331
332        // Initialize schema
333        db.execute_batch(
334            r#"
335            CREATE TABLE IF NOT EXISTS memory_entries (
336                id TEXT PRIMARY KEY,
337                content TEXT NOT NULL,
338                embedding BLOB,
339                memory_type TEXT NOT NULL,
340                source TEXT NOT NULL,
341                importance REAL NOT NULL,
342                access_count INTEGER NOT NULL DEFAULT 0,
343                last_accessed TEXT NOT NULL,
344                created_at TEXT NOT NULL,
345                expires_at TEXT,
346                agent_id TEXT,
347                session_id TEXT,
348                metadata TEXT,
349                tags TEXT
350            );
351            
352            CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(memory_type);
353            CREATE INDEX IF NOT EXISTS idx_agent_id ON memory_entries(agent_id);
354            CREATE INDEX IF NOT EXISTS idx_session_id ON memory_entries(session_id);
355            CREATE INDEX IF NOT EXISTS idx_created_at ON memory_entries(created_at);
356            CREATE INDEX IF NOT EXISTS idx_importance ON memory_entries(importance DESC);
357        "#,
358        )
359        .map_err(|e| MemoryError::Database(e.to_string()))?;
360
361        let mut store = Self {
362            config,
363            entries: Vec::new(),
364            db: Some(db),
365        };
366
367        store.load_from_db()?;
368        Ok(store)
369    }
370
371    /// Load entries from database
372    fn load_from_db(&mut self) -> Result<(), MemoryError> {
373        if let Some(ref db) = self.db {
374            let mut stmt = db
375                .prepare(
376                    "SELECT id, content, embedding, memory_type, source, importance, 
377                        access_count, last_accessed, created_at, expires_at, 
378                        agent_id, session_id, metadata, tags 
379                 FROM memory_entries 
380                 ORDER BY importance DESC, created_at DESC",
381                )
382                .map_err(|e| MemoryError::Database(e.to_string()))?;
383
384            let entries = stmt
385                .query_map([], |row| {
386                    let embedding_blob: Option<Vec<u8>> = row.get(2)?;
387                    let embedding = embedding_blob.map(|blob| {
388                        blob.chunks(4)
389                            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4])))
390                            .collect()
391                    });
392
393                    Ok(MemoryEntry {
394                        id: row.get(0)?,
395                        content: row.get(1)?,
396                        embedding,
397                        memory_type: serde_json::from_str(&row.get::<_, String>(3)?)
398                            .unwrap_or(MemoryType::LongTerm),
399                        source: serde_json::from_str(&row.get::<_, String>(4)?)
400                            .unwrap_or(MemorySource::UserInput),
401                        importance: row.get(5)?,
402                        access_count: row.get(6)?,
403                        last_accessed: row
404                            .get::<_, String>(7)?
405                            .parse()
406                            .unwrap_or_else(|_| Utc::now()),
407                        created_at: row
408                            .get::<_, String>(8)?
409                            .parse()
410                            .unwrap_or_else(|_| Utc::now()),
411                        expires_at: row
412                            .get::<_, Option<String>>(9)?
413                            .and_then(|s| s.parse().ok()),
414                        agent_id: row.get(10)?,
415                        session_id: row.get(11)?,
416                        metadata: row
417                            .get::<_, Option<String>>(12)?
418                            .and_then(|s| serde_json::from_str(&s).ok())
419                            .unwrap_or_default(),
420                        tags: row
421                            .get::<_, Option<String>>(13)?
422                            .and_then(|s| serde_json::from_str(&s).ok())
423                            .unwrap_or_default(),
424                    })
425                })
426                .map_err(|e| MemoryError::Database(e.to_string()))?;
427
428            self.entries = entries.filter_map(|e| e.ok()).collect();
429        }
430        Ok(())
431    }
432
433    /// Add a memory entry
434    pub fn add(&mut self, entry: MemoryEntry) -> Result<MemoryId, MemoryError> {
435        let id = entry.id.clone();
436
437        // Persist to database if available
438        if let Some(ref db) = self.db {
439            let embedding_blob: Option<Vec<u8>> = entry
440                .embedding
441                .as_ref()
442                .map(|emb| emb.iter().flat_map(|f| f.to_le_bytes()).collect());
443
444            db.execute(
445                "INSERT OR REPLACE INTO memory_entries 
446                 (id, content, embedding, memory_type, source, importance, 
447                  access_count, last_accessed, created_at, expires_at, 
448                  agent_id, session_id, metadata, tags)
449                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)",
450                rusqlite::params![
451                    entry.id,
452                    entry.content,
453                    embedding_blob,
454                    serde_json::to_string(&entry.memory_type).unwrap_or_default(),
455                    serde_json::to_string(&entry.source).unwrap_or_default(),
456                    entry.importance,
457                    entry.access_count,
458                    entry.last_accessed.to_rfc3339(),
459                    entry.created_at.to_rfc3339(),
460                    entry.expires_at.map(|e| e.to_rfc3339()),
461                    entry.agent_id,
462                    entry.session_id,
463                    serde_json::to_string(&entry.metadata).ok(),
464                    serde_json::to_string(&entry.tags).ok(),
465                ],
466            )
467            .map_err(|e| MemoryError::Database(e.to_string()))?;
468        }
469
470        self.entries.push(entry);
471
472        // Prune if needed
473        if self.entries.len() > self.config.max_entries {
474            self.prune()?;
475        }
476
477        Ok(id)
478    }
479
480    /// Search for similar entries
481    pub fn search(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
482        let mut results: Vec<(usize, f32)> = self
483            .entries
484            .iter()
485            .enumerate()
486            .filter(|(_, e)| !e.is_expired() && e.embedding.is_some())
487            .map(|(i, e)| {
488                let score = self
489                    .config
490                    .similarity_metric
491                    .calculate(query_embedding, e.embedding.as_ref().unwrap());
492                (i, score)
493            })
494            .collect();
495
496        // Sort by score descending
497        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
498
499        // Take top results and update access counts
500        results
501            .into_iter()
502            .take(limit)
503            .enumerate()
504            .map(|(rank, (idx, score))| {
505                self.entries[idx].access_count += 1;
506                self.entries[idx].last_accessed = Utc::now();
507
508                SearchResult {
509                    entry: self.entries[idx].clone(),
510                    score,
511                    rank,
512                }
513            })
514            .collect()
515    }
516
517    /// Search by memory type
518    pub fn search_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
519        self.entries
520            .iter()
521            .filter(|e| e.memory_type == memory_type && !e.is_expired())
522            .take(limit)
523            .collect()
524    }
525
526    /// Search by tags
527    pub fn search_by_tags(&self, tags: &[String], limit: usize) -> Vec<&MemoryEntry> {
528        self.entries
529            .iter()
530            .filter(|e| !e.is_expired() && tags.iter().any(|t| e.tags.contains(t)))
531            .take(limit)
532            .collect()
533    }
534
535    /// Get entry by ID
536    pub fn get(&self, id: &str) -> Option<&MemoryEntry> {
537        self.entries.iter().find(|e| e.id == id)
538    }
539
540    /// Delete entry
541    pub fn delete(&mut self, id: &str) -> Result<bool, MemoryError> {
542        if let Some(pos) = self.entries.iter().position(|e| e.id == id) {
543            self.entries.remove(pos);
544
545            if let Some(ref db) = self.db {
546                db.execute("DELETE FROM memory_entries WHERE id = ?1", [id])
547                    .map_err(|e| MemoryError::Database(e.to_string()))?;
548            }
549
550            Ok(true)
551        } else {
552            Ok(false)
553        }
554    }
555
556    /// Prune old/low-importance entries
557    fn prune(&mut self) -> Result<(), MemoryError> {
558        // Remove expired entries
559        self.entries.retain(|e| !e.is_expired());
560
561        // If still over limit, remove lowest importance entries
562        if self.entries.len() > self.config.max_entries {
563            self.entries.sort_by(|a, b| {
564                b.importance
565                    .partial_cmp(&a.importance)
566                    .unwrap_or(std::cmp::Ordering::Equal)
567            });
568            self.entries.truncate(self.config.max_entries);
569        }
570
571        Ok(())
572    }
573
574    /// Get statistics
575    pub fn stats(&self) -> VectorStoreStats {
576        VectorStoreStats {
577            total_entries: self.entries.len(),
578            entries_by_type: self.entries.iter().fold(HashMap::new(), |mut acc, e| {
579                *acc.entry(format!("{:?}", e.memory_type)).or_insert(0) += 1;
580                acc
581            }),
582            total_access_count: self.entries.iter().map(|e| e.access_count).sum(),
583            avg_importance: if self.entries.is_empty() {
584                0.0
585            } else {
586                self.entries.iter().map(|e| e.importance).sum::<f32>() / self.entries.len() as f32
587            },
588        }
589    }
590}
591
592/// Vector store statistics
593#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct VectorStoreStats {
595    pub total_entries: usize,
596    pub entries_by_type: HashMap<String, usize>,
597    pub total_access_count: u64,
598    pub avg_importance: f32,
599}
600
601// =============================================================================
602// Knowledge Base / RAG
603// =============================================================================
604
605/// Document for the knowledge base
606#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct Document {
608    /// Unique identifier
609    pub id: String,
610    /// Document title
611    pub title: String,
612    /// Full content
613    pub content: String,
614    /// Document type
615    pub doc_type: DocumentType,
616    /// Source URL or path
617    pub source: String,
618    /// Chunked content for embedding
619    pub chunks: Vec<DocumentChunk>,
620    /// Creation timestamp
621    pub created_at: DateTime<Utc>,
622    /// Last updated
623    pub updated_at: DateTime<Utc>,
624    /// Metadata
625    pub metadata: HashMap<String, serde_json::Value>,
626}
627
628/// Document types
629#[derive(Debug, Clone, Serialize, Deserialize)]
630#[serde(rename_all = "snake_case")]
631pub enum DocumentType {
632    Text,
633    Markdown,
634    Code { language: String },
635    Html,
636    Pdf,
637    Json,
638    Yaml,
639    Csv,
640    Custom { mime_type: String },
641}
642
643/// A chunk of a document
644#[derive(Debug, Clone, Serialize, Deserialize)]
645pub struct DocumentChunk {
646    /// Chunk index
647    pub index: u32,
648    /// Chunk content
649    pub content: String,
650    /// Start position in original document
651    pub start_pos: usize,
652    /// End position in original document
653    pub end_pos: usize,
654    /// Vector embedding
655    pub embedding: Option<Embedding>,
656    /// Token count estimate
657    pub token_count: u32,
658}
659
660/// Configuration for document chunking
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ChunkingConfig {
663    /// Target chunk size in tokens
664    pub chunk_size: usize,
665    /// Overlap between chunks in tokens
666    pub chunk_overlap: usize,
667    /// Chunking strategy
668    pub strategy: ChunkingStrategy,
669}
670
671impl Default for ChunkingConfig {
672    fn default() -> Self {
673        Self {
674            chunk_size: 512,
675            chunk_overlap: 50,
676            strategy: ChunkingStrategy::Semantic,
677        }
678    }
679}
680
681/// Chunking strategies
682#[derive(Debug, Clone, Default, Serialize, Deserialize)]
683#[serde(rename_all = "snake_case")]
684pub enum ChunkingStrategy {
685    /// Fixed-size character chunks
686    FixedSize,
687    /// Split on sentences
688    Sentence,
689    /// Split on paragraphs
690    Paragraph,
691    /// Semantic chunking (respects structure)
692    #[default]
693    Semantic,
694    /// Code-aware chunking
695    Code,
696}
697
698/// Knowledge base for RAG
699pub struct KnowledgeBase {
700    /// Vector store for semantic search
701    vector_store: VectorStore,
702    /// Documents
703    documents: HashMap<String, Document>,
704    /// Chunking configuration
705    chunking_config: ChunkingConfig,
706}
707
708impl KnowledgeBase {
709    /// Create a new knowledge base
710    pub fn new(vector_config: VectorStoreConfig) -> Self {
711        Self {
712            vector_store: VectorStore::new(vector_config),
713            documents: HashMap::new(),
714            chunking_config: ChunkingConfig::default(),
715        }
716    }
717
718    /// Create with persistence
719    pub fn with_persistence(
720        vector_config: VectorStoreConfig,
721        db_path: impl AsRef<Path>,
722    ) -> Result<Self, MemoryError> {
723        Ok(Self {
724            vector_store: VectorStore::with_persistence(vector_config, db_path)?,
725            documents: HashMap::new(),
726            chunking_config: ChunkingConfig::default(),
727        })
728    }
729
730    /// Add a document
731    pub fn add_document(&mut self, mut document: Document) -> Result<String, MemoryError> {
732        // Chunk the document
733        document.chunks = self.chunk_document(&document.content);
734
735        let doc_id = document.id.clone();
736
737        // Add chunks to vector store
738        for chunk in &document.chunks {
739            if let Some(ref embedding) = chunk.embedding {
740                let entry = MemoryEntry::new(
741                    &chunk.content,
742                    MemoryType::Semantic,
743                    MemorySource::Document {
744                        path: document.source.clone(),
745                        chunk_index: chunk.index,
746                    },
747                )
748                .with_embedding(embedding.clone())
749                .with_tag(format!("doc:{}", doc_id));
750
751                self.vector_store.add(entry)?;
752            }
753        }
754
755        self.documents.insert(doc_id.clone(), document);
756        Ok(doc_id)
757    }
758
759    /// Chunk a document
760    fn chunk_document(&self, content: &str) -> Vec<DocumentChunk> {
761        match self.chunking_config.strategy {
762            ChunkingStrategy::Semantic => self.semantic_chunk(content),
763            ChunkingStrategy::Paragraph => self.paragraph_chunk(content),
764            ChunkingStrategy::Sentence => self.sentence_chunk(content),
765            ChunkingStrategy::FixedSize => self.fixed_chunk(content),
766            ChunkingStrategy::Code => self.code_chunk(content),
767        }
768    }
769
770    fn semantic_chunk(&self, content: &str) -> Vec<DocumentChunk> {
771        // Split on double newlines (paragraphs) but respect size limits
772        let mut chunks = Vec::new();
773        let mut current_chunk = String::new();
774        let mut start_pos = 0;
775        let mut chunk_index = 0;
776
777        for para in content.split("\n\n") {
778            let para = para.trim();
779            if para.is_empty() {
780                continue;
781            }
782
783            let para_tokens = estimate_tokens(para);
784            let current_tokens = estimate_tokens(&current_chunk);
785
786            if current_tokens + para_tokens > self.chunking_config.chunk_size
787                && !current_chunk.is_empty()
788            {
789                // Save current chunk
790                let end_pos = start_pos + current_chunk.len();
791                chunks.push(DocumentChunk {
792                    index: chunk_index,
793                    content: current_chunk.trim().to_string(),
794                    start_pos,
795                    end_pos,
796                    embedding: None,
797                    token_count: estimate_tokens(&current_chunk) as u32,
798                });
799                chunk_index += 1;
800                start_pos = end_pos;
801                current_chunk = String::new();
802            }
803
804            if !current_chunk.is_empty() {
805                current_chunk.push_str("\n\n");
806            }
807            current_chunk.push_str(para);
808        }
809
810        // Add remaining content
811        if !current_chunk.is_empty() {
812            let end_pos = start_pos + current_chunk.len();
813            chunks.push(DocumentChunk {
814                index: chunk_index,
815                content: current_chunk.trim().to_string(),
816                start_pos,
817                end_pos,
818                embedding: None,
819                token_count: estimate_tokens(&current_chunk) as u32,
820            });
821        }
822
823        chunks
824    }
825
826    fn paragraph_chunk(&self, content: &str) -> Vec<DocumentChunk> {
827        content
828            .split("\n\n")
829            .filter(|p| !p.trim().is_empty())
830            .enumerate()
831            .scan(0usize, |pos, (i, para)| {
832                let start = *pos;
833                *pos += para.len() + 2;
834                Some(DocumentChunk {
835                    index: i as u32,
836                    content: para.trim().to_string(),
837                    start_pos: start,
838                    end_pos: *pos,
839                    embedding: None,
840                    token_count: estimate_tokens(para) as u32,
841                })
842            })
843            .collect()
844    }
845
846    fn sentence_chunk(&self, content: &str) -> Vec<DocumentChunk> {
847        // Simple sentence splitting (could be improved with NLP)
848        let sentences: Vec<&str> = content
849            .split(['.', '!', '?'])
850            .filter(|s| !s.trim().is_empty())
851            .collect();
852
853        let mut chunks = Vec::new();
854        let mut current = String::new();
855        let mut start = 0;
856        let mut idx = 0;
857
858        for sentence in sentences {
859            let sentence = sentence.trim();
860            if estimate_tokens(&current) + estimate_tokens(sentence)
861                > self.chunking_config.chunk_size
862                && !current.is_empty() {
863                    chunks.push(DocumentChunk {
864                        index: idx,
865                        content: current.clone(),
866                        start_pos: start,
867                        end_pos: start + current.len(),
868                        embedding: None,
869                        token_count: estimate_tokens(&current) as u32,
870                    });
871                    idx += 1;
872                    start += current.len();
873                    current.clear();
874                }
875            if !current.is_empty() {
876                current.push(' ');
877            }
878            current.push_str(sentence);
879            current.push('.');
880        }
881
882        if !current.is_empty() {
883            chunks.push(DocumentChunk {
884                index: idx,
885                content: current.clone(),
886                start_pos: start,
887                end_pos: start + current.len(),
888                embedding: None,
889                token_count: estimate_tokens(&current) as u32,
890            });
891        }
892
893        chunks
894    }
895
896    fn fixed_chunk(&self, content: &str) -> Vec<DocumentChunk> {
897        let chars_per_chunk = self.chunking_config.chunk_size * 4; // Rough estimate
898        content
899            .chars()
900            .collect::<Vec<_>>()
901            .chunks(chars_per_chunk)
902            .enumerate()
903            .map(|(i, chars)| {
904                let s: String = chars.iter().collect();
905                DocumentChunk {
906                    index: i as u32,
907                    content: s.clone(),
908                    start_pos: i * chars_per_chunk,
909                    end_pos: (i + 1) * chars_per_chunk,
910                    embedding: None,
911                    token_count: estimate_tokens(&s) as u32,
912                }
913            })
914            .collect()
915    }
916
917    fn code_chunk(&self, content: &str) -> Vec<DocumentChunk> {
918        // Split on function/class definitions (simple heuristic)
919        let mut chunks = Vec::new();
920        let mut current = String::new();
921        let mut start = 0;
922        let mut idx = 0;
923
924        for line in content.lines() {
925            let is_boundary = line.starts_with("fn ")
926                || line.starts_with("pub fn ")
927                || line.starts_with("async fn ")
928                || line.starts_with("impl ")
929                || line.starts_with("struct ")
930                || line.starts_with("enum ")
931                || line.starts_with("trait ")
932                || line.starts_with("class ")
933                || line.starts_with("def ")
934                || line.starts_with("function ")
935                || line.starts_with("const ")
936                || line.starts_with("export ");
937
938            if is_boundary && !current.is_empty() {
939                chunks.push(DocumentChunk {
940                    index: idx,
941                    content: current.clone(),
942                    start_pos: start,
943                    end_pos: start + current.len(),
944                    embedding: None,
945                    token_count: estimate_tokens(&current) as u32,
946                });
947                idx += 1;
948                start += current.len();
949                current.clear();
950            }
951
952            current.push_str(line);
953            current.push('\n');
954        }
955
956        if !current.is_empty() {
957            chunks.push(DocumentChunk {
958                index: idx,
959                content: current.clone(),
960                start_pos: start,
961                end_pos: start + current.len(),
962                embedding: None,
963                token_count: estimate_tokens(&current) as u32,
964            });
965        }
966
967        chunks
968    }
969
970    /// Retrieve relevant context for a query
971    pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
972        self.vector_store.search(query_embedding, limit)
973    }
974
975    /// Get document by ID
976    pub fn get_document(&self, id: &str) -> Option<&Document> {
977        self.documents.get(id)
978    }
979
980    /// List all documents
981    pub fn list_documents(&self) -> Vec<&Document> {
982        self.documents.values().collect()
983    }
984
985    /// Delete document
986    pub fn delete_document(&mut self, id: &str) -> bool {
987        self.documents.remove(id).is_some()
988    }
989}
990
991// =============================================================================
992// Context Window Manager
993// =============================================================================
994
995/// Manages context for LLM prompts
996#[derive(Debug, Clone)]
997pub struct ContextWindow {
998    /// Maximum tokens for context
999    pub max_tokens: usize,
1000    /// Reserved tokens for response
1001    pub reserved_for_response: usize,
1002    /// Context segments
1003    segments: Vec<ContextSegment>,
1004}
1005
1006/// A segment of context
1007#[derive(Debug, Clone, Serialize, Deserialize)]
1008pub struct ContextSegment {
1009    /// Segment type
1010    pub segment_type: ContextSegmentType,
1011    /// Content
1012    pub content: String,
1013    /// Token count
1014    pub tokens: usize,
1015    /// Priority (higher = more important)
1016    pub priority: u32,
1017    /// Whether this segment is required
1018    pub required: bool,
1019}
1020
1021/// Types of context segments
1022#[derive(Debug, Clone, Serialize, Deserialize)]
1023#[serde(rename_all = "snake_case")]
1024pub enum ContextSegmentType {
1025    SystemPrompt,
1026    UserPreferences,
1027    ConversationHistory,
1028    RetrievedContext,
1029    ToolResults,
1030    CurrentQuery,
1031    Custom { name: String },
1032}
1033
1034impl ContextWindow {
1035    /// Create a new context window
1036    pub fn new(max_tokens: usize) -> Self {
1037        Self {
1038            max_tokens,
1039            reserved_for_response: max_tokens / 4, // Reserve 25% for response
1040            segments: Vec::new(),
1041        }
1042    }
1043
1044    /// Add a context segment
1045    pub fn add_segment(&mut self, segment: ContextSegment) {
1046        self.segments.push(segment);
1047    }
1048
1049    /// Build the final context, respecting token limits
1050    pub fn build(&mut self) -> String {
1051        let available = self.max_tokens - self.reserved_for_response;
1052
1053        // Sort by priority (required first, then by priority)
1054        self.segments
1055            .sort_by(|a, b| match (a.required, b.required) {
1056                (true, false) => std::cmp::Ordering::Less,
1057                (false, true) => std::cmp::Ordering::Greater,
1058                _ => b.priority.cmp(&a.priority),
1059            });
1060
1061        let mut total_tokens = 0;
1062        let mut result = Vec::new();
1063
1064        for segment in &self.segments {
1065            if total_tokens + segment.tokens <= available {
1066                result.push(segment.content.clone());
1067                total_tokens += segment.tokens;
1068            } else if segment.required {
1069                // Truncate if required
1070                let remaining = available.saturating_sub(total_tokens);
1071                if remaining > 0 {
1072                    let truncated = truncate_to_tokens(&segment.content, remaining);
1073                    result.push(truncated);
1074                    break;
1075                }
1076            }
1077        }
1078
1079        result.join("\n\n")
1080    }
1081
1082    /// Get current token usage
1083    pub fn token_usage(&self) -> (usize, usize) {
1084        let used: usize = self.segments.iter().map(|s| s.tokens).sum();
1085        (used, self.max_tokens - self.reserved_for_response)
1086    }
1087}
1088
1089// =============================================================================
1090// Cache
1091// =============================================================================
1092
1093/// Cache entry
1094#[derive(Debug, Clone, Serialize, Deserialize)]
1095pub struct CacheEntry<T> {
1096    pub key: String,
1097    pub value: T,
1098    pub created_at: DateTime<Utc>,
1099    pub expires_at: Option<DateTime<Utc>>,
1100    pub access_count: u64,
1101}
1102
1103impl<T> CacheEntry<T> {
1104    pub fn is_expired(&self) -> bool {
1105        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
1106    }
1107}
1108
1109/// Simple LRU cache for agent computations
1110pub struct AgentCache<T> {
1111    entries: HashMap<String, CacheEntry<T>>,
1112    max_size: usize,
1113}
1114
1115impl<T: Clone> AgentCache<T> {
1116    /// Create a new cache
1117    pub fn new(max_size: usize) -> Self {
1118        Self {
1119            entries: HashMap::new(),
1120            max_size,
1121        }
1122    }
1123
1124    /// Get a value from cache
1125    pub fn get(&mut self, key: &str) -> Option<T> {
1126        if let Some(entry) = self.entries.get_mut(key) {
1127            if entry.is_expired() {
1128                self.entries.remove(key);
1129                return None;
1130            }
1131            entry.access_count += 1;
1132            Some(entry.value.clone())
1133        } else {
1134            None
1135        }
1136    }
1137
1138    /// Set a value in cache
1139    pub fn set(&mut self, key: impl Into<String>, value: T, ttl: Option<chrono::Duration>) {
1140        let key = key.into();
1141        let now = Utc::now();
1142
1143        self.entries.insert(
1144            key.clone(),
1145            CacheEntry {
1146                key,
1147                value,
1148                created_at: now,
1149                expires_at: ttl.map(|d| now + d),
1150                access_count: 0,
1151            },
1152        );
1153
1154        // Evict if over size
1155        if self.entries.len() > self.max_size {
1156            self.evict_lru();
1157        }
1158    }
1159
1160    /// Remove a value
1161    pub fn remove(&mut self, key: &str) -> Option<T> {
1162        self.entries.remove(key).map(|e| e.value)
1163    }
1164
1165    /// Clear the cache
1166    pub fn clear(&mut self) {
1167        self.entries.clear();
1168    }
1169
1170    /// Evict least recently used entries
1171    fn evict_lru(&mut self) {
1172        // Remove expired first
1173        self.entries.retain(|_, v| !v.is_expired());
1174
1175        // If still over, remove least accessed
1176        if self.entries.len() > self.max_size {
1177            // Collect keys to remove (sorted by access count)
1178            let mut entries: Vec<_> = self
1179                .entries
1180                .iter()
1181                .map(|(k, v)| (k.clone(), v.access_count))
1182                .collect();
1183            entries.sort_by_key(|(_, count)| *count);
1184
1185            let to_remove = self.entries.len() - self.max_size;
1186            let keys_to_remove: Vec<String> = entries
1187                .into_iter()
1188                .take(to_remove)
1189                .map(|(k, _)| k)
1190                .collect();
1191
1192            for key in keys_to_remove {
1193                self.entries.remove(&key);
1194            }
1195        }
1196    }
1197}
1198
1199// =============================================================================
1200// Memory Manager (Unified Interface)
1201// =============================================================================
1202
1203/// Configuration for the memory manager
1204#[derive(Debug, Clone, Serialize, Deserialize)]
1205pub struct MemoryConfig {
1206    /// Vector store configuration
1207    pub vector_store: VectorStoreConfig,
1208    /// Chunking configuration
1209    pub chunking: ChunkingConfig,
1210    /// Context window size
1211    pub context_window_tokens: usize,
1212    /// Cache size
1213    pub cache_size: usize,
1214    /// Database path (None for in-memory)
1215    pub db_path: Option<String>,
1216    /// Auto-summarize long conversations
1217    pub auto_summarize: bool,
1218    /// Summarize after this many messages
1219    pub summarize_threshold: usize,
1220}
1221
1222impl Default for MemoryConfig {
1223    fn default() -> Self {
1224        Self {
1225            vector_store: VectorStoreConfig::default(),
1226            chunking: ChunkingConfig::default(),
1227            context_window_tokens: 8192,
1228            cache_size: 1000,
1229            db_path: None,
1230            auto_summarize: true,
1231            summarize_threshold: 20,
1232        }
1233    }
1234}
1235
1236/// Unified memory manager for agents
1237pub struct MemoryManager {
1238    config: MemoryConfig,
1239    vector_store: VectorStore,
1240    knowledge_base: KnowledgeBase,
1241    cache: AgentCache<String>,
1242}
1243
1244impl MemoryManager {
1245    /// Create a new memory manager
1246    pub fn new(config: MemoryConfig) -> Result<Self, MemoryError> {
1247        let vector_store = if let Some(ref path) = config.db_path {
1248            VectorStore::with_persistence(config.vector_store.clone(), path)?
1249        } else {
1250            VectorStore::new(config.vector_store.clone())
1251        };
1252
1253        let knowledge_base = if let Some(ref path) = config.db_path {
1254            let kb_path = format!("{}_kb", path);
1255            KnowledgeBase::with_persistence(config.vector_store.clone(), kb_path)?
1256        } else {
1257            KnowledgeBase::new(config.vector_store.clone())
1258        };
1259
1260        Ok(Self {
1261            config: config.clone(),
1262            vector_store,
1263            knowledge_base,
1264            cache: AgentCache::new(config.cache_size),
1265        })
1266    }
1267
1268    /// Store a memory
1269    pub fn remember(
1270        &mut self,
1271        content: impl Into<String>,
1272        memory_type: MemoryType,
1273        source: MemorySource,
1274    ) -> Result<MemoryId, MemoryError> {
1275        let entry = MemoryEntry::new(content, memory_type, source);
1276        self.vector_store.add(entry)
1277    }
1278
1279    /// Store a memory with embedding
1280    pub fn remember_with_embedding(
1281        &mut self,
1282        content: impl Into<String>,
1283        embedding: Embedding,
1284        memory_type: MemoryType,
1285        source: MemorySource,
1286    ) -> Result<MemoryId, MemoryError> {
1287        let entry = MemoryEntry::new(content, memory_type, source).with_embedding(embedding);
1288        self.vector_store.add(entry)
1289    }
1290
1291    /// Recall memories similar to a query
1292    pub fn recall(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1293        self.vector_store.search(query_embedding, limit)
1294    }
1295
1296    /// Recall by type
1297    pub fn recall_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
1298        self.vector_store.search_by_type(memory_type, limit)
1299    }
1300
1301    /// Add document to knowledge base
1302    pub fn add_document(&mut self, document: Document) -> Result<String, MemoryError> {
1303        self.knowledge_base.add_document(document)
1304    }
1305
1306    /// Retrieve from knowledge base
1307    pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1308        self.knowledge_base.retrieve(query_embedding, limit)
1309    }
1310
1311    /// Build context for a prompt
1312    pub fn build_context(
1313        &mut self,
1314        query_embedding: &Embedding,
1315        system_prompt: &str,
1316        conversation: &[String],
1317    ) -> String {
1318        let mut context = ContextWindow::new(self.config.context_window_tokens);
1319
1320        // System prompt (required)
1321        context.add_segment(ContextSegment {
1322            segment_type: ContextSegmentType::SystemPrompt,
1323            content: system_prompt.to_string(),
1324            tokens: estimate_tokens(system_prompt),
1325            priority: 100,
1326            required: true,
1327        });
1328
1329        // Retrieved context
1330        let retrieved = self.recall(query_embedding, 5);
1331        if !retrieved.is_empty() {
1332            let retrieved_text: String = retrieved
1333                .iter()
1334                .map(|r| format!("- {}", r.entry.content))
1335                .collect::<Vec<_>>()
1336                .join("\n");
1337
1338            context.add_segment(ContextSegment {
1339                segment_type: ContextSegmentType::RetrievedContext,
1340                content: format!("Relevant context:\n{}", retrieved_text),
1341                tokens: estimate_tokens(&retrieved_text) + 20,
1342                priority: 80,
1343                required: false,
1344            });
1345        }
1346
1347        // Conversation history
1348        let conv_text = conversation.join("\n");
1349        context.add_segment(ContextSegment {
1350            segment_type: ContextSegmentType::ConversationHistory,
1351            content: conv_text.clone(),
1352            tokens: estimate_tokens(&conv_text),
1353            priority: 90,
1354            required: false,
1355        });
1356
1357        context.build()
1358    }
1359
1360    /// Cache a computation result
1361    pub fn cache_result(
1362        &mut self,
1363        key: impl Into<String>,
1364        value: String,
1365        ttl: Option<chrono::Duration>,
1366    ) {
1367        self.cache.set(key, value, ttl);
1368    }
1369
1370    /// Get cached result
1371    pub fn get_cached(&mut self, key: &str) -> Option<String> {
1372        self.cache.get(key)
1373    }
1374
1375    /// Get statistics
1376    pub fn stats(&self) -> MemoryStats {
1377        MemoryStats {
1378            vector_store: self.vector_store.stats(),
1379            document_count: self.knowledge_base.list_documents().len(),
1380        }
1381    }
1382}
1383
1384/// Memory statistics
1385#[derive(Debug, Clone, Serialize, Deserialize)]
1386pub struct MemoryStats {
1387    pub vector_store: VectorStoreStats,
1388    pub document_count: usize,
1389}
1390
1391// =============================================================================
1392// Error Types
1393// =============================================================================
1394
1395/// Memory system errors
1396#[derive(Debug, Clone)]
1397pub enum MemoryError {
1398    /// Database error
1399    Database(String),
1400    /// Embedding error
1401    Embedding(String),
1402    /// Not found
1403    NotFound(String),
1404    /// Invalid input
1405    InvalidInput(String),
1406    /// IO error
1407    Io(String),
1408}
1409
1410impl std::fmt::Display for MemoryError {
1411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1412        match self {
1413            MemoryError::Database(e) => write!(f, "Database error: {}", e),
1414            MemoryError::Embedding(e) => write!(f, "Embedding error: {}", e),
1415            MemoryError::NotFound(e) => write!(f, "Not found: {}", e),
1416            MemoryError::InvalidInput(e) => write!(f, "Invalid input: {}", e),
1417            MemoryError::Io(e) => write!(f, "IO error: {}", e),
1418        }
1419    }
1420}
1421
1422impl std::error::Error for MemoryError {}
1423
1424// =============================================================================
1425// Utility Functions
1426// =============================================================================
1427
1428fn generate_memory_id() -> String {
1429    use std::time::{SystemTime, UNIX_EPOCH};
1430    let timestamp = SystemTime::now()
1431        .duration_since(UNIX_EPOCH)
1432        .unwrap()
1433        .as_nanos();
1434    format!("mem_{:x}", timestamp)
1435}
1436
1437/// Estimate token count (rough approximation: 4 chars per token)
1438fn estimate_tokens(text: &str) -> usize {
1439    (text.len() as f32 / 4.0).ceil() as usize
1440}
1441
1442/// Truncate text to approximate token count
1443fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
1444    let max_chars = max_tokens * 4;
1445    if text.len() <= max_chars {
1446        text.to_string()
1447    } else {
1448        format!("{}...", &text[..max_chars.min(text.len())])
1449    }
1450}
1451
1452// =============================================================================
1453// Embedding Provider Trait
1454// =============================================================================
1455
1456/// Trait for embedding providers
1457#[async_trait::async_trait]
1458pub trait EmbeddingProvider: Send + Sync {
1459    /// Generate embedding for text
1460    async fn embed(&self, text: &str) -> Result<Embedding, MemoryError>;
1461
1462    /// Generate embeddings for multiple texts
1463    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError>;
1464
1465    /// Get the embedding dimension
1466    fn dimension(&self) -> usize;
1467}
1468
1469/// OpenAI embedding provider
1470pub struct OpenAIEmbedding {
1471    #[allow(dead_code)]
1472    api_key: String,
1473    model: String,
1474}
1475
1476impl OpenAIEmbedding {
1477    pub fn new(api_key: impl Into<String>) -> Self {
1478        Self {
1479            api_key: api_key.into(),
1480            model: "text-embedding-3-small".to_string(),
1481        }
1482    }
1483
1484    pub fn with_model(mut self, model: impl Into<String>) -> Self {
1485        self.model = model.into();
1486        self
1487    }
1488}
1489
1490#[async_trait::async_trait]
1491impl EmbeddingProvider for OpenAIEmbedding {
1492    async fn embed(&self, _text: &str) -> Result<Embedding, MemoryError> {
1493        // Implementation would call OpenAI API
1494        // For now, return a placeholder
1495        Ok(vec![0.0; 1536])
1496    }
1497
1498    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError> {
1499        let mut results = Vec::new();
1500        for text in texts {
1501            results.push(self.embed(text).await?);
1502        }
1503        Ok(results)
1504    }
1505
1506    fn dimension(&self) -> usize {
1507        match self.model.as_str() {
1508            "text-embedding-3-large" => 3072,
1509            _ => 1536,
1510        }
1511    }
1512}
1513
1514// =============================================================================
1515// Tests
1516// =============================================================================
1517
1518#[cfg(test)]
1519mod tests {
1520    use super::*;
1521
1522    #[test]
1523    fn test_memory_entry_creation() {
1524        let entry = MemoryEntry::new(
1525            "Test content",
1526            MemoryType::LongTerm,
1527            MemorySource::UserInput,
1528        );
1529        assert!(!entry.id.is_empty());
1530        assert_eq!(entry.content, "Test content");
1531        assert_eq!(entry.memory_type, MemoryType::LongTerm);
1532    }
1533
1534    #[test]
1535    fn test_similarity_metrics() {
1536        let a = vec![1.0, 0.0, 0.0];
1537        let b = vec![1.0, 0.0, 0.0];
1538        let c = vec![0.0, 1.0, 0.0];
1539
1540        assert!((SimilarityMetric::Cosine.calculate(&a, &b) - 1.0).abs() < 0.001);
1541        assert!((SimilarityMetric::Cosine.calculate(&a, &c) - 0.0).abs() < 0.001);
1542    }
1543
1544    #[test]
1545    fn test_vector_store() {
1546        let config = VectorStoreConfig::default();
1547        let mut store = VectorStore::new(config);
1548
1549        let entry = MemoryEntry::new("Test", MemoryType::ShortTerm, MemorySource::UserInput)
1550            .with_embedding(vec![1.0, 0.0, 0.0]);
1551
1552        let id = store.add(entry).unwrap();
1553        assert!(!id.is_empty());
1554        assert!(store.get(&id).is_some());
1555    }
1556
1557    #[test]
1558    fn test_context_window() {
1559        let mut ctx = ContextWindow::new(1000);
1560
1561        ctx.add_segment(ContextSegment {
1562            segment_type: ContextSegmentType::SystemPrompt,
1563            content: "You are helpful".to_string(),
1564            tokens: 10,
1565            priority: 100,
1566            required: true,
1567        });
1568
1569        let result = ctx.build();
1570        assert!(result.contains("You are helpful"));
1571    }
1572
1573    #[test]
1574    fn test_cache() {
1575        let mut cache: AgentCache<String> = AgentCache::new(10);
1576
1577        cache.set("key1", "value1".to_string(), None);
1578        assert_eq!(cache.get("key1"), Some("value1".to_string()));
1579        assert_eq!(cache.get("key2"), None);
1580    }
1581
1582    #[test]
1583    fn test_estimate_tokens() {
1584        assert_eq!(estimate_tokens("hello"), 2); // 5 chars / 4 = 1.25 -> 2
1585        assert_eq!(estimate_tokens("hello world"), 3); // 11 chars / 4 = 2.75 -> 3
1586    }
1587}