autoagents_core/agent/memory/
sliding_window.rs1use async_trait::async_trait;
6use autoagents_llm::{chat::ChatMessage, error::LLMError};
7use std::collections::VecDeque;
8
9use super::{MemoryProvider, MemoryType};
10
11#[derive(Debug, Clone)]
13pub enum TrimStrategy {
14 Drop,
16 Summarize,
18}
19
20#[derive(Debug, Clone)]
29pub struct SlidingWindowMemory {
30 messages: VecDeque<ChatMessage>,
31 window_size: usize,
32 trim_strategy: TrimStrategy,
33 needs_summary: bool,
34}
35
36impl SlidingWindowMemory {
37 pub fn new(window_size: usize) -> Self {
48 Self::with_strategy(window_size, TrimStrategy::Drop)
49 }
50
51 pub fn with_strategy(window_size: usize, strategy: TrimStrategy) -> Self {
58 if window_size == 0 {
59 panic!("Window size must be greater than 0");
60 }
61
62 Self {
63 messages: VecDeque::with_capacity(window_size),
64 window_size,
65 trim_strategy: strategy,
66 needs_summary: false,
67 }
68 }
69
70 pub fn window_size(&self) -> usize {
76 self.window_size
77 }
78
79 pub fn messages(&self) -> Vec<ChatMessage> {
85 Vec::from(self.messages.clone())
86 }
87
88 pub fn recent_messages(&self, limit: usize) -> Vec<ChatMessage> {
98 let len = self.messages.len();
99 let start = len.saturating_sub(limit);
100 self.messages.range(start..).cloned().collect()
101 }
102
103 pub fn needs_summary(&self) -> bool {
105 self.needs_summary
106 }
107
108 pub fn mark_for_summary(&mut self) {
110 self.needs_summary = true;
111 }
112
113 pub fn replace_with_summary(&mut self, summary: String) {
119 self.messages.clear();
120 self.messages
121 .push_back(ChatMessage::assistant().content(summary).build());
122 self.needs_summary = false;
123 }
124}
125
126#[async_trait]
127impl MemoryProvider for SlidingWindowMemory {
128 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
129 if self.messages.len() >= self.window_size {
130 match self.trim_strategy {
131 TrimStrategy::Drop => {
132 self.messages.pop_front();
133 }
134 TrimStrategy::Summarize => {
135 self.mark_for_summary();
136 }
137 }
138 }
139 self.messages.push_back(message.clone());
140 Ok(())
141 }
142
143 async fn recall(
144 &self,
145 _query: &str,
146 limit: Option<usize>,
147 ) -> Result<Vec<ChatMessage>, LLMError> {
148 let limit = limit.unwrap_or(self.messages.len());
149 Ok(self.recent_messages(limit))
150 }
151
152 async fn clear(&mut self) -> Result<(), LLMError> {
153 self.messages.clear();
154 Ok(())
155 }
156
157 fn memory_type(&self) -> MemoryType {
158 MemoryType::SlidingWindow
159 }
160
161 fn size(&self) -> usize {
162 self.messages.len()
163 }
164
165 fn needs_summary(&self) -> bool {
166 self.needs_summary
167 }
168
169 fn mark_for_summary(&mut self) {
170 self.needs_summary = true;
171 }
172
173 fn replace_with_summary(&mut self, summary: String) {
174 self.messages.clear();
175 self.messages
176 .push_back(ChatMessage::assistant().content(summary).build());
177 self.needs_summary = false;
178 }
179
180 fn clone_box(&self) -> Box<dyn MemoryProvider> {
181 Box::new(self.clone())
182 }
183
184 fn preload(&mut self, data: Vec<ChatMessage>) -> bool {
185 self.messages.clear();
186 for msg in data {
187 self.messages.push_back(msg);
188 }
189 true
190 }
191
192 fn export(&self) -> Vec<ChatMessage> {
193 Vec::from(self.messages.clone())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
201
202 #[test]
203 fn test_new_sliding_window_memory() {
204 let memory = SlidingWindowMemory::new(5);
205 assert_eq!(memory.window_size(), 5);
206 assert_eq!(memory.size(), 0);
207 assert!(memory.is_empty());
208 assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
209 }
210
211 #[test]
212 fn test_sliding_window_memory_with_strategy() {
213 let memory = SlidingWindowMemory::with_strategy(3, TrimStrategy::Summarize);
214 assert_eq!(memory.window_size(), 3);
215 assert_eq!(memory.size(), 0);
216 assert!(memory.is_empty());
217 }
218
219 #[test]
220 #[should_panic(expected = "Window size must be greater than 0")]
221 fn test_new_sliding_window_memory_zero_size() {
222 SlidingWindowMemory::new(0);
223 }
224
225 #[tokio::test]
226 async fn test_remember_single_message() {
227 let mut memory = SlidingWindowMemory::new(3);
228 let message = ChatMessage {
229 role: ChatRole::User,
230 message_type: MessageType::Text,
231 content: "Hello".to_string(),
232 };
233
234 memory.remember(&message).await.unwrap();
235 assert_eq!(memory.size(), 1);
236 assert!(!memory.is_empty());
237
238 let messages = memory.messages();
239 assert_eq!(messages.len(), 1);
240 assert_eq!(messages[0].content, "Hello");
241 }
242
243 #[tokio::test]
244 async fn test_remember_multiple_messages() {
245 let mut memory = SlidingWindowMemory::new(3);
246
247 for i in 1..=3 {
248 let message = ChatMessage {
249 role: ChatRole::User,
250 message_type: MessageType::Text,
251 content: format!("Message {i}"),
252 };
253 memory.remember(&message).await.unwrap();
254 }
255
256 assert_eq!(memory.size(), 3);
257 let messages = memory.messages();
258 assert_eq!(messages.len(), 3);
259 assert_eq!(messages[0].content, "Message 1");
260 assert_eq!(messages[2].content, "Message 3");
261 }
262
263 #[tokio::test]
264 async fn test_sliding_window_overflow_drop_strategy() {
265 let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Drop);
266
267 for i in 1..=3 {
269 let message = ChatMessage {
270 role: ChatRole::User,
271 message_type: MessageType::Text,
272 content: format!("Message {i}"),
273 };
274 memory.remember(&message).await.unwrap();
275 }
276
277 assert_eq!(memory.size(), 2);
279 let messages = memory.messages();
280 assert_eq!(messages[0].content, "Message 2");
281 assert_eq!(messages[1].content, "Message 3");
282 }
283
284 #[tokio::test]
285 async fn test_sliding_window_overflow_summarize_strategy() {
286 let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Summarize);
287
288 let message1 = ChatMessage {
290 role: ChatRole::User,
291 message_type: MessageType::Text,
292 content: "First message".to_string(),
293 };
294 memory.remember(&message1).await.unwrap();
295
296 let message2 = ChatMessage {
298 role: ChatRole::User,
299 message_type: MessageType::Text,
300 content: "Second message".to_string(),
301 };
302 memory.remember(&message2).await.unwrap();
303
304 let message3 = ChatMessage {
306 role: ChatRole::User,
307 message_type: MessageType::Text,
308 content: "Third message".to_string(),
309 };
310 memory.remember(&message3).await.unwrap();
311
312 assert!(memory.needs_summary());
313 assert_eq!(memory.size(), 3); }
315
316 #[tokio::test]
317 async fn test_recall_all_messages() {
318 let mut memory = SlidingWindowMemory::new(3);
319
320 for i in 1..=3 {
321 let message = ChatMessage {
322 role: ChatRole::User,
323 message_type: MessageType::Text,
324 content: format!("Message {i}"),
325 };
326 memory.remember(&message).await.unwrap();
327 }
328
329 let recalled = memory.recall("", None).await.unwrap();
330 assert_eq!(recalled.len(), 3);
331 assert_eq!(recalled[0].content, "Message 1");
332 assert_eq!(recalled[2].content, "Message 3");
333 }
334
335 #[tokio::test]
336 async fn test_recall_with_limit() {
337 let mut memory = SlidingWindowMemory::new(5);
338
339 for i in 1..=5 {
340 let message = ChatMessage {
341 role: ChatRole::User,
342 message_type: MessageType::Text,
343 content: format!("Message {i}"),
344 };
345 memory.remember(&message).await.unwrap();
346 }
347
348 let recalled = memory.recall("", Some(2)).await.unwrap();
349 assert_eq!(recalled.len(), 2);
350 assert_eq!(recalled[0].content, "Message 4");
351 assert_eq!(recalled[1].content, "Message 5");
352 }
353
354 #[tokio::test]
355 async fn test_clear_memory() {
356 let mut memory = SlidingWindowMemory::new(3);
357
358 let message = ChatMessage {
359 role: ChatRole::User,
360 message_type: MessageType::Text,
361 content: "Test message".to_string(),
362 };
363 memory.remember(&message).await.unwrap();
364
365 assert_eq!(memory.size(), 1);
366 memory.clear().await.unwrap();
367 assert_eq!(memory.size(), 0);
368 assert!(memory.is_empty());
369 }
370
371 #[test]
372 fn test_recent_messages() {
373 let mut memory = SlidingWindowMemory::new(5);
374
375 for i in 1..=5 {
377 let message = ChatMessage {
378 role: ChatRole::User,
379 message_type: MessageType::Text,
380 content: format!("Message {i}"),
381 };
382 memory.messages.push_back(message);
383 }
384
385 let recent = memory.recent_messages(3);
386 assert_eq!(recent.len(), 3);
387 assert_eq!(recent[0].content, "Message 3");
388 assert_eq!(recent[2].content, "Message 5");
389 }
390
391 #[test]
392 fn test_recent_messages_limit_exceeds_size() {
393 let mut memory = SlidingWindowMemory::new(5);
394
395 for i in 1..=2 {
397 let message = ChatMessage {
398 role: ChatRole::User,
399 message_type: MessageType::Text,
400 content: format!("Message {i}"),
401 };
402 memory.messages.push_back(message);
403 }
404
405 let recent = memory.recent_messages(10);
406 assert_eq!(recent.len(), 2);
407 assert_eq!(recent[0].content, "Message 1");
408 assert_eq!(recent[1].content, "Message 2");
409 }
410
411 #[test]
412 fn test_mark_for_summary() {
413 let mut memory = SlidingWindowMemory::new(3);
414 assert!(!memory.needs_summary());
415
416 memory.mark_for_summary();
417 assert!(memory.needs_summary());
418 }
419
420 #[test]
421 fn test_replace_with_summary() {
422 let mut memory = SlidingWindowMemory::new(3);
423
424 for i in 1..=3 {
426 let message = ChatMessage {
427 role: ChatRole::User,
428 message_type: MessageType::Text,
429 content: format!("Message {i}"),
430 };
431 memory.messages.push_back(message);
432 }
433
434 memory.mark_for_summary();
435 assert!(memory.needs_summary());
436 assert_eq!(memory.size(), 3);
437
438 memory.replace_with_summary("This is a summary".to_string());
439
440 assert!(!memory.needs_summary());
441 assert_eq!(memory.size(), 1);
442 let messages = memory.messages();
443 assert_eq!(messages[0].content, "This is a summary");
444 assert_eq!(messages[0].role, ChatRole::Assistant);
445 }
446
447 #[test]
448 fn test_memory_provider_trait_methods() {
449 let memory = SlidingWindowMemory::new(3);
450
451 assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
453 assert_eq!(memory.size(), 0);
454 assert!(memory.is_empty());
455 assert!(!memory.needs_summary());
456 assert!(memory.get_event_receiver().is_none());
457 }
458}