1use async_trait::async_trait;
2use autoagents_llm::{chat::ChatMessage, error::LLMError};
3use regex::Regex;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6use tokio::sync::broadcast;
7
8mod sliding_window;
9pub use sliding_window::SlidingWindowMemory;
10
11#[cfg(test)]
12mod tests {
13 use super::*;
14 use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
15 use autoagents_llm::error::LLMError;
16 use std::sync::Arc;
17
18 struct MockMemoryProvider {
20 messages: Vec<ChatMessage>,
21 should_fail: bool,
22 }
23
24 impl MockMemoryProvider {
25 fn new() -> Self {
26 Self {
27 messages: Vec::new(),
28 should_fail: false,
29 }
30 }
31
32 fn with_failure() -> Self {
33 Self {
34 messages: Vec::new(),
35 should_fail: true,
36 }
37 }
38
39 fn with_messages(messages: Vec<ChatMessage>) -> Self {
40 Self {
41 messages,
42 should_fail: false,
43 }
44 }
45 }
46
47 #[async_trait::async_trait]
48 impl MemoryProvider for MockMemoryProvider {
49 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
50 if self.should_fail {
51 return Err(LLMError::ProviderError("Mock memory error".to_string()));
52 }
53 self.messages.push(message.clone());
54 Ok(())
55 }
56
57 async fn recall(
58 &self,
59 _query: &str,
60 limit: Option<usize>,
61 ) -> Result<Vec<ChatMessage>, LLMError> {
62 if self.should_fail {
63 return Err(LLMError::ProviderError("Mock recall error".to_string()));
64 }
65 let limit = limit.unwrap_or(self.messages.len());
66 Ok(self.messages.iter().take(limit).cloned().collect())
67 }
68
69 async fn clear(&mut self) -> Result<(), LLMError> {
70 if self.should_fail {
71 return Err(LLMError::ProviderError("Mock clear error".to_string()));
72 }
73 self.messages.clear();
74 Ok(())
75 }
76
77 fn memory_type(&self) -> MemoryType {
78 MemoryType::SlidingWindow
79 }
80
81 fn size(&self) -> usize {
82 self.messages.len()
83 }
84 }
85
86 #[test]
87 fn test_memory_type_serialization() {
88 let sliding_window = MemoryType::SlidingWindow;
89 let serialized = serde_json::to_string(&sliding_window).unwrap();
90 assert_eq!(serialized, "\"SlidingWindow\"");
91 }
92
93 #[test]
94 fn test_memory_type_deserialization() {
95 let deserialized: MemoryType = serde_json::from_str("\"SlidingWindow\"").unwrap();
96 assert_eq!(deserialized, MemoryType::SlidingWindow);
97 }
98
99 #[test]
100 fn test_memory_type_debug() {
101 let sliding_window = MemoryType::SlidingWindow;
102 let debug_str = format!("{sliding_window:?}");
103 assert!(debug_str.contains("SlidingWindow"));
104 }
105
106 #[test]
107 fn test_memory_type_clone() {
108 let sliding_window = MemoryType::SlidingWindow;
109 let cloned = sliding_window.clone();
110 assert_eq!(sliding_window, cloned);
111 }
112
113 #[test]
114 fn test_memory_type_equality() {
115 let type1 = MemoryType::SlidingWindow;
116 let type2 = MemoryType::SlidingWindow;
117 assert_eq!(type1, type2);
118 }
119
120 #[test]
121 fn test_message_condition_any() {
122 let condition = MessageCondition::Any;
123 let message = ChatMessage {
124 role: ChatRole::User,
125 message_type: MessageType::Text,
126 content: "test message".to_string(),
127 };
128 let event = MessageEvent {
129 role: "user".to_string(),
130 msg: message,
131 };
132 assert!(condition.matches(&event));
133 }
134
135 #[test]
136 fn test_message_condition_eq() {
137 let condition = MessageCondition::Eq("test message".to_string());
138 let message = ChatMessage {
139 role: ChatRole::User,
140 message_type: MessageType::Text,
141 content: "test message".to_string(),
142 };
143 let event = MessageEvent {
144 role: "user".to_string(),
145 msg: message,
146 };
147 assert!(condition.matches(&event));
148
149 let different_message = ChatMessage {
150 role: ChatRole::User,
151 message_type: MessageType::Text,
152 content: "different message".to_string(),
153 };
154 let different_event = MessageEvent {
155 role: "user".to_string(),
156 msg: different_message,
157 };
158 assert!(!condition.matches(&different_event));
159 }
160
161 #[test]
162 fn test_message_condition_contains() {
163 let condition = MessageCondition::Contains("test".to_string());
164 let message = ChatMessage {
165 role: ChatRole::User,
166 message_type: MessageType::Text,
167 content: "this is a test message".to_string(),
168 };
169 let event = MessageEvent {
170 role: "user".to_string(),
171 msg: message,
172 };
173 assert!(condition.matches(&event));
174
175 let non_matching_message = ChatMessage {
176 role: ChatRole::User,
177 message_type: MessageType::Text,
178 content: "this is different".to_string(),
179 };
180 let non_matching_event = MessageEvent {
181 role: "user".to_string(),
182 msg: non_matching_message,
183 };
184 assert!(!condition.matches(&non_matching_event));
185 }
186
187 #[test]
188 fn test_message_condition_not_contains() {
189 let condition = MessageCondition::NotContains("error".to_string());
190 let message = ChatMessage {
191 role: ChatRole::User,
192 message_type: MessageType::Text,
193 content: "this is a test message".to_string(),
194 };
195 let event = MessageEvent {
196 role: "user".to_string(),
197 msg: message,
198 };
199 assert!(condition.matches(&event));
200
201 let error_message = ChatMessage {
202 role: ChatRole::User,
203 message_type: MessageType::Text,
204 content: "this is an error message".to_string(),
205 };
206 let error_event = MessageEvent {
207 role: "user".to_string(),
208 msg: error_message,
209 };
210 assert!(!condition.matches(&error_event));
211 }
212
213 #[test]
214 fn test_message_condition_role_is() {
215 let condition = MessageCondition::RoleIs("user".to_string());
216 let message = ChatMessage {
217 role: ChatRole::User,
218 message_type: MessageType::Text,
219 content: "test message".to_string(),
220 };
221 let event = MessageEvent {
222 role: "user".to_string(),
223 msg: message,
224 };
225 assert!(condition.matches(&event));
226
227 let assistant_event = MessageEvent {
228 role: "assistant".to_string(),
229 msg: ChatMessage {
230 role: ChatRole::Assistant,
231 message_type: MessageType::Text,
232 content: "test message".to_string(),
233 },
234 };
235 assert!(!condition.matches(&assistant_event));
236 }
237
238 #[test]
239 fn test_message_condition_role_not() {
240 let condition = MessageCondition::RoleNot("system".to_string());
241 let message = ChatMessage {
242 role: ChatRole::User,
243 message_type: MessageType::Text,
244 content: "test message".to_string(),
245 };
246 let event = MessageEvent {
247 role: "user".to_string(),
248 msg: message,
249 };
250 assert!(condition.matches(&event));
251
252 let system_event = MessageEvent {
253 role: "system".to_string(),
254 msg: ChatMessage {
255 role: ChatRole::System,
256 message_type: MessageType::Text,
257 content: "test message".to_string(),
258 },
259 };
260 assert!(!condition.matches(&system_event));
261 }
262
263 #[test]
264 fn test_message_condition_len_gt() {
265 let condition = MessageCondition::LenGt(5);
266 let long_message = ChatMessage {
267 role: ChatRole::User,
268 message_type: MessageType::Text,
269 content: "this is a long message".to_string(),
270 };
271 let long_event = MessageEvent {
272 role: "user".to_string(),
273 msg: long_message,
274 };
275 assert!(condition.matches(&long_event));
276
277 let short_message = ChatMessage {
278 role: ChatRole::User,
279 message_type: MessageType::Text,
280 content: "hi".to_string(),
281 };
282 let short_event = MessageEvent {
283 role: "user".to_string(),
284 msg: short_message,
285 };
286 assert!(!condition.matches(&short_event));
287 }
288
289 #[test]
290 fn test_message_condition_custom() {
291 let condition = MessageCondition::Custom(Arc::new(|msg| msg.content.starts_with("hello")));
292 let hello_message = ChatMessage {
293 role: ChatRole::User,
294 message_type: MessageType::Text,
295 content: "hello world".to_string(),
296 };
297 let hello_event = MessageEvent {
298 role: "user".to_string(),
299 msg: hello_message,
300 };
301 assert!(condition.matches(&hello_event));
302
303 let goodbye_message = ChatMessage {
304 role: ChatRole::User,
305 message_type: MessageType::Text,
306 content: "goodbye world".to_string(),
307 };
308 let goodbye_event = MessageEvent {
309 role: "user".to_string(),
310 msg: goodbye_message,
311 };
312 assert!(!condition.matches(&goodbye_event));
313 }
314
315 #[test]
316 fn test_message_condition_empty() {
317 let condition = MessageCondition::Empty;
318 let empty_message = ChatMessage {
319 role: ChatRole::User,
320 message_type: MessageType::Text,
321 content: "".to_string(),
322 };
323 let empty_event = MessageEvent {
324 role: "user".to_string(),
325 msg: empty_message,
326 };
327 assert!(condition.matches(&empty_event));
328
329 let non_empty_message = ChatMessage {
330 role: ChatRole::User,
331 message_type: MessageType::Text,
332 content: "not empty".to_string(),
333 };
334 let non_empty_event = MessageEvent {
335 role: "user".to_string(),
336 msg: non_empty_message,
337 };
338 assert!(!condition.matches(&non_empty_event));
339 }
340
341 #[test]
342 fn test_message_condition_all() {
343 let condition = MessageCondition::All(vec![
344 MessageCondition::Contains("test".to_string()),
345 MessageCondition::LenGt(5),
346 MessageCondition::RoleIs("user".to_string()),
347 ]);
348
349 let matching_message = ChatMessage {
350 role: ChatRole::User,
351 message_type: MessageType::Text,
352 content: "this is a test message".to_string(),
353 };
354 let matching_event = MessageEvent {
355 role: "user".to_string(),
356 msg: matching_message,
357 };
358 assert!(condition.matches(&matching_event));
359
360 let non_matching_message = ChatMessage {
361 role: ChatRole::User,
362 message_type: MessageType::Text,
363 content: "hi".to_string(),
364 };
365 let non_matching_event = MessageEvent {
366 role: "user".to_string(),
367 msg: non_matching_message,
368 };
369 assert!(!condition.matches(&non_matching_event));
370 }
371
372 #[test]
373 fn test_message_condition_any_of() {
374 let condition = MessageCondition::AnyOf(vec![
375 MessageCondition::Contains("hello".to_string()),
376 MessageCondition::Contains("goodbye".to_string()),
377 MessageCondition::Empty,
378 ]);
379
380 let hello_message = ChatMessage {
381 role: ChatRole::User,
382 message_type: MessageType::Text,
383 content: "hello world".to_string(),
384 };
385 let hello_event = MessageEvent {
386 role: "user".to_string(),
387 msg: hello_message,
388 };
389 assert!(condition.matches(&hello_event));
390
391 let goodbye_message = ChatMessage {
392 role: ChatRole::User,
393 message_type: MessageType::Text,
394 content: "goodbye world".to_string(),
395 };
396 let goodbye_event = MessageEvent {
397 role: "user".to_string(),
398 msg: goodbye_message,
399 };
400 assert!(condition.matches(&goodbye_event));
401
402 let empty_message = ChatMessage {
403 role: ChatRole::User,
404 message_type: MessageType::Text,
405 content: "".to_string(),
406 };
407 let empty_event = MessageEvent {
408 role: "user".to_string(),
409 msg: empty_message,
410 };
411 assert!(condition.matches(&empty_event));
412
413 let non_matching_message = ChatMessage {
414 role: ChatRole::User,
415 message_type: MessageType::Text,
416 content: "test message".to_string(),
417 };
418 let non_matching_event = MessageEvent {
419 role: "user".to_string(),
420 msg: non_matching_message,
421 };
422 assert!(!condition.matches(&non_matching_event));
423 }
424
425 #[test]
426 fn test_message_condition_regex() {
427 let condition = MessageCondition::Regex(r"\d+".to_string());
428 let number_message = ChatMessage {
429 role: ChatRole::User,
430 message_type: MessageType::Text,
431 content: "there are 123 items".to_string(),
432 };
433 let number_event = MessageEvent {
434 role: "user".to_string(),
435 msg: number_message,
436 };
437 assert!(condition.matches(&number_event));
438
439 let no_number_message = ChatMessage {
440 role: ChatRole::User,
441 message_type: MessageType::Text,
442 content: "no numbers here".to_string(),
443 };
444 let no_number_event = MessageEvent {
445 role: "user".to_string(),
446 msg: no_number_message,
447 };
448 assert!(!condition.matches(&no_number_event));
449 }
450
451 #[test]
452 fn test_message_condition_regex_invalid() {
453 let condition = MessageCondition::Regex("[invalid regex".to_string());
454 let message = ChatMessage {
455 role: ChatRole::User,
456 message_type: MessageType::Text,
457 content: "test message".to_string(),
458 };
459 let event = MessageEvent {
460 role: "user".to_string(),
461 msg: message,
462 };
463 assert!(!condition.matches(&event));
464 }
465
466 #[test]
467 fn test_message_event_creation() {
468 let message = ChatMessage {
469 role: ChatRole::User,
470 message_type: MessageType::Text,
471 content: "test message".to_string(),
472 };
473 let event = MessageEvent {
474 role: "user".to_string(),
475 msg: message.clone(),
476 };
477 assert_eq!(event.role, "user");
478 assert_eq!(event.msg.content, "test message");
479 }
480
481 #[test]
482 fn test_message_event_clone() {
483 let message = ChatMessage {
484 role: ChatRole::User,
485 message_type: MessageType::Text,
486 content: "test message".to_string(),
487 };
488 let event = MessageEvent {
489 role: "user".to_string(),
490 msg: message,
491 };
492 let cloned = event.clone();
493 assert_eq!(event.role, cloned.role);
494 assert_eq!(event.msg.content, cloned.msg.content);
495 }
496
497 #[test]
498 fn test_message_event_debug() {
499 let message = ChatMessage {
500 role: ChatRole::User,
501 message_type: MessageType::Text,
502 content: "test message".to_string(),
503 };
504 let event = MessageEvent {
505 role: "user".to_string(),
506 msg: message,
507 };
508 let debug_str = format!("{event:?}");
509 assert!(debug_str.contains("MessageEvent"));
510 assert!(debug_str.contains("user"));
511 }
512
513 #[tokio::test]
514 async fn test_mock_memory_provider_remember() {
515 let mut provider = MockMemoryProvider::new();
516 let message = ChatMessage {
517 role: ChatRole::User,
518 message_type: MessageType::Text,
519 content: "test message".to_string(),
520 };
521
522 let result = provider.remember(&message).await;
523 assert!(result.is_ok());
524 assert_eq!(provider.size(), 1);
525 }
526
527 #[tokio::test]
528 async fn test_mock_memory_provider_remember_failure() {
529 let mut provider = MockMemoryProvider::with_failure();
530 let message = ChatMessage {
531 role: ChatRole::User,
532 message_type: MessageType::Text,
533 content: "test message".to_string(),
534 };
535
536 let result = provider.remember(&message).await;
537 assert!(result.is_err());
538 assert!(result
539 .unwrap_err()
540 .to_string()
541 .contains("Mock memory error"));
542 }
543
544 #[tokio::test]
545 async fn test_mock_memory_provider_recall() {
546 let messages = vec![
547 ChatMessage {
548 role: ChatRole::User,
549 message_type: MessageType::Text,
550 content: "first message".to_string(),
551 },
552 ChatMessage {
553 role: ChatRole::Assistant,
554 message_type: MessageType::Text,
555 content: "second message".to_string(),
556 },
557 ];
558 let provider = MockMemoryProvider::with_messages(messages);
559
560 let result = provider.recall("", None).await;
561 assert!(result.is_ok());
562 let recalled = result.unwrap();
563 assert_eq!(recalled.len(), 2);
564 assert_eq!(recalled[0].content, "first message");
565 assert_eq!(recalled[1].content, "second message");
566 }
567
568 #[tokio::test]
569 async fn test_mock_memory_provider_recall_with_limit() {
570 let messages = vec![
571 ChatMessage {
572 role: ChatRole::User,
573 message_type: MessageType::Text,
574 content: "first message".to_string(),
575 },
576 ChatMessage {
577 role: ChatRole::Assistant,
578 message_type: MessageType::Text,
579 content: "second message".to_string(),
580 },
581 ];
582 let provider = MockMemoryProvider::with_messages(messages);
583
584 let result = provider.recall("", Some(1)).await;
585 assert!(result.is_ok());
586 let recalled = result.unwrap();
587 assert_eq!(recalled.len(), 1);
588 assert_eq!(recalled[0].content, "first message");
589 }
590
591 #[tokio::test]
592 async fn test_mock_memory_provider_recall_failure() {
593 let provider = MockMemoryProvider::with_failure();
594
595 let result = provider.recall("", None).await;
596 assert!(result.is_err());
597 assert!(result
598 .unwrap_err()
599 .to_string()
600 .contains("Mock recall error"));
601 }
602
603 #[tokio::test]
604 async fn test_mock_memory_provider_clear() {
605 let messages = vec![ChatMessage {
606 role: ChatRole::User,
607 message_type: MessageType::Text,
608 content: "first message".to_string(),
609 }];
610 let mut provider = MockMemoryProvider::with_messages(messages);
611 assert_eq!(provider.size(), 1);
612
613 let result = provider.clear().await;
614 assert!(result.is_ok());
615 assert_eq!(provider.size(), 0);
616 }
617
618 #[tokio::test]
619 async fn test_mock_memory_provider_clear_failure() {
620 let mut provider = MockMemoryProvider::with_failure();
621
622 let result = provider.clear().await;
623 assert!(result.is_err());
624 assert!(result.unwrap_err().to_string().contains("Mock clear error"));
625 }
626
627 #[test]
628 fn test_mock_memory_provider_memory_type() {
629 let provider = MockMemoryProvider::new();
630 assert_eq!(provider.memory_type(), MemoryType::SlidingWindow);
631 }
632
633 #[test]
634 fn test_mock_memory_provider_size() {
635 let provider = MockMemoryProvider::new();
636 assert_eq!(provider.size(), 0);
637
638 let messages = vec![ChatMessage {
639 role: ChatRole::User,
640 message_type: MessageType::Text,
641 content: "message".to_string(),
642 }];
643 let provider_with_messages = MockMemoryProvider::with_messages(messages);
644 assert_eq!(provider_with_messages.size(), 1);
645 }
646
647 #[test]
648 fn test_mock_memory_provider_is_empty() {
649 let provider = MockMemoryProvider::new();
650 assert!(provider.is_empty());
651
652 let messages = vec![ChatMessage {
653 role: ChatRole::User,
654 message_type: MessageType::Text,
655 content: "message".to_string(),
656 }];
657 let provider_with_messages = MockMemoryProvider::with_messages(messages);
658 assert!(!provider_with_messages.is_empty());
659 }
660
661 #[test]
662 fn test_memory_provider_default_methods() {
663 let provider = MockMemoryProvider::new();
664 assert!(!provider.needs_summary());
665 assert!(provider.get_event_receiver().is_none());
666 }
667
668 #[tokio::test]
669 async fn test_memory_provider_remember_with_role() {
670 let mut provider = MockMemoryProvider::new();
671 let message = ChatMessage {
672 role: ChatRole::User,
673 message_type: MessageType::Text,
674 content: "test message".to_string(),
675 };
676
677 let result = provider
678 .remember_with_role(&message, "custom_role".to_string())
679 .await;
680 assert!(result.is_ok());
681 assert_eq!(provider.size(), 1);
682 }
683
684 #[test]
685 fn test_memory_provider_mark_for_summary() {
686 let mut provider = MockMemoryProvider::new();
687 provider.mark_for_summary(); assert!(!provider.needs_summary()); }
690
691 #[test]
692 fn test_memory_provider_replace_with_summary() {
693 let mut provider = MockMemoryProvider::new();
694 provider.replace_with_summary("Summary text".to_string()); assert_eq!(provider.size(), 0); }
697}
698
699#[derive(Debug, Clone)]
701pub struct MessageEvent {
702 pub role: String,
704 pub msg: ChatMessage,
706}
707
708#[derive(Clone)]
710pub enum MessageCondition {
711 Any,
713 Eq(String),
715 Contains(String),
717 NotContains(String),
719 RoleIs(String),
721 RoleNot(String),
723 LenGt(usize),
725 Custom(Arc<dyn Fn(&ChatMessage) -> bool + Send + Sync>),
727 Empty,
729 All(Vec<MessageCondition>),
731 AnyOf(Vec<MessageCondition>),
733 Regex(String),
735}
736
737impl MessageCondition {
738 pub fn matches(&self, event: &MessageEvent) -> bool {
740 match self {
741 MessageCondition::Any => true,
742 MessageCondition::Eq(text) => event.msg.content == *text,
743 MessageCondition::Contains(text) => event.msg.content.contains(text),
744 MessageCondition::NotContains(text) => !event.msg.content.contains(text),
745 MessageCondition::RoleIs(role) => event.role == *role,
746 MessageCondition::RoleNot(role) => event.role != *role,
747 MessageCondition::LenGt(len) => event.msg.content.len() > *len,
748 MessageCondition::Custom(func) => func(&event.msg),
749 MessageCondition::Empty => event.msg.content.is_empty(),
750 MessageCondition::All(inner) => inner.iter().all(|c| c.matches(event)),
751 MessageCondition::AnyOf(inner) => inner.iter().any(|c| c.matches(event)),
752 MessageCondition::Regex(regex) => Regex::new(regex)
753 .map(|re| re.is_match(&event.msg.content))
754 .unwrap_or(false),
755 }
756 }
757}
758
759#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
761pub enum MemoryType {
762 SlidingWindow,
764}
765
766#[async_trait]
773pub trait MemoryProvider: Send + Sync {
774 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError>;
785
786 async fn recall(&self, query: &str, limit: Option<usize>)
798 -> Result<Vec<ChatMessage>, LLMError>;
799
800 async fn clear(&mut self) -> Result<(), LLMError>;
807
808 fn memory_type(&self) -> MemoryType;
814
815 fn size(&self) -> usize;
821
822 fn is_empty(&self) -> bool {
828 self.size() == 0
829 }
830
831 fn needs_summary(&self) -> bool {
833 false
834 }
835
836 fn mark_for_summary(&mut self) {}
838
839 fn replace_with_summary(&mut self, _summary: String) {}
841
842 fn get_event_receiver(&self) -> Option<broadcast::Receiver<MessageEvent>> {
844 None
845 }
846
847 async fn remember_with_role(
849 &mut self,
850 message: &ChatMessage,
851 _role: String,
852 ) -> Result<(), LLMError> {
853 self.remember(message).await
854 }
855}