1use std::collections::{HashMap, HashSet};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use common::{Memory, MemoryType};
15
16use crate::consolidate::cosine_sim;
17
18const COMPRESS_MAX_IMPORTANCE: f32 = 0.6;
21
22const COMPRESS_EPSILON: f32 = 0.88;
24
25const COMPRESS_MIN_SAMPLES: usize = 2;
27
28const SOFT_DEPRECATION_SECS: u64 = 30 * 86_400;
30
31pub type MemoryEmbedPair = (Memory, Vec<f32>);
33
34#[derive(Debug, Default)]
36pub struct CompressResult {
37 pub memories_scanned: usize,
39 pub clusters_found: usize,
41 pub summaries_created: usize,
43 pub originals_deprecated: usize,
45 pub summary_ids: Vec<String>,
47 pub deprecated_ids: Vec<String>,
49}
50
51pub fn compress_memories(
60 memories: &[MemoryEmbedPair],
61) -> (Vec<MemoryEmbedPair>, Vec<MemoryEmbedPair>, CompressResult) {
62 let now_secs = SystemTime::now()
63 .duration_since(UNIX_EPOCH)
64 .unwrap_or_default()
65 .as_secs();
66
67 let candidate_indices: Vec<usize> = memories
69 .iter()
70 .enumerate()
71 .filter(|(_, (m, _))| m.importance <= COMPRESS_MAX_IMPORTANCE && m.expires_at.is_none())
72 .map(|(i, _)| i)
73 .collect();
74
75 let mut result = CompressResult {
76 memories_scanned: candidate_indices.len(),
77 ..Default::default()
78 };
79
80 if candidate_indices.len() < COMPRESS_MIN_SAMPLES {
81 return (Vec::new(), Vec::new(), result);
82 }
83
84 let n = candidate_indices.len();
87
88 let mut neighbors: HashMap<usize, Vec<usize>> = HashMap::new();
90 for p in 0..n {
91 for q in (p + 1)..n {
92 let i = candidate_indices[p];
93 let j = candidate_indices[q];
94 let sim = cosine_sim(&memories[i].1, &memories[j].1);
95 if sim >= COMPRESS_EPSILON {
96 neighbors.entry(p).or_default().push(q);
97 neighbors.entry(q).or_default().push(p);
98 }
99 }
100 }
101
102 let min_nb = COMPRESS_MIN_SAMPLES.saturating_sub(1).max(1);
104 let core: HashSet<usize> = (0..n)
105 .filter(|p| neighbors.get(p).map_or(0, |v| v.len()) >= min_nb)
106 .collect();
107
108 let mut visited: HashSet<usize> = HashSet::new();
110 let mut clusters: Vec<Vec<usize>> = Vec::new();
111
112 for &cp in &core {
113 if visited.contains(&cp) {
114 continue;
115 }
116 let mut cluster = Vec::new();
117 let mut stack = vec![cp];
118 while let Some(node) = stack.pop() {
119 if visited.insert(node) {
120 cluster.push(node);
121 if let Some(nbrs) = neighbors.get(&node) {
122 for &nb in nbrs {
123 if core.contains(&nb) && !visited.contains(&nb) {
124 stack.push(nb);
125 }
126 }
127 }
128 }
129 }
130 if cluster.len() >= COMPRESS_MIN_SAMPLES {
131 clusters.push(cluster);
132 }
133 }
134
135 result.clusters_found = clusters.len();
136 if clusters.is_empty() {
137 return (Vec::new(), Vec::new(), result);
138 }
139
140 let expires_at = now_secs + SOFT_DEPRECATION_SECS;
143 let mut summaries: Vec<(Memory, Vec<f32>)> = Vec::new();
144 let mut deprecated: Vec<(Memory, Vec<f32>)> = Vec::new();
145
146 for cluster in &clusters {
147 let members: Vec<&(Memory, Vec<f32>)> = cluster
149 .iter()
150 .map(|&p| &memories[candidate_indices[p]])
151 .collect();
152
153 let summary_content = members
155 .iter()
156 .map(|(m, _)| m.content.as_str())
157 .collect::<Vec<_>>()
158 .join(" | ");
159
160 let max_importance = members
162 .iter()
163 .map(|(m, _)| m.importance)
164 .fold(f32::NEG_INFINITY, f32::max);
165 let oldest_created_at = members
166 .iter()
167 .map(|(m, _)| m.created_at)
168 .min()
169 .unwrap_or(now_secs);
170
171 let agent_id = members[0].0.agent_id.clone();
172 let summary_id = format!(
173 "mem_compress_{:x}",
174 now_secs ^ (cluster[0] as u64 * 0x9e3779b97f4a7c15)
175 );
176
177 let dim = members[0].1.len();
179 let centroid: Vec<f32> = if dim > 0 {
180 let mut sum = vec![0.0f32; dim];
181 let mut valid = 0usize;
182 for (_, emb) in &members {
183 if emb.len() == dim {
184 for (i, v) in emb.iter().enumerate() {
185 sum[i] += v;
186 }
187 valid += 1;
188 }
189 }
190 if valid > 0 {
191 let norm_factor = valid as f32;
192 let mut centroid: Vec<f32> = sum.into_iter().map(|v| v / norm_factor).collect();
193 let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
195 if norm > 1e-8 {
196 for v in &mut centroid {
197 *v /= norm;
198 }
199 }
200 centroid
201 } else {
202 vec![0.0f32; dim]
203 }
204 } else {
205 Vec::new()
206 };
207
208 let summary_memory = Memory {
209 id: summary_id.clone(),
210 memory_type: MemoryType::Semantic,
211 content: summary_content,
212 agent_id: agent_id.clone(),
213 session_id: None,
214 importance: max_importance,
215 tags: vec!["compressed".to_string()],
216 metadata: Some(serde_json::json!({
217 "compressed_from": cluster.len(),
218 "compressed_at": now_secs,
219 })),
220 created_at: oldest_created_at,
221 last_accessed_at: now_secs,
222 access_count: 0,
223 ttl_seconds: None,
224 expires_at: None,
225 };
226
227 summaries.push((summary_memory, centroid));
228 result.summary_ids.push(summary_id);
229
230 for (mem, emb) in &members {
232 let dep = Memory {
233 expires_at: Some(expires_at),
234 ..(*mem).clone()
235 };
236 result.deprecated_ids.push(dep.id.clone());
237 deprecated.push((dep, (*emb).clone()));
238 }
239 }
240
241 result.summaries_created = summaries.len();
242 result.originals_deprecated = deprecated.len();
243
244 (summaries, deprecated, result)
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 fn mk_mem(id: &str, content: &str, importance: f32) -> Memory {
252 Memory {
253 id: id.to_string(),
254 memory_type: MemoryType::Episodic,
255 content: content.to_string(),
256 agent_id: "agent1".to_string(),
257 session_id: None,
258 importance,
259 tags: vec![],
260 metadata: None,
261 created_at: 1_000_000,
262 last_accessed_at: 1_000_000,
263 access_count: 0,
264 ttl_seconds: None,
265 expires_at: None,
266 }
267 }
268
269 fn near_vec(base: &[f32], noise: f32) -> Vec<f32> {
270 let mut v: Vec<f32> = base
271 .iter()
272 .enumerate()
273 .map(|(i, x)| x + if i == 0 { noise } else { 0.0 })
274 .collect();
275 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
276 for x in &mut v {
277 *x /= norm;
278 }
279 v
280 }
281
282 #[test]
283 fn test_compress_empty() {
284 let (summaries, deprecated, result) = compress_memories(&[]);
285 assert!(summaries.is_empty());
286 assert!(deprecated.is_empty());
287 assert_eq!(result.clusters_found, 0);
288 }
289
290 #[test]
291 fn test_compress_high_importance_skipped() {
292 let base = vec![1.0f32, 0.0, 0.0, 0.0];
294 let pairs = vec![
295 (mk_mem("a", "content a", 0.9), near_vec(&base, 0.01)),
296 (mk_mem("b", "content b", 0.8), near_vec(&base, 0.02)),
297 ];
298 let (summaries, _, result) = compress_memories(&pairs);
299 assert_eq!(result.memories_scanned, 0);
300 assert!(summaries.is_empty());
301 }
302
303 #[test]
304 fn test_compress_two_similar_low_importance() {
305 let base = vec![1.0f32, 0.0, 0.0, 0.0];
306 let pairs = vec![
307 (
308 mk_mem("a", "The API latency is high", 0.4),
309 near_vec(&base, 0.01),
310 ),
311 (
312 mk_mem("b", "API response times are slow", 0.3),
313 near_vec(&base, 0.02),
314 ),
315 ];
316 let (summaries, deprecated, result) = compress_memories(&pairs);
317 assert_eq!(result.clusters_found, 1);
318 assert_eq!(result.summaries_created, 1);
319 assert_eq!(result.originals_deprecated, 2);
320 assert_eq!(summaries.len(), 1);
321 assert_eq!(deprecated.len(), 2);
322 assert!(summaries[0].0.content.contains("API latency"));
324 assert!(summaries[0].0.content.contains("API response"));
325 for (m, _) in &deprecated {
327 assert!(m.expires_at.is_some());
328 }
329 }
330
331 #[test]
332 fn test_compress_orthogonal_no_cluster() {
333 let pairs = vec![
334 (
335 mk_mem("a", "vector search", 0.4),
336 vec![1.0f32, 0.0, 0.0, 0.0],
337 ),
338 (
339 mk_mem("b", "graph traversal", 0.3),
340 vec![0.0f32, 1.0, 0.0, 0.0],
341 ),
342 ];
343 let (summaries, _, result) = compress_memories(&pairs);
344 assert_eq!(result.clusters_found, 0);
345 assert!(summaries.is_empty());
346 }
347
348 #[test]
349 fn test_detect_near_duplicate() {
350 use crate::consolidate::detect_near_duplicate;
351
352 let candidates = vec![
353 ("mem_1".to_string(), vec![1.0f32, 0.0, 0.0]),
354 ("mem_2".to_string(), vec![0.0f32, 1.0, 0.0]),
355 ];
356
357 let dup = vec![0.9999f32, 0.01, 0.0];
359 let norm: f32 = dup.iter().map(|x| x * x).sum::<f32>().sqrt();
360 let dup: Vec<f32> = dup.into_iter().map(|x| x / norm).collect();
361
362 let found = detect_near_duplicate(&candidates, &dup, 0.95);
363 assert_eq!(found, Some("mem_1".to_string()));
364
365 let ortho = vec![0.0f32, 0.0, 1.0];
367 assert!(detect_near_duplicate(&candidates, &ortho, 0.95).is_none());
368 }
369}