1use async_trait::async_trait;
2use autoagents_llm::{chat::ChatMessage, error::LLMError};
3use regex::Regex;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6#[cfg(not(target_arch = "wasm32"))]
7use tokio::sync::broadcast;
8
9mod sliding_window;
10pub use sliding_window::SlidingWindowMemory;
11
12#[cfg(test)]
13mod tests {
14 use super::*;
15 use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
16 use autoagents_llm::error::LLMError;
17 use std::sync::Arc;
18
19 struct MockMemoryProvider {
21 messages: Vec<ChatMessage>,
22 should_fail: bool,
23 }
24
25 impl MockMemoryProvider {
26 fn new() -> Self {
27 Self {
28 messages: Vec::new(),
29 should_fail: false,
30 }
31 }
32
33 fn with_failure() -> Self {
34 Self {
35 messages: Vec::new(),
36 should_fail: true,
37 }
38 }
39
40 fn with_messages(messages: Vec<ChatMessage>) -> Self {
41 Self {
42 messages,
43 should_fail: false,
44 }
45 }
46 }
47
48 #[async_trait::async_trait]
49 impl MemoryProvider for MockMemoryProvider {
50 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
51 if self.should_fail {
52 return Err(LLMError::ProviderError("Mock memory error".to_string()));
53 }
54 self.messages.push(message.clone());
55 Ok(())
56 }
57
58 async fn recall(
59 &self,
60 _query: &str,
61 limit: Option<usize>,
62 ) -> Result<Vec<ChatMessage>, LLMError> {
63 if self.should_fail {
64 return Err(LLMError::ProviderError("Mock recall error".to_string()));
65 }
66 let limit = limit.unwrap_or(self.messages.len());
67 Ok(self.messages.iter().take(limit).cloned().collect())
68 }
69
70 async fn clear(&mut self) -> Result<(), LLMError> {
71 if self.should_fail {
72 return Err(LLMError::ProviderError("Mock clear error".to_string()));
73 }
74 self.messages.clear();
75 Ok(())
76 }
77
78 fn memory_type(&self) -> MemoryType {
79 MemoryType::SlidingWindow
80 }
81
82 fn size(&self) -> usize {
83 self.messages.len()
84 }
85
86 fn clone_box(&self) -> Box<dyn MemoryProvider> {
87 unimplemented!()
88 }
89 }
90
91 #[test]
92 fn test_memory_type_serialization() {
93 let sliding_window = MemoryType::SlidingWindow;
94 let serialized = serde_json::to_string(&sliding_window).unwrap();
95 assert_eq!(serialized, "\"SlidingWindow\"");
96 }
97
98 #[test]
99 fn test_memory_type_deserialization() {
100 let deserialized: MemoryType = serde_json::from_str("\"SlidingWindow\"").unwrap();
101 assert_eq!(deserialized, MemoryType::SlidingWindow);
102 }
103
104 #[test]
105 fn test_memory_type_debug() {
106 let sliding_window = MemoryType::SlidingWindow;
107 let debug_str = format!("{sliding_window:?}");
108 assert!(debug_str.contains("SlidingWindow"));
109 }
110
111 #[test]
112 fn test_memory_type_clone() {
113 let sliding_window = MemoryType::SlidingWindow;
114 let cloned = sliding_window.clone();
115 assert_eq!(sliding_window, cloned);
116 }
117
118 #[test]
119 fn test_memory_type_equality() {
120 let type1 = MemoryType::SlidingWindow;
121 let type2 = MemoryType::SlidingWindow;
122 assert_eq!(type1, type2);
123 }
124
125 #[test]
126 fn test_message_condition_any() {
127 let condition = MessageCondition::Any;
128 let message = ChatMessage {
129 role: ChatRole::User,
130 message_type: MessageType::Text,
131 content: "test message".to_string(),
132 };
133 let event = MessageEvent {
134 role: "user".to_string(),
135 msg: message,
136 };
137 assert!(condition.matches(&event));
138 }
139
140 #[test]
141 fn test_message_condition_eq() {
142 let condition = MessageCondition::Eq("test message".to_string());
143 let message = ChatMessage {
144 role: ChatRole::User,
145 message_type: MessageType::Text,
146 content: "test message".to_string(),
147 };
148 let event = MessageEvent {
149 role: "user".to_string(),
150 msg: message,
151 };
152 assert!(condition.matches(&event));
153
154 let different_message = ChatMessage {
155 role: ChatRole::User,
156 message_type: MessageType::Text,
157 content: "different message".to_string(),
158 };
159 let different_event = MessageEvent {
160 role: "user".to_string(),
161 msg: different_message,
162 };
163 assert!(!condition.matches(&different_event));
164 }
165
166 #[test]
167 fn test_message_condition_contains() {
168 let condition = MessageCondition::Contains("test".to_string());
169 let message = ChatMessage {
170 role: ChatRole::User,
171 message_type: MessageType::Text,
172 content: "this is a test message".to_string(),
173 };
174 let event = MessageEvent {
175 role: "user".to_string(),
176 msg: message,
177 };
178 assert!(condition.matches(&event));
179
180 let non_matching_message = ChatMessage {
181 role: ChatRole::User,
182 message_type: MessageType::Text,
183 content: "this is different".to_string(),
184 };
185 let non_matching_event = MessageEvent {
186 role: "user".to_string(),
187 msg: non_matching_message,
188 };
189 assert!(!condition.matches(&non_matching_event));
190 }
191
192 #[test]
193 fn test_message_condition_not_contains() {
194 let condition = MessageCondition::NotContains("error".to_string());
195 let message = ChatMessage {
196 role: ChatRole::User,
197 message_type: MessageType::Text,
198 content: "this is a test message".to_string(),
199 };
200 let event = MessageEvent {
201 role: "user".to_string(),
202 msg: message,
203 };
204 assert!(condition.matches(&event));
205
206 let error_message = ChatMessage {
207 role: ChatRole::User,
208 message_type: MessageType::Text,
209 content: "this is an error message".to_string(),
210 };
211 let error_event = MessageEvent {
212 role: "user".to_string(),
213 msg: error_message,
214 };
215 assert!(!condition.matches(&error_event));
216 }
217
218 #[test]
219 fn test_message_condition_role_is() {
220 let condition = MessageCondition::RoleIs("user".to_string());
221 let message = ChatMessage {
222 role: ChatRole::User,
223 message_type: MessageType::Text,
224 content: "test message".to_string(),
225 };
226 let event = MessageEvent {
227 role: "user".to_string(),
228 msg: message,
229 };
230 assert!(condition.matches(&event));
231
232 let assistant_event = MessageEvent {
233 role: "assistant".to_string(),
234 msg: ChatMessage {
235 role: ChatRole::Assistant,
236 message_type: MessageType::Text,
237 content: "test message".to_string(),
238 },
239 };
240 assert!(!condition.matches(&assistant_event));
241 }
242
243 #[test]
244 fn test_message_condition_role_not() {
245 let condition = MessageCondition::RoleNot("system".to_string());
246 let message = ChatMessage {
247 role: ChatRole::User,
248 message_type: MessageType::Text,
249 content: "test message".to_string(),
250 };
251 let event = MessageEvent {
252 role: "user".to_string(),
253 msg: message,
254 };
255 assert!(condition.matches(&event));
256
257 let system_event = MessageEvent {
258 role: "system".to_string(),
259 msg: ChatMessage {
260 role: ChatRole::System,
261 message_type: MessageType::Text,
262 content: "test message".to_string(),
263 },
264 };
265 assert!(!condition.matches(&system_event));
266 }
267
268 #[test]
269 fn test_message_condition_len_gt() {
270 let condition = MessageCondition::LenGt(5);
271 let long_message = ChatMessage {
272 role: ChatRole::User,
273 message_type: MessageType::Text,
274 content: "this is a long message".to_string(),
275 };
276 let long_event = MessageEvent {
277 role: "user".to_string(),
278 msg: long_message,
279 };
280 assert!(condition.matches(&long_event));
281
282 let short_message = ChatMessage {
283 role: ChatRole::User,
284 message_type: MessageType::Text,
285 content: "hi".to_string(),
286 };
287 let short_event = MessageEvent {
288 role: "user".to_string(),
289 msg: short_message,
290 };
291 assert!(!condition.matches(&short_event));
292 }
293
294 #[test]
295 fn test_message_condition_custom() {
296 let condition = MessageCondition::Custom(Arc::new(|msg| msg.content.starts_with("hello")));
297 let hello_message = ChatMessage {
298 role: ChatRole::User,
299 message_type: MessageType::Text,
300 content: "hello world".to_string(),
301 };
302 let hello_event = MessageEvent {
303 role: "user".to_string(),
304 msg: hello_message,
305 };
306 assert!(condition.matches(&hello_event));
307
308 let goodbye_message = ChatMessage {
309 role: ChatRole::User,
310 message_type: MessageType::Text,
311 content: "goodbye world".to_string(),
312 };
313 let goodbye_event = MessageEvent {
314 role: "user".to_string(),
315 msg: goodbye_message,
316 };
317 assert!(!condition.matches(&goodbye_event));
318 }
319
320 #[test]
321 fn test_message_condition_empty() {
322 let condition = MessageCondition::Empty;
323 let empty_message = ChatMessage {
324 role: ChatRole::User,
325 message_type: MessageType::Text,
326 content: "".to_string(),
327 };
328 let empty_event = MessageEvent {
329 role: "user".to_string(),
330 msg: empty_message,
331 };
332 assert!(condition.matches(&empty_event));
333
334 let non_empty_message = ChatMessage {
335 role: ChatRole::User,
336 message_type: MessageType::Text,
337 content: "not empty".to_string(),
338 };
339 let non_empty_event = MessageEvent {
340 role: "user".to_string(),
341 msg: non_empty_message,
342 };
343 assert!(!condition.matches(&non_empty_event));
344 }
345
346 #[test]
347 fn test_message_condition_all() {
348 let condition = MessageCondition::All(vec![
349 MessageCondition::Contains("test".to_string()),
350 MessageCondition::LenGt(5),
351 MessageCondition::RoleIs("user".to_string()),
352 ]);
353
354 let matching_message = ChatMessage {
355 role: ChatRole::User,
356 message_type: MessageType::Text,
357 content: "this is a test message".to_string(),
358 };
359 let matching_event = MessageEvent {
360 role: "user".to_string(),
361 msg: matching_message,
362 };
363 assert!(condition.matches(&matching_event));
364
365 let non_matching_message = ChatMessage {
366 role: ChatRole::User,
367 message_type: MessageType::Text,
368 content: "hi".to_string(),
369 };
370 let non_matching_event = MessageEvent {
371 role: "user".to_string(),
372 msg: non_matching_message,
373 };
374 assert!(!condition.matches(&non_matching_event));
375 }
376
377 #[test]
378 fn test_message_condition_any_of() {
379 let condition = MessageCondition::AnyOf(vec![
380 MessageCondition::Contains("hello".to_string()),
381 MessageCondition::Contains("goodbye".to_string()),
382 MessageCondition::Empty,
383 ]);
384
385 let hello_message = ChatMessage {
386 role: ChatRole::User,
387 message_type: MessageType::Text,
388 content: "hello world".to_string(),
389 };
390 let hello_event = MessageEvent {
391 role: "user".to_string(),
392 msg: hello_message,
393 };
394 assert!(condition.matches(&hello_event));
395
396 let goodbye_message = ChatMessage {
397 role: ChatRole::User,
398 message_type: MessageType::Text,
399 content: "goodbye world".to_string(),
400 };
401 let goodbye_event = MessageEvent {
402 role: "user".to_string(),
403 msg: goodbye_message,
404 };
405 assert!(condition.matches(&goodbye_event));
406
407 let empty_message = ChatMessage {
408 role: ChatRole::User,
409 message_type: MessageType::Text,
410 content: "".to_string(),
411 };
412 let empty_event = MessageEvent {
413 role: "user".to_string(),
414 msg: empty_message,
415 };
416 assert!(condition.matches(&empty_event));
417
418 let non_matching_message = ChatMessage {
419 role: ChatRole::User,
420 message_type: MessageType::Text,
421 content: "test message".to_string(),
422 };
423 let non_matching_event = MessageEvent {
424 role: "user".to_string(),
425 msg: non_matching_message,
426 };
427 assert!(!condition.matches(&non_matching_event));
428 }
429
430 #[test]
431 fn test_message_condition_regex() {
432 let condition = MessageCondition::Regex(r"\d+".to_string());
433 let number_message = ChatMessage {
434 role: ChatRole::User,
435 message_type: MessageType::Text,
436 content: "there are 123 items".to_string(),
437 };
438 let number_event = MessageEvent {
439 role: "user".to_string(),
440 msg: number_message,
441 };
442 assert!(condition.matches(&number_event));
443
444 let no_number_message = ChatMessage {
445 role: ChatRole::User,
446 message_type: MessageType::Text,
447 content: "no numbers here".to_string(),
448 };
449 let no_number_event = MessageEvent {
450 role: "user".to_string(),
451 msg: no_number_message,
452 };
453 assert!(!condition.matches(&no_number_event));
454 }
455
456 #[test]
457 fn test_message_condition_regex_invalid() {
458 let condition = MessageCondition::Regex("[invalid regex".to_string());
459 let message = ChatMessage {
460 role: ChatRole::User,
461 message_type: MessageType::Text,
462 content: "test message".to_string(),
463 };
464 let event = MessageEvent {
465 role: "user".to_string(),
466 msg: message,
467 };
468 assert!(!condition.matches(&event));
469 }
470
471 #[test]
472 fn test_message_event_creation() {
473 let message = ChatMessage {
474 role: ChatRole::User,
475 message_type: MessageType::Text,
476 content: "test message".to_string(),
477 };
478 let event = MessageEvent {
479 role: "user".to_string(),
480 msg: message.clone(),
481 };
482 assert_eq!(event.role, "user");
483 assert_eq!(event.msg.content, "test message");
484 }
485
486 #[test]
487 fn test_message_event_clone() {
488 let message = ChatMessage {
489 role: ChatRole::User,
490 message_type: MessageType::Text,
491 content: "test message".to_string(),
492 };
493 let event = MessageEvent {
494 role: "user".to_string(),
495 msg: message,
496 };
497 let cloned = event.clone();
498 assert_eq!(event.role, cloned.role);
499 assert_eq!(event.msg.content, cloned.msg.content);
500 }
501
502 #[test]
503 fn test_message_event_debug() {
504 let message = ChatMessage {
505 role: ChatRole::User,
506 message_type: MessageType::Text,
507 content: "test message".to_string(),
508 };
509 let event = MessageEvent {
510 role: "user".to_string(),
511 msg: message,
512 };
513 let debug_str = format!("{event:?}");
514 assert!(debug_str.contains("MessageEvent"));
515 assert!(debug_str.contains("user"));
516 }
517
518 #[tokio::test]
519 async fn test_mock_memory_provider_remember() {
520 let mut provider = MockMemoryProvider::new();
521 let message = ChatMessage {
522 role: ChatRole::User,
523 message_type: MessageType::Text,
524 content: "test message".to_string(),
525 };
526
527 let result = provider.remember(&message).await;
528 assert!(result.is_ok());
529 assert_eq!(provider.size(), 1);
530 }
531
532 #[tokio::test]
533 async fn test_mock_memory_provider_remember_failure() {
534 let mut provider = MockMemoryProvider::with_failure();
535 let message = ChatMessage {
536 role: ChatRole::User,
537 message_type: MessageType::Text,
538 content: "test message".to_string(),
539 };
540
541 let result = provider.remember(&message).await;
542 assert!(result.is_err());
543 assert!(
544 result
545 .unwrap_err()
546 .to_string()
547 .contains("Mock memory error")
548 );
549 }
550
551 #[tokio::test]
552 async fn test_mock_memory_provider_recall() {
553 let messages = vec![
554 ChatMessage {
555 role: ChatRole::User,
556 message_type: MessageType::Text,
557 content: "first message".to_string(),
558 },
559 ChatMessage {
560 role: ChatRole::Assistant,
561 message_type: MessageType::Text,
562 content: "second message".to_string(),
563 },
564 ];
565 let provider = MockMemoryProvider::with_messages(messages);
566
567 let result = provider.recall("", None).await;
568 assert!(result.is_ok());
569 let recalled = result.unwrap();
570 assert_eq!(recalled.len(), 2);
571 assert_eq!(recalled[0].content, "first message");
572 assert_eq!(recalled[1].content, "second message");
573 }
574
575 #[tokio::test]
576 async fn test_mock_memory_provider_recall_with_limit() {
577 let messages = vec![
578 ChatMessage {
579 role: ChatRole::User,
580 message_type: MessageType::Text,
581 content: "first message".to_string(),
582 },
583 ChatMessage {
584 role: ChatRole::Assistant,
585 message_type: MessageType::Text,
586 content: "second message".to_string(),
587 },
588 ];
589 let provider = MockMemoryProvider::with_messages(messages);
590
591 let result = provider.recall("", Some(1)).await;
592 assert!(result.is_ok());
593 let recalled = result.unwrap();
594 assert_eq!(recalled.len(), 1);
595 assert_eq!(recalled[0].content, "first message");
596 }
597
598 #[tokio::test]
599 async fn test_mock_memory_provider_recall_failure() {
600 let provider = MockMemoryProvider::with_failure();
601
602 let result = provider.recall("", None).await;
603 assert!(result.is_err());
604 assert!(
605 result
606 .unwrap_err()
607 .to_string()
608 .contains("Mock recall error")
609 );
610 }
611
612 #[tokio::test]
613 async fn test_mock_memory_provider_clear() {
614 let messages = vec![ChatMessage {
615 role: ChatRole::User,
616 message_type: MessageType::Text,
617 content: "first message".to_string(),
618 }];
619 let mut provider = MockMemoryProvider::with_messages(messages);
620 assert_eq!(provider.size(), 1);
621
622 let result = provider.clear().await;
623 assert!(result.is_ok());
624 assert_eq!(provider.size(), 0);
625 }
626
627 #[tokio::test]
628 async fn test_mock_memory_provider_clear_failure() {
629 let mut provider = MockMemoryProvider::with_failure();
630
631 let result = provider.clear().await;
632 assert!(result.is_err());
633 assert!(result.unwrap_err().to_string().contains("Mock clear error"));
634 }
635
636 #[test]
637 fn test_mock_memory_provider_memory_type() {
638 let provider = MockMemoryProvider::new();
639 assert_eq!(provider.memory_type(), MemoryType::SlidingWindow);
640 }
641
642 #[test]
643 fn test_mock_memory_provider_size() {
644 let provider = MockMemoryProvider::new();
645 assert_eq!(provider.size(), 0);
646
647 let messages = vec![ChatMessage {
648 role: ChatRole::User,
649 message_type: MessageType::Text,
650 content: "message".to_string(),
651 }];
652 let provider_with_messages = MockMemoryProvider::with_messages(messages);
653 assert_eq!(provider_with_messages.size(), 1);
654 }
655
656 #[test]
657 fn test_mock_memory_provider_is_empty() {
658 let provider = MockMemoryProvider::new();
659 assert!(provider.is_empty());
660
661 let messages = vec![ChatMessage {
662 role: ChatRole::User,
663 message_type: MessageType::Text,
664 content: "message".to_string(),
665 }];
666 let provider_with_messages = MockMemoryProvider::with_messages(messages);
667 assert!(!provider_with_messages.is_empty());
668 }
669
670 #[test]
671 fn test_memory_provider_default_methods() {
672 let provider = MockMemoryProvider::new();
673 assert!(!provider.needs_summary());
674 assert!(provider.get_event_receiver().is_none());
675 }
676
677 #[tokio::test]
678 async fn test_memory_provider_remember_with_role() {
679 let mut provider = MockMemoryProvider::new();
680 let message = ChatMessage {
681 role: ChatRole::User,
682 message_type: MessageType::Text,
683 content: "test message".to_string(),
684 };
685
686 let result = provider
687 .remember_with_role(&message, "custom_role".to_string())
688 .await;
689 assert!(result.is_ok());
690 assert_eq!(provider.size(), 1);
691 }
692
693 #[test]
694 fn test_memory_provider_mark_for_summary() {
695 let mut provider = MockMemoryProvider::new();
696 provider.mark_for_summary(); assert!(!provider.needs_summary()); }
699
700 #[test]
701 fn test_memory_provider_replace_with_summary() {
702 let mut provider = MockMemoryProvider::new();
703 provider.replace_with_summary("Summary text".to_string()); assert_eq!(provider.size(), 0); }
706}
707
708#[derive(Debug, Clone)]
710pub struct MessageEvent {
711 pub role: String,
713 pub msg: ChatMessage,
715}
716
717#[derive(Clone)]
719pub enum MessageCondition {
720 Any,
722 Eq(String),
724 Contains(String),
726 NotContains(String),
728 RoleIs(String),
730 RoleNot(String),
732 LenGt(usize),
734 Custom(Arc<dyn Fn(&ChatMessage) -> bool + Send + Sync>),
736 Empty,
738 All(Vec<MessageCondition>),
740 AnyOf(Vec<MessageCondition>),
742 Regex(String),
744}
745
746impl MessageCondition {
747 pub fn matches(&self, event: &MessageEvent) -> bool {
749 match self {
750 MessageCondition::Any => true,
751 MessageCondition::Eq(text) => event.msg.content == *text,
752 MessageCondition::Contains(text) => event.msg.content.contains(text),
753 MessageCondition::NotContains(text) => !event.msg.content.contains(text),
754 MessageCondition::RoleIs(role) => event.role == *role,
755 MessageCondition::RoleNot(role) => event.role != *role,
756 MessageCondition::LenGt(len) => event.msg.content.len() > *len,
757 MessageCondition::Custom(func) => func(&event.msg),
758 MessageCondition::Empty => event.msg.content.is_empty(),
759 MessageCondition::All(inner) => inner.iter().all(|c| c.matches(event)),
760 MessageCondition::AnyOf(inner) => inner.iter().any(|c| c.matches(event)),
761 MessageCondition::Regex(regex) => Regex::new(regex)
762 .map(|re| re.is_match(&event.msg.content))
763 .unwrap_or(false),
764 }
765 }
766}
767
768#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
770pub enum MemoryType {
771 SlidingWindow,
773}
774
775#[async_trait]
782pub trait MemoryProvider: Send + Sync {
783 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError>;
794
795 async fn recall(&self, query: &str, limit: Option<usize>)
807 -> Result<Vec<ChatMessage>, LLMError>;
808
809 async fn clear(&mut self) -> Result<(), LLMError>;
816
817 fn memory_type(&self) -> MemoryType;
823
824 fn size(&self) -> usize;
830
831 fn is_empty(&self) -> bool {
837 self.size() == 0
838 }
839
840 fn needs_summary(&self) -> bool {
842 false
843 }
844
845 fn mark_for_summary(&mut self) {}
847
848 fn replace_with_summary(&mut self, _summary: String) {}
850
851 #[cfg(not(target_arch = "wasm32"))]
853 fn get_event_receiver(&self) -> Option<broadcast::Receiver<MessageEvent>> {
854 None
855 }
856
857 #[cfg(target_arch = "wasm32")]
858 fn get_event_receiver(&self) -> Option<()> {
859 None
860 }
861
862 async fn remember_with_role(
864 &mut self,
865 message: &ChatMessage,
866 _role: String,
867 ) -> Result<(), LLMError> {
868 self.remember(message).await
869 }
870
871 fn clone_box(&self) -> Box<dyn MemoryProvider>;
874
875 fn id(&self) -> Option<String> {
878 None
879 }
880
881 fn preload(&mut self, _data: Vec<ChatMessage>) -> bool {
884 false
885 }
886
887 fn export(&self) -> Vec<ChatMessage> {
890 Vec::new()
891 }
892}