ceylon_next/memory/advanced/
working.rs1use super::{EnhancedMemoryEntry, ImportanceLevel, MemoryConfig, MemoryType};
14use std::collections::VecDeque;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18pub struct WorkingMemory {
20 memories: Arc<RwLock<VecDeque<EnhancedMemoryEntry>>>,
22 config: MemoryConfig,
24 token_count: Arc<RwLock<usize>>,
26}
27
28impl WorkingMemory {
29 pub fn new(config: MemoryConfig) -> Self {
31 Self {
32 memories: Arc::new(RwLock::new(VecDeque::new())),
33 config,
34 token_count: Arc::new(RwLock::new(0)),
35 }
36 }
37
38 pub async fn add(&self, mut entry: EnhancedMemoryEntry) -> Result<(), String> {
40 entry.memory_type = MemoryType::Working;
41
42 let mut memories = self.memories.write().await;
43 let mut token_count = self.token_count.write().await;
44
45 let entry_tokens = Self::estimate_tokens(&entry);
47 *token_count += entry_tokens;
48
49 memories.push_front(entry);
51
52 while memories.len() > self.config.working_memory_limit {
54 if let Some(evicted) = memories.pop_back() {
55 let evicted_tokens = Self::estimate_tokens(&evicted);
56 *token_count = token_count.saturating_sub(evicted_tokens);
57 }
58 }
59
60 Ok(())
61 }
62
63 pub async fn get_all(&self) -> Vec<EnhancedMemoryEntry> {
65 let memories = self.memories.read().await;
66 memories.iter().cloned().collect()
67 }
68
69 pub async fn get_recent_within_limit(&self, max_tokens: usize) -> Vec<EnhancedMemoryEntry> {
71 let memories = self.memories.read().await;
72 let mut result = Vec::new();
73 let mut current_tokens = 0;
74
75 for memory in memories.iter() {
76 let memory_tokens = Self::estimate_tokens(memory);
77 if current_tokens + memory_tokens > max_tokens {
78 break;
79 }
80 result.push(memory.clone());
81 current_tokens += memory_tokens;
82 }
83
84 result
85 }
86
87 pub async fn get_by_importance(&self, min_importance: ImportanceLevel) -> Vec<EnhancedMemoryEntry> {
89 let memories = self.memories.read().await;
90 memories
91 .iter()
92 .filter(|m| m.importance >= min_importance)
93 .cloned()
94 .collect()
95 }
96
97 pub async fn mark_accessed(&self, memory_id: &str) -> Result<(), String> {
99 let mut memories = self.memories.write().await;
100
101 if let Some(pos) = memories.iter().position(|m| m.entry.id == memory_id) {
102 if let Some(mut memory) = memories.remove(pos) {
103 memory.mark_accessed();
104 memories.push_front(memory);
105 }
106 }
107
108 Ok(())
109 }
110
111 pub async fn clear(&self) {
113 let mut memories = self.memories.write().await;
114 let mut token_count = self.token_count.write().await;
115 memories.clear();
116 *token_count = 0;
117 }
118
119 pub async fn token_count(&self) -> usize {
121 *self.token_count.read().await
122 }
123
124 pub async fn memory_count(&self) -> usize {
126 self.memories.read().await.len()
127 }
128
129 pub async fn create_context(&self, max_tokens: Option<usize>) -> String {
131 let memories = if let Some(limit) = max_tokens {
132 self.get_recent_within_limit(limit).await
133 } else {
134 self.get_all().await
135 };
136
137 if memories.is_empty() {
138 return String::new();
139 }
140
141 let mut context = String::from("RECENT CONTEXT:\n");
142
143 for (i, memory) in memories.iter().enumerate() {
144 if let Some(summary) = &memory.summary {
146 context.push_str(&format!("{}. {}\n", i + 1, summary));
147 } else {
148 context.push_str(&format!("{}. Conversation from {}:\n",
149 i + 1,
150 Self::format_timestamp(memory.entry.created_at)
151 ));
152
153 for msg in &memory.entry.messages {
154 if msg.role == "user" {
155 context.push_str(&format!(" User: {}\n", msg.content));
156 } else if msg.role == "assistant" {
157 context.push_str(&format!(" Assistant: {}\n", msg.content));
158 }
159 }
160 }
161
162 if !memory.key_points.is_empty() {
164 context.push_str(" Key points:\n");
165 for point in &memory.key_points {
166 context.push_str(&format!(" - {}\n", point));
167 }
168 }
169
170 context.push_str("\n");
171 }
172
173 context
174 }
175
176 fn estimate_tokens(entry: &EnhancedMemoryEntry) -> usize {
178 let mut total_chars = 0;
180
181 for msg in &entry.entry.messages {
183 total_chars += msg.content.len();
184 }
185
186 if let Some(summary) = &entry.summary {
188 total_chars += summary.len();
189 }
190
191 for point in &entry.key_points {
193 total_chars += point.len();
194 }
195
196 (total_chars / 4).max(1)
197 }
198
199 fn format_timestamp(timestamp: u64) -> String {
201 use std::time::{SystemTime, UNIX_EPOCH, Duration};
202
203 let dt = UNIX_EPOCH + Duration::from_secs(timestamp);
204 let elapsed = SystemTime::now().duration_since(dt).unwrap_or_default();
205
206 let seconds = elapsed.as_secs();
207 if seconds < 60 {
208 format!("{} seconds ago", seconds)
209 } else if seconds < 3600 {
210 format!("{} minutes ago", seconds / 60)
211 } else if seconds < 86400 {
212 format!("{} hours ago", seconds / 3600)
213 } else {
214 format!("{} days ago", seconds / 86400)
215 }
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::memory::MemoryEntry;
223
224 #[tokio::test]
225 async fn test_working_memory_basic() {
226 let config = MemoryConfig {
227 working_memory_limit: 3,
228 ..Default::default()
229 };
230 let wm = WorkingMemory::new(config);
231
232 for i in 0..5 {
234 let entry = MemoryEntry::new(
235 "agent-1".to_string(),
236 format!("task-{}", i),
237 vec![],
238 );
239 let enhanced = EnhancedMemoryEntry::new(entry, MemoryType::Working);
240 wm.add(enhanced).await.unwrap();
241 }
242
243 assert_eq!(wm.memory_count().await, 3);
245 }
246
247 #[tokio::test]
248 async fn test_working_memory_access() {
249 let config = MemoryConfig::default();
250 let wm = WorkingMemory::new(config);
251
252 let entry1 = MemoryEntry::new("agent-1".to_string(), "task-1".to_string(), vec![]);
253 let id1 = entry1.id.clone();
254 let enhanced1 = EnhancedMemoryEntry::new(entry1, MemoryType::Working);
255 wm.add(enhanced1).await.unwrap();
256
257 let entry2 = MemoryEntry::new("agent-1".to_string(), "task-2".to_string(), vec![]);
258 let enhanced2 = EnhancedMemoryEntry::new(entry2, MemoryType::Working);
259 wm.add(enhanced2).await.unwrap();
260
261 wm.mark_accessed(&id1).await.unwrap();
263
264 let memories = wm.get_all().await;
265 assert_eq!(memories[0].entry.id, id1);
266 }
267
268 #[tokio::test]
269 async fn test_working_memory_importance() {
270 let config = MemoryConfig::default();
271 let wm = WorkingMemory::new(config);
272
273 let entry1 = MemoryEntry::new("agent-1".to_string(), "task-1".to_string(), vec![]);
274 let mut enhanced1 = EnhancedMemoryEntry::new(entry1, MemoryType::Working);
275 enhanced1.importance = ImportanceLevel::Critical;
276 wm.add(enhanced1).await.unwrap();
277
278 let entry2 = MemoryEntry::new("agent-1".to_string(), "task-2".to_string(), vec![]);
279 let mut enhanced2 = EnhancedMemoryEntry::new(entry2, MemoryType::Working);
280 enhanced2.importance = ImportanceLevel::Low;
281 wm.add(enhanced2).await.unwrap();
282
283 let important = wm.get_by_importance(ImportanceLevel::High).await;
284 assert_eq!(important.len(), 1);
285 assert_eq!(important[0].importance, ImportanceLevel::Critical);
286 }
287}