rexis_rag/
memory.rs

1//! # RRAG Memory System
2//!
3//! Conversation memory and context management with Rust-native async patterns.
4//! Designed for efficient state management and persistence.
5
6use crate::{RragError, RragResult};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13/// Conversation message
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ConversationMessage {
16    /// Message ID
17    pub id: String,
18
19    /// Role (user, assistant, system, tool)
20    pub role: MessageRole,
21
22    /// Message content
23    pub content: String,
24
25    /// Message metadata
26    pub metadata: HashMap<String, serde_json::Value>,
27
28    /// Timestamp
29    pub timestamp: chrono::DateTime<chrono::Utc>,
30
31    /// Token count (if available)
32    pub token_count: Option<usize>,
33}
34
35/// Message roles in conversation
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub enum MessageRole {
38    /// Message from the user
39    User,
40    /// Message from the AI assistant
41    Assistant,
42    /// System message
43    System,
44    /// Message from a tool execution
45    Tool,
46}
47
48impl ConversationMessage {
49    /// Create a new conversation message
50    pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
51        Self {
52            id: uuid::Uuid::new_v4().to_string(),
53            role,
54            content: content.into(),
55            metadata: HashMap::new(),
56            timestamp: chrono::Utc::now(),
57            token_count: None,
58        }
59    }
60
61    /// Create a user message
62    pub fn user(content: impl Into<String>) -> Self {
63        Self::new(MessageRole::User, content)
64    }
65
66    /// Create an assistant message
67    pub fn assistant(content: impl Into<String>) -> Self {
68        Self::new(MessageRole::Assistant, content)
69    }
70
71    /// Create a system message
72    pub fn system(content: impl Into<String>) -> Self {
73        Self::new(MessageRole::System, content)
74    }
75
76    /// Create a tool message
77    pub fn tool(content: impl Into<String>) -> Self {
78        Self::new(MessageRole::Tool, content)
79    }
80
81    /// Add metadata to the message
82    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
83        self.metadata.insert(key.into(), value);
84        self
85    }
86
87    /// Set the token count for the message
88    pub fn with_token_count(mut self, count: usize) -> Self {
89        self.token_count = Some(count);
90        self
91    }
92
93    /// Get estimated token count (simple heuristic if not set)
94    pub fn estimated_tokens(&self) -> usize {
95        self.token_count.unwrap_or_else(|| {
96            // Simple estimation: ~4 characters per token
97            self.content.len() / 4
98        })
99    }
100}
101
102/// Memory summary for efficient storage
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MemorySummary {
105    /// Summary text
106    pub summary: String,
107
108    /// Number of messages summarized
109    pub message_count: usize,
110
111    /// Token count of original messages
112    pub original_tokens: usize,
113
114    /// Token count of summary
115    pub summary_tokens: usize,
116
117    /// Time range covered - start time
118    pub start_time: chrono::DateTime<chrono::Utc>,
119    /// Time range covered - end time
120    pub end_time: chrono::DateTime<chrono::Utc>,
121
122    /// Summary metadata
123    pub metadata: HashMap<String, serde_json::Value>,
124}
125
126/// Core memory trait for conversation management
127#[async_trait]
128pub trait Memory: Send + Sync {
129    /// Add a message to the conversation
130    async fn add_message(&self, conversation_id: &str, role: &str, content: &str)
131        -> RragResult<()>;
132
133    /// Add a structured message
134    async fn add_structured_message(
135        &self,
136        conversation_id: &str,
137        message: ConversationMessage,
138    ) -> RragResult<()>;
139
140    /// Get conversation history
141    async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>>;
142
143    /// Get structured conversation history
144    async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>>;
145
146    /// Get recent messages with limit
147    async fn get_recent_messages(
148        &self,
149        conversation_id: &str,
150        limit: usize,
151    ) -> RragResult<Vec<ConversationMessage>>;
152
153    /// Clear conversation history
154    async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()>;
155
156    /// Get memory variables for prompt injection
157    async fn get_memory_variables(
158        &self,
159        conversation_id: &str,
160    ) -> RragResult<HashMap<String, String>>;
161
162    /// Save arbitrary context
163    async fn save_context(
164        &self,
165        conversation_id: &str,
166        context: HashMap<String, String>,
167    ) -> RragResult<()>;
168
169    /// Health check
170    async fn health_check(&self) -> RragResult<bool>;
171}
172
173/// Buffer memory - keeps recent messages in memory
174pub struct ConversationBufferMemory {
175    /// Stored conversations
176    conversations: Arc<RwLock<HashMap<String, VecDeque<ConversationMessage>>>>,
177
178    /// Configuration
179    config: BufferMemoryConfig,
180}
181
182/// Configuration for buffer-based memory
183#[derive(Debug, Clone)]
184pub struct BufferMemoryConfig {
185    /// Maximum messages to keep per conversation
186    pub max_messages: Option<usize>,
187
188    /// Maximum age of messages in seconds
189    pub max_age_seconds: Option<u64>,
190
191    /// Memory key for prompt variables
192    pub memory_key: String,
193}
194
195impl Default for BufferMemoryConfig {
196    fn default() -> Self {
197        Self {
198            max_messages: Some(100),
199            max_age_seconds: Some(3600 * 24), // 24 hours
200            memory_key: "history".to_string(),
201        }
202    }
203}
204
205impl ConversationBufferMemory {
206    /// Create a new buffer memory with default configuration
207    pub fn new() -> Self {
208        Self {
209            conversations: Arc::new(RwLock::new(HashMap::new())),
210            config: BufferMemoryConfig::default(),
211        }
212    }
213
214    /// Create a new buffer memory with custom configuration
215    pub fn with_config(config: BufferMemoryConfig) -> Self {
216        Self {
217            conversations: Arc::new(RwLock::new(HashMap::new())),
218            config,
219        }
220    }
221
222    /// Clean up old messages based on configuration
223    async fn cleanup_old_messages(&self, conversation_id: &str) {
224        let mut conversations = self.conversations.write().await;
225
226        if let Some(messages) = conversations.get_mut(conversation_id) {
227            // Remove old messages by age
228            if let Some(max_age) = self.config.max_age_seconds {
229                let cutoff_time = chrono::Utc::now() - chrono::Duration::seconds(max_age as i64);
230                while let Some(front) = messages.front() {
231                    if front.timestamp < cutoff_time {
232                        messages.pop_front();
233                    } else {
234                        break;
235                    }
236                }
237            }
238
239            // Limit by count
240            if let Some(max_messages) = self.config.max_messages {
241                while messages.len() > max_messages {
242                    messages.pop_front();
243                }
244            }
245        }
246    }
247}
248
249impl Default for ConversationBufferMemory {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255#[async_trait]
256impl Memory for ConversationBufferMemory {
257    async fn add_message(
258        &self,
259        conversation_id: &str,
260        role: &str,
261        content: &str,
262    ) -> RragResult<()> {
263        let role = match role.to_lowercase().as_str() {
264            "user" => MessageRole::User,
265            "assistant" => MessageRole::Assistant,
266            "system" => MessageRole::System,
267            "tool" => MessageRole::Tool,
268            _ => MessageRole::User, // Default fallback
269        };
270
271        let message = ConversationMessage::new(role, content);
272        self.add_structured_message(conversation_id, message).await
273    }
274
275    async fn add_structured_message(
276        &self,
277        conversation_id: &str,
278        message: ConversationMessage,
279    ) -> RragResult<()> {
280        let mut conversations = self.conversations.write().await;
281
282        let messages = conversations
283            .entry(conversation_id.to_string())
284            .or_insert_with(VecDeque::new);
285
286        messages.push_back(message);
287
288        // Release the lock before cleanup
289        drop(conversations);
290
291        // Clean up old messages
292        self.cleanup_old_messages(conversation_id).await;
293
294        Ok(())
295    }
296
297    async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>> {
298        let conversations = self.conversations.read().await;
299
300        if let Some(messages) = conversations.get(conversation_id) {
301            let history = messages
302                .iter()
303                .map(|msg| format!("{:?}: {}", msg.role, msg.content))
304                .collect();
305            Ok(history)
306        } else {
307            Ok(Vec::new())
308        }
309    }
310
311    async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>> {
312        let conversations = self.conversations.read().await;
313
314        if let Some(messages) = conversations.get(conversation_id) {
315            Ok(messages.iter().cloned().collect())
316        } else {
317            Ok(Vec::new())
318        }
319    }
320
321    async fn get_recent_messages(
322        &self,
323        conversation_id: &str,
324        limit: usize,
325    ) -> RragResult<Vec<ConversationMessage>> {
326        let conversations = self.conversations.read().await;
327
328        if let Some(messages) = conversations.get(conversation_id) {
329            let recent: Vec<ConversationMessage> =
330                messages.iter().rev().take(limit).rev().cloned().collect();
331            Ok(recent)
332        } else {
333            Ok(Vec::new())
334        }
335    }
336
337    async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()> {
338        let mut conversations = self.conversations.write().await;
339        conversations.remove(conversation_id);
340        Ok(())
341    }
342
343    async fn get_memory_variables(
344        &self,
345        conversation_id: &str,
346    ) -> RragResult<HashMap<String, String>> {
347        let history = self.get_conversation_history(conversation_id).await?;
348        let mut variables = HashMap::new();
349
350        variables.insert(self.config.memory_key.clone(), history.join("\n"));
351
352        Ok(variables)
353    }
354
355    async fn save_context(
356        &self,
357        _conversation_id: &str,
358        _context: HashMap<String, String>,
359    ) -> RragResult<()> {
360        // Simple buffer memory doesn't persist additional context
361        // This could be extended to store context in message metadata
362        Ok(())
363    }
364
365    async fn health_check(&self) -> RragResult<bool> {
366        Ok(true)
367    }
368}
369
370/// Token-aware buffer memory that respects token limits
371pub struct ConversationTokenBufferMemory {
372    /// Base buffer memory
373    buffer: ConversationBufferMemory,
374
375    /// Token-specific configuration
376    token_config: TokenBufferConfig,
377}
378
379/// Configuration for token-aware memory buffer
380#[derive(Debug, Clone)]
381pub struct TokenBufferConfig {
382    /// Maximum tokens to keep in memory
383    pub max_tokens: usize,
384
385    /// Buffer size to keep below max (for safety)
386    pub buffer_tokens: usize,
387
388    /// How to handle token overflow
389    pub overflow_strategy: TokenOverflowStrategy,
390}
391
392/// Strategy for handling token overflow in memory buffer
393#[derive(Debug, Clone)]
394pub enum TokenOverflowStrategy {
395    /// Remove oldest messages
396    RemoveOldest,
397
398    /// Summarize old messages
399    Summarize,
400
401    /// Truncate message content
402    Truncate,
403}
404
405impl Default for TokenBufferConfig {
406    fn default() -> Self {
407        Self {
408            max_tokens: 4000,
409            buffer_tokens: 500,
410            overflow_strategy: TokenOverflowStrategy::RemoveOldest,
411        }
412    }
413}
414
415impl ConversationTokenBufferMemory {
416    /// Create a new token buffer memory with default configuration
417    pub fn new() -> Self {
418        Self {
419            buffer: ConversationBufferMemory::new(),
420            token_config: TokenBufferConfig::default(),
421        }
422    }
423
424    /// Create a new token buffer memory with custom configuration
425    pub fn with_config(buffer_config: BufferMemoryConfig, token_config: TokenBufferConfig) -> Self {
426        Self {
427            buffer: ConversationBufferMemory::with_config(buffer_config),
428            token_config,
429        }
430    }
431
432    /// Calculate total tokens in conversation
433    async fn calculate_total_tokens(&self, conversation_id: &str) -> RragResult<usize> {
434        let messages = self.buffer.get_messages(conversation_id).await?;
435        let total = messages.iter().map(|msg| msg.estimated_tokens()).sum();
436        Ok(total)
437    }
438
439    /// Handle token overflow
440    async fn handle_token_overflow(&self, conversation_id: &str) -> RragResult<()> {
441        let current_tokens = self.calculate_total_tokens(conversation_id).await?;
442
443        if current_tokens <= self.token_config.max_tokens {
444            return Ok(());
445        }
446
447        match self.token_config.overflow_strategy {
448            TokenOverflowStrategy::RemoveOldest => {
449                let mut conversations = self.buffer.conversations.write().await;
450
451                if let Some(messages) = conversations.get_mut(conversation_id) {
452                    while !messages.is_empty() {
453                        let total: usize = messages.iter().map(|msg| msg.estimated_tokens()).sum();
454                        if total <= self.token_config.max_tokens - self.token_config.buffer_tokens {
455                            break;
456                        }
457                        messages.pop_front();
458                    }
459                }
460            }
461            TokenOverflowStrategy::Summarize => {
462                // This would require integration with an LLM for summarization
463                // For now, fall back to removing oldest
464                let mut conversations = self.buffer.conversations.write().await;
465
466                if let Some(messages) = conversations.get_mut(conversation_id) {
467                    // Remove half the messages as a simple strategy
468                    let remove_count = messages.len() / 2;
469                    for _ in 0..remove_count {
470                        messages.pop_front();
471                    }
472                }
473            }
474            TokenOverflowStrategy::Truncate => {
475                // Truncate message content (not implemented in this example)
476                return Err(RragError::memory(
477                    "token_overflow",
478                    "Truncate strategy not implemented",
479                ));
480            }
481        }
482
483        Ok(())
484    }
485}
486
487impl Default for ConversationTokenBufferMemory {
488    fn default() -> Self {
489        Self::new()
490    }
491}
492
493#[async_trait]
494impl Memory for ConversationTokenBufferMemory {
495    async fn add_message(
496        &self,
497        conversation_id: &str,
498        role: &str,
499        content: &str,
500    ) -> RragResult<()> {
501        self.buffer
502            .add_message(conversation_id, role, content)
503            .await?;
504        self.handle_token_overflow(conversation_id).await?;
505        Ok(())
506    }
507
508    async fn add_structured_message(
509        &self,
510        conversation_id: &str,
511        message: ConversationMessage,
512    ) -> RragResult<()> {
513        self.buffer
514            .add_structured_message(conversation_id, message)
515            .await?;
516        self.handle_token_overflow(conversation_id).await?;
517        Ok(())
518    }
519
520    async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>> {
521        self.buffer.get_conversation_history(conversation_id).await
522    }
523
524    async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>> {
525        self.buffer.get_messages(conversation_id).await
526    }
527
528    async fn get_recent_messages(
529        &self,
530        conversation_id: &str,
531        limit: usize,
532    ) -> RragResult<Vec<ConversationMessage>> {
533        self.buffer
534            .get_recent_messages(conversation_id, limit)
535            .await
536    }
537
538    async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()> {
539        self.buffer.clear_conversation(conversation_id).await
540    }
541
542    async fn get_memory_variables(
543        &self,
544        conversation_id: &str,
545    ) -> RragResult<HashMap<String, String>> {
546        let mut variables = self.buffer.get_memory_variables(conversation_id).await?;
547
548        // Add token information
549        let token_count = self.calculate_total_tokens(conversation_id).await?;
550        variables.insert("token_count".to_string(), token_count.to_string());
551        variables.insert(
552            "max_tokens".to_string(),
553            self.token_config.max_tokens.to_string(),
554        );
555
556        Ok(variables)
557    }
558
559    async fn save_context(
560        &self,
561        conversation_id: &str,
562        context: HashMap<String, String>,
563    ) -> RragResult<()> {
564        self.buffer.save_context(conversation_id, context).await
565    }
566
567    async fn health_check(&self) -> RragResult<bool> {
568        self.buffer.health_check().await
569    }
570}
571
572/// Summary memory that automatically summarizes old conversations
573pub struct ConversationSummaryMemory {
574    /// Current conversation buffer
575    current_messages: Arc<RwLock<HashMap<String, VecDeque<ConversationMessage>>>>,
576
577    /// Stored summaries
578    summaries: Arc<RwLock<HashMap<String, Vec<MemorySummary>>>>,
579
580    /// Configuration
581    config: SummaryMemoryConfig,
582}
583
584/// Configuration for summary-based memory management
585#[derive(Debug, Clone)]
586pub struct SummaryMemoryConfig {
587    /// Maximum messages before summarization
588    pub max_messages_before_summary: usize,
589
590    /// Maximum total tokens before summarization
591    pub max_tokens_before_summary: usize,
592
593    /// Number of recent messages to keep after summarization
594    pub keep_recent_messages: usize,
595
596    /// Memory key for variables
597    pub memory_key: String,
598
599    /// Summary key for variables
600    pub summary_key: String,
601}
602
603impl Default for SummaryMemoryConfig {
604    fn default() -> Self {
605        Self {
606            max_messages_before_summary: 20,
607            max_tokens_before_summary: 2000,
608            keep_recent_messages: 5,
609            memory_key: "history".to_string(),
610            summary_key: "summary".to_string(),
611        }
612    }
613}
614
615impl ConversationSummaryMemory {
616    /// Create a new summary memory with default configuration
617    pub fn new() -> Self {
618        Self {
619            current_messages: Arc::new(RwLock::new(HashMap::new())),
620            summaries: Arc::new(RwLock::new(HashMap::new())),
621            config: SummaryMemoryConfig::default(),
622        }
623    }
624
625    /// Create a new summary memory with custom configuration
626    pub fn with_config(config: SummaryMemoryConfig) -> Self {
627        Self {
628            current_messages: Arc::new(RwLock::new(HashMap::new())),
629            summaries: Arc::new(RwLock::new(HashMap::new())),
630            config,
631        }
632    }
633
634    /// Check if summarization is needed
635    async fn should_summarize(&self, conversation_id: &str) -> RragResult<bool> {
636        let messages = self.current_messages.read().await;
637
638        if let Some(msg_deque) = messages.get(conversation_id) {
639            // Check message count
640            if msg_deque.len() > self.config.max_messages_before_summary {
641                return Ok(true);
642            }
643
644            // Check token count
645            let total_tokens: usize = msg_deque.iter().map(|msg| msg.estimated_tokens()).sum();
646            if total_tokens > self.config.max_tokens_before_summary {
647                return Ok(true);
648            }
649        }
650
651        Ok(false)
652    }
653
654    /// Perform summarization (mock implementation)
655    async fn summarize_conversation(&self, conversation_id: &str) -> RragResult<()> {
656        let mut messages = self.current_messages.write().await;
657        let mut summaries = self.summaries.write().await;
658
659        if let Some(msg_deque) = messages.get_mut(conversation_id) {
660            if msg_deque.len() <= self.config.keep_recent_messages {
661                return Ok(());
662            }
663
664            // Calculate how many messages to summarize
665            let to_summarize_count = msg_deque.len() - self.config.keep_recent_messages;
666
667            // Extract messages to summarize
668            let mut to_summarize = Vec::new();
669            for _ in 0..to_summarize_count {
670                if let Some(msg) = msg_deque.pop_front() {
671                    to_summarize.push(msg);
672                }
673            }
674
675            if !to_summarize.is_empty() {
676                // Create a simple summary (in production, would use LLM)
677                let summary_text = format!(
678                    "Summary of {} messages from {} to {}",
679                    to_summarize.len(),
680                    to_summarize
681                        .first()
682                        .unwrap()
683                        .timestamp
684                        .format("%Y-%m-%d %H:%M:%S"),
685                    to_summarize
686                        .last()
687                        .unwrap()
688                        .timestamp
689                        .format("%Y-%m-%d %H:%M:%S")
690                );
691
692                let original_tokens = to_summarize.iter().map(|msg| msg.estimated_tokens()).sum();
693
694                let summary = MemorySummary {
695                    summary: summary_text,
696                    message_count: to_summarize.len(),
697                    original_tokens,
698                    summary_tokens: 50, // Estimated
699                    start_time: to_summarize.first().unwrap().timestamp,
700                    end_time: to_summarize.last().unwrap().timestamp,
701                    metadata: HashMap::new(),
702                };
703
704                // Store the summary
705                summaries
706                    .entry(conversation_id.to_string())
707                    .or_insert_with(Vec::new)
708                    .push(summary);
709            }
710        }
711
712        Ok(())
713    }
714}
715
716impl Default for ConversationSummaryMemory {
717    fn default() -> Self {
718        Self::new()
719    }
720}
721
722#[async_trait]
723impl Memory for ConversationSummaryMemory {
724    async fn add_message(
725        &self,
726        conversation_id: &str,
727        role: &str,
728        content: &str,
729    ) -> RragResult<()> {
730        let role = match role.to_lowercase().as_str() {
731            "user" => MessageRole::User,
732            "assistant" => MessageRole::Assistant,
733            "system" => MessageRole::System,
734            "tool" => MessageRole::Tool,
735            _ => MessageRole::User,
736        };
737
738        let message = ConversationMessage::new(role, content);
739        self.add_structured_message(conversation_id, message).await
740    }
741
742    async fn add_structured_message(
743        &self,
744        conversation_id: &str,
745        message: ConversationMessage,
746    ) -> RragResult<()> {
747        // Add the message
748        {
749            let mut messages = self.current_messages.write().await;
750            let msg_deque = messages
751                .entry(conversation_id.to_string())
752                .or_insert_with(VecDeque::new);
753            msg_deque.push_back(message);
754        }
755
756        // Check if summarization is needed
757        if self.should_summarize(conversation_id).await? {
758            self.summarize_conversation(conversation_id).await?;
759        }
760
761        Ok(())
762    }
763
764    async fn get_conversation_history(&self, conversation_id: &str) -> RragResult<Vec<String>> {
765        let messages = self.current_messages.read().await;
766        let summaries = self.summaries.read().await;
767
768        let mut history = Vec::new();
769
770        // Add summaries first
771        if let Some(summary_list) = summaries.get(conversation_id) {
772            for summary in summary_list {
773                history.push(format!("Summary: {}", summary.summary));
774            }
775        }
776
777        // Add current messages
778        if let Some(msg_deque) = messages.get(conversation_id) {
779            for msg in msg_deque {
780                history.push(format!("{:?}: {}", msg.role, msg.content));
781            }
782        }
783
784        Ok(history)
785    }
786
787    async fn get_messages(&self, conversation_id: &str) -> RragResult<Vec<ConversationMessage>> {
788        let messages = self.current_messages.read().await;
789
790        if let Some(msg_deque) = messages.get(conversation_id) {
791            Ok(msg_deque.iter().cloned().collect())
792        } else {
793            Ok(Vec::new())
794        }
795    }
796
797    async fn get_recent_messages(
798        &self,
799        conversation_id: &str,
800        limit: usize,
801    ) -> RragResult<Vec<ConversationMessage>> {
802        let messages = self.current_messages.read().await;
803
804        if let Some(msg_deque) = messages.get(conversation_id) {
805            let recent: Vec<ConversationMessage> =
806                msg_deque.iter().rev().take(limit).rev().cloned().collect();
807            Ok(recent)
808        } else {
809            Ok(Vec::new())
810        }
811    }
812
813    async fn clear_conversation(&self, conversation_id: &str) -> RragResult<()> {
814        let mut messages = self.current_messages.write().await;
815        let mut summaries = self.summaries.write().await;
816
817        messages.remove(conversation_id);
818        summaries.remove(conversation_id);
819
820        Ok(())
821    }
822
823    async fn get_memory_variables(
824        &self,
825        conversation_id: &str,
826    ) -> RragResult<HashMap<String, String>> {
827        let mut variables = HashMap::new();
828
829        // Get current conversation
830        let history = self.get_conversation_history(conversation_id).await?;
831        variables.insert(self.config.memory_key.clone(), history.join("\n"));
832
833        // Get summary
834        let summaries = self.summaries.read().await;
835        if let Some(summary_list) = summaries.get(conversation_id) {
836            let summary_text = summary_list
837                .iter()
838                .map(|s| s.summary.clone())
839                .collect::<Vec<_>>()
840                .join("\n");
841            variables.insert(self.config.summary_key.clone(), summary_text);
842        }
843
844        Ok(variables)
845    }
846
847    async fn save_context(
848        &self,
849        _conversation_id: &str,
850        _context: HashMap<String, String>,
851    ) -> RragResult<()> {
852        // Could store context in message metadata or separate storage
853        Ok(())
854    }
855
856    async fn health_check(&self) -> RragResult<bool> {
857        Ok(true)
858    }
859}
860
861/// High-level memory service that can switch between different memory types
862pub struct MemoryService {
863    /// Active memory implementation
864    memory: Arc<dyn Memory>,
865
866    /// Service configuration
867    config: MemoryServiceConfig,
868}
869
870/// Configuration for the memory service
871#[derive(Debug, Clone)]
872pub struct MemoryServiceConfig {
873    /// Default conversation settings
874    pub default_conversation_settings: ConversationSettings,
875
876    /// Enable memory persistence
877    pub enable_persistence: bool,
878
879    /// Persistence interval in seconds
880    pub persistence_interval_seconds: u64,
881}
882
883/// Settings for individual conversations
884#[derive(Debug, Clone)]
885pub struct ConversationSettings {
886    /// Maximum messages per conversation
887    pub max_messages: Option<usize>,
888
889    /// Maximum age for messages
890    pub max_age_hours: Option<u64>,
891
892    /// Auto-summarization threshold
893    pub auto_summarize_threshold: Option<usize>,
894}
895
896impl Default for MemoryServiceConfig {
897    fn default() -> Self {
898        Self {
899            default_conversation_settings: ConversationSettings::default(),
900            enable_persistence: false,
901            persistence_interval_seconds: 300, // 5 minutes
902        }
903    }
904}
905
906impl Default for ConversationSettings {
907    fn default() -> Self {
908        Self {
909            max_messages: Some(100),
910            max_age_hours: Some(24),
911            auto_summarize_threshold: Some(50),
912        }
913    }
914}
915
916impl MemoryService {
917    /// Create a new memory service with default configuration
918    pub fn new(memory: Arc<dyn Memory>) -> Self {
919        Self {
920            memory,
921            config: MemoryServiceConfig::default(),
922        }
923    }
924
925    /// Create a new memory service with custom configuration
926    pub fn with_config(memory: Arc<dyn Memory>, config: MemoryServiceConfig) -> Self {
927        Self { memory, config }
928    }
929
930    /// Add a user message
931    pub async fn add_user_message(&self, conversation_id: &str, content: &str) -> RragResult<()> {
932        self.memory
933            .add_message(conversation_id, "user", content)
934            .await
935    }
936
937    /// Add an assistant message
938    pub async fn add_assistant_message(
939        &self,
940        conversation_id: &str,
941        content: &str,
942    ) -> RragResult<()> {
943        self.memory
944            .add_message(conversation_id, "assistant", content)
945            .await
946    }
947
948    /// Get formatted conversation for prompts
949    pub async fn get_conversation_context(&self, conversation_id: &str) -> RragResult<String> {
950        let variables = self.memory.get_memory_variables(conversation_id).await?;
951
952        // Return the main history
953        Ok(variables.get("history").unwrap_or(&String::new()).clone())
954    }
955
956    /// Get memory variables for prompt templates
957    pub async fn get_prompt_variables(
958        &self,
959        conversation_id: &str,
960    ) -> RragResult<HashMap<String, String>> {
961        self.memory.get_memory_variables(conversation_id).await
962    }
963
964    /// Health check
965    pub async fn health_check(&self) -> RragResult<bool> {
966        self.memory.health_check().await
967    }
968}
969
970#[cfg(test)]
971mod tests {
972    use super::*;
973
974    #[tokio::test]
975    async fn test_conversation_message() {
976        let msg = ConversationMessage::user("Hello world")
977            .with_metadata("source", serde_json::Value::String("test".to_string()))
978            .with_token_count(10);
979
980        assert_eq!(msg.role, MessageRole::User);
981        assert_eq!(msg.content, "Hello world");
982        assert_eq!(msg.estimated_tokens(), 10);
983        assert_eq!(
984            msg.metadata.get("source").unwrap().as_str().unwrap(),
985            "test"
986        );
987    }
988
989    #[tokio::test]
990    async fn test_buffer_memory() {
991        let memory = ConversationBufferMemory::new();
992        let conv_id = "test_conversation";
993
994        // Add messages
995        memory.add_message(conv_id, "user", "Hello").await.unwrap();
996        memory
997            .add_message(conv_id, "assistant", "Hi there!")
998            .await
999            .unwrap();
1000
1001        // Get history
1002        let history = memory.get_conversation_history(conv_id).await.unwrap();
1003        assert_eq!(history.len(), 2);
1004        assert!(history[0].contains("Hello"));
1005        assert!(history[1].contains("Hi there!"));
1006
1007        // Get messages
1008        let messages = memory.get_messages(conv_id).await.unwrap();
1009        assert_eq!(messages.len(), 2);
1010        assert_eq!(messages[0].role, MessageRole::User);
1011        assert_eq!(messages[1].role, MessageRole::Assistant);
1012
1013        // Test recent messages
1014        let recent = memory.get_recent_messages(conv_id, 1).await.unwrap();
1015        assert_eq!(recent.len(), 1);
1016        assert_eq!(recent[0].content, "Hi there!");
1017    }
1018
1019    #[tokio::test]
1020    async fn test_token_buffer_memory() {
1021        let config = TokenBufferConfig {
1022            max_tokens: 100,
1023            buffer_tokens: 10,
1024            overflow_strategy: TokenOverflowStrategy::RemoveOldest,
1025        };
1026
1027        let memory =
1028            ConversationTokenBufferMemory::with_config(BufferMemoryConfig::default(), config);
1029
1030        let conv_id = "test_token_conversation";
1031
1032        // Add many messages to trigger overflow
1033        for i in 0..20 {
1034            memory
1035                .add_message(
1036                    conv_id,
1037                    "user",
1038                    &format!("Message number {} with some content", i),
1039                )
1040                .await
1041                .unwrap();
1042        }
1043
1044        let total_tokens = memory.calculate_total_tokens(conv_id).await.unwrap();
1045        assert!(
1046            total_tokens <= 100,
1047            "Total tokens {} should be <= 100",
1048            total_tokens
1049        );
1050
1051        let messages = memory.get_messages(conv_id).await.unwrap();
1052        assert!(
1053            messages.len() < 20,
1054            "Should have removed some messages due to token limit"
1055        );
1056    }
1057
1058    #[tokio::test]
1059    async fn test_memory_service() {
1060        let memory = Arc::new(ConversationBufferMemory::new());
1061        let service = MemoryService::new(memory);
1062
1063        let conv_id = "service_test";
1064
1065        service
1066            .add_user_message(conv_id, "How are you?")
1067            .await
1068            .unwrap();
1069        service
1070            .add_assistant_message(conv_id, "I'm doing well, thank you!")
1071            .await
1072            .unwrap();
1073
1074        let context = service.get_conversation_context(conv_id).await.unwrap();
1075        assert!(context.contains("How are you?"));
1076        assert!(context.contains("I'm doing well"));
1077
1078        let variables = service.get_prompt_variables(conv_id).await.unwrap();
1079        assert!(variables.contains_key("history"));
1080
1081        assert!(service.health_check().await.unwrap());
1082    }
1083
1084    #[tokio::test]
1085    async fn test_summary_memory() {
1086        let config = SummaryMemoryConfig {
1087            max_messages_before_summary: 3,
1088            max_tokens_before_summary: 1000,
1089            keep_recent_messages: 1,
1090            memory_key: "history".to_string(),
1091            summary_key: "summary".to_string(),
1092        };
1093
1094        let memory = ConversationSummaryMemory::with_config(config);
1095        let conv_id = "summary_test";
1096
1097        // Add enough messages to trigger summarization
1098        memory
1099            .add_message(conv_id, "user", "First message")
1100            .await
1101            .unwrap();
1102        memory
1103            .add_message(conv_id, "assistant", "First response")
1104            .await
1105            .unwrap();
1106        memory
1107            .add_message(conv_id, "user", "Second message")
1108            .await
1109            .unwrap();
1110        memory
1111            .add_message(conv_id, "assistant", "Second response")
1112            .await
1113            .unwrap();
1114
1115        // Should have triggered summarization
1116        let messages = memory.get_messages(conv_id).await.unwrap();
1117        assert!(messages.len() <= 1, "Should have summarized old messages");
1118
1119        let variables = memory.get_memory_variables(conv_id).await.unwrap();
1120        assert!(
1121            variables.contains_key("summary"),
1122            "Should have summary in variables"
1123        );
1124    }
1125}