rexis_rag/agent/memory/
conversation.rs

1//! Conversation memory storage with persistence
2
3use crate::error::{RragError, RragResult};
4use crate::storage::{Memory, MemoryValue};
5use rexis_llm::{ChatMessage, MessageRole}; // Use re-exported rsllm types
6use uuid::Uuid;
7
8/// Conversation memory backed by persistent storage
9pub struct ConversationMemoryStore {
10    /// Storage backend
11    storage: std::sync::Arc<dyn Memory>,
12
13    /// Session namespace (session::{session_id}::conversation)
14    namespace: String,
15
16    /// Maximum number of messages to keep
17    max_length: usize,
18
19    /// Whether to persist messages
20    persist: bool,
21}
22
23impl ConversationMemoryStore {
24    /// Create a new conversation memory store
25    pub fn new(
26        storage: std::sync::Arc<dyn Memory>,
27        session_id: String,
28        max_length: usize,
29        persist: bool,
30    ) -> Self {
31        let namespace = format!("session::{}::conversation", session_id);
32
33        Self {
34            storage,
35            namespace,
36            max_length,
37            persist,
38        }
39    }
40
41    /// Add a message to conversation history
42    pub async fn add_message(&self, message: ChatMessage) -> RragResult<()> {
43        if !self.persist {
44            // TODO: Keep in-memory cache for non-persistent mode
45            return Ok(());
46        }
47
48        // Get current message count
49        let count = self.count().await?;
50
51        // Store message
52        let key = self.message_key(count);
53        let value = self.message_to_value(&message)?;
54
55        self.storage.set(&key, value).await?;
56
57        // Prune if exceeded max length
58        if count + 1 > self.max_length {
59            self.prune_old_messages().await?;
60        }
61
62        Ok(())
63    }
64
65    /// Get all messages in order
66    pub async fn get_messages(&self) -> RragResult<Vec<ChatMessage>> {
67        if !self.persist {
68            // TODO: Return in-memory cache
69            return Ok(Vec::new());
70        }
71
72        let count = self.count().await?;
73        let mut messages = Vec::with_capacity(count);
74
75        for idx in 0..count {
76            let key = self.message_key(idx);
77            if let Some(value) = self.storage.get(&key).await? {
78                let message = self.value_to_message(&value)?;
79                messages.push(message);
80            }
81        }
82
83        Ok(messages)
84    }
85
86    /// Get the number of messages
87    pub async fn count(&self) -> RragResult<usize> {
88        if !self.persist {
89            return Ok(0);
90        }
91
92        let count_key = format!("{}::count", self.namespace);
93        if let Some(value) = self.storage.get(&count_key).await? {
94            if let Some(count) = value.as_integer() {
95                return Ok(count as usize);
96            }
97        }
98
99        Ok(0)
100    }
101
102    /// Clear all messages except system message
103    pub async fn clear(&self) -> RragResult<()> {
104        if !self.persist {
105            return Ok(());
106        }
107
108        // Get system message if it exists
109        let system_msg = if self.count().await? > 0 {
110            let key = self.message_key(0);
111            if let Some(value) = self.storage.get(&key).await? {
112                let msg = self.value_to_message(&value)?;
113                if matches!(msg.role, MessageRole::System) {
114                    Some(msg)
115                } else {
116                    None
117                }
118            } else {
119                None
120            }
121        } else {
122            None
123        };
124
125        // Clear all messages
126        self.storage.clear(Some(&self.namespace)).await?;
127
128        // Restore system message if it existed
129        if let Some(msg) = system_msg {
130            self.add_message(msg).await?;
131        }
132
133        Ok(())
134    }
135
136    /// Generate message key
137    fn message_key(&self, index: usize) -> String {
138        format!("{}::msg_{}", self.namespace, index)
139    }
140
141    /// Convert ChatMessage to MemoryValue
142    fn message_to_value(&self, message: &ChatMessage) -> RragResult<MemoryValue> {
143        let json = serde_json::to_value(message).map_err(|e| {
144            RragError::storage(
145                "serialize_message",
146                std::io::Error::new(std::io::ErrorKind::Other, e),
147            )
148        })?;
149
150        Ok(MemoryValue::Json(json))
151    }
152
153    /// Convert MemoryValue to ChatMessage
154    fn value_to_message(&self, value: &MemoryValue) -> RragResult<ChatMessage> {
155        if let Some(json) = value.as_json() {
156            let message = serde_json::from_value(json.clone()).map_err(|e| {
157                RragError::storage(
158                    "deserialize_message",
159                    std::io::Error::new(std::io::ErrorKind::Other, e),
160                )
161            })?;
162
163            Ok(message)
164        } else {
165            Err(RragError::storage(
166                "invalid_message_type",
167                std::io::Error::new(std::io::ErrorKind::InvalidData, "Expected JSON value"),
168            ))
169        }
170    }
171
172    /// Prune old messages to maintain max_length
173    async fn prune_old_messages(&self) -> RragResult<()> {
174        let count = self.count().await?;
175
176        if count <= self.max_length {
177            return Ok(());
178        }
179
180        // Keep system message (index 0) if it exists
181        let has_system = if let Some(value) = self.storage.get(&self.message_key(0)).await? {
182            let msg = self.value_to_message(&value)?;
183            matches!(msg.role, MessageRole::System)
184        } else {
185            false
186        };
187
188        let start_idx = if has_system { 1 } else { 0 };
189        let to_remove = count - self.max_length;
190
191        // Delete old messages
192        let mut keys_to_delete = Vec::new();
193        for idx in start_idx..(start_idx + to_remove) {
194            keys_to_delete.push(self.message_key(idx));
195        }
196
197        self.storage.mdelete(&keys_to_delete).await?;
198
199        // Shift remaining messages down
200        for idx in (start_idx + to_remove)..count {
201            let old_key = self.message_key(idx);
202            let new_key = self.message_key(idx - to_remove);
203
204            if let Some(value) = self.storage.get(&old_key).await? {
205                self.storage.set(&new_key, value).await?;
206                self.storage.delete(&old_key).await?;
207            }
208        }
209
210        // Update count
211        let count_key = format!("{}::count", self.namespace);
212        self.storage
213            .set(&count_key, MemoryValue::Integer((count - to_remove) as i64))
214            .await?;
215
216        Ok(())
217    }
218
219    /// Check if conversation is empty
220    pub async fn is_empty(&self) -> RragResult<bool> {
221        Ok(self.count().await? == 0)
222    }
223}
224
225/// Generate a unique session ID
226pub fn generate_session_id() -> String {
227    Uuid::new_v4().to_string()
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::storage::InMemoryStorage;
234    use std::sync::Arc;
235
236    #[tokio::test]
237    async fn test_conversation_memory_store() {
238        let storage = Arc::new(InMemoryStorage::new());
239        let session_id = generate_session_id();
240        let store = ConversationMemoryStore::new(storage, session_id, 10, true);
241
242        // Add messages
243        store
244            .add_message(ChatMessage::system("You are a helpful assistant"))
245            .await
246            .unwrap();
247        store.add_message(ChatMessage::user("Hello")).await.unwrap();
248        store
249            .add_message(ChatMessage::assistant("Hi there!"))
250            .await
251            .unwrap();
252
253        // Get messages
254        let messages = store.get_messages().await.unwrap();
255        assert_eq!(messages.len(), 3);
256
257        // Check count
258        assert_eq!(store.count().await.unwrap(), 3);
259
260        // Clear
261        store.clear().await.unwrap();
262
263        // System message should remain
264        let messages = store.get_messages().await.unwrap();
265        assert_eq!(messages.len(), 1);
266        assert!(matches!(messages[0].role, rexis_llm::MessageRole::System));
267    }
268}