autoagents_core/agent/memory/
sliding_window.rs

1//! Simple sliding window memory implementation.
2//!
3//! This module provides a basic FIFO (First In, First Out) memory that maintains
4//! a fixed-size window of the most recent conversation messages.
5use async_trait::async_trait;
6use autoagents_llm::{chat::ChatMessage, error::LLMError};
7use std::collections::VecDeque;
8
9use super::{MemoryProvider, MemoryType};
10
11/// Strategy for handling memory when window size limit is reached
12#[derive(Debug, Clone)]
13pub enum TrimStrategy {
14    /// Drop oldest messages (FIFO behavior)
15    Drop,
16    /// Summarize all messages into one before adding new ones
17    Summarize,
18}
19
20/// Simple sliding window memory that keeps the N most recent messages.
21///
22/// This implementation uses a FIFO strategy where old messages are automatically
23/// removed when the window size limit is reached. It's suitable for:
24/// - Simple conversation contexts
25/// - Memory-constrained environments
26/// - Cases where only recent context matters
27///
28#[derive(Debug, Clone)]
29pub struct SlidingWindowMemory {
30    messages: VecDeque<ChatMessage>,
31    window_size: usize,
32    trim_strategy: TrimStrategy,
33    needs_summary: bool,
34}
35
36impl SlidingWindowMemory {
37    /// Create a new sliding window memory with the specified window size.
38    ///
39    /// # Arguments
40    ///
41    /// * `window_size` - Maximum number of messages to keep in memory
42    ///
43    /// # Panics
44    ///
45    /// Panics if `window_size` is 0
46    ///
47    pub fn new(window_size: usize) -> Self {
48        Self::with_strategy(window_size, TrimStrategy::Drop)
49    }
50
51    /// Create a new sliding window memory with specified trim strategy
52    ///
53    /// # Arguments
54    ///
55    /// * `window_size` - Maximum number of messages to keep in memory
56    /// * `strategy` - How to handle overflow when window is full
57    pub fn with_strategy(window_size: usize, strategy: TrimStrategy) -> Self {
58        if window_size == 0 {
59            panic!("Window size must be greater than 0");
60        }
61
62        Self {
63            messages: VecDeque::with_capacity(window_size),
64            window_size,
65            trim_strategy: strategy,
66            needs_summary: false,
67        }
68    }
69
70    /// Get the configured window size.
71    ///
72    /// # Returns
73    ///
74    /// The maximum number of messages this memory can hold
75    pub fn window_size(&self) -> usize {
76        self.window_size
77    }
78
79    /// Get all stored messages in chronological order.
80    ///
81    /// # Returns
82    ///
83    /// A vector containing all messages from oldest to newest
84    pub fn messages(&self) -> Vec<ChatMessage> {
85        Vec::from(self.messages.clone())
86    }
87
88    /// Get the most recent N messages.
89    ///
90    /// # Arguments
91    ///
92    /// * `limit` - Maximum number of recent messages to return
93    ///
94    /// # Returns
95    ///
96    /// A vector containing the most recent messages, up to `limit`
97    pub fn recent_messages(&self, limit: usize) -> Vec<ChatMessage> {
98        let len = self.messages.len();
99        let start = len.saturating_sub(limit);
100        self.messages.range(start..).cloned().collect()
101    }
102
103    /// Check if memory needs summarization
104    pub fn needs_summary(&self) -> bool {
105        self.needs_summary
106    }
107
108    /// Mark memory as needing summarization
109    pub fn mark_for_summary(&mut self) {
110        self.needs_summary = true;
111    }
112
113    /// Replace all messages with a summary
114    ///
115    /// # Arguments
116    ///
117    /// * `summary` - The summary text to replace all messages with
118    pub fn replace_with_summary(&mut self, summary: String) {
119        self.messages.clear();
120        self.messages
121            .push_back(ChatMessage::assistant().content(summary).build());
122        self.needs_summary = false;
123    }
124}
125
126#[async_trait]
127impl MemoryProvider for SlidingWindowMemory {
128    async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
129        if self.messages.len() >= self.window_size {
130            match self.trim_strategy {
131                TrimStrategy::Drop => {
132                    self.messages.pop_front();
133                }
134                TrimStrategy::Summarize => {
135                    self.mark_for_summary();
136                }
137            }
138        }
139        self.messages.push_back(message.clone());
140        Ok(())
141    }
142
143    async fn recall(
144        &self,
145        _query: &str,
146        limit: Option<usize>,
147    ) -> Result<Vec<ChatMessage>, LLMError> {
148        let limit = limit.unwrap_or(self.messages.len());
149        Ok(self.recent_messages(limit))
150    }
151
152    async fn clear(&mut self) -> Result<(), LLMError> {
153        self.messages.clear();
154        Ok(())
155    }
156
157    fn memory_type(&self) -> MemoryType {
158        MemoryType::SlidingWindow
159    }
160
161    fn size(&self) -> usize {
162        self.messages.len()
163    }
164
165    fn needs_summary(&self) -> bool {
166        self.needs_summary
167    }
168
169    fn mark_for_summary(&mut self) {
170        self.needs_summary = true;
171    }
172
173    fn replace_with_summary(&mut self, summary: String) {
174        self.messages.clear();
175        self.messages
176            .push_back(ChatMessage::assistant().content(summary).build());
177        self.needs_summary = false;
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
185
186    #[test]
187    fn test_new_sliding_window_memory() {
188        let memory = SlidingWindowMemory::new(5);
189        assert_eq!(memory.window_size(), 5);
190        assert_eq!(memory.size(), 0);
191        assert!(memory.is_empty());
192        assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
193    }
194
195    #[test]
196    fn test_sliding_window_memory_with_strategy() {
197        let memory = SlidingWindowMemory::with_strategy(3, TrimStrategy::Summarize);
198        assert_eq!(memory.window_size(), 3);
199        assert_eq!(memory.size(), 0);
200        assert!(memory.is_empty());
201    }
202
203    #[test]
204    #[should_panic(expected = "Window size must be greater than 0")]
205    fn test_new_sliding_window_memory_zero_size() {
206        SlidingWindowMemory::new(0);
207    }
208
209    #[tokio::test]
210    async fn test_remember_single_message() {
211        let mut memory = SlidingWindowMemory::new(3);
212        let message = ChatMessage {
213            role: ChatRole::User,
214            message_type: MessageType::Text,
215            content: "Hello".to_string(),
216        };
217
218        memory.remember(&message).await.unwrap();
219        assert_eq!(memory.size(), 1);
220        assert!(!memory.is_empty());
221
222        let messages = memory.messages();
223        assert_eq!(messages.len(), 1);
224        assert_eq!(messages[0].content, "Hello");
225    }
226
227    #[tokio::test]
228    async fn test_remember_multiple_messages() {
229        let mut memory = SlidingWindowMemory::new(3);
230
231        for i in 1..=3 {
232            let message = ChatMessage {
233                role: ChatRole::User,
234                message_type: MessageType::Text,
235                content: format!("Message {i}"),
236            };
237            memory.remember(&message).await.unwrap();
238        }
239
240        assert_eq!(memory.size(), 3);
241        let messages = memory.messages();
242        assert_eq!(messages.len(), 3);
243        assert_eq!(messages[0].content, "Message 1");
244        assert_eq!(messages[2].content, "Message 3");
245    }
246
247    #[tokio::test]
248    async fn test_sliding_window_overflow_drop_strategy() {
249        let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Drop);
250
251        // Add 3 messages to a window of size 2
252        for i in 1..=3 {
253            let message = ChatMessage {
254                role: ChatRole::User,
255                message_type: MessageType::Text,
256                content: format!("Message {i}"),
257            };
258            memory.remember(&message).await.unwrap();
259        }
260
261        // Should only keep the last 2 messages
262        assert_eq!(memory.size(), 2);
263        let messages = memory.messages();
264        assert_eq!(messages[0].content, "Message 2");
265        assert_eq!(messages[1].content, "Message 3");
266    }
267
268    #[tokio::test]
269    async fn test_sliding_window_overflow_summarize_strategy() {
270        let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Summarize);
271
272        // Add first message
273        let message1 = ChatMessage {
274            role: ChatRole::User,
275            message_type: MessageType::Text,
276            content: "First message".to_string(),
277        };
278        memory.remember(&message1).await.unwrap();
279
280        // Add second message
281        let message2 = ChatMessage {
282            role: ChatRole::User,
283            message_type: MessageType::Text,
284            content: "Second message".to_string(),
285        };
286        memory.remember(&message2).await.unwrap();
287
288        // Add third message - should trigger summarize flag
289        let message3 = ChatMessage {
290            role: ChatRole::User,
291            message_type: MessageType::Text,
292            content: "Third message".to_string(),
293        };
294        memory.remember(&message3).await.unwrap();
295
296        assert!(memory.needs_summary());
297        assert_eq!(memory.size(), 3); // Still contains all messages until summarized
298    }
299
300    #[tokio::test]
301    async fn test_recall_all_messages() {
302        let mut memory = SlidingWindowMemory::new(3);
303
304        for i in 1..=3 {
305            let message = ChatMessage {
306                role: ChatRole::User,
307                message_type: MessageType::Text,
308                content: format!("Message {i}"),
309            };
310            memory.remember(&message).await.unwrap();
311        }
312
313        let recalled = memory.recall("", None).await.unwrap();
314        assert_eq!(recalled.len(), 3);
315        assert_eq!(recalled[0].content, "Message 1");
316        assert_eq!(recalled[2].content, "Message 3");
317    }
318
319    #[tokio::test]
320    async fn test_recall_with_limit() {
321        let mut memory = SlidingWindowMemory::new(5);
322
323        for i in 1..=5 {
324            let message = ChatMessage {
325                role: ChatRole::User,
326                message_type: MessageType::Text,
327                content: format!("Message {i}"),
328            };
329            memory.remember(&message).await.unwrap();
330        }
331
332        let recalled = memory.recall("", Some(2)).await.unwrap();
333        assert_eq!(recalled.len(), 2);
334        assert_eq!(recalled[0].content, "Message 4");
335        assert_eq!(recalled[1].content, "Message 5");
336    }
337
338    #[tokio::test]
339    async fn test_clear_memory() {
340        let mut memory = SlidingWindowMemory::new(3);
341
342        let message = ChatMessage {
343            role: ChatRole::User,
344            message_type: MessageType::Text,
345            content: "Test message".to_string(),
346        };
347        memory.remember(&message).await.unwrap();
348
349        assert_eq!(memory.size(), 1);
350        memory.clear().await.unwrap();
351        assert_eq!(memory.size(), 0);
352        assert!(memory.is_empty());
353    }
354
355    #[test]
356    fn test_recent_messages() {
357        let mut memory = SlidingWindowMemory::new(5);
358
359        // Add messages directly to the internal deque for testing
360        for i in 1..=5 {
361            let message = ChatMessage {
362                role: ChatRole::User,
363                message_type: MessageType::Text,
364                content: format!("Message {i}"),
365            };
366            memory.messages.push_back(message);
367        }
368
369        let recent = memory.recent_messages(3);
370        assert_eq!(recent.len(), 3);
371        assert_eq!(recent[0].content, "Message 3");
372        assert_eq!(recent[2].content, "Message 5");
373    }
374
375    #[test]
376    fn test_recent_messages_limit_exceeds_size() {
377        let mut memory = SlidingWindowMemory::new(5);
378
379        // Add only 2 messages
380        for i in 1..=2 {
381            let message = ChatMessage {
382                role: ChatRole::User,
383                message_type: MessageType::Text,
384                content: format!("Message {i}"),
385            };
386            memory.messages.push_back(message);
387        }
388
389        let recent = memory.recent_messages(10);
390        assert_eq!(recent.len(), 2);
391        assert_eq!(recent[0].content, "Message 1");
392        assert_eq!(recent[1].content, "Message 2");
393    }
394
395    #[test]
396    fn test_mark_for_summary() {
397        let mut memory = SlidingWindowMemory::new(3);
398        assert!(!memory.needs_summary());
399
400        memory.mark_for_summary();
401        assert!(memory.needs_summary());
402    }
403
404    #[test]
405    fn test_replace_with_summary() {
406        let mut memory = SlidingWindowMemory::new(3);
407
408        // Add some messages
409        for i in 1..=3 {
410            let message = ChatMessage {
411                role: ChatRole::User,
412                message_type: MessageType::Text,
413                content: format!("Message {i}"),
414            };
415            memory.messages.push_back(message);
416        }
417
418        memory.mark_for_summary();
419        assert!(memory.needs_summary());
420        assert_eq!(memory.size(), 3);
421
422        memory.replace_with_summary("This is a summary".to_string());
423
424        assert!(!memory.needs_summary());
425        assert_eq!(memory.size(), 1);
426        let messages = memory.messages();
427        assert_eq!(messages[0].content, "This is a summary");
428        assert_eq!(messages[0].role, ChatRole::Assistant);
429    }
430
431    #[test]
432    fn test_memory_provider_trait_methods() {
433        let memory = SlidingWindowMemory::new(3);
434
435        // Test trait methods
436        assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
437        assert_eq!(memory.size(), 0);
438        assert!(memory.is_empty());
439        assert!(!memory.needs_summary());
440        assert!(memory.get_event_receiver().is_none());
441    }
442}