rexis_rag/agent/memory/
conversation.rs1use crate::error::{RragError, RragResult};
4use crate::storage::{Memory, MemoryValue};
5use rexis_llm::{ChatMessage, MessageRole}; use uuid::Uuid;
7
8pub struct ConversationMemoryStore {
10 storage: std::sync::Arc<dyn Memory>,
12
13 namespace: String,
15
16 max_length: usize,
18
19 persist: bool,
21}
22
23impl ConversationMemoryStore {
24 pub fn new(
26 storage: std::sync::Arc<dyn Memory>,
27 session_id: String,
28 max_length: usize,
29 persist: bool,
30 ) -> Self {
31 let namespace = format!("session::{}::conversation", session_id);
32
33 Self {
34 storage,
35 namespace,
36 max_length,
37 persist,
38 }
39 }
40
41 pub async fn add_message(&self, message: ChatMessage) -> RragResult<()> {
43 if !self.persist {
44 return Ok(());
46 }
47
48 let count = self.count().await?;
50
51 let key = self.message_key(count);
53 let value = self.message_to_value(&message)?;
54
55 self.storage.set(&key, value).await?;
56
57 if count + 1 > self.max_length {
59 self.prune_old_messages().await?;
60 }
61
62 Ok(())
63 }
64
65 pub async fn get_messages(&self) -> RragResult<Vec<ChatMessage>> {
67 if !self.persist {
68 return Ok(Vec::new());
70 }
71
72 let count = self.count().await?;
73 let mut messages = Vec::with_capacity(count);
74
75 for idx in 0..count {
76 let key = self.message_key(idx);
77 if let Some(value) = self.storage.get(&key).await? {
78 let message = self.value_to_message(&value)?;
79 messages.push(message);
80 }
81 }
82
83 Ok(messages)
84 }
85
86 pub async fn count(&self) -> RragResult<usize> {
88 if !self.persist {
89 return Ok(0);
90 }
91
92 let count_key = format!("{}::count", self.namespace);
93 if let Some(value) = self.storage.get(&count_key).await? {
94 if let Some(count) = value.as_integer() {
95 return Ok(count as usize);
96 }
97 }
98
99 Ok(0)
100 }
101
102 pub async fn clear(&self) -> RragResult<()> {
104 if !self.persist {
105 return Ok(());
106 }
107
108 let system_msg = if self.count().await? > 0 {
110 let key = self.message_key(0);
111 if let Some(value) = self.storage.get(&key).await? {
112 let msg = self.value_to_message(&value)?;
113 if matches!(msg.role, MessageRole::System) {
114 Some(msg)
115 } else {
116 None
117 }
118 } else {
119 None
120 }
121 } else {
122 None
123 };
124
125 self.storage.clear(Some(&self.namespace)).await?;
127
128 if let Some(msg) = system_msg {
130 self.add_message(msg).await?;
131 }
132
133 Ok(())
134 }
135
136 fn message_key(&self, index: usize) -> String {
138 format!("{}::msg_{}", self.namespace, index)
139 }
140
141 fn message_to_value(&self, message: &ChatMessage) -> RragResult<MemoryValue> {
143 let json = serde_json::to_value(message).map_err(|e| {
144 RragError::storage(
145 "serialize_message",
146 std::io::Error::new(std::io::ErrorKind::Other, e),
147 )
148 })?;
149
150 Ok(MemoryValue::Json(json))
151 }
152
153 fn value_to_message(&self, value: &MemoryValue) -> RragResult<ChatMessage> {
155 if let Some(json) = value.as_json() {
156 let message = serde_json::from_value(json.clone()).map_err(|e| {
157 RragError::storage(
158 "deserialize_message",
159 std::io::Error::new(std::io::ErrorKind::Other, e),
160 )
161 })?;
162
163 Ok(message)
164 } else {
165 Err(RragError::storage(
166 "invalid_message_type",
167 std::io::Error::new(std::io::ErrorKind::InvalidData, "Expected JSON value"),
168 ))
169 }
170 }
171
172 async fn prune_old_messages(&self) -> RragResult<()> {
174 let count = self.count().await?;
175
176 if count <= self.max_length {
177 return Ok(());
178 }
179
180 let has_system = if let Some(value) = self.storage.get(&self.message_key(0)).await? {
182 let msg = self.value_to_message(&value)?;
183 matches!(msg.role, MessageRole::System)
184 } else {
185 false
186 };
187
188 let start_idx = if has_system { 1 } else { 0 };
189 let to_remove = count - self.max_length;
190
191 let mut keys_to_delete = Vec::new();
193 for idx in start_idx..(start_idx + to_remove) {
194 keys_to_delete.push(self.message_key(idx));
195 }
196
197 self.storage.mdelete(&keys_to_delete).await?;
198
199 for idx in (start_idx + to_remove)..count {
201 let old_key = self.message_key(idx);
202 let new_key = self.message_key(idx - to_remove);
203
204 if let Some(value) = self.storage.get(&old_key).await? {
205 self.storage.set(&new_key, value).await?;
206 self.storage.delete(&old_key).await?;
207 }
208 }
209
210 let count_key = format!("{}::count", self.namespace);
212 self.storage
213 .set(&count_key, MemoryValue::Integer((count - to_remove) as i64))
214 .await?;
215
216 Ok(())
217 }
218
219 pub async fn is_empty(&self) -> RragResult<bool> {
221 Ok(self.count().await? == 0)
222 }
223}
224
225pub fn generate_session_id() -> String {
227 Uuid::new_v4().to_string()
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::storage::InMemoryStorage;
234 use std::sync::Arc;
235
236 #[tokio::test]
237 async fn test_conversation_memory_store() {
238 let storage = Arc::new(InMemoryStorage::new());
239 let session_id = generate_session_id();
240 let store = ConversationMemoryStore::new(storage, session_id, 10, true);
241
242 store
244 .add_message(ChatMessage::system("You are a helpful assistant"))
245 .await
246 .unwrap();
247 store.add_message(ChatMessage::user("Hello")).await.unwrap();
248 store
249 .add_message(ChatMessage::assistant("Hi there!"))
250 .await
251 .unwrap();
252
253 let messages = store.get_messages().await.unwrap();
255 assert_eq!(messages.len(), 3);
256
257 assert_eq!(store.count().await.unwrap(), 3);
259
260 store.clear().await.unwrap();
262
263 let messages = store.get_messages().await.unwrap();
265 assert_eq!(messages.len(), 1);
266 assert!(matches!(messages[0].role, rexis_llm::MessageRole::System));
267 }
268}