1use async_trait::async_trait;
2use autoagents_llm::{chat::ChatMessage, error::LLMError};
3use regex::Regex;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6#[cfg(not(target_arch = "wasm32"))]
7use tokio::sync::broadcast;
8
9mod sliding_window;
10pub use sliding_window::SlidingWindowMemory;
11
12#[cfg(test)]
13mod tests {
14 use super::*;
15 use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
16 use autoagents_llm::error::LLMError;
17 use std::sync::Arc;
18
19 struct MockMemoryProvider {
21 messages: Vec<ChatMessage>,
22 should_fail: bool,
23 }
24
25 impl MockMemoryProvider {
26 fn new() -> Self {
27 Self {
28 messages: Vec::new(),
29 should_fail: false,
30 }
31 }
32
33 fn with_failure() -> Self {
34 Self {
35 messages: Vec::new(),
36 should_fail: true,
37 }
38 }
39
40 fn with_messages(messages: Vec<ChatMessage>) -> Self {
41 Self {
42 messages,
43 should_fail: false,
44 }
45 }
46 }
47
48 #[async_trait::async_trait]
49 impl MemoryProvider for MockMemoryProvider {
50 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError> {
51 if self.should_fail {
52 return Err(LLMError::ProviderError("Mock memory error".to_string()));
53 }
54 self.messages.push(message.clone());
55 Ok(())
56 }
57
58 async fn recall(
59 &self,
60 _query: &str,
61 limit: Option<usize>,
62 ) -> Result<Vec<ChatMessage>, LLMError> {
63 if self.should_fail {
64 return Err(LLMError::ProviderError("Mock recall error".to_string()));
65 }
66 let limit = limit.unwrap_or(self.messages.len());
67 Ok(self.messages.iter().take(limit).cloned().collect())
68 }
69
70 async fn clear(&mut self) -> Result<(), LLMError> {
71 if self.should_fail {
72 return Err(LLMError::ProviderError("Mock clear error".to_string()));
73 }
74 self.messages.clear();
75 Ok(())
76 }
77
78 fn memory_type(&self) -> MemoryType {
79 MemoryType::SlidingWindow
80 }
81
82 fn size(&self) -> usize {
83 self.messages.len()
84 }
85 }
86
87 #[test]
88 fn test_memory_type_serialization() {
89 let sliding_window = MemoryType::SlidingWindow;
90 let serialized = serde_json::to_string(&sliding_window).unwrap();
91 assert_eq!(serialized, "\"SlidingWindow\"");
92 }
93
94 #[test]
95 fn test_memory_type_deserialization() {
96 let deserialized: MemoryType = serde_json::from_str("\"SlidingWindow\"").unwrap();
97 assert_eq!(deserialized, MemoryType::SlidingWindow);
98 }
99
100 #[test]
101 fn test_memory_type_debug() {
102 let sliding_window = MemoryType::SlidingWindow;
103 let debug_str = format!("{sliding_window:?}");
104 assert!(debug_str.contains("SlidingWindow"));
105 }
106
107 #[test]
108 fn test_memory_type_clone() {
109 let sliding_window = MemoryType::SlidingWindow;
110 let cloned = sliding_window.clone();
111 assert_eq!(sliding_window, cloned);
112 }
113
114 #[test]
115 fn test_memory_type_equality() {
116 let type1 = MemoryType::SlidingWindow;
117 let type2 = MemoryType::SlidingWindow;
118 assert_eq!(type1, type2);
119 }
120
121 #[test]
122 fn test_message_condition_any() {
123 let condition = MessageCondition::Any;
124 let message = ChatMessage {
125 role: ChatRole::User,
126 message_type: MessageType::Text,
127 content: "test message".to_string(),
128 };
129 let event = MessageEvent {
130 role: "user".to_string(),
131 msg: message,
132 };
133 assert!(condition.matches(&event));
134 }
135
136 #[test]
137 fn test_message_condition_eq() {
138 let condition = MessageCondition::Eq("test message".to_string());
139 let message = ChatMessage {
140 role: ChatRole::User,
141 message_type: MessageType::Text,
142 content: "test message".to_string(),
143 };
144 let event = MessageEvent {
145 role: "user".to_string(),
146 msg: message,
147 };
148 assert!(condition.matches(&event));
149
150 let different_message = ChatMessage {
151 role: ChatRole::User,
152 message_type: MessageType::Text,
153 content: "different message".to_string(),
154 };
155 let different_event = MessageEvent {
156 role: "user".to_string(),
157 msg: different_message,
158 };
159 assert!(!condition.matches(&different_event));
160 }
161
162 #[test]
163 fn test_message_condition_contains() {
164 let condition = MessageCondition::Contains("test".to_string());
165 let message = ChatMessage {
166 role: ChatRole::User,
167 message_type: MessageType::Text,
168 content: "this is a test message".to_string(),
169 };
170 let event = MessageEvent {
171 role: "user".to_string(),
172 msg: message,
173 };
174 assert!(condition.matches(&event));
175
176 let non_matching_message = ChatMessage {
177 role: ChatRole::User,
178 message_type: MessageType::Text,
179 content: "this is different".to_string(),
180 };
181 let non_matching_event = MessageEvent {
182 role: "user".to_string(),
183 msg: non_matching_message,
184 };
185 assert!(!condition.matches(&non_matching_event));
186 }
187
188 #[test]
189 fn test_message_condition_not_contains() {
190 let condition = MessageCondition::NotContains("error".to_string());
191 let message = ChatMessage {
192 role: ChatRole::User,
193 message_type: MessageType::Text,
194 content: "this is a test message".to_string(),
195 };
196 let event = MessageEvent {
197 role: "user".to_string(),
198 msg: message,
199 };
200 assert!(condition.matches(&event));
201
202 let error_message = ChatMessage {
203 role: ChatRole::User,
204 message_type: MessageType::Text,
205 content: "this is an error message".to_string(),
206 };
207 let error_event = MessageEvent {
208 role: "user".to_string(),
209 msg: error_message,
210 };
211 assert!(!condition.matches(&error_event));
212 }
213
214 #[test]
215 fn test_message_condition_role_is() {
216 let condition = MessageCondition::RoleIs("user".to_string());
217 let message = ChatMessage {
218 role: ChatRole::User,
219 message_type: MessageType::Text,
220 content: "test message".to_string(),
221 };
222 let event = MessageEvent {
223 role: "user".to_string(),
224 msg: message,
225 };
226 assert!(condition.matches(&event));
227
228 let assistant_event = MessageEvent {
229 role: "assistant".to_string(),
230 msg: ChatMessage {
231 role: ChatRole::Assistant,
232 message_type: MessageType::Text,
233 content: "test message".to_string(),
234 },
235 };
236 assert!(!condition.matches(&assistant_event));
237 }
238
239 #[test]
240 fn test_message_condition_role_not() {
241 let condition = MessageCondition::RoleNot("system".to_string());
242 let message = ChatMessage {
243 role: ChatRole::User,
244 message_type: MessageType::Text,
245 content: "test message".to_string(),
246 };
247 let event = MessageEvent {
248 role: "user".to_string(),
249 msg: message,
250 };
251 assert!(condition.matches(&event));
252
253 let system_event = MessageEvent {
254 role: "system".to_string(),
255 msg: ChatMessage {
256 role: ChatRole::System,
257 message_type: MessageType::Text,
258 content: "test message".to_string(),
259 },
260 };
261 assert!(!condition.matches(&system_event));
262 }
263
264 #[test]
265 fn test_message_condition_len_gt() {
266 let condition = MessageCondition::LenGt(5);
267 let long_message = ChatMessage {
268 role: ChatRole::User,
269 message_type: MessageType::Text,
270 content: "this is a long message".to_string(),
271 };
272 let long_event = MessageEvent {
273 role: "user".to_string(),
274 msg: long_message,
275 };
276 assert!(condition.matches(&long_event));
277
278 let short_message = ChatMessage {
279 role: ChatRole::User,
280 message_type: MessageType::Text,
281 content: "hi".to_string(),
282 };
283 let short_event = MessageEvent {
284 role: "user".to_string(),
285 msg: short_message,
286 };
287 assert!(!condition.matches(&short_event));
288 }
289
290 #[test]
291 fn test_message_condition_custom() {
292 let condition = MessageCondition::Custom(Arc::new(|msg| msg.content.starts_with("hello")));
293 let hello_message = ChatMessage {
294 role: ChatRole::User,
295 message_type: MessageType::Text,
296 content: "hello world".to_string(),
297 };
298 let hello_event = MessageEvent {
299 role: "user".to_string(),
300 msg: hello_message,
301 };
302 assert!(condition.matches(&hello_event));
303
304 let goodbye_message = ChatMessage {
305 role: ChatRole::User,
306 message_type: MessageType::Text,
307 content: "goodbye world".to_string(),
308 };
309 let goodbye_event = MessageEvent {
310 role: "user".to_string(),
311 msg: goodbye_message,
312 };
313 assert!(!condition.matches(&goodbye_event));
314 }
315
316 #[test]
317 fn test_message_condition_empty() {
318 let condition = MessageCondition::Empty;
319 let empty_message = ChatMessage {
320 role: ChatRole::User,
321 message_type: MessageType::Text,
322 content: "".to_string(),
323 };
324 let empty_event = MessageEvent {
325 role: "user".to_string(),
326 msg: empty_message,
327 };
328 assert!(condition.matches(&empty_event));
329
330 let non_empty_message = ChatMessage {
331 role: ChatRole::User,
332 message_type: MessageType::Text,
333 content: "not empty".to_string(),
334 };
335 let non_empty_event = MessageEvent {
336 role: "user".to_string(),
337 msg: non_empty_message,
338 };
339 assert!(!condition.matches(&non_empty_event));
340 }
341
342 #[test]
343 fn test_message_condition_all() {
344 let condition = MessageCondition::All(vec![
345 MessageCondition::Contains("test".to_string()),
346 MessageCondition::LenGt(5),
347 MessageCondition::RoleIs("user".to_string()),
348 ]);
349
350 let matching_message = ChatMessage {
351 role: ChatRole::User,
352 message_type: MessageType::Text,
353 content: "this is a test message".to_string(),
354 };
355 let matching_event = MessageEvent {
356 role: "user".to_string(),
357 msg: matching_message,
358 };
359 assert!(condition.matches(&matching_event));
360
361 let non_matching_message = ChatMessage {
362 role: ChatRole::User,
363 message_type: MessageType::Text,
364 content: "hi".to_string(),
365 };
366 let non_matching_event = MessageEvent {
367 role: "user".to_string(),
368 msg: non_matching_message,
369 };
370 assert!(!condition.matches(&non_matching_event));
371 }
372
373 #[test]
374 fn test_message_condition_any_of() {
375 let condition = MessageCondition::AnyOf(vec![
376 MessageCondition::Contains("hello".to_string()),
377 MessageCondition::Contains("goodbye".to_string()),
378 MessageCondition::Empty,
379 ]);
380
381 let hello_message = ChatMessage {
382 role: ChatRole::User,
383 message_type: MessageType::Text,
384 content: "hello world".to_string(),
385 };
386 let hello_event = MessageEvent {
387 role: "user".to_string(),
388 msg: hello_message,
389 };
390 assert!(condition.matches(&hello_event));
391
392 let goodbye_message = ChatMessage {
393 role: ChatRole::User,
394 message_type: MessageType::Text,
395 content: "goodbye world".to_string(),
396 };
397 let goodbye_event = MessageEvent {
398 role: "user".to_string(),
399 msg: goodbye_message,
400 };
401 assert!(condition.matches(&goodbye_event));
402
403 let empty_message = ChatMessage {
404 role: ChatRole::User,
405 message_type: MessageType::Text,
406 content: "".to_string(),
407 };
408 let empty_event = MessageEvent {
409 role: "user".to_string(),
410 msg: empty_message,
411 };
412 assert!(condition.matches(&empty_event));
413
414 let non_matching_message = ChatMessage {
415 role: ChatRole::User,
416 message_type: MessageType::Text,
417 content: "test message".to_string(),
418 };
419 let non_matching_event = MessageEvent {
420 role: "user".to_string(),
421 msg: non_matching_message,
422 };
423 assert!(!condition.matches(&non_matching_event));
424 }
425
426 #[test]
427 fn test_message_condition_regex() {
428 let condition = MessageCondition::Regex(r"\d+".to_string());
429 let number_message = ChatMessage {
430 role: ChatRole::User,
431 message_type: MessageType::Text,
432 content: "there are 123 items".to_string(),
433 };
434 let number_event = MessageEvent {
435 role: "user".to_string(),
436 msg: number_message,
437 };
438 assert!(condition.matches(&number_event));
439
440 let no_number_message = ChatMessage {
441 role: ChatRole::User,
442 message_type: MessageType::Text,
443 content: "no numbers here".to_string(),
444 };
445 let no_number_event = MessageEvent {
446 role: "user".to_string(),
447 msg: no_number_message,
448 };
449 assert!(!condition.matches(&no_number_event));
450 }
451
452 #[test]
453 fn test_message_condition_regex_invalid() {
454 let condition = MessageCondition::Regex("[invalid regex".to_string());
455 let message = ChatMessage {
456 role: ChatRole::User,
457 message_type: MessageType::Text,
458 content: "test message".to_string(),
459 };
460 let event = MessageEvent {
461 role: "user".to_string(),
462 msg: message,
463 };
464 assert!(!condition.matches(&event));
465 }
466
467 #[test]
468 fn test_message_event_creation() {
469 let message = ChatMessage {
470 role: ChatRole::User,
471 message_type: MessageType::Text,
472 content: "test message".to_string(),
473 };
474 let event = MessageEvent {
475 role: "user".to_string(),
476 msg: message.clone(),
477 };
478 assert_eq!(event.role, "user");
479 assert_eq!(event.msg.content, "test message");
480 }
481
482 #[test]
483 fn test_message_event_clone() {
484 let message = ChatMessage {
485 role: ChatRole::User,
486 message_type: MessageType::Text,
487 content: "test message".to_string(),
488 };
489 let event = MessageEvent {
490 role: "user".to_string(),
491 msg: message,
492 };
493 let cloned = event.clone();
494 assert_eq!(event.role, cloned.role);
495 assert_eq!(event.msg.content, cloned.msg.content);
496 }
497
498 #[test]
499 fn test_message_event_debug() {
500 let message = ChatMessage {
501 role: ChatRole::User,
502 message_type: MessageType::Text,
503 content: "test message".to_string(),
504 };
505 let event = MessageEvent {
506 role: "user".to_string(),
507 msg: message,
508 };
509 let debug_str = format!("{event:?}");
510 assert!(debug_str.contains("MessageEvent"));
511 assert!(debug_str.contains("user"));
512 }
513
514 #[tokio::test]
515 async fn test_mock_memory_provider_remember() {
516 let mut provider = MockMemoryProvider::new();
517 let message = ChatMessage {
518 role: ChatRole::User,
519 message_type: MessageType::Text,
520 content: "test message".to_string(),
521 };
522
523 let result = provider.remember(&message).await;
524 assert!(result.is_ok());
525 assert_eq!(provider.size(), 1);
526 }
527
528 #[tokio::test]
529 async fn test_mock_memory_provider_remember_failure() {
530 let mut provider = MockMemoryProvider::with_failure();
531 let message = ChatMessage {
532 role: ChatRole::User,
533 message_type: MessageType::Text,
534 content: "test message".to_string(),
535 };
536
537 let result = provider.remember(&message).await;
538 assert!(result.is_err());
539 assert!(result
540 .unwrap_err()
541 .to_string()
542 .contains("Mock memory error"));
543 }
544
545 #[tokio::test]
546 async fn test_mock_memory_provider_recall() {
547 let messages = vec![
548 ChatMessage {
549 role: ChatRole::User,
550 message_type: MessageType::Text,
551 content: "first message".to_string(),
552 },
553 ChatMessage {
554 role: ChatRole::Assistant,
555 message_type: MessageType::Text,
556 content: "second message".to_string(),
557 },
558 ];
559 let provider = MockMemoryProvider::with_messages(messages);
560
561 let result = provider.recall("", None).await;
562 assert!(result.is_ok());
563 let recalled = result.unwrap();
564 assert_eq!(recalled.len(), 2);
565 assert_eq!(recalled[0].content, "first message");
566 assert_eq!(recalled[1].content, "second message");
567 }
568
569 #[tokio::test]
570 async fn test_mock_memory_provider_recall_with_limit() {
571 let messages = vec![
572 ChatMessage {
573 role: ChatRole::User,
574 message_type: MessageType::Text,
575 content: "first message".to_string(),
576 },
577 ChatMessage {
578 role: ChatRole::Assistant,
579 message_type: MessageType::Text,
580 content: "second message".to_string(),
581 },
582 ];
583 let provider = MockMemoryProvider::with_messages(messages);
584
585 let result = provider.recall("", Some(1)).await;
586 assert!(result.is_ok());
587 let recalled = result.unwrap();
588 assert_eq!(recalled.len(), 1);
589 assert_eq!(recalled[0].content, "first message");
590 }
591
592 #[tokio::test]
593 async fn test_mock_memory_provider_recall_failure() {
594 let provider = MockMemoryProvider::with_failure();
595
596 let result = provider.recall("", None).await;
597 assert!(result.is_err());
598 assert!(result
599 .unwrap_err()
600 .to_string()
601 .contains("Mock recall error"));
602 }
603
604 #[tokio::test]
605 async fn test_mock_memory_provider_clear() {
606 let messages = vec![ChatMessage {
607 role: ChatRole::User,
608 message_type: MessageType::Text,
609 content: "first message".to_string(),
610 }];
611 let mut provider = MockMemoryProvider::with_messages(messages);
612 assert_eq!(provider.size(), 1);
613
614 let result = provider.clear().await;
615 assert!(result.is_ok());
616 assert_eq!(provider.size(), 0);
617 }
618
619 #[tokio::test]
620 async fn test_mock_memory_provider_clear_failure() {
621 let mut provider = MockMemoryProvider::with_failure();
622
623 let result = provider.clear().await;
624 assert!(result.is_err());
625 assert!(result.unwrap_err().to_string().contains("Mock clear error"));
626 }
627
628 #[test]
629 fn test_mock_memory_provider_memory_type() {
630 let provider = MockMemoryProvider::new();
631 assert_eq!(provider.memory_type(), MemoryType::SlidingWindow);
632 }
633
634 #[test]
635 fn test_mock_memory_provider_size() {
636 let provider = MockMemoryProvider::new();
637 assert_eq!(provider.size(), 0);
638
639 let messages = vec![ChatMessage {
640 role: ChatRole::User,
641 message_type: MessageType::Text,
642 content: "message".to_string(),
643 }];
644 let provider_with_messages = MockMemoryProvider::with_messages(messages);
645 assert_eq!(provider_with_messages.size(), 1);
646 }
647
648 #[test]
649 fn test_mock_memory_provider_is_empty() {
650 let provider = MockMemoryProvider::new();
651 assert!(provider.is_empty());
652
653 let messages = vec![ChatMessage {
654 role: ChatRole::User,
655 message_type: MessageType::Text,
656 content: "message".to_string(),
657 }];
658 let provider_with_messages = MockMemoryProvider::with_messages(messages);
659 assert!(!provider_with_messages.is_empty());
660 }
661
662 #[test]
663 fn test_memory_provider_default_methods() {
664 let provider = MockMemoryProvider::new();
665 assert!(!provider.needs_summary());
666 assert!(provider.get_event_receiver().is_none());
667 }
668
669 #[tokio::test]
670 async fn test_memory_provider_remember_with_role() {
671 let mut provider = MockMemoryProvider::new();
672 let message = ChatMessage {
673 role: ChatRole::User,
674 message_type: MessageType::Text,
675 content: "test message".to_string(),
676 };
677
678 let result = provider
679 .remember_with_role(&message, "custom_role".to_string())
680 .await;
681 assert!(result.is_ok());
682 assert_eq!(provider.size(), 1);
683 }
684
685 #[test]
686 fn test_memory_provider_mark_for_summary() {
687 let mut provider = MockMemoryProvider::new();
688 provider.mark_for_summary(); assert!(!provider.needs_summary()); }
691
692 #[test]
693 fn test_memory_provider_replace_with_summary() {
694 let mut provider = MockMemoryProvider::new();
695 provider.replace_with_summary("Summary text".to_string()); assert_eq!(provider.size(), 0); }
698}
699
700#[derive(Debug, Clone)]
702pub struct MessageEvent {
703 pub role: String,
705 pub msg: ChatMessage,
707}
708
709#[derive(Clone)]
711pub enum MessageCondition {
712 Any,
714 Eq(String),
716 Contains(String),
718 NotContains(String),
720 RoleIs(String),
722 RoleNot(String),
724 LenGt(usize),
726 Custom(Arc<dyn Fn(&ChatMessage) -> bool + Send + Sync>),
728 Empty,
730 All(Vec<MessageCondition>),
732 AnyOf(Vec<MessageCondition>),
734 Regex(String),
736}
737
738impl MessageCondition {
739 pub fn matches(&self, event: &MessageEvent) -> bool {
741 match self {
742 MessageCondition::Any => true,
743 MessageCondition::Eq(text) => event.msg.content == *text,
744 MessageCondition::Contains(text) => event.msg.content.contains(text),
745 MessageCondition::NotContains(text) => !event.msg.content.contains(text),
746 MessageCondition::RoleIs(role) => event.role == *role,
747 MessageCondition::RoleNot(role) => event.role != *role,
748 MessageCondition::LenGt(len) => event.msg.content.len() > *len,
749 MessageCondition::Custom(func) => func(&event.msg),
750 MessageCondition::Empty => event.msg.content.is_empty(),
751 MessageCondition::All(inner) => inner.iter().all(|c| c.matches(event)),
752 MessageCondition::AnyOf(inner) => inner.iter().any(|c| c.matches(event)),
753 MessageCondition::Regex(regex) => Regex::new(regex)
754 .map(|re| re.is_match(&event.msg.content))
755 .unwrap_or(false),
756 }
757 }
758}
759
760#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
762pub enum MemoryType {
763 SlidingWindow,
765}
766
767#[async_trait]
774pub trait MemoryProvider: Send + Sync {
775 async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError>;
786
787 async fn recall(&self, query: &str, limit: Option<usize>)
799 -> Result<Vec<ChatMessage>, LLMError>;
800
801 async fn clear(&mut self) -> Result<(), LLMError>;
808
809 fn memory_type(&self) -> MemoryType;
815
816 fn size(&self) -> usize;
822
823 fn is_empty(&self) -> bool {
829 self.size() == 0
830 }
831
832 fn needs_summary(&self) -> bool {
834 false
835 }
836
837 fn mark_for_summary(&mut self) {}
839
840 fn replace_with_summary(&mut self, _summary: String) {}
842
843 #[cfg(not(target_arch = "wasm32"))]
845 fn get_event_receiver(&self) -> Option<broadcast::Receiver<MessageEvent>> {
846 None
847 }
848
849 #[cfg(target_arch = "wasm32")]
850 fn get_event_receiver(&self) -> Option<()> {
851 None
852 }
853
854 async fn remember_with_role(
856 &mut self,
857 message: &ChatMessage,
858 _role: String,
859 ) -> Result<(), LLMError> {
860 self.remember(message).await
861 }
862}