Skip to main content

autoagents_core/agent/memory/
sliding_window.rs

1//! Simple sliding window memory implementation.
2//!
3//! This module provides a basic FIFO (First In, First Out) memory that maintains
4//! a fixed-size window of the most recent conversation messages.
5use async_trait::async_trait;
6use autoagents_llm::{chat::ChatMessage, error::LLMError};
7use std::collections::VecDeque;
8
9use super::{MemoryProvider, MemoryType};
10
11/// Strategy for handling memory when window size limit is reached
12#[derive(Debug, Clone)]
13pub enum TrimStrategy {
14    /// Drop oldest messages (FIFO behavior)
15    Drop,
16    /// Summarize all messages into one before adding new ones
17    Summarize,
18}
19
20/// Simple sliding window memory that keeps the N most recent messages.
21///
22/// This implementation uses a FIFO strategy where old messages are automatically
23/// removed when the window size limit is reached. It's suitable for:
24/// - Simple conversation contexts
25/// - Memory-constrained environments
26/// - Cases where only recent context matters
27///
28#[derive(Debug, Clone)]
29pub struct SlidingWindowMemory {
30    messages: VecDeque<ChatMessage>,
31    window_size: usize,
32    trim_strategy: TrimStrategy,
33    needs_summary: bool,
34}
35
36impl SlidingWindowMemory {
37    /// Create a new sliding window memory with the specified window size.
38    ///
39    /// # Arguments
40    ///
41    /// * `window_size` - Maximum number of messages to keep in memory
42    ///
43    /// # Panics
44    ///
45    /// Panics if `window_size` is 0
46    ///
47    pub fn new(window_size: usize) -> Self {
48        Self::with_strategy(window_size, TrimStrategy::Drop)
49    }
50
51    /// Create a new sliding window memory with specified trim strategy
52    ///
53    /// # Arguments
54    ///
55    /// * `window_size` - Maximum number of messages to keep in memory
56    /// * `strategy` - How to handle overflow when window is full
57    pub fn with_strategy(window_size: usize, strategy: TrimStrategy) -> Self {
58        if window_size == 0 {
59            panic!("Window size must be greater than 0");
60        }
61
62        Self {
63            messages: VecDeque::with_capacity(window_size),
64            window_size,
65            trim_strategy: strategy,
66            needs_summary: false,
67        }
68    }
69
70    /// Get the configured window size.
71    ///
72    /// # Returns
73    ///
74    /// The maximum number of messages this memory can hold
75    pub fn window_size(&self) -> usize {
76        self.window_size
77    }
78
79    /// Get all stored messages in chronological order.
80    ///
81    /// # Returns
82    ///
83    /// A vector containing all messages from oldest to newest
84    pub fn messages(&self) -> Vec<ChatMessage> {
85        Vec::from(self.messages.clone())
86    }
87
88    /// Get the most recent N messages.
89    ///
90    /// # Arguments
91    ///
92    /// * `limit` - Maximum number of recent messages to return
93    ///
94    /// # Returns
95    ///
96    /// A vector containing the most recent messages, up to `limit`
97    pub fn recent_messages(&self, limit: usize) -> Vec<ChatMessage> {
98        let len = self.messages.len();
99        let start = len.saturating_sub(limit);
100        self.messages.range(start..).cloned().collect()
101    }
102
103    /// Check if memory needs summarization
104    pub fn needs_summary(&self) -> bool {
105        self.needs_summary
106    }
107
108    /// Mark memory as needing summarization
109    pub fn mark_for_summary(&mut self) {
110        self.needs_summary = true;
111    }
112
113    /// Replace all messages with a summary
114    ///
115    /// # Arguments
116    ///
117    /// * `summary` - The summary text to replace all messages with
118    pub fn replace_with_summary(&mut self, summary: String) {
119        self.messages.clear();
120        self.messages
121            .push_back(ChatMessage::assistant().content(summary).build());
122        self.needs_summary = false;
123    }
124}
125
126#[async_trait]
127impl MemoryProvider for SlidingWindowMemory {
128    async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
129        if self.messages.len() >= self.window_size {
130            match self.trim_strategy {
131                TrimStrategy::Drop => {
132                    self.messages.pop_front();
133                }
134                TrimStrategy::Summarize => {
135                    self.mark_for_summary();
136                }
137            }
138        }
139        self.messages.push_back(message.clone());
140        Ok(())
141    }
142
143    async fn recall(
144        &self,
145        _query: &str,
146        limit: Option<usize>,
147    ) -> Result<Vec<ChatMessage>, LLMError> {
148        let limit = limit.unwrap_or(self.messages.len());
149        Ok(self.recent_messages(limit))
150    }
151
152    async fn clear(&mut self) -> Result<(), LLMError> {
153        self.messages.clear();
154        Ok(())
155    }
156
157    fn memory_type(&self) -> MemoryType {
158        MemoryType::SlidingWindow
159    }
160
161    fn size(&self) -> usize {
162        self.messages.len()
163    }
164
165    fn needs_summary(&self) -> bool {
166        self.needs_summary
167    }
168
169    fn mark_for_summary(&mut self) {
170        self.needs_summary = true;
171    }
172
173    fn replace_with_summary(&mut self, summary: String) {
174        self.messages.clear();
175        self.messages
176            .push_back(ChatMessage::assistant().content(summary).build());
177        self.needs_summary = false;
178    }
179
180    fn clone_box(&self) -> Box<dyn MemoryProvider> {
181        Box::new(self.clone())
182    }
183
184    fn preload(&mut self, data: Vec<ChatMessage>) -> bool {
185        self.messages.clear();
186        for msg in data {
187            self.messages.push_back(msg);
188        }
189        true
190    }
191
192    fn export(&self) -> Vec<ChatMessage> {
193        Vec::from(self.messages.clone())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
201
202    #[test]
203    fn test_new_sliding_window_memory() {
204        let memory = SlidingWindowMemory::new(5);
205        assert_eq!(memory.window_size(), 5);
206        assert_eq!(memory.size(), 0);
207        assert!(memory.is_empty());
208        assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
209    }
210
211    #[test]
212    fn test_sliding_window_memory_with_strategy() {
213        let memory = SlidingWindowMemory::with_strategy(3, TrimStrategy::Summarize);
214        assert_eq!(memory.window_size(), 3);
215        assert_eq!(memory.size(), 0);
216        assert!(memory.is_empty());
217    }
218
219    #[test]
220    #[should_panic(expected = "Window size must be greater than 0")]
221    fn test_new_sliding_window_memory_zero_size() {
222        SlidingWindowMemory::new(0);
223    }
224
225    #[tokio::test]
226    async fn test_remember_single_message() {
227        let mut memory = SlidingWindowMemory::new(3);
228        let message = ChatMessage {
229            role: ChatRole::User,
230            message_type: MessageType::Text,
231            content: "Hello".to_string(),
232        };
233
234        memory.remember(&message).await.unwrap();
235        assert_eq!(memory.size(), 1);
236        assert!(!memory.is_empty());
237
238        let messages = memory.messages();
239        assert_eq!(messages.len(), 1);
240        assert_eq!(messages[0].content, "Hello");
241    }
242
243    #[tokio::test]
244    async fn test_remember_multiple_messages() {
245        let mut memory = SlidingWindowMemory::new(3);
246
247        for i in 1..=3 {
248            let message = ChatMessage {
249                role: ChatRole::User,
250                message_type: MessageType::Text,
251                content: format!("Message {i}"),
252            };
253            memory.remember(&message).await.unwrap();
254        }
255
256        assert_eq!(memory.size(), 3);
257        let messages = memory.messages();
258        assert_eq!(messages.len(), 3);
259        assert_eq!(messages[0].content, "Message 1");
260        assert_eq!(messages[2].content, "Message 3");
261    }
262
263    #[tokio::test]
264    async fn test_sliding_window_overflow_drop_strategy() {
265        let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Drop);
266
267        // Add 3 messages to a window of size 2
268        for i in 1..=3 {
269            let message = ChatMessage {
270                role: ChatRole::User,
271                message_type: MessageType::Text,
272                content: format!("Message {i}"),
273            };
274            memory.remember(&message).await.unwrap();
275        }
276
277        // Should only keep the last 2 messages
278        assert_eq!(memory.size(), 2);
279        let messages = memory.messages();
280        assert_eq!(messages[0].content, "Message 2");
281        assert_eq!(messages[1].content, "Message 3");
282    }
283
284    #[tokio::test]
285    async fn test_sliding_window_overflow_summarize_strategy() {
286        let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Summarize);
287
288        // Add first message
289        let message1 = ChatMessage {
290            role: ChatRole::User,
291            message_type: MessageType::Text,
292            content: "First message".to_string(),
293        };
294        memory.remember(&message1).await.unwrap();
295
296        // Add second message
297        let message2 = ChatMessage {
298            role: ChatRole::User,
299            message_type: MessageType::Text,
300            content: "Second message".to_string(),
301        };
302        memory.remember(&message2).await.unwrap();
303
304        // Add third message - should trigger summarize flag
305        let message3 = ChatMessage {
306            role: ChatRole::User,
307            message_type: MessageType::Text,
308            content: "Third message".to_string(),
309        };
310        memory.remember(&message3).await.unwrap();
311
312        assert!(memory.needs_summary());
313        assert_eq!(memory.size(), 3); // Still contains all messages until summarized
314    }
315
316    #[tokio::test]
317    async fn test_recall_all_messages() {
318        let mut memory = SlidingWindowMemory::new(3);
319
320        for i in 1..=3 {
321            let message = ChatMessage {
322                role: ChatRole::User,
323                message_type: MessageType::Text,
324                content: format!("Message {i}"),
325            };
326            memory.remember(&message).await.unwrap();
327        }
328
329        let recalled = memory.recall("", None).await.unwrap();
330        assert_eq!(recalled.len(), 3);
331        assert_eq!(recalled[0].content, "Message 1");
332        assert_eq!(recalled[2].content, "Message 3");
333    }
334
335    #[tokio::test]
336    async fn test_recall_with_limit() {
337        let mut memory = SlidingWindowMemory::new(5);
338
339        for i in 1..=5 {
340            let message = ChatMessage {
341                role: ChatRole::User,
342                message_type: MessageType::Text,
343                content: format!("Message {i}"),
344            };
345            memory.remember(&message).await.unwrap();
346        }
347
348        let recalled = memory.recall("", Some(2)).await.unwrap();
349        assert_eq!(recalled.len(), 2);
350        assert_eq!(recalled[0].content, "Message 4");
351        assert_eq!(recalled[1].content, "Message 5");
352    }
353
354    #[tokio::test]
355    async fn test_clear_memory() {
356        let mut memory = SlidingWindowMemory::new(3);
357
358        let message = ChatMessage {
359            role: ChatRole::User,
360            message_type: MessageType::Text,
361            content: "Test message".to_string(),
362        };
363        memory.remember(&message).await.unwrap();
364
365        assert_eq!(memory.size(), 1);
366        memory.clear().await.unwrap();
367        assert_eq!(memory.size(), 0);
368        assert!(memory.is_empty());
369    }
370
371    #[test]
372    fn test_recent_messages() {
373        let mut memory = SlidingWindowMemory::new(5);
374
375        // Add messages directly to the internal deque for testing
376        for i in 1..=5 {
377            let message = ChatMessage {
378                role: ChatRole::User,
379                message_type: MessageType::Text,
380                content: format!("Message {i}"),
381            };
382            memory.messages.push_back(message);
383        }
384
385        let recent = memory.recent_messages(3);
386        assert_eq!(recent.len(), 3);
387        assert_eq!(recent[0].content, "Message 3");
388        assert_eq!(recent[2].content, "Message 5");
389    }
390
391    #[test]
392    fn test_recent_messages_limit_exceeds_size() {
393        let mut memory = SlidingWindowMemory::new(5);
394
395        // Add only 2 messages
396        for i in 1..=2 {
397            let message = ChatMessage {
398                role: ChatRole::User,
399                message_type: MessageType::Text,
400                content: format!("Message {i}"),
401            };
402            memory.messages.push_back(message);
403        }
404
405        let recent = memory.recent_messages(10);
406        assert_eq!(recent.len(), 2);
407        assert_eq!(recent[0].content, "Message 1");
408        assert_eq!(recent[1].content, "Message 2");
409    }
410
411    #[test]
412    fn test_mark_for_summary() {
413        let mut memory = SlidingWindowMemory::new(3);
414        assert!(!memory.needs_summary());
415
416        memory.mark_for_summary();
417        assert!(memory.needs_summary());
418    }
419
420    #[test]
421    fn test_replace_with_summary() {
422        let mut memory = SlidingWindowMemory::new(3);
423
424        // Add some messages
425        for i in 1..=3 {
426            let message = ChatMessage {
427                role: ChatRole::User,
428                message_type: MessageType::Text,
429                content: format!("Message {i}"),
430            };
431            memory.messages.push_back(message);
432        }
433
434        memory.mark_for_summary();
435        assert!(memory.needs_summary());
436        assert_eq!(memory.size(), 3);
437
438        memory.replace_with_summary("This is a summary".to_string());
439
440        assert!(!memory.needs_summary());
441        assert_eq!(memory.size(), 1);
442        let messages = memory.messages();
443        assert_eq!(messages[0].content, "This is a summary");
444        assert_eq!(messages[0].role, ChatRole::Assistant);
445    }
446
447    #[test]
448    fn test_memory_provider_trait_methods() {
449        let memory = SlidingWindowMemory::new(3);
450
451        // Test trait methods
452        assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
453        assert_eq!(memory.size(), 0);
454        assert!(memory.is_empty());
455        assert!(!memory.needs_summary());
456        assert!(memory.get_event_receiver().is_none());
457    }
458}