chasm_cli/agency/
memory.rs

1// Copyright (c) 2024-2026 Nervosys LLC
2// SPDX-License-Identifier: Apache-2.0
3//! Memory and RAG (Retrieval-Augmented Generation) System
4//!
5//! Provides persistent context, knowledge retrieval, and semantic search for agents.
6//!
7//! ## Features
8//!
9//! - **Vector Store**: Semantic similarity search using embeddings
10//! - **Memory Types**: Short-term, long-term, episodic, and semantic memory
11//! - **Knowledge Base**: Structured document storage with chunking
12//! - **Context Window**: Smart context management for LLM prompts
13
14#![allow(dead_code)]
15//! - **Caching**: Frequently accessed information caching
16
17use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::path::Path;
21
22// =============================================================================
23// Core Types
24// =============================================================================
25
26/// Unique identifier for memory entries
27pub type MemoryId = String;
28
29/// Vector embedding (typically 384-1536 dimensions depending on model)
30pub type Embedding = Vec<f32>;
31
32/// Memory entry representing a piece of stored knowledge
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MemoryEntry {
35    /// Unique identifier
36    pub id: MemoryId,
37    /// The content/text of this memory
38    pub content: String,
39    /// Vector embedding for similarity search
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub embedding: Option<Embedding>,
42    /// Memory type classification
43    pub memory_type: MemoryType,
44    /// Source of this memory (conversation, document, user input, etc.)
45    pub source: MemorySource,
46    /// Importance score (0.0 - 1.0)
47    pub importance: f32,
48    /// Access count for LRU caching
49    pub access_count: u64,
50    /// Last accessed timestamp
51    pub last_accessed: DateTime<Utc>,
52    /// Creation timestamp
53    pub created_at: DateTime<Utc>,
54    /// Optional expiration
55    pub expires_at: Option<DateTime<Utc>>,
56    /// Associated agent ID
57    pub agent_id: Option<String>,
58    /// Associated session ID
59    pub session_id: Option<String>,
60    /// Custom metadata
61    pub metadata: HashMap<String, serde_json::Value>,
62    /// Tags for filtering
63    pub tags: Vec<String>,
64}
65
66impl MemoryEntry {
67    /// Create a new memory entry
68    pub fn new(content: impl Into<String>, memory_type: MemoryType, source: MemorySource) -> Self {
69        let now = Utc::now();
70        Self {
71            id: generate_memory_id(),
72            content: content.into(),
73            embedding: None,
74            memory_type,
75            source,
76            importance: 0.5,
77            access_count: 0,
78            last_accessed: now,
79            created_at: now,
80            expires_at: None,
81            agent_id: None,
82            session_id: None,
83            metadata: HashMap::new(),
84            tags: Vec::new(),
85        }
86    }
87
88    /// Set the embedding
89    pub fn with_embedding(mut self, embedding: Embedding) -> Self {
90        self.embedding = Some(embedding);
91        self
92    }
93
94    /// Set importance score
95    pub fn with_importance(mut self, importance: f32) -> Self {
96        self.importance = importance.clamp(0.0, 1.0);
97        self
98    }
99
100    /// Set agent ID
101    pub fn with_agent(mut self, agent_id: impl Into<String>) -> Self {
102        self.agent_id = Some(agent_id.into());
103        self
104    }
105
106    /// Set session ID
107    pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
108        self.session_id = Some(session_id.into());
109        self
110    }
111
112    /// Add a tag
113    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
114        self.tags.push(tag.into());
115        self
116    }
117
118    /// Set expiration
119    pub fn expires_in(mut self, duration: chrono::Duration) -> Self {
120        self.expires_at = Some(Utc::now() + duration);
121        self
122    }
123
124    /// Check if expired
125    pub fn is_expired(&self) -> bool {
126        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
127    }
128}
129
130/// Types of memory
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum MemoryType {
134    /// Short-term working memory (current conversation context)
135    ShortTerm,
136    /// Long-term persistent memory (facts, preferences, learned info)
137    LongTerm,
138    /// Episodic memory (specific events and experiences)
139    Episodic,
140    /// Semantic memory (concepts, relationships, general knowledge)
141    Semantic,
142    /// Procedural memory (how to do things, workflows)
143    Procedural,
144    /// User preferences and settings
145    Preference,
146    /// Cached computation results
147    Cache,
148}
149
150/// Source of memory entry
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "snake_case")]
153pub enum MemorySource {
154    /// From a conversation message
155    Conversation {
156        session_id: String,
157        message_id: String,
158    },
159    /// From a document/file
160    Document { path: String, chunk_index: u32 },
161    /// Direct user input/instruction
162    UserInput,
163    /// Agent reasoning/reflection
164    AgentReasoning { agent_id: String },
165    /// External API or tool result
166    ToolResult { tool_name: String },
167    /// Web page or URL
168    WebPage { url: String },
169    /// System-generated summary
170    Summary { source_ids: Vec<String> },
171    /// Custom source
172    Custom { source_type: String },
173}
174
175// =============================================================================
176// Vector Store
177// =============================================================================
178
179/// Configuration for the vector store
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct VectorStoreConfig {
182    /// Embedding model to use
183    pub embedding_model: EmbeddingModel,
184    /// Dimension of embeddings
185    pub embedding_dim: usize,
186    /// Similarity metric
187    pub similarity_metric: SimilarityMetric,
188    /// Maximum entries before pruning
189    pub max_entries: usize,
190    /// Database path
191    pub db_path: Option<String>,
192}
193
194impl Default for VectorStoreConfig {
195    fn default() -> Self {
196        Self {
197            embedding_model: EmbeddingModel::default(),
198            embedding_dim: 384,
199            similarity_metric: SimilarityMetric::Cosine,
200            max_entries: 100_000,
201            db_path: None,
202        }
203    }
204}
205
206/// Embedding model options
207#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208#[serde(rename_all = "snake_case")]
209pub enum EmbeddingModel {
210    /// OpenAI text-embedding-3-small (1536 dims)
211    OpenAISmall,
212    /// OpenAI text-embedding-3-large (3072 dims)
213    OpenAILarge,
214    /// OpenAI text-embedding-ada-002 (1536 dims)
215    OpenAIAda,
216    /// Sentence Transformers all-MiniLM-L6-v2 (384 dims)
217    #[default]
218    MiniLM,
219    /// Sentence Transformers all-mpnet-base-v2 (768 dims)
220    MPNet,
221    /// Cohere embed-english-v3.0 (1024 dims)
222    Cohere,
223    /// Google text-embedding-004 (768 dims)
224    GoogleGecko,
225    /// Voyage AI voyage-2 (1024 dims)
226    Voyage,
227    /// Local model via Ollama
228    Ollama { model: String },
229    /// Custom model
230    Custom { name: String, dim: usize },
231}
232
233impl EmbeddingModel {
234    /// Get the dimension for this model
235    pub fn dimension(&self) -> usize {
236        match self {
237            EmbeddingModel::OpenAISmall => 1536,
238            EmbeddingModel::OpenAILarge => 3072,
239            EmbeddingModel::OpenAIAda => 1536,
240            EmbeddingModel::MiniLM => 384,
241            EmbeddingModel::MPNet => 768,
242            EmbeddingModel::Cohere => 1024,
243            EmbeddingModel::GoogleGecko => 768,
244            EmbeddingModel::Voyage => 1024,
245            EmbeddingModel::Ollama { .. } => 4096, // Typical for Ollama models
246            EmbeddingModel::Custom { dim, .. } => *dim,
247        }
248    }
249}
250
251/// Similarity metrics for vector search
252#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
253#[serde(rename_all = "snake_case")]
254pub enum SimilarityMetric {
255    #[default]
256    Cosine,
257    Euclidean,
258    DotProduct,
259    Manhattan,
260}
261
262impl SimilarityMetric {
263    /// Calculate similarity between two vectors
264    pub fn calculate(&self, a: &[f32], b: &[f32]) -> f32 {
265        assert_eq!(a.len(), b.len(), "Vector dimensions must match");
266
267        match self {
268            SimilarityMetric::Cosine => {
269                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
270                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
271                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
272                if norm_a == 0.0 || norm_b == 0.0 {
273                    0.0
274                } else {
275                    dot / (norm_a * norm_b)
276                }
277            }
278            SimilarityMetric::Euclidean => {
279                let dist: f32 = a
280                    .iter()
281                    .zip(b.iter())
282                    .map(|(x, y)| (x - y).powi(2))
283                    .sum::<f32>()
284                    .sqrt();
285                1.0 / (1.0 + dist) // Convert distance to similarity
286            }
287            SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
288            SimilarityMetric::Manhattan => {
289                let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
290                1.0 / (1.0 + dist)
291            }
292        }
293    }
294}
295
296/// Search result from vector store
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SearchResult {
299    /// The memory entry
300    pub entry: MemoryEntry,
301    /// Similarity score (0.0 - 1.0)
302    pub score: f32,
303    /// Rank in results
304    pub rank: usize,
305}
306
307/// Vector store for semantic search
308pub struct VectorStore {
309    config: VectorStoreConfig,
310    entries: Vec<MemoryEntry>,
311    db: Option<rusqlite::Connection>,
312}
313
314impl VectorStore {
315    /// Create a new in-memory vector store
316    pub fn new(config: VectorStoreConfig) -> Self {
317        Self {
318            config,
319            entries: Vec::new(),
320            db: None,
321        }
322    }
323
324    /// Create a vector store with SQLite persistence
325    pub fn with_persistence(
326        config: VectorStoreConfig,
327        db_path: impl AsRef<Path>,
328    ) -> Result<Self, MemoryError> {
329        let db = rusqlite::Connection::open(db_path.as_ref())
330            .map_err(|e| MemoryError::Database(e.to_string()))?;
331
332        // Initialize schema
333        db.execute_batch(
334            r#"
335            CREATE TABLE IF NOT EXISTS memory_entries (
336                id TEXT PRIMARY KEY,
337                content TEXT NOT NULL,
338                embedding BLOB,
339                memory_type TEXT NOT NULL,
340                source TEXT NOT NULL,
341                importance REAL NOT NULL,
342                access_count INTEGER NOT NULL DEFAULT 0,
343                last_accessed TEXT NOT NULL,
344                created_at TEXT NOT NULL,
345                expires_at TEXT,
346                agent_id TEXT,
347                session_id TEXT,
348                metadata TEXT,
349                tags TEXT
350            );
351            
352            CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(memory_type);
353            CREATE INDEX IF NOT EXISTS idx_agent_id ON memory_entries(agent_id);
354            CREATE INDEX IF NOT EXISTS idx_session_id ON memory_entries(session_id);
355            CREATE INDEX IF NOT EXISTS idx_created_at ON memory_entries(created_at);
356            CREATE INDEX IF NOT EXISTS idx_importance ON memory_entries(importance DESC);
357        "#,
358        )
359        .map_err(|e| MemoryError::Database(e.to_string()))?;
360
361        let mut store = Self {
362            config,
363            entries: Vec::new(),
364            db: Some(db),
365        };
366
367        store.load_from_db()?;
368        Ok(store)
369    }
370
371    /// Load entries from database
372    fn load_from_db(&mut self) -> Result<(), MemoryError> {
373        if let Some(ref db) = self.db {
374            let mut stmt = db
375                .prepare(
376                    "SELECT id, content, embedding, memory_type, source, importance, 
377                        access_count, last_accessed, created_at, expires_at, 
378                        agent_id, session_id, metadata, tags 
379                 FROM memory_entries 
380                 ORDER BY importance DESC, created_at DESC",
381                )
382                .map_err(|e| MemoryError::Database(e.to_string()))?;
383
384            let entries = stmt
385                .query_map([], |row| {
386                    let embedding_blob: Option<Vec<u8>> = row.get(2)?;
387                    let embedding = embedding_blob.map(|blob| {
388                        blob.chunks(4)
389                            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap_or([0; 4])))
390                            .collect()
391                    });
392
393                    Ok(MemoryEntry {
394                        id: row.get(0)?,
395                        content: row.get(1)?,
396                        embedding,
397                        memory_type: serde_json::from_str(&row.get::<_, String>(3)?)
398                            .unwrap_or(MemoryType::LongTerm),
399                        source: serde_json::from_str(&row.get::<_, String>(4)?)
400                            .unwrap_or(MemorySource::UserInput),
401                        importance: row.get(5)?,
402                        access_count: row.get(6)?,
403                        last_accessed: row
404                            .get::<_, String>(7)?
405                            .parse()
406                            .unwrap_or_else(|_| Utc::now()),
407                        created_at: row
408                            .get::<_, String>(8)?
409                            .parse()
410                            .unwrap_or_else(|_| Utc::now()),
411                        expires_at: row
412                            .get::<_, Option<String>>(9)?
413                            .and_then(|s| s.parse().ok()),
414                        agent_id: row.get(10)?,
415                        session_id: row.get(11)?,
416                        metadata: row
417                            .get::<_, Option<String>>(12)?
418                            .and_then(|s| serde_json::from_str(&s).ok())
419                            .unwrap_or_default(),
420                        tags: row
421                            .get::<_, Option<String>>(13)?
422                            .and_then(|s| serde_json::from_str(&s).ok())
423                            .unwrap_or_default(),
424                    })
425                })
426                .map_err(|e| MemoryError::Database(e.to_string()))?;
427
428            self.entries = entries.filter_map(|e| e.ok()).collect();
429        }
430        Ok(())
431    }
432
433    /// Add a memory entry
434    pub fn add(&mut self, entry: MemoryEntry) -> Result<MemoryId, MemoryError> {
435        let id = entry.id.clone();
436
437        // Persist to database if available
438        if let Some(ref db) = self.db {
439            let embedding_blob: Option<Vec<u8>> = entry
440                .embedding
441                .as_ref()
442                .map(|emb| emb.iter().flat_map(|f| f.to_le_bytes()).collect());
443
444            db.execute(
445                "INSERT OR REPLACE INTO memory_entries 
446                 (id, content, embedding, memory_type, source, importance, 
447                  access_count, last_accessed, created_at, expires_at, 
448                  agent_id, session_id, metadata, tags)
449                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)",
450                rusqlite::params![
451                    entry.id,
452                    entry.content,
453                    embedding_blob,
454                    serde_json::to_string(&entry.memory_type).unwrap_or_default(),
455                    serde_json::to_string(&entry.source).unwrap_or_default(),
456                    entry.importance,
457                    entry.access_count,
458                    entry.last_accessed.to_rfc3339(),
459                    entry.created_at.to_rfc3339(),
460                    entry.expires_at.map(|e| e.to_rfc3339()),
461                    entry.agent_id,
462                    entry.session_id,
463                    serde_json::to_string(&entry.metadata).ok(),
464                    serde_json::to_string(&entry.tags).ok(),
465                ],
466            )
467            .map_err(|e| MemoryError::Database(e.to_string()))?;
468        }
469
470        self.entries.push(entry);
471
472        // Prune if needed
473        if self.entries.len() > self.config.max_entries {
474            self.prune()?;
475        }
476
477        Ok(id)
478    }
479
480    /// Search for similar entries
481    pub fn search(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
482        let mut results: Vec<(usize, f32)> = self
483            .entries
484            .iter()
485            .enumerate()
486            .filter(|(_, e)| !e.is_expired() && e.embedding.is_some())
487            .map(|(i, e)| {
488                let score = self
489                    .config
490                    .similarity_metric
491                    .calculate(query_embedding, e.embedding.as_ref().unwrap());
492                (i, score)
493            })
494            .collect();
495
496        // Sort by score descending
497        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
498
499        // Take top results and update access counts
500        results
501            .into_iter()
502            .take(limit)
503            .enumerate()
504            .map(|(rank, (idx, score))| {
505                self.entries[idx].access_count += 1;
506                self.entries[idx].last_accessed = Utc::now();
507
508                SearchResult {
509                    entry: self.entries[idx].clone(),
510                    score,
511                    rank,
512                }
513            })
514            .collect()
515    }
516
517    /// Search by memory type
518    pub fn search_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
519        self.entries
520            .iter()
521            .filter(|e| e.memory_type == memory_type && !e.is_expired())
522            .take(limit)
523            .collect()
524    }
525
526    /// Search by tags
527    pub fn search_by_tags(&self, tags: &[String], limit: usize) -> Vec<&MemoryEntry> {
528        self.entries
529            .iter()
530            .filter(|e| !e.is_expired() && tags.iter().any(|t| e.tags.contains(t)))
531            .take(limit)
532            .collect()
533    }
534
535    /// Get entry by ID
536    pub fn get(&self, id: &str) -> Option<&MemoryEntry> {
537        self.entries.iter().find(|e| e.id == id)
538    }
539
540    /// Delete entry
541    pub fn delete(&mut self, id: &str) -> Result<bool, MemoryError> {
542        if let Some(pos) = self.entries.iter().position(|e| e.id == id) {
543            self.entries.remove(pos);
544
545            if let Some(ref db) = self.db {
546                db.execute("DELETE FROM memory_entries WHERE id = ?1", [id])
547                    .map_err(|e| MemoryError::Database(e.to_string()))?;
548            }
549
550            Ok(true)
551        } else {
552            Ok(false)
553        }
554    }
555
556    /// Prune old/low-importance entries
557    fn prune(&mut self) -> Result<(), MemoryError> {
558        // Remove expired entries
559        self.entries.retain(|e| !e.is_expired());
560
561        // If still over limit, remove lowest importance entries
562        if self.entries.len() > self.config.max_entries {
563            self.entries.sort_by(|a, b| {
564                b.importance
565                    .partial_cmp(&a.importance)
566                    .unwrap_or(std::cmp::Ordering::Equal)
567            });
568            self.entries.truncate(self.config.max_entries);
569        }
570
571        Ok(())
572    }
573
574    /// Get statistics
575    pub fn stats(&self) -> VectorStoreStats {
576        VectorStoreStats {
577            total_entries: self.entries.len(),
578            entries_by_type: self.entries.iter().fold(HashMap::new(), |mut acc, e| {
579                *acc.entry(format!("{:?}", e.memory_type)).or_insert(0) += 1;
580                acc
581            }),
582            total_access_count: self.entries.iter().map(|e| e.access_count).sum(),
583            avg_importance: if self.entries.is_empty() {
584                0.0
585            } else {
586                self.entries.iter().map(|e| e.importance).sum::<f32>() / self.entries.len() as f32
587            },
588        }
589    }
590}
591
592/// Vector store statistics
593#[derive(Debug, Clone, Serialize, Deserialize)]
594pub struct VectorStoreStats {
595    pub total_entries: usize,
596    pub entries_by_type: HashMap<String, usize>,
597    pub total_access_count: u64,
598    pub avg_importance: f32,
599}
600
601// =============================================================================
602// Knowledge Base / RAG
603// =============================================================================
604
605/// Document for the knowledge base
606#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct Document {
608    /// Unique identifier
609    pub id: String,
610    /// Document title
611    pub title: String,
612    /// Full content
613    pub content: String,
614    /// Document type
615    pub doc_type: DocumentType,
616    /// Source URL or path
617    pub source: String,
618    /// Chunked content for embedding
619    pub chunks: Vec<DocumentChunk>,
620    /// Creation timestamp
621    pub created_at: DateTime<Utc>,
622    /// Last updated
623    pub updated_at: DateTime<Utc>,
624    /// Metadata
625    pub metadata: HashMap<String, serde_json::Value>,
626}
627
628/// Document types
629#[derive(Debug, Clone, Serialize, Deserialize)]
630#[serde(rename_all = "snake_case")]
631pub enum DocumentType {
632    Text,
633    Markdown,
634    Code { language: String },
635    Html,
636    Pdf,
637    Json,
638    Yaml,
639    Csv,
640    Custom { mime_type: String },
641}
642
643/// A chunk of a document
644#[derive(Debug, Clone, Serialize, Deserialize)]
645pub struct DocumentChunk {
646    /// Chunk index
647    pub index: u32,
648    /// Chunk content
649    pub content: String,
650    /// Start position in original document
651    pub start_pos: usize,
652    /// End position in original document
653    pub end_pos: usize,
654    /// Vector embedding
655    pub embedding: Option<Embedding>,
656    /// Token count estimate
657    pub token_count: u32,
658}
659
660/// Configuration for document chunking
661#[derive(Debug, Clone, Serialize, Deserialize)]
662pub struct ChunkingConfig {
663    /// Target chunk size in tokens
664    pub chunk_size: usize,
665    /// Overlap between chunks in tokens
666    pub chunk_overlap: usize,
667    /// Chunking strategy
668    pub strategy: ChunkingStrategy,
669}
670
671impl Default for ChunkingConfig {
672    fn default() -> Self {
673        Self {
674            chunk_size: 512,
675            chunk_overlap: 50,
676            strategy: ChunkingStrategy::Semantic,
677        }
678    }
679}
680
681/// Chunking strategies
682#[derive(Debug, Clone, Default, Serialize, Deserialize)]
683#[serde(rename_all = "snake_case")]
684pub enum ChunkingStrategy {
685    /// Fixed-size character chunks
686    FixedSize,
687    /// Split on sentences
688    Sentence,
689    /// Split on paragraphs
690    Paragraph,
691    /// Semantic chunking (respects structure)
692    #[default]
693    Semantic,
694    /// Code-aware chunking
695    Code,
696}
697
698/// Knowledge base for RAG
699pub struct KnowledgeBase {
700    /// Vector store for semantic search
701    vector_store: VectorStore,
702    /// Documents
703    documents: HashMap<String, Document>,
704    /// Chunking configuration
705    chunking_config: ChunkingConfig,
706}
707
708impl KnowledgeBase {
709    /// Create a new knowledge base
710    pub fn new(vector_config: VectorStoreConfig) -> Self {
711        Self {
712            vector_store: VectorStore::new(vector_config),
713            documents: HashMap::new(),
714            chunking_config: ChunkingConfig::default(),
715        }
716    }
717
718    /// Create with persistence
719    pub fn with_persistence(
720        vector_config: VectorStoreConfig,
721        db_path: impl AsRef<Path>,
722    ) -> Result<Self, MemoryError> {
723        Ok(Self {
724            vector_store: VectorStore::with_persistence(vector_config, db_path)?,
725            documents: HashMap::new(),
726            chunking_config: ChunkingConfig::default(),
727        })
728    }
729
730    /// Add a document
731    pub fn add_document(&mut self, mut document: Document) -> Result<String, MemoryError> {
732        // Chunk the document
733        document.chunks = self.chunk_document(&document.content);
734
735        let doc_id = document.id.clone();
736
737        // Add chunks to vector store
738        for chunk in &document.chunks {
739            if let Some(ref embedding) = chunk.embedding {
740                let entry = MemoryEntry::new(
741                    &chunk.content,
742                    MemoryType::Semantic,
743                    MemorySource::Document {
744                        path: document.source.clone(),
745                        chunk_index: chunk.index,
746                    },
747                )
748                .with_embedding(embedding.clone())
749                .with_tag(format!("doc:{}", doc_id));
750
751                self.vector_store.add(entry)?;
752            }
753        }
754
755        self.documents.insert(doc_id.clone(), document);
756        Ok(doc_id)
757    }
758
759    /// Chunk a document
760    fn chunk_document(&self, content: &str) -> Vec<DocumentChunk> {
761        match self.chunking_config.strategy {
762            ChunkingStrategy::Semantic => self.semantic_chunk(content),
763            ChunkingStrategy::Paragraph => self.paragraph_chunk(content),
764            ChunkingStrategy::Sentence => self.sentence_chunk(content),
765            ChunkingStrategy::FixedSize => self.fixed_chunk(content),
766            ChunkingStrategy::Code => self.code_chunk(content),
767        }
768    }
769
770    fn semantic_chunk(&self, content: &str) -> Vec<DocumentChunk> {
771        // Split on double newlines (paragraphs) but respect size limits
772        let mut chunks = Vec::new();
773        let mut current_chunk = String::new();
774        let mut start_pos = 0;
775        let mut chunk_index = 0;
776
777        for para in content.split("\n\n") {
778            let para = para.trim();
779            if para.is_empty() {
780                continue;
781            }
782
783            let para_tokens = estimate_tokens(para);
784            let current_tokens = estimate_tokens(&current_chunk);
785
786            if current_tokens + para_tokens > self.chunking_config.chunk_size
787                && !current_chunk.is_empty()
788            {
789                // Save current chunk
790                let end_pos = start_pos + current_chunk.len();
791                chunks.push(DocumentChunk {
792                    index: chunk_index,
793                    content: current_chunk.trim().to_string(),
794                    start_pos,
795                    end_pos,
796                    embedding: None,
797                    token_count: estimate_tokens(&current_chunk) as u32,
798                });
799                chunk_index += 1;
800                start_pos = end_pos;
801                current_chunk = String::new();
802            }
803
804            if !current_chunk.is_empty() {
805                current_chunk.push_str("\n\n");
806            }
807            current_chunk.push_str(para);
808        }
809
810        // Add remaining content
811        if !current_chunk.is_empty() {
812            let end_pos = start_pos + current_chunk.len();
813            chunks.push(DocumentChunk {
814                index: chunk_index,
815                content: current_chunk.trim().to_string(),
816                start_pos,
817                end_pos,
818                embedding: None,
819                token_count: estimate_tokens(&current_chunk) as u32,
820            });
821        }
822
823        chunks
824    }
825
826    fn paragraph_chunk(&self, content: &str) -> Vec<DocumentChunk> {
827        content
828            .split("\n\n")
829            .filter(|p| !p.trim().is_empty())
830            .enumerate()
831            .scan(0usize, |pos, (i, para)| {
832                let start = *pos;
833                *pos += para.len() + 2;
834                Some(DocumentChunk {
835                    index: i as u32,
836                    content: para.trim().to_string(),
837                    start_pos: start,
838                    end_pos: *pos,
839                    embedding: None,
840                    token_count: estimate_tokens(para) as u32,
841                })
842            })
843            .collect()
844    }
845
846    fn sentence_chunk(&self, content: &str) -> Vec<DocumentChunk> {
847        // Simple sentence splitting (could be improved with NLP)
848        let sentences: Vec<&str> = content
849            .split(|c| c == '.' || c == '!' || c == '?')
850            .filter(|s| !s.trim().is_empty())
851            .collect();
852
853        let mut chunks = Vec::new();
854        let mut current = String::new();
855        let mut start = 0;
856        let mut idx = 0;
857
858        for sentence in sentences {
859            let sentence = sentence.trim();
860            if estimate_tokens(&current) + estimate_tokens(sentence)
861                > self.chunking_config.chunk_size
862            {
863                if !current.is_empty() {
864                    chunks.push(DocumentChunk {
865                        index: idx,
866                        content: current.clone(),
867                        start_pos: start,
868                        end_pos: start + current.len(),
869                        embedding: None,
870                        token_count: estimate_tokens(&current) as u32,
871                    });
872                    idx += 1;
873                    start += current.len();
874                    current.clear();
875                }
876            }
877            if !current.is_empty() {
878                current.push(' ');
879            }
880            current.push_str(sentence);
881            current.push('.');
882        }
883
884        if !current.is_empty() {
885            chunks.push(DocumentChunk {
886                index: idx,
887                content: current.clone(),
888                start_pos: start,
889                end_pos: start + current.len(),
890                embedding: None,
891                token_count: estimate_tokens(&current) as u32,
892            });
893        }
894
895        chunks
896    }
897
898    fn fixed_chunk(&self, content: &str) -> Vec<DocumentChunk> {
899        let chars_per_chunk = self.chunking_config.chunk_size * 4; // Rough estimate
900        content
901            .chars()
902            .collect::<Vec<_>>()
903            .chunks(chars_per_chunk)
904            .enumerate()
905            .map(|(i, chars)| {
906                let s: String = chars.iter().collect();
907                DocumentChunk {
908                    index: i as u32,
909                    content: s.clone(),
910                    start_pos: i * chars_per_chunk,
911                    end_pos: (i + 1) * chars_per_chunk,
912                    embedding: None,
913                    token_count: estimate_tokens(&s) as u32,
914                }
915            })
916            .collect()
917    }
918
919    fn code_chunk(&self, content: &str) -> Vec<DocumentChunk> {
920        // Split on function/class definitions (simple heuristic)
921        let mut chunks = Vec::new();
922        let mut current = String::new();
923        let mut start = 0;
924        let mut idx = 0;
925
926        for line in content.lines() {
927            let is_boundary = line.starts_with("fn ")
928                || line.starts_with("pub fn ")
929                || line.starts_with("async fn ")
930                || line.starts_with("impl ")
931                || line.starts_with("struct ")
932                || line.starts_with("enum ")
933                || line.starts_with("trait ")
934                || line.starts_with("class ")
935                || line.starts_with("def ")
936                || line.starts_with("function ")
937                || line.starts_with("const ")
938                || line.starts_with("export ");
939
940            if is_boundary && !current.is_empty() {
941                chunks.push(DocumentChunk {
942                    index: idx,
943                    content: current.clone(),
944                    start_pos: start,
945                    end_pos: start + current.len(),
946                    embedding: None,
947                    token_count: estimate_tokens(&current) as u32,
948                });
949                idx += 1;
950                start += current.len();
951                current.clear();
952            }
953
954            current.push_str(line);
955            current.push('\n');
956        }
957
958        if !current.is_empty() {
959            chunks.push(DocumentChunk {
960                index: idx,
961                content: current.clone(),
962                start_pos: start,
963                end_pos: start + current.len(),
964                embedding: None,
965                token_count: estimate_tokens(&current) as u32,
966            });
967        }
968
969        chunks
970    }
971
972    /// Retrieve relevant context for a query
973    pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
974        self.vector_store.search(query_embedding, limit)
975    }
976
977    /// Get document by ID
978    pub fn get_document(&self, id: &str) -> Option<&Document> {
979        self.documents.get(id)
980    }
981
982    /// List all documents
983    pub fn list_documents(&self) -> Vec<&Document> {
984        self.documents.values().collect()
985    }
986
987    /// Delete document
988    pub fn delete_document(&mut self, id: &str) -> bool {
989        self.documents.remove(id).is_some()
990    }
991}
992
993// =============================================================================
994// Context Window Manager
995// =============================================================================
996
997/// Manages context for LLM prompts
998#[derive(Debug, Clone)]
999pub struct ContextWindow {
1000    /// Maximum tokens for context
1001    pub max_tokens: usize,
1002    /// Reserved tokens for response
1003    pub reserved_for_response: usize,
1004    /// Context segments
1005    segments: Vec<ContextSegment>,
1006}
1007
1008/// A segment of context
1009#[derive(Debug, Clone, Serialize, Deserialize)]
1010pub struct ContextSegment {
1011    /// Segment type
1012    pub segment_type: ContextSegmentType,
1013    /// Content
1014    pub content: String,
1015    /// Token count
1016    pub tokens: usize,
1017    /// Priority (higher = more important)
1018    pub priority: u32,
1019    /// Whether this segment is required
1020    pub required: bool,
1021}
1022
1023/// Types of context segments
1024#[derive(Debug, Clone, Serialize, Deserialize)]
1025#[serde(rename_all = "snake_case")]
1026pub enum ContextSegmentType {
1027    SystemPrompt,
1028    UserPreferences,
1029    ConversationHistory,
1030    RetrievedContext,
1031    ToolResults,
1032    CurrentQuery,
1033    Custom { name: String },
1034}
1035
1036impl ContextWindow {
1037    /// Create a new context window
1038    pub fn new(max_tokens: usize) -> Self {
1039        Self {
1040            max_tokens,
1041            reserved_for_response: max_tokens / 4, // Reserve 25% for response
1042            segments: Vec::new(),
1043        }
1044    }
1045
1046    /// Add a context segment
1047    pub fn add_segment(&mut self, segment: ContextSegment) {
1048        self.segments.push(segment);
1049    }
1050
1051    /// Build the final context, respecting token limits
1052    pub fn build(&mut self) -> String {
1053        let available = self.max_tokens - self.reserved_for_response;
1054
1055        // Sort by priority (required first, then by priority)
1056        self.segments
1057            .sort_by(|a, b| match (a.required, b.required) {
1058                (true, false) => std::cmp::Ordering::Less,
1059                (false, true) => std::cmp::Ordering::Greater,
1060                _ => b.priority.cmp(&a.priority),
1061            });
1062
1063        let mut total_tokens = 0;
1064        let mut result = Vec::new();
1065
1066        for segment in &self.segments {
1067            if total_tokens + segment.tokens <= available {
1068                result.push(segment.content.clone());
1069                total_tokens += segment.tokens;
1070            } else if segment.required {
1071                // Truncate if required
1072                let remaining = available.saturating_sub(total_tokens);
1073                if remaining > 0 {
1074                    let truncated = truncate_to_tokens(&segment.content, remaining);
1075                    result.push(truncated);
1076                    break;
1077                }
1078            }
1079        }
1080
1081        result.join("\n\n")
1082    }
1083
1084    /// Get current token usage
1085    pub fn token_usage(&self) -> (usize, usize) {
1086        let used: usize = self.segments.iter().map(|s| s.tokens).sum();
1087        (used, self.max_tokens - self.reserved_for_response)
1088    }
1089}
1090
1091// =============================================================================
1092// Cache
1093// =============================================================================
1094
1095/// Cache entry
1096#[derive(Debug, Clone, Serialize, Deserialize)]
1097pub struct CacheEntry<T> {
1098    pub key: String,
1099    pub value: T,
1100    pub created_at: DateTime<Utc>,
1101    pub expires_at: Option<DateTime<Utc>>,
1102    pub access_count: u64,
1103}
1104
1105impl<T> CacheEntry<T> {
1106    pub fn is_expired(&self) -> bool {
1107        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
1108    }
1109}
1110
1111/// Simple LRU cache for agent computations
1112pub struct AgentCache<T> {
1113    entries: HashMap<String, CacheEntry<T>>,
1114    max_size: usize,
1115}
1116
1117impl<T: Clone> AgentCache<T> {
1118    /// Create a new cache
1119    pub fn new(max_size: usize) -> Self {
1120        Self {
1121            entries: HashMap::new(),
1122            max_size,
1123        }
1124    }
1125
1126    /// Get a value from cache
1127    pub fn get(&mut self, key: &str) -> Option<T> {
1128        if let Some(entry) = self.entries.get_mut(key) {
1129            if entry.is_expired() {
1130                self.entries.remove(key);
1131                return None;
1132            }
1133            entry.access_count += 1;
1134            Some(entry.value.clone())
1135        } else {
1136            None
1137        }
1138    }
1139
1140    /// Set a value in cache
1141    pub fn set(&mut self, key: impl Into<String>, value: T, ttl: Option<chrono::Duration>) {
1142        let key = key.into();
1143        let now = Utc::now();
1144
1145        self.entries.insert(
1146            key.clone(),
1147            CacheEntry {
1148                key,
1149                value,
1150                created_at: now,
1151                expires_at: ttl.map(|d| now + d),
1152                access_count: 0,
1153            },
1154        );
1155
1156        // Evict if over size
1157        if self.entries.len() > self.max_size {
1158            self.evict_lru();
1159        }
1160    }
1161
1162    /// Remove a value
1163    pub fn remove(&mut self, key: &str) -> Option<T> {
1164        self.entries.remove(key).map(|e| e.value)
1165    }
1166
1167    /// Clear the cache
1168    pub fn clear(&mut self) {
1169        self.entries.clear();
1170    }
1171
1172    /// Evict least recently used entries
1173    fn evict_lru(&mut self) {
1174        // Remove expired first
1175        self.entries.retain(|_, v| !v.is_expired());
1176
1177        // If still over, remove least accessed
1178        if self.entries.len() > self.max_size {
1179            // Collect keys to remove (sorted by access count)
1180            let mut entries: Vec<_> = self
1181                .entries
1182                .iter()
1183                .map(|(k, v)| (k.clone(), v.access_count))
1184                .collect();
1185            entries.sort_by_key(|(_, count)| *count);
1186
1187            let to_remove = self.entries.len() - self.max_size;
1188            let keys_to_remove: Vec<String> = entries
1189                .into_iter()
1190                .take(to_remove)
1191                .map(|(k, _)| k)
1192                .collect();
1193
1194            for key in keys_to_remove {
1195                self.entries.remove(&key);
1196            }
1197        }
1198    }
1199}
1200
1201// =============================================================================
1202// Memory Manager (Unified Interface)
1203// =============================================================================
1204
1205/// Configuration for the memory manager
1206#[derive(Debug, Clone, Serialize, Deserialize)]
1207pub struct MemoryConfig {
1208    /// Vector store configuration
1209    pub vector_store: VectorStoreConfig,
1210    /// Chunking configuration
1211    pub chunking: ChunkingConfig,
1212    /// Context window size
1213    pub context_window_tokens: usize,
1214    /// Cache size
1215    pub cache_size: usize,
1216    /// Database path (None for in-memory)
1217    pub db_path: Option<String>,
1218    /// Auto-summarize long conversations
1219    pub auto_summarize: bool,
1220    /// Summarize after this many messages
1221    pub summarize_threshold: usize,
1222}
1223
1224impl Default for MemoryConfig {
1225    fn default() -> Self {
1226        Self {
1227            vector_store: VectorStoreConfig::default(),
1228            chunking: ChunkingConfig::default(),
1229            context_window_tokens: 8192,
1230            cache_size: 1000,
1231            db_path: None,
1232            auto_summarize: true,
1233            summarize_threshold: 20,
1234        }
1235    }
1236}
1237
1238/// Unified memory manager for agents
1239pub struct MemoryManager {
1240    config: MemoryConfig,
1241    vector_store: VectorStore,
1242    knowledge_base: KnowledgeBase,
1243    cache: AgentCache<String>,
1244}
1245
1246impl MemoryManager {
1247    /// Create a new memory manager
1248    pub fn new(config: MemoryConfig) -> Result<Self, MemoryError> {
1249        let vector_store = if let Some(ref path) = config.db_path {
1250            VectorStore::with_persistence(config.vector_store.clone(), path)?
1251        } else {
1252            VectorStore::new(config.vector_store.clone())
1253        };
1254
1255        let knowledge_base = if let Some(ref path) = config.db_path {
1256            let kb_path = format!("{}_kb", path);
1257            KnowledgeBase::with_persistence(config.vector_store.clone(), kb_path)?
1258        } else {
1259            KnowledgeBase::new(config.vector_store.clone())
1260        };
1261
1262        Ok(Self {
1263            config: config.clone(),
1264            vector_store,
1265            knowledge_base,
1266            cache: AgentCache::new(config.cache_size),
1267        })
1268    }
1269
1270    /// Store a memory
1271    pub fn remember(
1272        &mut self,
1273        content: impl Into<String>,
1274        memory_type: MemoryType,
1275        source: MemorySource,
1276    ) -> Result<MemoryId, MemoryError> {
1277        let entry = MemoryEntry::new(content, memory_type, source);
1278        self.vector_store.add(entry)
1279    }
1280
1281    /// Store a memory with embedding
1282    pub fn remember_with_embedding(
1283        &mut self,
1284        content: impl Into<String>,
1285        embedding: Embedding,
1286        memory_type: MemoryType,
1287        source: MemorySource,
1288    ) -> Result<MemoryId, MemoryError> {
1289        let entry = MemoryEntry::new(content, memory_type, source).with_embedding(embedding);
1290        self.vector_store.add(entry)
1291    }
1292
1293    /// Recall memories similar to a query
1294    pub fn recall(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1295        self.vector_store.search(query_embedding, limit)
1296    }
1297
1298    /// Recall by type
1299    pub fn recall_by_type(&self, memory_type: MemoryType, limit: usize) -> Vec<&MemoryEntry> {
1300        self.vector_store.search_by_type(memory_type, limit)
1301    }
1302
1303    /// Add document to knowledge base
1304    pub fn add_document(&mut self, document: Document) -> Result<String, MemoryError> {
1305        self.knowledge_base.add_document(document)
1306    }
1307
1308    /// Retrieve from knowledge base
1309    pub fn retrieve(&mut self, query_embedding: &Embedding, limit: usize) -> Vec<SearchResult> {
1310        self.knowledge_base.retrieve(query_embedding, limit)
1311    }
1312
1313    /// Build context for a prompt
1314    pub fn build_context(
1315        &mut self,
1316        query_embedding: &Embedding,
1317        system_prompt: &str,
1318        conversation: &[String],
1319    ) -> String {
1320        let mut context = ContextWindow::new(self.config.context_window_tokens);
1321
1322        // System prompt (required)
1323        context.add_segment(ContextSegment {
1324            segment_type: ContextSegmentType::SystemPrompt,
1325            content: system_prompt.to_string(),
1326            tokens: estimate_tokens(system_prompt),
1327            priority: 100,
1328            required: true,
1329        });
1330
1331        // Retrieved context
1332        let retrieved = self.recall(query_embedding, 5);
1333        if !retrieved.is_empty() {
1334            let retrieved_text: String = retrieved
1335                .iter()
1336                .map(|r| format!("- {}", r.entry.content))
1337                .collect::<Vec<_>>()
1338                .join("\n");
1339
1340            context.add_segment(ContextSegment {
1341                segment_type: ContextSegmentType::RetrievedContext,
1342                content: format!("Relevant context:\n{}", retrieved_text),
1343                tokens: estimate_tokens(&retrieved_text) + 20,
1344                priority: 80,
1345                required: false,
1346            });
1347        }
1348
1349        // Conversation history
1350        let conv_text = conversation.join("\n");
1351        context.add_segment(ContextSegment {
1352            segment_type: ContextSegmentType::ConversationHistory,
1353            content: conv_text.clone(),
1354            tokens: estimate_tokens(&conv_text),
1355            priority: 90,
1356            required: false,
1357        });
1358
1359        context.build()
1360    }
1361
1362    /// Cache a computation result
1363    pub fn cache_result(
1364        &mut self,
1365        key: impl Into<String>,
1366        value: String,
1367        ttl: Option<chrono::Duration>,
1368    ) {
1369        self.cache.set(key, value, ttl);
1370    }
1371
1372    /// Get cached result
1373    pub fn get_cached(&mut self, key: &str) -> Option<String> {
1374        self.cache.get(key)
1375    }
1376
1377    /// Get statistics
1378    pub fn stats(&self) -> MemoryStats {
1379        MemoryStats {
1380            vector_store: self.vector_store.stats(),
1381            document_count: self.knowledge_base.list_documents().len(),
1382        }
1383    }
1384}
1385
1386/// Memory statistics
1387#[derive(Debug, Clone, Serialize, Deserialize)]
1388pub struct MemoryStats {
1389    pub vector_store: VectorStoreStats,
1390    pub document_count: usize,
1391}
1392
1393// =============================================================================
1394// Error Types
1395// =============================================================================
1396
1397/// Memory system errors
1398#[derive(Debug, Clone)]
1399pub enum MemoryError {
1400    /// Database error
1401    Database(String),
1402    /// Embedding error
1403    Embedding(String),
1404    /// Not found
1405    NotFound(String),
1406    /// Invalid input
1407    InvalidInput(String),
1408    /// IO error
1409    Io(String),
1410}
1411
1412impl std::fmt::Display for MemoryError {
1413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1414        match self {
1415            MemoryError::Database(e) => write!(f, "Database error: {}", e),
1416            MemoryError::Embedding(e) => write!(f, "Embedding error: {}", e),
1417            MemoryError::NotFound(e) => write!(f, "Not found: {}", e),
1418            MemoryError::InvalidInput(e) => write!(f, "Invalid input: {}", e),
1419            MemoryError::Io(e) => write!(f, "IO error: {}", e),
1420        }
1421    }
1422}
1423
1424impl std::error::Error for MemoryError {}
1425
1426// =============================================================================
1427// Utility Functions
1428// =============================================================================
1429
1430fn generate_memory_id() -> String {
1431    use std::time::{SystemTime, UNIX_EPOCH};
1432    let timestamp = SystemTime::now()
1433        .duration_since(UNIX_EPOCH)
1434        .unwrap()
1435        .as_nanos();
1436    format!("mem_{:x}", timestamp)
1437}
1438
1439/// Estimate token count (rough approximation: 4 chars per token)
1440fn estimate_tokens(text: &str) -> usize {
1441    (text.len() as f32 / 4.0).ceil() as usize
1442}
1443
1444/// Truncate text to approximate token count
1445fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
1446    let max_chars = max_tokens * 4;
1447    if text.len() <= max_chars {
1448        text.to_string()
1449    } else {
1450        format!("{}...", &text[..max_chars.min(text.len())])
1451    }
1452}
1453
1454// =============================================================================
1455// Embedding Provider Trait
1456// =============================================================================
1457
1458/// Trait for embedding providers
1459#[async_trait::async_trait]
1460pub trait EmbeddingProvider: Send + Sync {
1461    /// Generate embedding for text
1462    async fn embed(&self, text: &str) -> Result<Embedding, MemoryError>;
1463
1464    /// Generate embeddings for multiple texts
1465    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError>;
1466
1467    /// Get the embedding dimension
1468    fn dimension(&self) -> usize;
1469}
1470
1471/// OpenAI embedding provider
1472pub struct OpenAIEmbedding {
1473    #[allow(dead_code)]
1474    api_key: String,
1475    model: String,
1476}
1477
1478impl OpenAIEmbedding {
1479    pub fn new(api_key: impl Into<String>) -> Self {
1480        Self {
1481            api_key: api_key.into(),
1482            model: "text-embedding-3-small".to_string(),
1483        }
1484    }
1485
1486    pub fn with_model(mut self, model: impl Into<String>) -> Self {
1487        self.model = model.into();
1488        self
1489    }
1490}
1491
1492#[async_trait::async_trait]
1493impl EmbeddingProvider for OpenAIEmbedding {
1494    async fn embed(&self, _text: &str) -> Result<Embedding, MemoryError> {
1495        // Implementation would call OpenAI API
1496        // For now, return a placeholder
1497        Ok(vec![0.0; 1536])
1498    }
1499
1500    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, MemoryError> {
1501        let mut results = Vec::new();
1502        for text in texts {
1503            results.push(self.embed(text).await?);
1504        }
1505        Ok(results)
1506    }
1507
1508    fn dimension(&self) -> usize {
1509        match self.model.as_str() {
1510            "text-embedding-3-large" => 3072,
1511            _ => 1536,
1512        }
1513    }
1514}
1515
1516// =============================================================================
1517// Tests
1518// =============================================================================
1519
1520#[cfg(test)]
1521mod tests {
1522    use super::*;
1523
1524    #[test]
1525    fn test_memory_entry_creation() {
1526        let entry = MemoryEntry::new(
1527            "Test content",
1528            MemoryType::LongTerm,
1529            MemorySource::UserInput,
1530        );
1531        assert!(!entry.id.is_empty());
1532        assert_eq!(entry.content, "Test content");
1533        assert_eq!(entry.memory_type, MemoryType::LongTerm);
1534    }
1535
1536    #[test]
1537    fn test_similarity_metrics() {
1538        let a = vec![1.0, 0.0, 0.0];
1539        let b = vec![1.0, 0.0, 0.0];
1540        let c = vec![0.0, 1.0, 0.0];
1541
1542        assert!((SimilarityMetric::Cosine.calculate(&a, &b) - 1.0).abs() < 0.001);
1543        assert!((SimilarityMetric::Cosine.calculate(&a, &c) - 0.0).abs() < 0.001);
1544    }
1545
1546    #[test]
1547    fn test_vector_store() {
1548        let config = VectorStoreConfig::default();
1549        let mut store = VectorStore::new(config);
1550
1551        let entry = MemoryEntry::new("Test", MemoryType::ShortTerm, MemorySource::UserInput)
1552            .with_embedding(vec![1.0, 0.0, 0.0]);
1553
1554        let id = store.add(entry).unwrap();
1555        assert!(!id.is_empty());
1556        assert!(store.get(&id).is_some());
1557    }
1558
1559    #[test]
1560    fn test_context_window() {
1561        let mut ctx = ContextWindow::new(1000);
1562
1563        ctx.add_segment(ContextSegment {
1564            segment_type: ContextSegmentType::SystemPrompt,
1565            content: "You are helpful".to_string(),
1566            tokens: 10,
1567            priority: 100,
1568            required: true,
1569        });
1570
1571        let result = ctx.build();
1572        assert!(result.contains("You are helpful"));
1573    }
1574
1575    #[test]
1576    fn test_cache() {
1577        let mut cache: AgentCache<String> = AgentCache::new(10);
1578
1579        cache.set("key1", "value1".to_string(), None);
1580        assert_eq!(cache.get("key1"), Some("value1".to_string()));
1581        assert_eq!(cache.get("key2"), None);
1582    }
1583
1584    #[test]
1585    fn test_estimate_tokens() {
1586        assert_eq!(estimate_tokens("hello"), 2); // 5 chars / 4 = 1.25 -> 2
1587        assert_eq!(estimate_tokens("hello world"), 3); // 11 chars / 4 = 2.75 -> 3
1588    }
1589}