rexis_rag/agent/memory/
compression.rs

1//! Memory compression and optimization strategies
2//!
3//! Provides utilities for compressing, archiving, and optimizing memory storage
4//! to manage memory growth over long conversations and agent lifecycles.
5
6use crate::error::RragResult;
7use crate::storage::{Memory, MemoryValue};
8use std::sync::Arc;
9
10#[cfg(feature = "rexis-llm-client")]
11use rexis_llm::{ChatMessage, Client};
12
13/// Configuration for memory compression
14#[derive(Debug, Clone)]
15pub struct CompressionConfig {
16    /// Maximum size in bytes before triggering compression
17    pub max_size_bytes: usize,
18
19    /// Maximum number of items before compression
20    pub max_items: usize,
21
22    /// Compression ratio target (0.0 to 1.0)
23    pub compression_ratio: f64,
24
25    /// Enable LLM-based intelligent compression
26    pub use_llm_compression: bool,
27
28    /// Minimum importance score to keep uncompressed
29    pub min_importance_threshold: f64,
30}
31
32impl Default for CompressionConfig {
33    fn default() -> Self {
34        Self {
35            max_size_bytes: 10_000_000, // 10MB
36            max_items: 10_000,
37            compression_ratio: 0.5,
38            use_llm_compression: true,
39            min_importance_threshold: 0.7,
40        }
41    }
42}
43
44/// Memory compression strategy
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum CompressionStrategy {
47    /// Remove oldest entries
48    RemoveOldest,
49
50    /// Remove least important entries
51    RemoveLeastImportant,
52
53    /// Merge similar entries
54    MergeSimilar,
55
56    /// Summarize and archive
57    SummarizeAndArchive,
58
59    /// Binary compression (gzip/zstd)
60    BinaryCompression,
61}
62
63/// Memory statistics for compression decisions
64#[derive(Debug, Clone)]
65pub struct MemoryStats {
66    /// Total size in bytes
67    pub total_bytes: usize,
68
69    /// Number of items
70    pub item_count: usize,
71
72    /// Average item size
73    pub avg_item_size: usize,
74
75    /// Oldest item timestamp
76    pub oldest_timestamp: Option<chrono::DateTime<chrono::Utc>>,
77
78    /// Newest item timestamp
79    pub newest_timestamp: Option<chrono::DateTime<chrono::Utc>>,
80}
81
82/// Memory compressor
83pub struct MemoryCompressor {
84    storage: Arc<dyn Memory>,
85    config: CompressionConfig,
86}
87
88impl MemoryCompressor {
89    /// Create a new memory compressor
90    pub fn new(storage: Arc<dyn Memory>, config: CompressionConfig) -> Self {
91        Self { storage, config }
92    }
93
94    /// Check if compression is needed based on stats
95    pub fn needs_compression(&self, stats: &MemoryStats) -> bool {
96        stats.total_bytes > self.config.max_size_bytes || stats.item_count > self.config.max_items
97    }
98
99    /// Calculate memory statistics for a namespace
100    pub async fn calculate_stats(&self, namespace: &str) -> RragResult<MemoryStats> {
101        use crate::storage::MemoryQuery;
102
103        let query = MemoryQuery::new().with_namespace(namespace.to_string());
104        let keys = self.storage.keys(&query).await?;
105
106        let mut total_bytes = 0;
107        let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
108        let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
109
110        for key in &keys {
111            if let Some(value) = self.storage.get(key).await? {
112                // Estimate size (rough approximation)
113                total_bytes += match &value {
114                    MemoryValue::String(s) => s.len(),
115                    MemoryValue::Integer(_) => 8,
116                    MemoryValue::Float(_) => 8,
117                    MemoryValue::Boolean(_) => 1,
118                    MemoryValue::Json(j) => j.to_string().len(),
119                    MemoryValue::Bytes(b) => b.len(),
120                    MemoryValue::List(items) => items.len() * 16, // rough estimate
121                    MemoryValue::Map(m) => m.len() * 32,          // rough estimate
122                };
123
124                // Try to extract timestamp from JSON values
125                if let MemoryValue::Json(json) = value {
126                    if let Some(timestamp_str) = json.get("timestamp").and_then(|v| v.as_str()) {
127                        if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(timestamp_str) {
128                            let utc_ts = ts.with_timezone(&chrono::Utc);
129                            oldest = Some(oldest.map_or(utc_ts, |o| o.min(utc_ts)));
130                            newest = Some(newest.map_or(utc_ts, |n| n.max(utc_ts)));
131                        }
132                    } else if let Some(created_str) =
133                        json.get("created_at").and_then(|v| v.as_str())
134                    {
135                        if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(created_str) {
136                            let utc_ts = ts.with_timezone(&chrono::Utc);
137                            oldest = Some(oldest.map_or(utc_ts, |o| o.min(utc_ts)));
138                            newest = Some(newest.map_or(utc_ts, |n| n.max(utc_ts)));
139                        }
140                    }
141                }
142            }
143        }
144
145        let item_count = keys.len();
146        let avg_item_size = if item_count > 0 {
147            total_bytes / item_count
148        } else {
149            0
150        };
151
152        Ok(MemoryStats {
153            total_bytes,
154            item_count,
155            avg_item_size,
156            oldest_timestamp: oldest,
157            newest_timestamp: newest,
158        })
159    }
160
161    /// Compress conversation memory by summarizing old messages (requires 'rsllm-client' feature)
162    #[cfg(feature = "rexis-llm-client")]
163    pub async fn compress_conversation_memory(
164        &self,
165        namespace: &str,
166        llm_client: &Client,
167        keep_recent_count: usize,
168    ) -> RragResult<usize> {
169        use crate::storage::MemoryQuery;
170
171        let query = MemoryQuery::new().with_namespace(namespace.to_string());
172        let keys = self.storage.keys(&query).await?;
173
174        if keys.len() <= keep_recent_count {
175            return Ok(0); // Nothing to compress
176        }
177
178        // Get all messages with timestamps
179        let mut messages: Vec<(String, serde_json::Value, chrono::DateTime<chrono::Utc>)> =
180            Vec::new();
181
182        for key in &keys {
183            if let Some(value) = self.storage.get(key).await? {
184                if let MemoryValue::Json(json) = value {
185                    if let Some(timestamp_str) = json.get("timestamp").and_then(|v| v.as_str()) {
186                        if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(timestamp_str) {
187                            messages.push((key.clone(), json, ts.with_timezone(&chrono::Utc)));
188                        }
189                    }
190                }
191            }
192        }
193
194        // Sort by timestamp (oldest first)
195        messages.sort_by(|a, b| a.2.cmp(&b.2));
196
197        // Keep recent messages, compress old ones
198        let to_compress = messages.len().saturating_sub(keep_recent_count);
199
200        if to_compress == 0 {
201            return Ok(0);
202        }
203
204        // Build text from old messages
205        let mut old_messages_text = String::new();
206        for (_, json, _) in messages.iter().take(to_compress) {
207            if let Some(role) = json.get("role").and_then(|v| v.as_str()) {
208                if let Some(content) = json.get("content").and_then(|v| v.as_str()) {
209                    old_messages_text.push_str(&format!("{}: {}\n", role, content));
210                }
211            }
212        }
213
214        // Generate summary using LLM
215        let summary_msg = ChatMessage::user(format!(
216            "Summarize these conversation messages in 2-3 sentences:\n\n{}",
217            old_messages_text
218        ));
219
220        let response = llm_client
221            .chat_completion(vec![summary_msg])
222            .await
223            .map_err(|e| crate::error::RragError::rsllm_client("conversation_compression", e))?;
224
225        let summary = response.content.trim().to_string();
226
227        // Store summary
228        let summary_key = format!("{}::summary::compressed", namespace);
229        self.storage
230            .set(
231                &summary_key,
232                MemoryValue::Json(serde_json::json!({
233                    "summary": summary,
234                    "compressed_count": to_compress,
235                    "compressed_at": chrono::Utc::now().to_rfc3339(),
236                })),
237            )
238            .await?;
239
240        // Delete old messages
241        let mut deleted = 0;
242        for (key, _, _) in messages.iter().take(to_compress) {
243            if self.storage.delete(key).await? {
244                deleted += 1;
245            }
246        }
247
248        tracing::info!(
249            namespace = namespace,
250            deleted = deleted,
251            "Compressed conversation memory"
252        );
253
254        Ok(deleted)
255    }
256
257    /// Remove old items based on timestamp
258    pub async fn remove_old_items(
259        &self,
260        namespace: &str,
261        older_than: chrono::DateTime<chrono::Utc>,
262    ) -> RragResult<usize> {
263        use crate::storage::MemoryQuery;
264
265        let query = MemoryQuery::new().with_namespace(namespace.to_string());
266        let keys = self.storage.keys(&query).await?;
267
268        let mut deleted = 0;
269
270        for key in keys {
271            if let Some(value) = self.storage.get(&key).await? {
272                if let MemoryValue::Json(json) = value {
273                    let should_delete = if let Some(timestamp_str) =
274                        json.get("timestamp").and_then(|v| v.as_str())
275                    {
276                        chrono::DateTime::parse_from_rfc3339(timestamp_str)
277                            .ok()
278                            .map(|ts| ts.with_timezone(&chrono::Utc) < older_than)
279                            .unwrap_or(false)
280                    } else if let Some(created_str) =
281                        json.get("created_at").and_then(|v| v.as_str())
282                    {
283                        chrono::DateTime::parse_from_rfc3339(created_str)
284                            .ok()
285                            .map(|ts| ts.with_timezone(&chrono::Utc) < older_than)
286                            .unwrap_or(false)
287                    } else {
288                        false
289                    };
290
291                    if should_delete && self.storage.delete(&key).await? {
292                        deleted += 1;
293                    }
294                }
295            }
296        }
297
298        tracing::info!(
299            namespace = namespace,
300            deleted = deleted,
301            "Removed old items"
302        );
303
304        Ok(deleted)
305    }
306
307    /// Remove least important items
308    pub async fn remove_least_important(
309        &self,
310        namespace: &str,
311        min_importance: f64,
312        max_to_remove: usize,
313    ) -> RragResult<usize> {
314        use crate::storage::MemoryQuery;
315
316        let query = MemoryQuery::new().with_namespace(namespace.to_string());
317        let keys = self.storage.keys(&query).await?;
318
319        let mut items_with_importance: Vec<(String, f64)> = Vec::new();
320
321        for key in keys {
322            if let Some(value) = self.storage.get(&key).await? {
323                if let MemoryValue::Json(json) = value {
324                    if let Some(importance) = json.get("importance").and_then(|v| v.as_f64()) {
325                        items_with_importance.push((key, importance));
326                    }
327                }
328            }
329        }
330
331        // Sort by importance (ascending)
332        items_with_importance.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
333
334        let mut deleted = 0;
335
336        for (key, importance) in items_with_importance.iter().take(max_to_remove) {
337            if *importance < min_importance && self.storage.delete(key).await? {
338                deleted += 1;
339            }
340        }
341
342        tracing::info!(
343            namespace = namespace,
344            deleted = deleted,
345            "Removed least important items"
346        );
347
348        Ok(deleted)
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::storage::InMemoryStorage;
356
357    #[tokio::test]
358    async fn test_memory_stats_calculation() {
359        let storage = Arc::new(InMemoryStorage::new());
360        let config = CompressionConfig::default();
361        let compressor = MemoryCompressor::new(storage.clone(), config);
362
363        // Store some test data
364        let namespace = "test";
365        for i in 0..5 {
366            let key = format!("{}::item::{}", namespace, i);
367            let value = MemoryValue::Json(serde_json::json!({
368                "id": i,
369                "content": "test data",
370                "timestamp": chrono::Utc::now().to_rfc3339(),
371            }));
372            storage.set(&key, value).await.unwrap();
373        }
374
375        let stats = compressor.calculate_stats(namespace).await.unwrap();
376
377        assert_eq!(stats.item_count, 5);
378        assert!(stats.total_bytes > 0);
379        assert!(stats.avg_item_size > 0);
380    }
381
382    #[tokio::test]
383    async fn test_remove_old_items() {
384        let storage = Arc::new(InMemoryStorage::new());
385        let config = CompressionConfig::default();
386        let compressor = MemoryCompressor::new(storage.clone(), config);
387
388        let namespace = "test";
389
390        // Store old item
391        let old_time = chrono::Utc::now() - chrono::Duration::days(10);
392        storage
393            .set(
394                &format!("{}::old", namespace),
395                MemoryValue::Json(serde_json::json!({
396                    "timestamp": old_time.to_rfc3339(),
397                })),
398            )
399            .await
400            .unwrap();
401
402        // Store recent item
403        storage
404            .set(
405                &format!("{}::recent", namespace),
406                MemoryValue::Json(serde_json::json!({
407                    "timestamp": chrono::Utc::now().to_rfc3339(),
408                })),
409            )
410            .await
411            .unwrap();
412
413        // Remove items older than 5 days
414        let cutoff = chrono::Utc::now() - chrono::Duration::days(5);
415        let deleted = compressor
416            .remove_old_items(namespace, cutoff)
417            .await
418            .unwrap();
419
420        assert_eq!(deleted, 1);
421        assert_eq!(storage.count(Some(namespace)).await.unwrap(), 1);
422    }
423}