1use async_trait::async_trait;
2use autoagents_llm::{chat::ChatMessage, error::LLMError};
3use regex::Regex;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6#[cfg(not(target_arch = "wasm32"))]
7use tokio::sync::broadcast;
8
9mod sliding_window;
10pub use sliding_window::SlidingWindowMemory;
11
12#[cfg(test)]
13mod tests {
14 use super::*;
15 use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
16 use autoagents_llm::error::LLMError;
17 use std::sync::Arc;
18
19 #[derive(Clone)]
21 struct MockMemoryProvider {
22 messages: Vec<ChatMessage>,
23 should_fail: bool,
24 }
25
26 impl MockMemoryProvider {
27 fn new() -> Self {
28 Self {
29 messages: Vec::new(),
30 should_fail: false,
31 }
32 }
33
34 fn with_failure() -> Self {
35 Self {
36 messages: Vec::new(),
37 should_fail: true,
38 }
39 }
40
41 fn with_messages(messages: Vec<ChatMessage>) -> Self {
42 Self {
43 messages,
44 should_fail: false,
45 }
46 }
47 }
48
49 #[async_trait::async_trait]
50 impl MemoryProvider for MockMemoryProvider {
51 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
52 if self.should_fail {
53 return Err(LLMError::ProviderError("Mock memory error".to_string()));
54 }
55 self.messages.push(message.clone());
56 Ok(())
57 }
58
59 async fn recall(
60 &self,
61 _query: &str,
62 limit: Option<usize>,
63 ) -> Result<Vec<ChatMessage>, LLMError> {
64 if self.should_fail {
65 return Err(LLMError::ProviderError("Mock recall error".to_string()));
66 }
67 let limit = limit.unwrap_or(self.messages.len());
68 Ok(self.messages.iter().take(limit).cloned().collect())
69 }
70
71 async fn clear(&mut self) -> Result<(), LLMError> {
72 if self.should_fail {
73 return Err(LLMError::ProviderError("Mock clear error".to_string()));
74 }
75 self.messages.clear();
76 Ok(())
77 }
78
79 fn memory_type(&self) -> MemoryType {
80 MemoryType::SlidingWindow
81 }
82
83 fn size(&self) -> usize {
84 self.messages.len()
85 }
86
87 fn clone_box(&self) -> Box<dyn MemoryProvider> {
88 Box::new(self.clone())
89 }
90 }
91
92 #[test]
93 fn test_memory_type_serialization() {
94 let sliding_window = MemoryType::SlidingWindow;
95 let serialized = serde_json::to_string(&sliding_window).unwrap();
96 assert_eq!(serialized, "\"SlidingWindow\"");
97 }
98
99 #[test]
100 fn test_memory_type_deserialization() {
101 let deserialized: MemoryType = serde_json::from_str("\"SlidingWindow\"").unwrap();
102 assert_eq!(deserialized, MemoryType::SlidingWindow);
103 }
104
105 #[test]
106 fn test_custom_memory_type_serialization() {
107 let custom = MemoryType::Custom;
108 let serialized = serde_json::to_string(&custom).unwrap();
109 assert_eq!(serialized, "\"Custom\"");
110
111 let deserialized: MemoryType = serde_json::from_str("\"Custom\"").unwrap();
112 assert_eq!(deserialized, MemoryType::Custom);
113 }
114
115 #[test]
116 fn test_memory_type_debug() {
117 let sliding_window = MemoryType::SlidingWindow;
118 let debug_str = format!("{sliding_window:?}");
119 assert!(debug_str.contains("SlidingWindow"));
120 }
121
122 #[test]
123 fn test_memory_type_clone() {
124 let sliding_window = MemoryType::SlidingWindow;
125 let cloned = sliding_window.clone();
126 assert_eq!(sliding_window, cloned);
127 }
128
129 #[test]
130 fn test_memory_type_equality() {
131 let type1 = MemoryType::SlidingWindow;
132 let type2 = MemoryType::SlidingWindow;
133 assert_eq!(type1, type2);
134 }
135
136 #[test]
137 fn test_message_condition_any() {
138 let condition = MessageCondition::Any;
139 let message = ChatMessage {
140 role: ChatRole::User,
141 message_type: MessageType::Text,
142 content: "test message".to_string(),
143 };
144 let event = MessageEvent {
145 role: "user".to_string(),
146 msg: message,
147 };
148 assert!(condition.matches(&event));
149 }
150
151 #[test]
152 fn test_message_condition_eq() {
153 let condition = MessageCondition::Eq("test message".to_string());
154 let message = ChatMessage {
155 role: ChatRole::User,
156 message_type: MessageType::Text,
157 content: "test message".to_string(),
158 };
159 let event = MessageEvent {
160 role: "user".to_string(),
161 msg: message,
162 };
163 assert!(condition.matches(&event));
164
165 let different_message = ChatMessage {
166 role: ChatRole::User,
167 message_type: MessageType::Text,
168 content: "different message".to_string(),
169 };
170 let different_event = MessageEvent {
171 role: "user".to_string(),
172 msg: different_message,
173 };
174 assert!(!condition.matches(&different_event));
175 }
176
177 #[test]
178 fn test_message_condition_contains() {
179 let condition = MessageCondition::Contains("test".to_string());
180 let message = ChatMessage {
181 role: ChatRole::User,
182 message_type: MessageType::Text,
183 content: "this is a test message".to_string(),
184 };
185 let event = MessageEvent {
186 role: "user".to_string(),
187 msg: message,
188 };
189 assert!(condition.matches(&event));
190
191 let non_matching_message = ChatMessage {
192 role: ChatRole::User,
193 message_type: MessageType::Text,
194 content: "this is different".to_string(),
195 };
196 let non_matching_event = MessageEvent {
197 role: "user".to_string(),
198 msg: non_matching_message,
199 };
200 assert!(!condition.matches(&non_matching_event));
201 }
202
203 #[test]
204 fn test_message_condition_not_contains() {
205 let condition = MessageCondition::NotContains("error".to_string());
206 let message = ChatMessage {
207 role: ChatRole::User,
208 message_type: MessageType::Text,
209 content: "this is a test message".to_string(),
210 };
211 let event = MessageEvent {
212 role: "user".to_string(),
213 msg: message,
214 };
215 assert!(condition.matches(&event));
216
217 let error_message = ChatMessage {
218 role: ChatRole::User,
219 message_type: MessageType::Text,
220 content: "this is an error message".to_string(),
221 };
222 let error_event = MessageEvent {
223 role: "user".to_string(),
224 msg: error_message,
225 };
226 assert!(!condition.matches(&error_event));
227 }
228
229 #[test]
230 fn test_message_condition_role_is() {
231 let condition = MessageCondition::RoleIs("user".to_string());
232 let message = ChatMessage {
233 role: ChatRole::User,
234 message_type: MessageType::Text,
235 content: "test message".to_string(),
236 };
237 let event = MessageEvent {
238 role: "user".to_string(),
239 msg: message,
240 };
241 assert!(condition.matches(&event));
242
243 let assistant_event = MessageEvent {
244 role: "assistant".to_string(),
245 msg: ChatMessage {
246 role: ChatRole::Assistant,
247 message_type: MessageType::Text,
248 content: "test message".to_string(),
249 },
250 };
251 assert!(!condition.matches(&assistant_event));
252 }
253
254 #[test]
255 fn test_message_condition_role_not() {
256 let condition = MessageCondition::RoleNot("system".to_string());
257 let message = ChatMessage {
258 role: ChatRole::User,
259 message_type: MessageType::Text,
260 content: "test message".to_string(),
261 };
262 let event = MessageEvent {
263 role: "user".to_string(),
264 msg: message,
265 };
266 assert!(condition.matches(&event));
267
268 let system_event = MessageEvent {
269 role: "system".to_string(),
270 msg: ChatMessage {
271 role: ChatRole::System,
272 message_type: MessageType::Text,
273 content: "test message".to_string(),
274 },
275 };
276 assert!(!condition.matches(&system_event));
277 }
278
279 #[test]
280 fn test_message_condition_len_gt() {
281 let condition = MessageCondition::LenGt(5);
282 let long_message = ChatMessage {
283 role: ChatRole::User,
284 message_type: MessageType::Text,
285 content: "this is a long message".to_string(),
286 };
287 let long_event = MessageEvent {
288 role: "user".to_string(),
289 msg: long_message,
290 };
291 assert!(condition.matches(&long_event));
292
293 let short_message = ChatMessage {
294 role: ChatRole::User,
295 message_type: MessageType::Text,
296 content: "hi".to_string(),
297 };
298 let short_event = MessageEvent {
299 role: "user".to_string(),
300 msg: short_message,
301 };
302 assert!(!condition.matches(&short_event));
303 }
304
305 #[test]
306 fn test_message_condition_custom() {
307 let condition = MessageCondition::Custom(Arc::new(|msg| msg.content.starts_with("hello")));
308 let hello_message = ChatMessage {
309 role: ChatRole::User,
310 message_type: MessageType::Text,
311 content: "hello world".to_string(),
312 };
313 let hello_event = MessageEvent {
314 role: "user".to_string(),
315 msg: hello_message,
316 };
317 assert!(condition.matches(&hello_event));
318
319 let goodbye_message = ChatMessage {
320 role: ChatRole::User,
321 message_type: MessageType::Text,
322 content: "goodbye world".to_string(),
323 };
324 let goodbye_event = MessageEvent {
325 role: "user".to_string(),
326 msg: goodbye_message,
327 };
328 assert!(!condition.matches(&goodbye_event));
329 }
330
331 #[test]
332 fn test_message_condition_empty() {
333 let condition = MessageCondition::Empty;
334 let empty_message = ChatMessage {
335 role: ChatRole::User,
336 message_type: MessageType::Text,
337 content: "".to_string(),
338 };
339 let empty_event = MessageEvent {
340 role: "user".to_string(),
341 msg: empty_message,
342 };
343 assert!(condition.matches(&empty_event));
344
345 let non_empty_message = ChatMessage {
346 role: ChatRole::User,
347 message_type: MessageType::Text,
348 content: "not empty".to_string(),
349 };
350 let non_empty_event = MessageEvent {
351 role: "user".to_string(),
352 msg: non_empty_message,
353 };
354 assert!(!condition.matches(&non_empty_event));
355 }
356
357 #[test]
358 fn test_message_condition_all() {
359 let condition = MessageCondition::All(vec![
360 MessageCondition::Contains("test".to_string()),
361 MessageCondition::LenGt(5),
362 MessageCondition::RoleIs("user".to_string()),
363 ]);
364
365 let matching_message = ChatMessage {
366 role: ChatRole::User,
367 message_type: MessageType::Text,
368 content: "this is a test message".to_string(),
369 };
370 let matching_event = MessageEvent {
371 role: "user".to_string(),
372 msg: matching_message,
373 };
374 assert!(condition.matches(&matching_event));
375
376 let non_matching_message = ChatMessage {
377 role: ChatRole::User,
378 message_type: MessageType::Text,
379 content: "hi".to_string(),
380 };
381 let non_matching_event = MessageEvent {
382 role: "user".to_string(),
383 msg: non_matching_message,
384 };
385 assert!(!condition.matches(&non_matching_event));
386 }
387
388 #[test]
389 fn test_message_condition_any_of() {
390 let condition = MessageCondition::AnyOf(vec![
391 MessageCondition::Contains("hello".to_string()),
392 MessageCondition::Contains("goodbye".to_string()),
393 MessageCondition::Empty,
394 ]);
395
396 let hello_message = ChatMessage {
397 role: ChatRole::User,
398 message_type: MessageType::Text,
399 content: "hello world".to_string(),
400 };
401 let hello_event = MessageEvent {
402 role: "user".to_string(),
403 msg: hello_message,
404 };
405 assert!(condition.matches(&hello_event));
406
407 let goodbye_message = ChatMessage {
408 role: ChatRole::User,
409 message_type: MessageType::Text,
410 content: "goodbye world".to_string(),
411 };
412 let goodbye_event = MessageEvent {
413 role: "user".to_string(),
414 msg: goodbye_message,
415 };
416 assert!(condition.matches(&goodbye_event));
417
418 let empty_message = ChatMessage {
419 role: ChatRole::User,
420 message_type: MessageType::Text,
421 content: "".to_string(),
422 };
423 let empty_event = MessageEvent {
424 role: "user".to_string(),
425 msg: empty_message,
426 };
427 assert!(condition.matches(&empty_event));
428
429 let non_matching_message = ChatMessage {
430 role: ChatRole::User,
431 message_type: MessageType::Text,
432 content: "test message".to_string(),
433 };
434 let non_matching_event = MessageEvent {
435 role: "user".to_string(),
436 msg: non_matching_message,
437 };
438 assert!(!condition.matches(&non_matching_event));
439 }
440
441 #[test]
442 fn test_message_condition_regex() {
443 let condition = MessageCondition::Regex(r"\d+".to_string());
444 let number_message = ChatMessage {
445 role: ChatRole::User,
446 message_type: MessageType::Text,
447 content: "there are 123 items".to_string(),
448 };
449 let number_event = MessageEvent {
450 role: "user".to_string(),
451 msg: number_message,
452 };
453 assert!(condition.matches(&number_event));
454
455 let no_number_message = ChatMessage {
456 role: ChatRole::User,
457 message_type: MessageType::Text,
458 content: "no numbers here".to_string(),
459 };
460 let no_number_event = MessageEvent {
461 role: "user".to_string(),
462 msg: no_number_message,
463 };
464 assert!(!condition.matches(&no_number_event));
465 }
466
467 #[test]
468 fn test_message_condition_regex_invalid() {
469 let condition = MessageCondition::Regex("[invalid regex".to_string());
470 let message = ChatMessage {
471 role: ChatRole::User,
472 message_type: MessageType::Text,
473 content: "test message".to_string(),
474 };
475 let event = MessageEvent {
476 role: "user".to_string(),
477 msg: message,
478 };
479 assert!(!condition.matches(&event));
480 }
481
482 #[test]
483 fn test_message_event_creation() {
484 let message = ChatMessage {
485 role: ChatRole::User,
486 message_type: MessageType::Text,
487 content: "test message".to_string(),
488 };
489 let event = MessageEvent {
490 role: "user".to_string(),
491 msg: message.clone(),
492 };
493 assert_eq!(event.role, "user");
494 assert_eq!(event.msg.content, "test message");
495 }
496
497 #[test]
498 fn test_message_event_clone() {
499 let message = ChatMessage {
500 role: ChatRole::User,
501 message_type: MessageType::Text,
502 content: "test message".to_string(),
503 };
504 let event = MessageEvent {
505 role: "user".to_string(),
506 msg: message,
507 };
508 let cloned = event.clone();
509 assert_eq!(event.role, cloned.role);
510 assert_eq!(event.msg.content, cloned.msg.content);
511 }
512
513 #[test]
514 fn test_message_event_debug() {
515 let message = ChatMessage {
516 role: ChatRole::User,
517 message_type: MessageType::Text,
518 content: "test message".to_string(),
519 };
520 let event = MessageEvent {
521 role: "user".to_string(),
522 msg: message,
523 };
524 let debug_str = format!("{event:?}");
525 assert!(debug_str.contains("MessageEvent"));
526 assert!(debug_str.contains("user"));
527 }
528
529 #[tokio::test]
530 async fn test_mock_memory_provider_remember() {
531 let mut provider = MockMemoryProvider::new();
532 let message = ChatMessage {
533 role: ChatRole::User,
534 message_type: MessageType::Text,
535 content: "test message".to_string(),
536 };
537
538 let result = provider.remember(&message).await;
539 assert!(result.is_ok());
540 assert_eq!(provider.size(), 1);
541 }
542
543 #[tokio::test]
544 async fn test_mock_memory_provider_remember_failure() {
545 let mut provider = MockMemoryProvider::with_failure();
546 let message = ChatMessage {
547 role: ChatRole::User,
548 message_type: MessageType::Text,
549 content: "test message".to_string(),
550 };
551
552 let result = provider.remember(&message).await;
553 assert!(result.is_err());
554 assert!(
555 result
556 .unwrap_err()
557 .to_string()
558 .contains("Mock memory error")
559 );
560 }
561
562 #[tokio::test]
563 async fn test_mock_memory_provider_recall() {
564 let messages = vec![
565 ChatMessage {
566 role: ChatRole::User,
567 message_type: MessageType::Text,
568 content: "first message".to_string(),
569 },
570 ChatMessage {
571 role: ChatRole::Assistant,
572 message_type: MessageType::Text,
573 content: "second message".to_string(),
574 },
575 ];
576 let provider = MockMemoryProvider::with_messages(messages);
577
578 let result = provider.recall("", None).await;
579 assert!(result.is_ok());
580 let recalled = result.unwrap();
581 assert_eq!(recalled.len(), 2);
582 assert_eq!(recalled[0].content, "first message");
583 assert_eq!(recalled[1].content, "second message");
584 }
585
586 #[tokio::test]
587 async fn test_mock_memory_provider_recall_with_limit() {
588 let messages = vec![
589 ChatMessage {
590 role: ChatRole::User,
591 message_type: MessageType::Text,
592 content: "first message".to_string(),
593 },
594 ChatMessage {
595 role: ChatRole::Assistant,
596 message_type: MessageType::Text,
597 content: "second message".to_string(),
598 },
599 ];
600 let provider = MockMemoryProvider::with_messages(messages);
601
602 let result = provider.recall("", Some(1)).await;
603 assert!(result.is_ok());
604 let recalled = result.unwrap();
605 assert_eq!(recalled.len(), 1);
606 assert_eq!(recalled[0].content, "first message");
607 }
608
609 #[tokio::test]
610 async fn test_mock_memory_provider_recall_failure() {
611 let provider = MockMemoryProvider::with_failure();
612
613 let result = provider.recall("", None).await;
614 assert!(result.is_err());
615 assert!(
616 result
617 .unwrap_err()
618 .to_string()
619 .contains("Mock recall error")
620 );
621 }
622
623 #[tokio::test]
624 async fn test_mock_memory_provider_clear() {
625 let messages = vec![ChatMessage {
626 role: ChatRole::User,
627 message_type: MessageType::Text,
628 content: "first message".to_string(),
629 }];
630 let mut provider = MockMemoryProvider::with_messages(messages);
631 assert_eq!(provider.size(), 1);
632
633 let result = provider.clear().await;
634 assert!(result.is_ok());
635 assert_eq!(provider.size(), 0);
636 }
637
638 #[tokio::test]
639 async fn test_mock_memory_provider_clear_failure() {
640 let mut provider = MockMemoryProvider::with_failure();
641
642 let result = provider.clear().await;
643 assert!(result.is_err());
644 assert!(result.unwrap_err().to_string().contains("Mock clear error"));
645 }
646
647 #[test]
648 fn test_mock_memory_provider_memory_type() {
649 let provider = MockMemoryProvider::new();
650 assert_eq!(provider.memory_type(), MemoryType::SlidingWindow);
651 }
652
653 #[test]
654 fn test_mock_memory_provider_size() {
655 let provider = MockMemoryProvider::new();
656 assert_eq!(provider.size(), 0);
657
658 let messages = vec![ChatMessage {
659 role: ChatRole::User,
660 message_type: MessageType::Text,
661 content: "message".to_string(),
662 }];
663 let provider_with_messages = MockMemoryProvider::with_messages(messages);
664 assert_eq!(provider_with_messages.size(), 1);
665 }
666
667 #[test]
668 fn test_mock_memory_provider_is_empty() {
669 let provider = MockMemoryProvider::new();
670 assert!(provider.is_empty());
671
672 let messages = vec![ChatMessage {
673 role: ChatRole::User,
674 message_type: MessageType::Text,
675 content: "message".to_string(),
676 }];
677 let provider_with_messages = MockMemoryProvider::with_messages(messages);
678 assert!(!provider_with_messages.is_empty());
679 }
680
681 #[test]
682 fn test_memory_provider_default_methods() {
683 let provider = MockMemoryProvider::new();
684 assert!(!provider.needs_summary());
685 assert!(provider.get_event_receiver().is_none());
686 }
687
688 #[tokio::test]
689 async fn test_memory_provider_remember_with_role() {
690 let mut provider = MockMemoryProvider::new();
691 let message = ChatMessage {
692 role: ChatRole::User,
693 message_type: MessageType::Text,
694 content: "test message".to_string(),
695 };
696
697 let result = provider
698 .remember_with_role(&message, "custom_role".to_string())
699 .await;
700 assert!(result.is_ok());
701 assert_eq!(provider.size(), 1);
702 }
703
704 #[test]
705 fn test_memory_provider_mark_for_summary() {
706 let mut provider = MockMemoryProvider::new();
707 provider.mark_for_summary(); assert!(!provider.needs_summary()); }
710
711 #[test]
712 fn test_memory_provider_replace_with_summary() {
713 let mut provider = MockMemoryProvider::new();
714 provider.replace_with_summary("Summary text".to_string()); assert_eq!(provider.size(), 0); }
717}
718
719#[derive(Debug, Clone)]
721pub struct MessageEvent {
722 pub role: String,
724 pub msg: ChatMessage,
726}
727
728#[derive(Clone)]
730pub enum MessageCondition {
731 Any,
733 Eq(String),
735 Contains(String),
737 NotContains(String),
739 RoleIs(String),
741 RoleNot(String),
743 LenGt(usize),
745 Custom(Arc<dyn Fn(&ChatMessage) -> bool + Send + Sync>),
747 Empty,
749 All(Vec<MessageCondition>),
751 AnyOf(Vec<MessageCondition>),
753 Regex(String),
755}
756
757impl MessageCondition {
758 pub fn matches(&self, event: &MessageEvent) -> bool {
760 match self {
761 MessageCondition::Any => true,
762 MessageCondition::Eq(text) => event.msg.content == *text,
763 MessageCondition::Contains(text) => event.msg.content.contains(text),
764 MessageCondition::NotContains(text) => !event.msg.content.contains(text),
765 MessageCondition::RoleIs(role) => event.role == *role,
766 MessageCondition::RoleNot(role) => event.role != *role,
767 MessageCondition::LenGt(len) => event.msg.content.len() > *len,
768 MessageCondition::Custom(func) => func(&event.msg),
769 MessageCondition::Empty => event.msg.content.is_empty(),
770 MessageCondition::All(inner) => inner.iter().all(|c| c.matches(event)),
771 MessageCondition::AnyOf(inner) => inner.iter().any(|c| c.matches(event)),
772 MessageCondition::Regex(regex) => Regex::new(regex)
773 .map(|re| re.is_match(&event.msg.content))
774 .unwrap_or(false),
775 }
776 }
777}
778
779#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
781pub enum MemoryType {
782 SlidingWindow,
784 Custom,
786}
787
788#[async_trait]
795pub trait MemoryProvider: Send + Sync {
796 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError>;
807
808 async fn recall(&self, query: &str, limit: Option<usize>)
820 -> Result<Vec<ChatMessage>, LLMError>;
821
822 async fn clear(&mut self) -> Result<(), LLMError>;
829
830 fn memory_type(&self) -> MemoryType;
836
837 fn size(&self) -> usize;
843
844 fn is_empty(&self) -> bool {
850 self.size() == 0
851 }
852
853 fn needs_summary(&self) -> bool {
855 false
856 }
857
858 fn mark_for_summary(&mut self) {}
860
861 fn replace_with_summary(&mut self, _summary: String) {}
863
864 #[cfg(not(target_arch = "wasm32"))]
866 fn get_event_receiver(&self) -> Option<broadcast::Receiver<MessageEvent>> {
867 None
868 }
869
870 #[cfg(target_arch = "wasm32")]
871 fn get_event_receiver(&self) -> Option<()> {
872 None
873 }
874
875 async fn remember_with_role(
877 &mut self,
878 message: &ChatMessage,
879 _role: String,
880 ) -> Result<(), LLMError> {
881 self.remember(message).await
882 }
883
884 fn clone_box(&self) -> Box<dyn MemoryProvider>;
887
888 fn id(&self) -> Option<String> {
891 None
892 }
893
894 fn preload(&mut self, _data: Vec<ChatMessage>) -> bool {
897 false
898 }
899
900 fn export(&self) -> Vec<ChatMessage> {
903 Vec::new()
904 }
905}