ceylon_next/memory/advanced/
working.rs

1//! Working Memory - Recent Context Management
2//!
3//! Working memory maintains the most recent and relevant context for the agent.
4//! It's designed for fast access and limited capacity, similar to human short-term memory.
5//!
6//! # Features
7//!
8//! - Fixed capacity with LRU (Least Recently Used) eviction
9//! - Token-aware context management
10//! - Automatic summarization of evicted memories
11//! - Priority-based retention
12
13use super::{EnhancedMemoryEntry, ImportanceLevel, MemoryConfig, MemoryType};
14use std::collections::VecDeque;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18/// Working memory for managing recent context
19pub struct WorkingMemory {
20    /// Recent memories (ordered by recency)
21    memories: Arc<RwLock<VecDeque<EnhancedMemoryEntry>>>,
22    /// Configuration
23    config: MemoryConfig,
24    /// Current token count estimate
25    token_count: Arc<RwLock<usize>>,
26}
27
28impl WorkingMemory {
29    /// Create a new working memory instance
30    pub fn new(config: MemoryConfig) -> Self {
31        Self {
32            memories: Arc::new(RwLock::new(VecDeque::new())),
33            config,
34            token_count: Arc::new(RwLock::new(0)),
35        }
36    }
37
38    /// Add a memory to working memory
39    pub async fn add(&self, mut entry: EnhancedMemoryEntry) -> Result<(), String> {
40        entry.memory_type = MemoryType::Working;
41
42        let mut memories = self.memories.write().await;
43        let mut token_count = self.token_count.write().await;
44
45        // Estimate tokens for this entry
46        let entry_tokens = Self::estimate_tokens(&entry);
47        *token_count += entry_tokens;
48
49        // Add to front (most recent)
50        memories.push_front(entry);
51
52        // Evict if over capacity
53        while memories.len() > self.config.working_memory_limit {
54            if let Some(evicted) = memories.pop_back() {
55                let evicted_tokens = Self::estimate_tokens(&evicted);
56                *token_count = token_count.saturating_sub(evicted_tokens);
57            }
58        }
59
60        Ok(())
61    }
62
63    /// Get all working memories
64    pub async fn get_all(&self) -> Vec<EnhancedMemoryEntry> {
65        let memories = self.memories.read().await;
66        memories.iter().cloned().collect()
67    }
68
69    /// Get recent memories up to a token limit
70    pub async fn get_recent_within_limit(&self, max_tokens: usize) -> Vec<EnhancedMemoryEntry> {
71        let memories = self.memories.read().await;
72        let mut result = Vec::new();
73        let mut current_tokens = 0;
74
75        for memory in memories.iter() {
76            let memory_tokens = Self::estimate_tokens(memory);
77            if current_tokens + memory_tokens > max_tokens {
78                break;
79            }
80            result.push(memory.clone());
81            current_tokens += memory_tokens;
82        }
83
84        result
85    }
86
87    /// Get memories by importance level
88    pub async fn get_by_importance(&self, min_importance: ImportanceLevel) -> Vec<EnhancedMemoryEntry> {
89        let memories = self.memories.read().await;
90        memories
91            .iter()
92            .filter(|m| m.importance >= min_importance)
93            .cloned()
94            .collect()
95    }
96
97    /// Mark a memory as accessed (moves it to front)
98    pub async fn mark_accessed(&self, memory_id: &str) -> Result<(), String> {
99        let mut memories = self.memories.write().await;
100
101        if let Some(pos) = memories.iter().position(|m| m.entry.id == memory_id) {
102            if let Some(mut memory) = memories.remove(pos) {
103                memory.mark_accessed();
104                memories.push_front(memory);
105            }
106        }
107
108        Ok(())
109    }
110
111    /// Clear all working memory
112    pub async fn clear(&self) {
113        let mut memories = self.memories.write().await;
114        let mut token_count = self.token_count.write().await;
115        memories.clear();
116        *token_count = 0;
117    }
118
119    /// Get current token count
120    pub async fn token_count(&self) -> usize {
121        *self.token_count.read().await
122    }
123
124    /// Get current memory count
125    pub async fn memory_count(&self) -> usize {
126        self.memories.read().await.len()
127    }
128
129    /// Create context message from working memory
130    pub async fn create_context(&self, max_tokens: Option<usize>) -> String {
131        let memories = if let Some(limit) = max_tokens {
132            self.get_recent_within_limit(limit).await
133        } else {
134            self.get_all().await
135        };
136
137        if memories.is_empty() {
138            return String::new();
139        }
140
141        let mut context = String::from("RECENT CONTEXT:\n");
142
143        for (i, memory) in memories.iter().enumerate() {
144            // Use summary if available, otherwise use full conversation
145            if let Some(summary) = &memory.summary {
146                context.push_str(&format!("{}. {}\n", i + 1, summary));
147            } else {
148                context.push_str(&format!("{}. Conversation from {}:\n",
149                    i + 1,
150                    Self::format_timestamp(memory.entry.created_at)
151                ));
152
153                for msg in &memory.entry.messages {
154                    if msg.role == "user" {
155                        context.push_str(&format!("  User: {}\n", msg.content));
156                    } else if msg.role == "assistant" {
157                        context.push_str(&format!("  Assistant: {}\n", msg.content));
158                    }
159                }
160            }
161
162            // Add key points if available
163            if !memory.key_points.is_empty() {
164                context.push_str("  Key points:\n");
165                for point in &memory.key_points {
166                    context.push_str(&format!("    - {}\n", point));
167                }
168            }
169
170            context.push_str("\n");
171        }
172
173        context
174    }
175
176    /// Estimate token count for a memory entry
177    fn estimate_tokens(entry: &EnhancedMemoryEntry) -> usize {
178        // Rough estimation: ~4 characters per token
179        let mut total_chars = 0;
180
181        // Count message content
182        for msg in &entry.entry.messages {
183            total_chars += msg.content.len();
184        }
185
186        // Count summary if present
187        if let Some(summary) = &entry.summary {
188            total_chars += summary.len();
189        }
190
191        // Count key points
192        for point in &entry.key_points {
193            total_chars += point.len();
194        }
195
196        (total_chars / 4).max(1)
197    }
198
199    /// Format timestamp for display
200    fn format_timestamp(timestamp: u64) -> String {
201        use std::time::{SystemTime, UNIX_EPOCH, Duration};
202
203        let dt = UNIX_EPOCH + Duration::from_secs(timestamp);
204        let elapsed = SystemTime::now().duration_since(dt).unwrap_or_default();
205
206        let seconds = elapsed.as_secs();
207        if seconds < 60 {
208            format!("{} seconds ago", seconds)
209        } else if seconds < 3600 {
210            format!("{} minutes ago", seconds / 60)
211        } else if seconds < 86400 {
212            format!("{} hours ago", seconds / 3600)
213        } else {
214            format!("{} days ago", seconds / 86400)
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::memory::MemoryEntry;
223
224    #[tokio::test]
225    async fn test_working_memory_basic() {
226        let config = MemoryConfig {
227            working_memory_limit: 3,
228            ..Default::default()
229        };
230        let wm = WorkingMemory::new(config);
231
232        // Add memories
233        for i in 0..5 {
234            let entry = MemoryEntry::new(
235                "agent-1".to_string(),
236                format!("task-{}", i),
237                vec![],
238            );
239            let enhanced = EnhancedMemoryEntry::new(entry, MemoryType::Working);
240            wm.add(enhanced).await.unwrap();
241        }
242
243        // Should only keep 3 most recent
244        assert_eq!(wm.memory_count().await, 3);
245    }
246
247    #[tokio::test]
248    async fn test_working_memory_access() {
249        let config = MemoryConfig::default();
250        let wm = WorkingMemory::new(config);
251
252        let entry1 = MemoryEntry::new("agent-1".to_string(), "task-1".to_string(), vec![]);
253        let id1 = entry1.id.clone();
254        let enhanced1 = EnhancedMemoryEntry::new(entry1, MemoryType::Working);
255        wm.add(enhanced1).await.unwrap();
256
257        let entry2 = MemoryEntry::new("agent-1".to_string(), "task-2".to_string(), vec![]);
258        let enhanced2 = EnhancedMemoryEntry::new(entry2, MemoryType::Working);
259        wm.add(enhanced2).await.unwrap();
260
261        // Access first entry (should move it to front)
262        wm.mark_accessed(&id1).await.unwrap();
263
264        let memories = wm.get_all().await;
265        assert_eq!(memories[0].entry.id, id1);
266    }
267
268    #[tokio::test]
269    async fn test_working_memory_importance() {
270        let config = MemoryConfig::default();
271        let wm = WorkingMemory::new(config);
272
273        let entry1 = MemoryEntry::new("agent-1".to_string(), "task-1".to_string(), vec![]);
274        let mut enhanced1 = EnhancedMemoryEntry::new(entry1, MemoryType::Working);
275        enhanced1.importance = ImportanceLevel::Critical;
276        wm.add(enhanced1).await.unwrap();
277
278        let entry2 = MemoryEntry::new("agent-1".to_string(), "task-2".to_string(), vec![]);
279        let mut enhanced2 = EnhancedMemoryEntry::new(entry2, MemoryType::Working);
280        enhanced2.importance = ImportanceLevel::Low;
281        wm.add(enhanced2).await.unwrap();
282
283        let important = wm.get_by_importance(ImportanceLevel::High).await;
284        assert_eq!(important.len(), 1);
285        assert_eq!(important[0].importance, ImportanceLevel::Critical);
286    }
287}