1use async_trait::async_trait;
6use autoagents_llm::{chat::ChatMessage, error::LLMError};
7use std::collections::VecDeque;
8
9use super::{MemoryProvider, MemoryType};
10
11#[derive(Debug, Clone)]
13pub enum TrimStrategy {
14 Drop,
16 Summarize,
18}
19
20#[derive(Debug, Clone)]
29pub struct SlidingWindowMemory {
30 messages: VecDeque<ChatMessage>,
31 window_size: usize,
32 trim_strategy: TrimStrategy,
33 needs_summary: bool,
34}
35
36impl SlidingWindowMemory {
37 pub fn new(window_size: usize) -> Self {
48 Self::with_strategy(window_size, TrimStrategy::Drop)
49 }
50
51 pub fn with_strategy(window_size: usize, strategy: TrimStrategy) -> Self {
58 if window_size == 0 {
59 panic!("Window size must be greater than 0");
60 }
61
62 Self {
63 messages: VecDeque::with_capacity(window_size),
64 window_size,
65 trim_strategy: strategy,
66 needs_summary: false,
67 }
68 }
69
70 pub fn window_size(&self) -> usize {
76 self.window_size
77 }
78
79 pub fn messages(&self) -> Vec<ChatMessage> {
85 Vec::from(self.messages.clone())
86 }
87
88 pub fn recent_messages(&self, limit: usize) -> Vec<ChatMessage> {
98 let len = self.messages.len();
99 let start = len.saturating_sub(limit);
100 self.messages.range(start..).cloned().collect()
101 }
102
103 pub fn needs_summary(&self) -> bool {
105 self.needs_summary
106 }
107
108 pub fn mark_for_summary(&mut self) {
110 self.needs_summary = true;
111 }
112
113 pub fn replace_with_summary(&mut self, summary: String) {
119 self.messages.clear();
120 self.messages
121 .push_back(ChatMessage::assistant().content(summary).build());
122 self.needs_summary = false;
123 }
124}
125
126#[async_trait]
127impl MemoryProvider for SlidingWindowMemory {
128 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
129 if self.messages.len() >= self.window_size {
130 match self.trim_strategy {
131 TrimStrategy::Drop => {
132 self.messages.pop_front();
133 }
134 TrimStrategy::Summarize => {
135 self.mark_for_summary();
136 }
137 }
138 }
139 self.messages.push_back(message.clone());
140 Ok(())
141 }
142
143 async fn recall(
144 &self,
145 _query: &str,
146 limit: Option<usize>,
147 ) -> Result<Vec<ChatMessage>, LLMError> {
148 let limit = limit.unwrap_or(self.messages.len());
149 Ok(self.recent_messages(limit))
150 }
151
152 async fn clear(&mut self) -> Result<(), LLMError> {
153 self.messages.clear();
154 Ok(())
155 }
156
157 fn memory_type(&self) -> MemoryType {
158 MemoryType::SlidingWindow
159 }
160
161 fn size(&self) -> usize {
162 self.messages.len()
163 }
164
165 fn needs_summary(&self) -> bool {
166 self.needs_summary
167 }
168
169 fn mark_for_summary(&mut self) {
170 self.needs_summary = true;
171 }
172
173 fn replace_with_summary(&mut self, summary: String) {
174 self.messages.clear();
175 self.messages
176 .push_back(ChatMessage::assistant().content(summary).build());
177 self.needs_summary = false;
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
185
186 #[test]
187 fn test_new_sliding_window_memory() {
188 let memory = SlidingWindowMemory::new(5);
189 assert_eq!(memory.window_size(), 5);
190 assert_eq!(memory.size(), 0);
191 assert!(memory.is_empty());
192 assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
193 }
194
195 #[test]
196 fn test_sliding_window_memory_with_strategy() {
197 let memory = SlidingWindowMemory::with_strategy(3, TrimStrategy::Summarize);
198 assert_eq!(memory.window_size(), 3);
199 assert_eq!(memory.size(), 0);
200 assert!(memory.is_empty());
201 }
202
203 #[test]
204 #[should_panic(expected = "Window size must be greater than 0")]
205 fn test_new_sliding_window_memory_zero_size() {
206 SlidingWindowMemory::new(0);
207 }
208
209 #[tokio::test]
210 async fn test_remember_single_message() {
211 let mut memory = SlidingWindowMemory::new(3);
212 let message = ChatMessage {
213 role: ChatRole::User,
214 message_type: MessageType::Text,
215 content: "Hello".to_string(),
216 };
217
218 memory.remember(&message).await.unwrap();
219 assert_eq!(memory.size(), 1);
220 assert!(!memory.is_empty());
221
222 let messages = memory.messages();
223 assert_eq!(messages.len(), 1);
224 assert_eq!(messages[0].content, "Hello");
225 }
226
227 #[tokio::test]
228 async fn test_remember_multiple_messages() {
229 let mut memory = SlidingWindowMemory::new(3);
230
231 for i in 1..=3 {
232 let message = ChatMessage {
233 role: ChatRole::User,
234 message_type: MessageType::Text,
235 content: format!("Message {i}"),
236 };
237 memory.remember(&message).await.unwrap();
238 }
239
240 assert_eq!(memory.size(), 3);
241 let messages = memory.messages();
242 assert_eq!(messages.len(), 3);
243 assert_eq!(messages[0].content, "Message 1");
244 assert_eq!(messages[2].content, "Message 3");
245 }
246
247 #[tokio::test]
248 async fn test_sliding_window_overflow_drop_strategy() {
249 let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Drop);
250
251 for i in 1..=3 {
253 let message = ChatMessage {
254 role: ChatRole::User,
255 message_type: MessageType::Text,
256 content: format!("Message {i}"),
257 };
258 memory.remember(&message).await.unwrap();
259 }
260
261 assert_eq!(memory.size(), 2);
263 let messages = memory.messages();
264 assert_eq!(messages[0].content, "Message 2");
265 assert_eq!(messages[1].content, "Message 3");
266 }
267
268 #[tokio::test]
269 async fn test_sliding_window_overflow_summarize_strategy() {
270 let mut memory = SlidingWindowMemory::with_strategy(2, TrimStrategy::Summarize);
271
272 let message1 = ChatMessage {
274 role: ChatRole::User,
275 message_type: MessageType::Text,
276 content: "First message".to_string(),
277 };
278 memory.remember(&message1).await.unwrap();
279
280 let message2 = ChatMessage {
282 role: ChatRole::User,
283 message_type: MessageType::Text,
284 content: "Second message".to_string(),
285 };
286 memory.remember(&message2).await.unwrap();
287
288 let message3 = ChatMessage {
290 role: ChatRole::User,
291 message_type: MessageType::Text,
292 content: "Third message".to_string(),
293 };
294 memory.remember(&message3).await.unwrap();
295
296 assert!(memory.needs_summary());
297 assert_eq!(memory.size(), 3); }
299
300 #[tokio::test]
301 async fn test_recall_all_messages() {
302 let mut memory = SlidingWindowMemory::new(3);
303
304 for i in 1..=3 {
305 let message = ChatMessage {
306 role: ChatRole::User,
307 message_type: MessageType::Text,
308 content: format!("Message {i}"),
309 };
310 memory.remember(&message).await.unwrap();
311 }
312
313 let recalled = memory.recall("", None).await.unwrap();
314 assert_eq!(recalled.len(), 3);
315 assert_eq!(recalled[0].content, "Message 1");
316 assert_eq!(recalled[2].content, "Message 3");
317 }
318
319 #[tokio::test]
320 async fn test_recall_with_limit() {
321 let mut memory = SlidingWindowMemory::new(5);
322
323 for i in 1..=5 {
324 let message = ChatMessage {
325 role: ChatRole::User,
326 message_type: MessageType::Text,
327 content: format!("Message {i}"),
328 };
329 memory.remember(&message).await.unwrap();
330 }
331
332 let recalled = memory.recall("", Some(2)).await.unwrap();
333 assert_eq!(recalled.len(), 2);
334 assert_eq!(recalled[0].content, "Message 4");
335 assert_eq!(recalled[1].content, "Message 5");
336 }
337
338 #[tokio::test]
339 async fn test_clear_memory() {
340 let mut memory = SlidingWindowMemory::new(3);
341
342 let message = ChatMessage {
343 role: ChatRole::User,
344 message_type: MessageType::Text,
345 content: "Test message".to_string(),
346 };
347 memory.remember(&message).await.unwrap();
348
349 assert_eq!(memory.size(), 1);
350 memory.clear().await.unwrap();
351 assert_eq!(memory.size(), 0);
352 assert!(memory.is_empty());
353 }
354
355 #[test]
356 fn test_recent_messages() {
357 let mut memory = SlidingWindowMemory::new(5);
358
359 for i in 1..=5 {
361 let message = ChatMessage {
362 role: ChatRole::User,
363 message_type: MessageType::Text,
364 content: format!("Message {i}"),
365 };
366 memory.messages.push_back(message);
367 }
368
369 let recent = memory.recent_messages(3);
370 assert_eq!(recent.len(), 3);
371 assert_eq!(recent[0].content, "Message 3");
372 assert_eq!(recent[2].content, "Message 5");
373 }
374
375 #[test]
376 fn test_recent_messages_limit_exceeds_size() {
377 let mut memory = SlidingWindowMemory::new(5);
378
379 for i in 1..=2 {
381 let message = ChatMessage {
382 role: ChatRole::User,
383 message_type: MessageType::Text,
384 content: format!("Message {i}"),
385 };
386 memory.messages.push_back(message);
387 }
388
389 let recent = memory.recent_messages(10);
390 assert_eq!(recent.len(), 2);
391 assert_eq!(recent[0].content, "Message 1");
392 assert_eq!(recent[1].content, "Message 2");
393 }
394
395 #[test]
396 fn test_mark_for_summary() {
397 let mut memory = SlidingWindowMemory::new(3);
398 assert!(!memory.needs_summary());
399
400 memory.mark_for_summary();
401 assert!(memory.needs_summary());
402 }
403
404 #[test]
405 fn test_replace_with_summary() {
406 let mut memory = SlidingWindowMemory::new(3);
407
408 for i in 1..=3 {
410 let message = ChatMessage {
411 role: ChatRole::User,
412 message_type: MessageType::Text,
413 content: format!("Message {i}"),
414 };
415 memory.messages.push_back(message);
416 }
417
418 memory.mark_for_summary();
419 assert!(memory.needs_summary());
420 assert_eq!(memory.size(), 3);
421
422 memory.replace_with_summary("This is a summary".to_string());
423
424 assert!(!memory.needs_summary());
425 assert_eq!(memory.size(), 1);
426 let messages = memory.messages();
427 assert_eq!(messages[0].content, "This is a summary");
428 assert_eq!(messages[0].role, ChatRole::Assistant);
429 }
430
431 #[test]
432 fn test_memory_provider_trait_methods() {
433 let memory = SlidingWindowMemory::new(3);
434
435 assert_eq!(memory.memory_type(), MemoryType::SlidingWindow);
437 assert_eq!(memory.size(), 0);
438 assert!(memory.is_empty());
439 assert!(!memory.needs_summary());
440 assert!(memory.get_event_receiver().is_none());
441 }
442}