1use crate::{RragError, RragResult};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ConversationMessage {
16 pub id: String,
18
19 pub role: MessageRole,
21
22 pub content: String,
24
25 pub metadata: HashMap<String, serde_json::Value>,
27
28 pub timestamp: chrono::DateTime<chrono::Utc>,
30
31 pub token_count: Option<usize>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub enum MessageRole {
38 User,
40 Assistant,
42 System,
44 Tool,
46}
47
48impl ConversationMessage {
49 pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
51 Self {
52 id: uuid::Uuid::new_v4().to_string(),
53 role,
54 content: content.into(),
55 metadata: HashMap::new(),
56 timestamp: chrono::Utc::now(),
57 token_count: None,
58 }
59 }
60
61 pub fn user(content: impl Into<String>) -> Self {
63 Self::new(MessageRole::User, content)
64 }
65
66 pub fn assistant(content: impl Into<String>) -> Self {
68 Self::new(MessageRole::Assistant, content)
69 }
70
71 pub fn system(content: impl Into<String>) -> Self {
73 Self::new(MessageRole::System, content)
74 }
75
76 pub fn tool(content: impl Into<String>) -> Self {
78 Self::new(MessageRole::Tool, content)
79 }
80
81 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
83 self.metadata.insert(key.into(), value);
84 self
85 }
86
87 pub fn with_token_count(mut self, count: usize) -> Self {
89 self.token_count = Some(count);
90 self
91 }
92
93 pub fn estimated_tokens(&self) -> usize {
95 self.token_count.unwrap_or_else(|| {
96 self.content.len() / 4
98 })
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MemorySummary {
105 pub summary: String,
107
108 pub message_count: usize,
110
111 pub original_tokens: usize,
113
114 pub summary_tokens: usize,
116
117 pub start_time: chrono::DateTime<chrono::Utc>,
119 pub end_time: chrono::DateTime<chrono::Utc>,
121
122 pub metadata: HashMap<String, serde_json::Value>,
124}
125
126#[async_trait]
128pub trait Memory: Send + Sync {
129 async fn add_message(&self, conversation_id: &str, role: &str, content: &str)
131 -> RragResult<()>;
132
133 async fn add_structured_message(
135 &self,
136 conversation_id: &str,
137 message: ConversationMessage,
138 ) -> RragResult<()>;
139
140 async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>>;
142
143 async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>>;
145
146 async fn get_recent_messages(
148 &self,
149 conversation_id: &str,
150 limit: usize,
151 ) -> RragResult<Vec<ConversationMessage>>;
152
153 async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()>;
155
156 async fn get_memory_variables(
158 &self,
159 conversation_id: &str,
160 ) -> RragResult<HashMap<String, String>>;
161
162 async fn save_context(
164 &self,
165 conversation_id: &str,
166 context: HashMap<String, String>,
167 ) -> RragResult<()>;
168
169 async fn health_check(&self) -> RragResult<bool>;
171}
172
173pub struct ConversationBufferMemory {
175 conversations: Arc<RwLock<HashMap<String, VecDeque<ConversationMessage>>>>,
177
178 config: BufferMemoryConfig,
180}
181
182#[derive(Debug, Clone)]
184pub struct BufferMemoryConfig {
185 pub max_messages: Option<usize>,
187
188 pub max_age_seconds: Option<u64>,
190
191 pub memory_key: String,
193}
194
195impl Default for BufferMemoryConfig {
196 fn default() -> Self {
197 Self {
198 max_messages: Some(100),
199 max_age_seconds: Some(3600 * 24), memory_key: "history".to_string(),
201 }
202 }
203}
204
205impl ConversationBufferMemory {
206 pub fn new() -> Self {
208 Self {
209 conversations: Arc::new(RwLock::new(HashMap::new())),
210 config: BufferMemoryConfig::default(),
211 }
212 }
213
214 pub fn with_config(config: BufferMemoryConfig) -> Self {
216 Self {
217 conversations: Arc::new(RwLock::new(HashMap::new())),
218 config,
219 }
220 }
221
222 async fn cleanup_old_messages(&self, conversation_id: &str) {
224 let mut conversations = self.conversations.write().await;
225
226 if let Some(messages) = conversations.get_mut(conversation_id) {
227 if let Some(max_age) = self.config.max_age_seconds {
229 let cutoff_time = chrono::Utc::now() - chrono::Duration::seconds(max_age as i64);
230 while let Some(front) = messages.front() {
231 if front.timestamp < cutoff_time {
232 messages.pop_front();
233 } else {
234 break;
235 }
236 }
237 }
238
239 if let Some(max_messages) = self.config.max_messages {
241 while messages.len() > max_messages {
242 messages.pop_front();
243 }
244 }
245 }
246 }
247}
248
249impl Default for ConversationBufferMemory {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255#[async_trait]
256impl Memory for ConversationBufferMemory {
257 async fn add_message(
258 &self,
259 conversation_id: &str,
260 role: &str,
261 content: &str,
262 ) -> RragResult<()> {
263 let role = match role.to_lowercase().as_str() {
264 "user" => MessageRole::User,
265 "assistant" => MessageRole::Assistant,
266 "system" => MessageRole::System,
267 "tool" => MessageRole::Tool,
268 _ => MessageRole::User, };
270
271 let message = ConversationMessage::new(role, content);
272 self.add_structured_message(conversation_id, message).await
273 }
274
275 async fn add_structured_message(
276 &self,
277 conversation_id: &str,
278 message: ConversationMessage,
279 ) -> RragResult<()> {
280 let mut conversations = self.conversations.write().await;
281
282 let messages = conversations
283 .entry(conversation_id.to_string())
284 .or_insert_with(VecDeque::new);
285
286 messages.push_back(message);
287
288 drop(conversations);
290
291 self.cleanup_old_messages(conversation_id).await;
293
294 Ok(())
295 }
296
297 async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>> {
298 let conversations = self.conversations.read().await;
299
300 if let Some(messages) = conversations.get(conversation_id) {
301 let history = messages
302 .iter()
303 .map(|msg| format!("{:?}: {}", msg.role, msg.content))
304 .collect();
305 Ok(history)
306 } else {
307 Ok(Vec::new())
308 }
309 }
310
311 async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>> {
312 let conversations = self.conversations.read().await;
313
314 if let Some(messages) = conversations.get(conversation_id) {
315 Ok(messages.iter().cloned().collect())
316 } else {
317 Ok(Vec::new())
318 }
319 }
320
321 async fn get_recent_messages(
322 &self,
323 conversation_id: &str,
324 limit: usize,
325 ) -> RragResult<Vec<ConversationMessage>> {
326 let conversations = self.conversations.read().await;
327
328 if let Some(messages) = conversations.get(conversation_id) {
329 let recent: Vec<ConversationMessage> =
330 messages.iter().rev().take(limit).rev().cloned().collect();
331 Ok(recent)
332 } else {
333 Ok(Vec::new())
334 }
335 }
336
337 async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()> {
338 let mut conversations = self.conversations.write().await;
339 conversations.remove(conversation_id);
340 Ok(())
341 }
342
343 async fn get_memory_variables(
344 &self,
345 conversation_id: &str,
346 ) -> RragResult<HashMap<String, String>> {
347 let history = self.get_conversation_history(conversation_id).await?;
348 let mut variables = HashMap::new();
349
350 variables.insert(self.config.memory_key.clone(), history.join("\n"));
351
352 Ok(variables)
353 }
354
355 async fn save_context(
356 &self,
357 _conversation_id: &str,
358 _context: HashMap<String, String>,
359 ) -> RragResult<()> {
360 Ok(())
363 }
364
365 async fn health_check(&self) -> RragResult<bool> {
366 Ok(true)
367 }
368}
369
370pub struct ConversationTokenBufferMemory {
372 buffer: ConversationBufferMemory,
374
375 token_config: TokenBufferConfig,
377}
378
379#[derive(Debug, Clone)]
381pub struct TokenBufferConfig {
382 pub max_tokens: usize,
384
385 pub buffer_tokens: usize,
387
388 pub overflow_strategy: TokenOverflowStrategy,
390}
391
392#[derive(Debug, Clone)]
394pub enum TokenOverflowStrategy {
395 RemoveOldest,
397
398 Summarize,
400
401 Truncate,
403}
404
405impl Default for TokenBufferConfig {
406 fn default() -> Self {
407 Self {
408 max_tokens: 4000,
409 buffer_tokens: 500,
410 overflow_strategy: TokenOverflowStrategy::RemoveOldest,
411 }
412 }
413}
414
415impl ConversationTokenBufferMemory {
416 pub fn new() -> Self {
418 Self {
419 buffer: ConversationBufferMemory::new(),
420 token_config: TokenBufferConfig::default(),
421 }
422 }
423
424 pub fn with_config(buffer_config: BufferMemoryConfig, token_config: TokenBufferConfig) -> Self {
426 Self {
427 buffer: ConversationBufferMemory::with_config(buffer_config),
428 token_config,
429 }
430 }
431
432 async fn calculate_total_tokens(&self, conversation_id: &str) -> RragResult<usize> {
434 let messages = self.buffer.get_messages(conversation_id).await?;
435 let total = messages.iter().map(|msg| msg.estimated_tokens()).sum();
436 Ok(total)
437 }
438
439 async fn handle_token_overflow(&self, conversation_id: &str) -> RragResult<()> {
441 let current_tokens = self.calculate_total_tokens(conversation_id).await?;
442
443 if current_tokens <= self.token_config.max_tokens {
444 return Ok(());
445 }
446
447 match self.token_config.overflow_strategy {
448 TokenOverflowStrategy::RemoveOldest => {
449 let mut conversations = self.buffer.conversations.write().await;
450
451 if let Some(messages) = conversations.get_mut(conversation_id) {
452 while !messages.is_empty() {
453 let total: usize = messages.iter().map(|msg| msg.estimated_tokens()).sum();
454 if total <= self.token_config.max_tokens - self.token_config.buffer_tokens {
455 break;
456 }
457 messages.pop_front();
458 }
459 }
460 }
461 TokenOverflowStrategy::Summarize => {
462 let mut conversations = self.buffer.conversations.write().await;
465
466 if let Some(messages) = conversations.get_mut(conversation_id) {
467 let remove_count = messages.len() / 2;
469 for _ in 0..remove_count {
470 messages.pop_front();
471 }
472 }
473 }
474 TokenOverflowStrategy::Truncate => {
475 return Err(RragError::memory(
477 "token_overflow",
478 "Truncate strategy not implemented",
479 ));
480 }
481 }
482
483 Ok(())
484 }
485}
486
487impl Default for ConversationTokenBufferMemory {
488 fn default() -> Self {
489 Self::new()
490 }
491}
492
493#[async_trait]
494impl Memory for ConversationTokenBufferMemory {
495 async fn add_message(
496 &self,
497 conversation_id: &str,
498 role: &str,
499 content: &str,
500 ) -> RragResult<()> {
501 self.buffer
502 .add_message(conversation_id, role, content)
503 .await?;
504 self.handle_token_overflow(conversation_id).await?;
505 Ok(())
506 }
507
508 async fn add_structured_message(
509 &self,
510 conversation_id: &str,
511 message: ConversationMessage,
512 ) -> RragResult<()> {
513 self.buffer
514 .add_structured_message(conversation_id, message)
515 .await?;
516 self.handle_token_overflow(conversation_id).await?;
517 Ok(())
518 }
519
520 async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>> {
521 self.buffer.get_conversation_history(conversation_id).await
522 }
523
524 async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>> {
525 self.buffer.get_messages(conversation_id).await
526 }
527
528 async fn get_recent_messages(
529 &self,
530 conversation_id: &str,
531 limit: usize,
532 ) -> RragResult<Vec<ConversationMessage>> {
533 self.buffer
534 .get_recent_messages(conversation_id, limit)
535 .await
536 }
537
538 async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()> {
539 self.buffer.clear_conversation(conversation_id).await
540 }
541
542 async fn get_memory_variables(
543 &self,
544 conversation_id: &str,
545 ) -> RragResult<HashMap<String, String>> {
546 let mut variables = self.buffer.get_memory_variables(conversation_id).await?;
547
548 let token_count = self.calculate_total_tokens(conversation_id).await?;
550 variables.insert("token_count".to_string(), token_count.to_string());
551 variables.insert(
552 "max_tokens".to_string(),
553 self.token_config.max_tokens.to_string(),
554 );
555
556 Ok(variables)
557 }
558
559 async fn save_context(
560 &self,
561 conversation_id: &str,
562 context: HashMap<String, String>,
563 ) -> RragResult<()> {
564 self.buffer.save_context(conversation_id, context).await
565 }
566
567 async fn health_check(&self) -> RragResult<bool> {
568 self.buffer.health_check().await
569 }
570}
571
572pub struct ConversationSummaryMemory {
574 current_messages: Arc<RwLock<HashMap<String, VecDeque<ConversationMessage>>>>,
576
577 summaries: Arc<RwLock<HashMap<String, Vec<MemorySummary>>>>,
579
580 config: SummaryMemoryConfig,
582}
583
584#[derive(Debug, Clone)]
586pub struct SummaryMemoryConfig {
587 pub max_messages_before_summary: usize,
589
590 pub max_tokens_before_summary: usize,
592
593 pub keep_recent_messages: usize,
595
596 pub memory_key: String,
598
599 pub summary_key: String,
601}
602
603impl Default for SummaryMemoryConfig {
604 fn default() -> Self {
605 Self {
606 max_messages_before_summary: 20,
607 max_tokens_before_summary: 2000,
608 keep_recent_messages: 5,
609 memory_key: "history".to_string(),
610 summary_key: "summary".to_string(),
611 }
612 }
613}
614
615impl ConversationSummaryMemory {
616 pub fn new() -> Self {
618 Self {
619 current_messages: Arc::new(RwLock::new(HashMap::new())),
620 summaries: Arc::new(RwLock::new(HashMap::new())),
621 config: SummaryMemoryConfig::default(),
622 }
623 }
624
625 pub fn with_config(config: SummaryMemoryConfig) -> Self {
627 Self {
628 current_messages: Arc::new(RwLock::new(HashMap::new())),
629 summaries: Arc::new(RwLock::new(HashMap::new())),
630 config,
631 }
632 }
633
634 async fn should_summarize(&self, conversation_id: &str) -> RragResult<bool> {
636 let messages = self.current_messages.read().await;
637
638 if let Some(msg_deque) = messages.get(conversation_id) {
639 if msg_deque.len() > self.config.max_messages_before_summary {
641 return Ok(true);
642 }
643
644 let total_tokens: usize = msg_deque.iter().map(|msg| msg.estimated_tokens()).sum();
646 if total_tokens > self.config.max_tokens_before_summary {
647 return Ok(true);
648 }
649 }
650
651 Ok(false)
652 }
653
654 async fn summarize_conversation(&self, conversation_id: &str) -> RragResult<()> {
656 let mut messages = self.current_messages.write().await;
657 let mut summaries = self.summaries.write().await;
658
659 if let Some(msg_deque) = messages.get_mut(conversation_id) {
660 if msg_deque.len() <= self.config.keep_recent_messages {
661 return Ok(());
662 }
663
664 let to_summarize_count = msg_deque.len() - self.config.keep_recent_messages;
666
667 let mut to_summarize = Vec::new();
669 for _ in 0..to_summarize_count {
670 if let Some(msg) = msg_deque.pop_front() {
671 to_summarize.push(msg);
672 }
673 }
674
675 if !to_summarize.is_empty() {
676 let summary_text = format!(
678 "Summary of {} messages from {} to {}",
679 to_summarize.len(),
680 to_summarize
681 .first()
682 .unwrap()
683 .timestamp
684 .format("%Y-%m-%d %H:%M:%S"),
685 to_summarize
686 .last()
687 .unwrap()
688 .timestamp
689 .format("%Y-%m-%d %H:%M:%S")
690 );
691
692 let original_tokens = to_summarize.iter().map(|msg| msg.estimated_tokens()).sum();
693
694 let summary = MemorySummary {
695 summary: summary_text,
696 message_count: to_summarize.len(),
697 original_tokens,
698 summary_tokens: 50, start_time: to_summarize.first().unwrap().timestamp,
700 end_time: to_summarize.last().unwrap().timestamp,
701 metadata: HashMap::new(),
702 };
703
704 summaries
706 .entry(conversation_id.to_string())
707 .or_insert_with(Vec::new)
708 .push(summary);
709 }
710 }
711
712 Ok(())
713 }
714}
715
716impl Default for ConversationSummaryMemory {
717 fn default() -> Self {
718 Self::new()
719 }
720}
721
722#[async_trait]
723impl Memory for ConversationSummaryMemory {
724 async fn add_message(
725 &self,
726 conversation_id: &str,
727 role: &str,
728 content: &str,
729 ) -> RragResult<()> {
730 let role = match role.to_lowercase().as_str() {
731 "user" => MessageRole::User,
732 "assistant" => MessageRole::Assistant,
733 "system" => MessageRole::System,
734 "tool" => MessageRole::Tool,
735 _ => MessageRole::User,
736 };
737
738 let message = ConversationMessage::new(role, content);
739 self.add_structured_message(conversation_id, message).await
740 }
741
742 async fn add_structured_message(
743 &self,
744 conversation_id: &str,
745 message: ConversationMessage,
746 ) -> RragResult<()> {
747 {
749 let mut messages = self.current_messages.write().await;
750 let msg_deque = messages
751 .entry(conversation_id.to_string())
752 .or_insert_with(VecDeque::new);
753 msg_deque.push_back(message);
754 }
755
756 if self.should_summarize(conversation_id).await? {
758 self.summarize_conversation(conversation_id).await?;
759 }
760
761 Ok(())
762 }
763
764 async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>> {
765 let messages = self.current_messages.read().await;
766 let summaries = self.summaries.read().await;
767
768 let mut history = Vec::new();
769
770 if let Some(summary_list) = summaries.get(conversation_id) {
772 for summary in summary_list {
773 history.push(format!("Summary: {}", summary.summary));
774 }
775 }
776
777 if let Some(msg_deque) = messages.get(conversation_id) {
779 for msg in msg_deque {
780 history.push(format!("{:?}: {}", msg.role, msg.content));
781 }
782 }
783
784 Ok(history)
785 }
786
787 async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>> {
788 let messages = self.current_messages.read().await;
789
790 if let Some(msg_deque) = messages.get(conversation_id) {
791 Ok(msg_deque.iter().cloned().collect())
792 } else {
793 Ok(Vec::new())
794 }
795 }
796
797 async fn get_recent_messages(
798 &self,
799 conversation_id: &str,
800 limit: usize,
801 ) -> RragResult<Vec<ConversationMessage>> {
802 let messages = self.current_messages.read().await;
803
804 if let Some(msg_deque) = messages.get(conversation_id) {
805 let recent: Vec<ConversationMessage> =
806 msg_deque.iter().rev().take(limit).rev().cloned().collect();
807 Ok(recent)
808 } else {
809 Ok(Vec::new())
810 }
811 }
812
813 async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()> {
814 let mut messages = self.current_messages.write().await;
815 let mut summaries = self.summaries.write().await;
816
817 messages.remove(conversation_id);
818 summaries.remove(conversation_id);
819
820 Ok(())
821 }
822
823 async fn get_memory_variables(
824 &self,
825 conversation_id: &str,
826 ) -> RragResult<HashMap<String, String>> {
827 let mut variables = HashMap::new();
828
829 let history = self.get_conversation_history(conversation_id).await?;
831 variables.insert(self.config.memory_key.clone(), history.join("\n"));
832
833 let summaries = self.summaries.read().await;
835 if let Some(summary_list) = summaries.get(conversation_id) {
836 let summary_text = summary_list
837 .iter()
838 .map(|s| s.summary.clone())
839 .collect::<Vec<_>>()
840 .join("\n");
841 variables.insert(self.config.summary_key.clone(), summary_text);
842 }
843
844 Ok(variables)
845 }
846
847 async fn save_context(
848 &self,
849 _conversation_id: &str,
850 _context: HashMap<String, String>,
851 ) -> RragResult<()> {
852 Ok(())
854 }
855
856 async fn health_check(&self) -> RragResult<bool> {
857 Ok(true)
858 }
859}
860
861pub struct MemoryService {
863 memory: Arc<dyn Memory>,
865
866 config: MemoryServiceConfig,
868}
869
870#[derive(Debug, Clone)]
872pub struct MemoryServiceConfig {
873 pub default_conversation_settings: ConversationSettings,
875
876 pub enable_persistence: bool,
878
879 pub persistence_interval_seconds: u64,
881}
882
883#[derive(Debug, Clone)]
885pub struct ConversationSettings {
886 pub max_messages: Option<usize>,
888
889 pub max_age_hours: Option<u64>,
891
892 pub auto_summarize_threshold: Option<usize>,
894}
895
896impl Default for MemoryServiceConfig {
897 fn default() -> Self {
898 Self {
899 default_conversation_settings: ConversationSettings::default(),
900 enable_persistence: false,
901 persistence_interval_seconds: 300, }
903 }
904}
905
906impl Default for ConversationSettings {
907 fn default() -> Self {
908 Self {
909 max_messages: Some(100),
910 max_age_hours: Some(24),
911 auto_summarize_threshold: Some(50),
912 }
913 }
914}
915
916impl MemoryService {
917 pub fn new(memory: Arc<dyn Memory>) -> Self {
919 Self {
920 memory,
921 config: MemoryServiceConfig::default(),
922 }
923 }
924
925 pub fn with_config(memory: Arc<dyn Memory>, config: MemoryServiceConfig) -> Self {
927 Self { memory, config }
928 }
929
930 pub async fn add_user_message(&self, conversation_id: &str, content: &str) -> RragResult<()> {
932 self.memory
933 .add_message(conversation_id, "user", content)
934 .await
935 }
936
937 pub async fn add_assistant_message(
939 &self,
940 conversation_id: &str,
941 content: &str,
942 ) -> RragResult<()> {
943 self.memory
944 .add_message(conversation_id, "assistant", content)
945 .await
946 }
947
948 pub async fn get_conversation_context(&self, conversation_id: &str) -> RragResult<String> {
950 let variables = self.memory.get_memory_variables(conversation_id).await?;
951
952 Ok(variables.get("history").unwrap_or(&String::new()).clone())
954 }
955
956 pub async fn get_prompt_variables(
958 &self,
959 conversation_id: &str,
960 ) -> RragResult<HashMap<String, String>> {
961 self.memory.get_memory_variables(conversation_id).await
962 }
963
964 pub async fn health_check(&self) -> RragResult<bool> {
966 self.memory.health_check().await
967 }
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973
974 #[tokio::test]
975 async fn test_conversation_message() {
976 let msg = ConversationMessage::user("Hello world")
977 .with_metadata("source", serde_json::Value::String("test".to_string()))
978 .with_token_count(10);
979
980 assert_eq!(msg.role, MessageRole::User);
981 assert_eq!(msg.content, "Hello world");
982 assert_eq!(msg.estimated_tokens(), 10);
983 assert_eq!(
984 msg.metadata.get("source").unwrap().as_str().unwrap(),
985 "test"
986 );
987 }
988
989 #[tokio::test]
990 async fn test_buffer_memory() {
991 let memory = ConversationBufferMemory::new();
992 let conv_id = "test_conversation";
993
994 memory.add_message(conv_id, "user", "Hello").await.unwrap();
996 memory
997 .add_message(conv_id, "assistant", "Hi there!")
998 .await
999 .unwrap();
1000
1001 let history = memory.get_conversation_history(conv_id).await.unwrap();
1003 assert_eq!(history.len(), 2);
1004 assert!(history[0].contains("Hello"));
1005 assert!(history[1].contains("Hi there!"));
1006
1007 let messages = memory.get_messages(conv_id).await.unwrap();
1009 assert_eq!(messages.len(), 2);
1010 assert_eq!(messages[0].role, MessageRole::User);
1011 assert_eq!(messages[1].role, MessageRole::Assistant);
1012
1013 let recent = memory.get_recent_messages(conv_id, 1).await.unwrap();
1015 assert_eq!(recent.len(), 1);
1016 assert_eq!(recent[0].content, "Hi there!");
1017 }
1018
1019 #[tokio::test]
1020 async fn test_token_buffer_memory() {
1021 let config = TokenBufferConfig {
1022 max_tokens: 100,
1023 buffer_tokens: 10,
1024 overflow_strategy: TokenOverflowStrategy::RemoveOldest,
1025 };
1026
1027 let memory =
1028 ConversationTokenBufferMemory::with_config(BufferMemoryConfig::default(), config);
1029
1030 let conv_id = "test_token_conversation";
1031
1032 for i in 0..20 {
1034 memory
1035 .add_message(
1036 conv_id,
1037 "user",
1038 &format!("Message number {} with some content", i),
1039 )
1040 .await
1041 .unwrap();
1042 }
1043
1044 let total_tokens = memory.calculate_total_tokens(conv_id).await.unwrap();
1045 assert!(
1046 total_tokens <= 100,
1047 "Total tokens {} should be <= 100",
1048 total_tokens
1049 );
1050
1051 let messages = memory.get_messages(conv_id).await.unwrap();
1052 assert!(
1053 messages.len() < 20,
1054 "Should have removed some messages due to token limit"
1055 );
1056 }
1057
1058 #[tokio::test]
1059 async fn test_memory_service() {
1060 let memory = Arc::new(ConversationBufferMemory::new());
1061 let service = MemoryService::new(memory);
1062
1063 let conv_id = "service_test";
1064
1065 service
1066 .add_user_message(conv_id, "How are you?")
1067 .await
1068 .unwrap();
1069 service
1070 .add_assistant_message(conv_id, "I'm doing well, thank you!")
1071 .await
1072 .unwrap();
1073
1074 let context = service.get_conversation_context(conv_id).await.unwrap();
1075 assert!(context.contains("How are you?"));
1076 assert!(context.contains("I'm doing well"));
1077
1078 let variables = service.get_prompt_variables(conv_id).await.unwrap();
1079 assert!(variables.contains_key("history"));
1080
1081 assert!(service.health_check().await.unwrap());
1082 }
1083
1084 #[tokio::test]
1085 async fn test_summary_memory() {
1086 let config = SummaryMemoryConfig {
1087 max_messages_before_summary: 3,
1088 max_tokens_before_summary: 1000,
1089 keep_recent_messages: 1,
1090 memory_key: "history".to_string(),
1091 summary_key: "summary".to_string(),
1092 };
1093
1094 let memory = ConversationSummaryMemory::with_config(config);
1095 let conv_id = "summary_test";
1096
1097 memory
1099 .add_message(conv_id, "user", "First message")
1100 .await
1101 .unwrap();
1102 memory
1103 .add_message(conv_id, "assistant", "First response")
1104 .await
1105 .unwrap();
1106 memory
1107 .add_message(conv_id, "user", "Second message")
1108 .await
1109 .unwrap();
1110 memory
1111 .add_message(conv_id, "assistant", "Second response")
1112 .await
1113 .unwrap();
1114
1115 let messages = memory.get_messages(conv_id).await.unwrap();
1117 assert!(messages.len() <= 1, "Should have summarized old messages");
1118
1119 let variables = memory.get_memory_variables(conv_id).await.unwrap();
1120 assert!(
1121 variables.contains_key("summary"),
1122 "Should have summary in variables"
1123 );
1124 }
1125}