Skip to main content

cortexai_agents/
multi_memory.rs

1//! # Multi-Memory System
2//!
3//! Hierarchical memory system with short-term, long-term, and entity memory.
4//!
5//! Inspired by CrewAI's memory architecture.
6//!
7//! ## Memory Types
8//!
9//! - **Short-term Memory**: Recent conversation context, auto-expires
10//! - **Long-term Memory**: Persistent facts and learnings, vector-searchable
11//! - **Entity Memory**: Knowledge about specific entities (people, places, concepts)
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use cortex::multi_memory::{MultiMemory, MemoryConfig};
17//!
18//! let memory = MultiMemory::new(MemoryConfig::default());
19//!
20//! // Store in short-term (recent context)
21//! memory.short_term.add("User asked about Rust").await?;
22//!
23//! // Store in long-term (persistent knowledge)
24//! memory.long_term.store("Rust is a systems programming language", embedding).await?;
25//!
26//! // Store entity information
27//! memory.entity.update("Rust", "category", "programming_language").await?;
28//! memory.entity.update("Rust", "creator", "Mozilla").await?;
29//!
30//! // Query across all memory types
31//! let context = memory.recall("Tell me about Rust", query_embedding).await?;
32//! ```
33
34use std::collections::{HashMap, VecDeque};
35use std::sync::Arc;
36use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
37
38use parking_lot::RwLock;
39use serde::{Deserialize, Serialize};
40use tracing::{debug, info};
41
42/// Configuration for the multi-memory system
43#[derive(Debug, Clone)]
44pub struct MemoryConfig {
45    /// Maximum items in short-term memory
46    pub short_term_capacity: usize,
47    /// TTL for short-term memories
48    pub short_term_ttl: Duration,
49    /// Maximum items in long-term memory
50    pub long_term_capacity: usize,
51    /// Similarity threshold for long-term recall (0.0 to 1.0)
52    pub similarity_threshold: f32,
53    /// Maximum entities to track
54    pub entity_capacity: usize,
55    /// Maximum attributes per entity
56    pub max_attributes_per_entity: usize,
57}
58
59impl Default for MemoryConfig {
60    fn default() -> Self {
61        Self {
62            short_term_capacity: 100,
63            short_term_ttl: Duration::from_secs(3600), // 1 hour
64            long_term_capacity: 10000,
65            similarity_threshold: 0.7,
66            entity_capacity: 1000,
67            max_attributes_per_entity: 50,
68        }
69    }
70}
71
72/// A memory item with metadata
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct MemoryItem {
75    /// Unique ID
76    pub id: String,
77    /// The content
78    pub content: String,
79    /// When it was created
80    pub created_at: u64,
81    /// When it was last accessed
82    pub last_accessed: u64,
83    /// Access count
84    pub access_count: u32,
85    /// Optional importance score (0.0 to 1.0)
86    pub importance: f32,
87    /// Source of the memory
88    pub source: MemorySource,
89    /// Optional metadata
90    pub metadata: HashMap<String, String>,
91}
92
93impl MemoryItem {
94    pub fn new(content: impl Into<String>, source: MemorySource) -> Self {
95        let now = SystemTime::now()
96            .duration_since(UNIX_EPOCH)
97            .unwrap_or_default()
98            .as_secs();
99
100        Self {
101            id: uuid::Uuid::new_v4().to_string(),
102            content: content.into(),
103            created_at: now,
104            last_accessed: now,
105            access_count: 0,
106            importance: 0.5,
107            source,
108            metadata: HashMap::new(),
109        }
110    }
111
112    pub fn with_importance(mut self, importance: f32) -> Self {
113        self.importance = importance.clamp(0.0, 1.0);
114        self
115    }
116
117    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
118        self.metadata.insert(key.into(), value.into());
119        self
120    }
121
122    fn touch(&mut self) {
123        self.last_accessed = SystemTime::now()
124            .duration_since(UNIX_EPOCH)
125            .unwrap_or_default()
126            .as_secs();
127        self.access_count += 1;
128    }
129}
130
131/// Source of a memory
132#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
133pub enum MemorySource {
134    /// From user input
135    User,
136    /// From agent response
137    Agent,
138    /// From tool execution
139    Tool,
140    /// From external knowledge base
141    External,
142    /// System-generated
143    System,
144}
145
146// ============================================================================
147// Short-Term Memory
148// ============================================================================
149
150/// Short-term memory for recent conversation context
151pub struct ShortTermMemory {
152    items: RwLock<VecDeque<(MemoryItem, Instant)>>,
153    capacity: usize,
154    ttl: Duration,
155}
156
157impl ShortTermMemory {
158    pub fn new(capacity: usize, ttl: Duration) -> Self {
159        Self {
160            items: RwLock::new(VecDeque::with_capacity(capacity)),
161            capacity,
162            ttl,
163        }
164    }
165
166    /// Add a new memory item
167    pub fn add(&self, content: impl Into<String>, source: MemorySource) {
168        let item = MemoryItem::new(content, source);
169        self.add_item(item);
170    }
171
172    /// Add a memory item
173    pub fn add_item(&self, item: MemoryItem) {
174        let mut items = self.items.write();
175
176        // Remove expired items
177        let now = Instant::now();
178        while let Some((_, created)) = items.front() {
179            if now.duration_since(*created) > self.ttl {
180                items.pop_front();
181            } else {
182                break;
183            }
184        }
185
186        // Enforce capacity
187        while items.len() >= self.capacity {
188            items.pop_front();
189        }
190
191        items.push_back((item, now));
192    }
193
194    /// Get all current memories (not expired)
195    pub fn get_all(&self) -> Vec<MemoryItem> {
196        let items = self.items.read();
197        let now = Instant::now();
198
199        items
200            .iter()
201            .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
202            .map(|(item, _)| item.clone())
203            .collect()
204    }
205
206    /// Get the N most recent memories
207    pub fn get_recent(&self, n: usize) -> Vec<MemoryItem> {
208        let items = self.items.read();
209        let now = Instant::now();
210
211        items
212            .iter()
213            .rev()
214            .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
215            .take(n)
216            .map(|(item, _)| item.clone())
217            .collect()
218    }
219
220    /// Search for memories containing text
221    pub fn search(&self, query: &str) -> Vec<MemoryItem> {
222        let query_lower = query.to_lowercase();
223        let items = self.items.read();
224        let now = Instant::now();
225
226        items
227            .iter()
228            .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
229            .filter(|(item, _)| item.content.to_lowercase().contains(&query_lower))
230            .map(|(item, _)| item.clone())
231            .collect()
232    }
233
234    /// Clear all memories
235    pub fn clear(&self) {
236        self.items.write().clear();
237    }
238
239    /// Get current count of memories
240    pub fn len(&self) -> usize {
241        let items = self.items.read();
242        let now = Instant::now();
243        items
244            .iter()
245            .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
246            .count()
247    }
248
249    pub fn is_empty(&self) -> bool {
250        self.len() == 0
251    }
252
253    /// Get as formatted context string
254    pub fn as_context(&self, max_items: usize) -> String {
255        let items = self.get_recent(max_items);
256        items
257            .iter()
258            .map(|item| format!("[{}] {}", format_source(&item.source), item.content))
259            .collect::<Vec<_>>()
260            .join("\n")
261    }
262}
263
264fn format_source(source: &MemorySource) -> &'static str {
265    match source {
266        MemorySource::User => "User",
267        MemorySource::Agent => "Agent",
268        MemorySource::Tool => "Tool",
269        MemorySource::External => "External",
270        MemorySource::System => "System",
271    }
272}
273
274// ============================================================================
275// Long-Term Memory
276// ============================================================================
277
278/// Entry in long-term memory with embedding
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct LongTermEntry {
281    pub item: MemoryItem,
282    pub embedding: Vec<f32>,
283}
284
285/// Long-term memory with semantic search
286pub struct LongTermMemory {
287    entries: RwLock<Vec<LongTermEntry>>,
288    capacity: usize,
289    similarity_threshold: f32,
290}
291
292impl LongTermMemory {
293    pub fn new(capacity: usize, similarity_threshold: f32) -> Self {
294        Self {
295            entries: RwLock::new(Vec::with_capacity(capacity)),
296            capacity,
297            similarity_threshold,
298        }
299    }
300
301    /// Store a memory with its embedding
302    pub fn store(
303        &self,
304        content: impl Into<String>,
305        embedding: Vec<f32>,
306        source: MemorySource,
307    ) -> String {
308        let item = MemoryItem::new(content, source);
309        self.store_item(item, embedding)
310    }
311
312    /// Store a memory item with embedding
313    pub fn store_item(&self, item: MemoryItem, embedding: Vec<f32>) -> String {
314        let id = item.id.clone();
315        let entry = LongTermEntry { item, embedding };
316
317        let mut entries = self.entries.write();
318
319        // Enforce capacity - remove oldest by access time
320        while entries.len() >= self.capacity {
321            if let Some(idx) = entries
322                .iter()
323                .enumerate()
324                .min_by_key(|(_, e)| e.item.last_accessed)
325                .map(|(idx, _)| idx)
326            {
327                entries.remove(idx);
328            }
329        }
330
331        entries.push(entry);
332        debug!(id = %id, "Long-term memory stored");
333        id
334    }
335
336    /// Search for similar memories using embedding
337    pub fn search(&self, query_embedding: &[f32], limit: usize) -> Vec<(MemoryItem, f32)> {
338        let mut entries = self.entries.write();
339
340        let mut results: Vec<_> = entries
341            .iter_mut()
342            .map(|entry| {
343                let similarity = cosine_similarity(&entry.embedding, query_embedding);
344                (entry, similarity)
345            })
346            .filter(|(_, sim)| *sim >= self.similarity_threshold)
347            .collect();
348
349        // Sort by similarity descending
350        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
351
352        // Take top N and update access info
353        results
354            .into_iter()
355            .take(limit)
356            .map(|(entry, sim)| {
357                entry.item.touch();
358                (entry.item.clone(), sim)
359            })
360            .collect()
361    }
362
363    /// Get all entries (for export/debugging)
364    pub fn get_all(&self) -> Vec<LongTermEntry> {
365        self.entries.read().clone()
366    }
367
368    /// Remove a specific memory
369    pub fn remove(&self, id: &str) -> bool {
370        let mut entries = self.entries.write();
371        let initial_len = entries.len();
372        entries.retain(|e| e.item.id != id);
373        entries.len() < initial_len
374    }
375
376    /// Clear all memories
377    pub fn clear(&self) {
378        self.entries.write().clear();
379    }
380
381    pub fn len(&self) -> usize {
382        self.entries.read().len()
383    }
384
385    pub fn is_empty(&self) -> bool {
386        self.entries.read().is_empty()
387    }
388
389    /// Get memories by importance threshold
390    pub fn get_important(&self, min_importance: f32) -> Vec<MemoryItem> {
391        self.entries
392            .read()
393            .iter()
394            .filter(|e| e.item.importance >= min_importance)
395            .map(|e| e.item.clone())
396            .collect()
397    }
398}
399
400fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
401    if a.len() != b.len() || a.is_empty() {
402        return 0.0;
403    }
404
405    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
406    let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
407    let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
408
409    if magnitude_a == 0.0 || magnitude_b == 0.0 {
410        return 0.0;
411    }
412
413    dot_product / (magnitude_a * magnitude_b)
414}
415
416// ============================================================================
417// Entity Memory
418// ============================================================================
419
420/// An entity with attributes
421#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct Entity {
423    /// Entity name/identifier
424    pub name: String,
425    /// Entity type (person, place, concept, etc.)
426    pub entity_type: String,
427    /// Attributes as key-value pairs
428    pub attributes: HashMap<String, String>,
429    /// When the entity was first seen
430    pub created_at: u64,
431    /// When the entity was last updated
432    pub updated_at: u64,
433    /// Confidence in the entity's existence (0.0 to 1.0)
434    pub confidence: f32,
435    /// Related entity names
436    pub relations: Vec<(String, String)>, // (relation_type, target_entity)
437}
438
439impl Entity {
440    pub fn new(name: impl Into<String>, entity_type: impl Into<String>) -> Self {
441        let now = SystemTime::now()
442            .duration_since(UNIX_EPOCH)
443            .unwrap_or_default()
444            .as_secs();
445
446        Self {
447            name: name.into(),
448            entity_type: entity_type.into(),
449            attributes: HashMap::new(),
450            created_at: now,
451            updated_at: now,
452            confidence: 1.0,
453            relations: Vec::new(),
454        }
455    }
456
457    pub fn set_attribute(&mut self, key: impl Into<String>, value: impl Into<String>) {
458        self.attributes.insert(key.into(), value.into());
459        self.touch();
460    }
461
462    pub fn get_attribute(&self, key: &str) -> Option<&String> {
463        self.attributes.get(key)
464    }
465
466    pub fn add_relation(&mut self, relation_type: impl Into<String>, target: impl Into<String>) {
467        self.relations.push((relation_type.into(), target.into()));
468        self.touch();
469    }
470
471    fn touch(&mut self) {
472        self.updated_at = SystemTime::now()
473            .duration_since(UNIX_EPOCH)
474            .unwrap_or_default()
475            .as_secs();
476    }
477}
478
479/// Entity memory for tracking knowledge about specific entities
480pub struct EntityMemory {
481    entities: RwLock<HashMap<String, Entity>>,
482    capacity: usize,
483    max_attributes: usize,
484}
485
486impl EntityMemory {
487    pub fn new(capacity: usize, max_attributes: usize) -> Self {
488        Self {
489            entities: RwLock::new(HashMap::with_capacity(capacity)),
490            capacity,
491            max_attributes,
492        }
493    }
494
495    /// Get or create an entity
496    pub fn get_or_create(&self, name: &str, entity_type: &str) -> Entity {
497        let mut entities = self.entities.write();
498
499        if let Some(entity) = entities.get(name) {
500            return entity.clone();
501        }
502
503        // Enforce capacity
504        while entities.len() >= self.capacity {
505            // Remove least recently updated entity
506            if let Some(oldest) = entities
507                .iter()
508                .min_by_key(|(_, e)| e.updated_at)
509                .map(|(k, _)| k.clone())
510            {
511                entities.remove(&oldest);
512            }
513        }
514
515        let entity = Entity::new(name, entity_type);
516        entities.insert(name.to_string(), entity.clone());
517        entity
518    }
519
520    /// Update an entity's attribute
521    pub fn update_attribute(&self, name: &str, key: &str, value: &str) -> bool {
522        let mut entities = self.entities.write();
523
524        if let Some(entity) = entities.get_mut(name) {
525            if entity.attributes.len() < self.max_attributes || entity.attributes.contains_key(key)
526            {
527                entity.set_attribute(key, value);
528                return true;
529            }
530        }
531        false
532    }
533
534    /// Get an entity by name
535    pub fn get(&self, name: &str) -> Option<Entity> {
536        self.entities.read().get(name).cloned()
537    }
538
539    /// Check if an entity exists
540    pub fn exists(&self, name: &str) -> bool {
541        self.entities.read().contains_key(name)
542    }
543
544    /// Add a relation between entities
545    pub fn add_relation(&self, source: &str, relation_type: &str, target: &str) -> bool {
546        let mut entities = self.entities.write();
547
548        if let Some(entity) = entities.get_mut(source) {
549            entity.add_relation(relation_type, target);
550            return true;
551        }
552        false
553    }
554
555    /// Get all entities of a specific type
556    pub fn get_by_type(&self, entity_type: &str) -> Vec<Entity> {
557        self.entities
558            .read()
559            .values()
560            .filter(|e| e.entity_type == entity_type)
561            .cloned()
562            .collect()
563    }
564
565    /// Search entities by attribute
566    pub fn search_by_attribute(&self, key: &str, value: &str) -> Vec<Entity> {
567        self.entities
568            .read()
569            .values()
570            .filter(|e| e.attributes.get(key).map(|v| v == value).unwrap_or(false))
571            .cloned()
572            .collect()
573    }
574
575    /// Get all entities
576    pub fn get_all(&self) -> Vec<Entity> {
577        self.entities.read().values().cloned().collect()
578    }
579
580    /// Remove an entity
581    pub fn remove(&self, name: &str) -> bool {
582        self.entities.write().remove(name).is_some()
583    }
584
585    /// Clear all entities
586    pub fn clear(&self) {
587        self.entities.write().clear();
588    }
589
590    pub fn len(&self) -> usize {
591        self.entities.read().len()
592    }
593
594    pub fn is_empty(&self) -> bool {
595        self.entities.read().is_empty()
596    }
597
598    /// Get entity as context string
599    pub fn entity_context(&self, name: &str) -> Option<String> {
600        self.get(name).map(|entity| {
601            let mut lines = vec![format!("{} ({})", entity.name, entity.entity_type)];
602
603            for (key, value) in &entity.attributes {
604                lines.push(format!("  - {}: {}", key, value));
605            }
606
607            for (relation, target) in &entity.relations {
608                lines.push(format!("  - {} -> {}", relation, target));
609            }
610
611            lines.join("\n")
612        })
613    }
614}
615
616// ============================================================================
617// Multi-Memory System
618// ============================================================================
619
620/// Unified multi-memory system combining all memory types
621pub struct MultiMemory {
622    /// Short-term memory for recent context
623    pub short_term: Arc<ShortTermMemory>,
624    /// Long-term memory for persistent knowledge
625    pub long_term: Arc<LongTermMemory>,
626    /// Entity memory for structured knowledge
627    pub entity: Arc<EntityMemory>,
628    /// Configuration
629    #[allow(dead_code)]
630    config: MemoryConfig,
631}
632
633impl MultiMemory {
634    pub fn new(config: MemoryConfig) -> Self {
635        Self {
636            short_term: Arc::new(ShortTermMemory::new(
637                config.short_term_capacity,
638                config.short_term_ttl,
639            )),
640            long_term: Arc::new(LongTermMemory::new(
641                config.long_term_capacity,
642                config.similarity_threshold,
643            )),
644            entity: Arc::new(EntityMemory::new(
645                config.entity_capacity,
646                config.max_attributes_per_entity,
647            )),
648            config,
649        }
650    }
651
652    /// Create with default configuration
653    pub fn default_config() -> Self {
654        Self::new(MemoryConfig::default())
655    }
656
657    /// Recall relevant memories for a query
658    pub fn recall(
659        &self,
660        query: &str,
661        query_embedding: Option<&[f32]>,
662        limit: usize,
663    ) -> RecallResult {
664        let mut result = RecallResult::default();
665
666        // Search short-term memory by text
667        result.short_term = self.short_term.search(query);
668
669        // Search long-term memory by embedding if available
670        if let Some(embedding) = query_embedding {
671            result.long_term = self.long_term.search(embedding, limit);
672        }
673
674        // Extract potential entity names from query (simple word extraction)
675        let words: Vec<&str> = query.split_whitespace().collect();
676        for word in words {
677            if let Some(entity) = self.entity.get(word) {
678                result.entities.push(entity);
679            }
680        }
681
682        result
683    }
684
685    /// Build context string from memories
686    pub fn build_context(
687        &self,
688        query: &str,
689        query_embedding: Option<&[f32]>,
690        max_short_term: usize,
691        max_long_term: usize,
692    ) -> String {
693        let mut sections = Vec::new();
694
695        // Recent context
696        let recent = self.short_term.get_recent(max_short_term);
697        if !recent.is_empty() {
698            let context = recent
699                .iter()
700                .map(|item| format!("- {}", item.content))
701                .collect::<Vec<_>>()
702                .join("\n");
703            sections.push(format!("## Recent Context\n{}", context));
704        }
705
706        // Relevant long-term memories
707        if let Some(embedding) = query_embedding {
708            let long_term = self.long_term.search(embedding, max_long_term);
709            if !long_term.is_empty() {
710                let context = long_term
711                    .iter()
712                    .map(|(item, score)| format!("- {} (relevance: {:.2})", item.content, score))
713                    .collect::<Vec<_>>()
714                    .join("\n");
715                sections.push(format!("## Relevant Knowledge\n{}", context));
716            }
717        }
718
719        // Related entities
720        let words: Vec<&str> = query.split_whitespace().collect();
721        let mut entity_contexts = Vec::new();
722        for word in words {
723            if let Some(ctx) = self.entity.entity_context(word) {
724                entity_contexts.push(ctx);
725            }
726        }
727        if !entity_contexts.is_empty() {
728            sections.push(format!(
729                "## Known Entities\n{}",
730                entity_contexts.join("\n\n")
731            ));
732        }
733
734        sections.join("\n\n")
735    }
736
737    /// Clear all memories
738    pub fn clear_all(&self) {
739        self.short_term.clear();
740        self.long_term.clear();
741        self.entity.clear();
742        info!("All memories cleared");
743    }
744
745    /// Get memory statistics
746    pub fn stats(&self) -> MemoryStats {
747        MemoryStats {
748            short_term_count: self.short_term.len(),
749            long_term_count: self.long_term.len(),
750            entity_count: self.entity.len(),
751        }
752    }
753}
754
755/// Result of a memory recall
756#[derive(Debug, Default)]
757pub struct RecallResult {
758    /// Memories from short-term
759    pub short_term: Vec<MemoryItem>,
760    /// Memories from long-term with similarity scores
761    pub long_term: Vec<(MemoryItem, f32)>,
762    /// Related entities
763    pub entities: Vec<Entity>,
764}
765
766impl RecallResult {
767    pub fn is_empty(&self) -> bool {
768        self.short_term.is_empty() && self.long_term.is_empty() && self.entities.is_empty()
769    }
770
771    pub fn total_count(&self) -> usize {
772        self.short_term.len() + self.long_term.len() + self.entities.len()
773    }
774}
775
776/// Memory statistics
777#[derive(Debug, Clone, Serialize, Deserialize)]
778pub struct MemoryStats {
779    pub short_term_count: usize,
780    pub long_term_count: usize,
781    pub entity_count: usize,
782}
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787
788    #[test]
789    fn test_short_term_memory_add_and_retrieve() {
790        let memory = ShortTermMemory::new(10, Duration::from_secs(3600));
791
792        memory.add("First message", MemorySource::User);
793        memory.add("Second message", MemorySource::Agent);
794
795        let all = memory.get_all();
796        assert_eq!(all.len(), 2);
797    }
798
799    #[test]
800    fn test_short_term_memory_capacity() {
801        let memory = ShortTermMemory::new(3, Duration::from_secs(3600));
802
803        memory.add("1", MemorySource::User);
804        memory.add("2", MemorySource::User);
805        memory.add("3", MemorySource::User);
806        memory.add("4", MemorySource::User);
807
808        let all = memory.get_all();
809        assert_eq!(all.len(), 3);
810        assert_eq!(all[0].content, "2");
811    }
812
813    #[test]
814    fn test_short_term_memory_search() {
815        let memory = ShortTermMemory::new(10, Duration::from_secs(3600));
816
817        memory.add("Hello world", MemorySource::User);
818        memory.add("Goodbye world", MemorySource::User);
819        memory.add("Something else", MemorySource::User);
820
821        let results = memory.search("world");
822        assert_eq!(results.len(), 2);
823    }
824
825    #[test]
826    fn test_long_term_memory_store_and_search() {
827        let memory = LongTermMemory::new(100, 0.5);
828
829        let embedding1 = vec![1.0, 0.0, 0.0];
830        let embedding2 = vec![0.0, 1.0, 0.0];
831        let embedding3 = vec![0.9, 0.1, 0.0]; // Similar to 1
832
833        memory.store("First fact", embedding1.clone(), MemorySource::External);
834        memory.store("Second fact", embedding2, MemorySource::External);
835        memory.store("Third fact", embedding3, MemorySource::External);
836
837        let results = memory.search(&embedding1, 2);
838        assert_eq!(results.len(), 2);
839        assert!(results[0].1 > results[1].1); // First should be more similar
840    }
841
842    #[test]
843    fn test_long_term_memory_threshold() {
844        let memory = LongTermMemory::new(100, 0.9); // High threshold
845
846        let embedding1 = vec![1.0, 0.0, 0.0];
847        let embedding2 = vec![0.0, 1.0, 0.0]; // Orthogonal, similarity = 0
848
849        memory.store("First fact", embedding1.clone(), MemorySource::External);
850        memory.store("Second fact", embedding2, MemorySource::External);
851
852        let results = memory.search(&embedding1, 10);
853        assert_eq!(results.len(), 1); // Only exact match passes threshold
854    }
855
856    #[test]
857    fn test_entity_memory_create_and_update() {
858        let memory = EntityMemory::new(100, 50);
859
860        let entity = memory.get_or_create("Rust", "programming_language");
861        assert_eq!(entity.name, "Rust");
862
863        memory.update_attribute("Rust", "creator", "Mozilla");
864        memory.update_attribute("Rust", "year", "2010");
865
866        let updated = memory.get("Rust").unwrap();
867        assert_eq!(updated.attributes.get("creator").unwrap(), "Mozilla");
868        assert_eq!(updated.attributes.get("year").unwrap(), "2010");
869    }
870
871    #[test]
872    fn test_entity_memory_relations() {
873        let memory = EntityMemory::new(100, 50);
874
875        memory.get_or_create("Rust", "programming_language");
876        memory.get_or_create("Mozilla", "organization");
877
878        memory.add_relation("Rust", "created_by", "Mozilla");
879
880        let rust = memory.get("Rust").unwrap();
881        assert_eq!(rust.relations.len(), 1);
882        assert_eq!(
883            rust.relations[0],
884            ("created_by".to_string(), "Mozilla".to_string())
885        );
886    }
887
888    #[test]
889    fn test_entity_memory_search() {
890        let memory = EntityMemory::new(100, 50);
891
892        memory.get_or_create("Rust", "programming_language");
893        memory.get_or_create("Python", "programming_language");
894        memory.get_or_create("Mozilla", "organization");
895
896        memory.update_attribute("Rust", "type", "compiled");
897        memory.update_attribute("Python", "type", "interpreted");
898
899        let langs = memory.get_by_type("programming_language");
900        assert_eq!(langs.len(), 2);
901
902        let compiled = memory.search_by_attribute("type", "compiled");
903        assert_eq!(compiled.len(), 1);
904        assert_eq!(compiled[0].name, "Rust");
905    }
906
907    #[test]
908    fn test_multi_memory_recall() {
909        let memory = MultiMemory::default_config();
910
911        // Add short-term memories - query must be substring of content
912        memory
913            .short_term
914            .add("User asked about Rust programming", MemorySource::User);
915
916        // Add long-term memories
917        let embedding = vec![1.0, 0.0, 0.0];
918        memory.long_term.store(
919            "Rust is a systems programming language",
920            embedding.clone(),
921            MemorySource::External,
922        );
923
924        // Add entity
925        memory.entity.get_or_create("Rust", "programming_language");
926        memory.entity.update_attribute("Rust", "creator", "Mozilla");
927
928        // Recall - use "Rust" as query since it's the entity name
929        let result = memory.recall("Rust", Some(&embedding), 10);
930
931        assert_eq!(result.short_term.len(), 1); // Contains "Rust"
932        assert_eq!(result.long_term.len(), 1); // Matches embedding
933        assert_eq!(result.entities.len(), 1); // "Rust" entity found by exact name
934    }
935
936    #[test]
937    fn test_multi_memory_build_context() {
938        let memory = MultiMemory::default_config();
939
940        memory
941            .short_term
942            .add("Previous message", MemorySource::User);
943
944        let embedding = vec![1.0, 0.0, 0.0];
945        memory
946            .long_term
947            .store("Relevant fact", embedding.clone(), MemorySource::External);
948
949        memory.entity.get_or_create("Test", "concept");
950        memory
951            .entity
952            .update_attribute("Test", "description", "A test entity");
953
954        let context = memory.build_context("Test query", Some(&embedding), 5, 5);
955
956        assert!(context.contains("Previous message"));
957        assert!(context.contains("Relevant fact"));
958        assert!(context.contains("Test"));
959    }
960
961    #[test]
962    fn test_memory_stats() {
963        let memory = MultiMemory::default_config();
964
965        memory.short_term.add("Message 1", MemorySource::User);
966        memory.short_term.add("Message 2", MemorySource::User);
967
968        memory
969            .long_term
970            .store("Fact 1", vec![1.0], MemorySource::External);
971
972        memory.entity.get_or_create("Entity1", "type");
973        memory.entity.get_or_create("Entity2", "type");
974        memory.entity.get_or_create("Entity3", "type");
975
976        let stats = memory.stats();
977        assert_eq!(stats.short_term_count, 2);
978        assert_eq!(stats.long_term_count, 1);
979        assert_eq!(stats.entity_count, 3);
980    }
981
982    #[test]
983    fn test_cosine_similarity() {
984        // Identical vectors
985        assert!((cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 0.001);
986
987        // Orthogonal vectors
988        assert!((cosine_similarity(&[1.0, 0.0], &[0.0, 1.0])).abs() < 0.001);
989
990        // Opposite vectors
991        assert!((cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 0.001);
992    }
993
994    #[test]
995    fn test_memory_item_builder() {
996        let item = MemoryItem::new("Test content", MemorySource::User)
997            .with_importance(0.8)
998            .with_metadata("key", "value");
999
1000        assert_eq!(item.content, "Test content");
1001        assert_eq!(item.importance, 0.8);
1002        assert_eq!(item.metadata.get("key").unwrap(), "value");
1003    }
1004}