use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
use common::{Memory, MemoryType};
use crate::consolidate::cosine_sim;
const COMPRESS_MAX_IMPORTANCE: f32 = 0.6;
const COMPRESS_EPSILON: f32 = 0.88;
const COMPRESS_MIN_SAMPLES: usize = 2;
const SOFT_DEPRECATION_SECS: u64 = 30 * 86_400;
pub type MemoryEmbedPair = (Memory, Vec<f32>);
#[derive(Debug, Default)]
pub struct CompressResult {
pub memories_scanned: usize,
pub clusters_found: usize,
pub summaries_created: usize,
pub originals_deprecated: usize,
pub summary_ids: Vec<String>,
pub deprecated_ids: Vec<String>,
}
pub fn compress_memories(
memories: &[MemoryEmbedPair],
) -> (Vec<MemoryEmbedPair>, Vec<MemoryEmbedPair>, CompressResult) {
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let candidate_indices: Vec<usize> = memories
.iter()
.enumerate()
.filter(|(_, (m, _))| m.importance <= COMPRESS_MAX_IMPORTANCE && m.expires_at.is_none())
.map(|(i, _)| i)
.collect();
let mut result = CompressResult {
memories_scanned: candidate_indices.len(),
..Default::default()
};
if candidate_indices.len() < COMPRESS_MIN_SAMPLES {
return (Vec::new(), Vec::new(), result);
}
let n = candidate_indices.len();
let mut neighbors: HashMap<usize, Vec<usize>> = HashMap::new();
for p in 0..n {
for q in (p + 1)..n {
let i = candidate_indices[p];
let j = candidate_indices[q];
let sim = cosine_sim(&memories[i].1, &memories[j].1);
if sim >= COMPRESS_EPSILON {
neighbors.entry(p).or_default().push(q);
neighbors.entry(q).or_default().push(p);
}
}
}
let min_nb = COMPRESS_MIN_SAMPLES.saturating_sub(1).max(1);
let core: HashSet<usize> = (0..n)
.filter(|p| neighbors.get(p).map_or(0, |v| v.len()) >= min_nb)
.collect();
let mut visited: HashSet<usize> = HashSet::new();
let mut clusters: Vec<Vec<usize>> = Vec::new();
for &cp in &core {
if visited.contains(&cp) {
continue;
}
let mut cluster = Vec::new();
let mut stack = vec![cp];
while let Some(node) = stack.pop() {
if visited.insert(node) {
cluster.push(node);
if let Some(nbrs) = neighbors.get(&node) {
for &nb in nbrs {
if core.contains(&nb) && !visited.contains(&nb) {
stack.push(nb);
}
}
}
}
}
if cluster.len() >= COMPRESS_MIN_SAMPLES {
clusters.push(cluster);
}
}
result.clusters_found = clusters.len();
if clusters.is_empty() {
return (Vec::new(), Vec::new(), result);
}
let expires_at = now_secs + SOFT_DEPRECATION_SECS;
let mut summaries: Vec<(Memory, Vec<f32>)> = Vec::new();
let mut deprecated: Vec<(Memory, Vec<f32>)> = Vec::new();
for cluster in &clusters {
let members: Vec<&(Memory, Vec<f32>)> = cluster
.iter()
.map(|&p| &memories[candidate_indices[p]])
.collect();
let summary_content = members
.iter()
.map(|(m, _)| m.content.as_str())
.collect::<Vec<_>>()
.join(" | ");
let max_importance = members
.iter()
.map(|(m, _)| m.importance)
.fold(f32::NEG_INFINITY, f32::max);
let oldest_created_at = members
.iter()
.map(|(m, _)| m.created_at)
.min()
.unwrap_or(now_secs);
let agent_id = members[0].0.agent_id.clone();
let summary_id = format!(
"mem_compress_{:x}",
now_secs ^ (cluster[0] as u64 * 0x9e3779b97f4a7c15)
);
let dim = members[0].1.len();
let centroid: Vec<f32> = if dim > 0 {
let mut sum = vec![0.0f32; dim];
let mut valid = 0usize;
for (_, emb) in &members {
if emb.len() == dim {
for (i, v) in emb.iter().enumerate() {
sum[i] += v;
}
valid += 1;
}
}
if valid > 0 {
let norm_factor = valid as f32;
let mut centroid: Vec<f32> = sum.into_iter().map(|v| v / norm_factor).collect();
let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for v in &mut centroid {
*v /= norm;
}
}
centroid
} else {
vec![0.0f32; dim]
}
} else {
Vec::new()
};
let summary_memory = Memory {
id: summary_id.clone(),
memory_type: MemoryType::Semantic,
content: summary_content,
agent_id: agent_id.clone(),
session_id: None,
importance: max_importance,
tags: vec!["compressed".to_string()],
metadata: Some(serde_json::json!({
"compressed_from": cluster.len(),
"compressed_at": now_secs,
})),
created_at: oldest_created_at,
last_accessed_at: now_secs,
access_count: 0,
ttl_seconds: None,
expires_at: None,
};
summaries.push((summary_memory, centroid));
result.summary_ids.push(summary_id);
for (mem, emb) in &members {
let dep = Memory {
expires_at: Some(expires_at),
..(*mem).clone()
};
result.deprecated_ids.push(dep.id.clone());
deprecated.push((dep, (*emb).clone()));
}
}
result.summaries_created = summaries.len();
result.originals_deprecated = deprecated.len();
(summaries, deprecated, result)
}
#[cfg(test)]
mod tests {
use super::*;
fn mk_mem(id: &str, content: &str, importance: f32) -> Memory {
Memory {
id: id.to_string(),
memory_type: MemoryType::Episodic,
content: content.to_string(),
agent_id: "agent1".to_string(),
session_id: None,
importance,
tags: vec![],
metadata: None,
created_at: 1_000_000,
last_accessed_at: 1_000_000,
access_count: 0,
ttl_seconds: None,
expires_at: None,
}
}
fn near_vec(base: &[f32], noise: f32) -> Vec<f32> {
let mut v: Vec<f32> = base
.iter()
.enumerate()
.map(|(i, x)| x + if i == 0 { noise } else { 0.0 })
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut v {
*x /= norm;
}
v
}
#[test]
fn test_compress_empty() {
let (summaries, deprecated, result) = compress_memories(&[]);
assert!(summaries.is_empty());
assert!(deprecated.is_empty());
assert_eq!(result.clusters_found, 0);
}
#[test]
fn test_compress_high_importance_skipped() {
let base = vec![1.0f32, 0.0, 0.0, 0.0];
let pairs = vec![
(mk_mem("a", "content a", 0.9), near_vec(&base, 0.01)),
(mk_mem("b", "content b", 0.8), near_vec(&base, 0.02)),
];
let (summaries, _, result) = compress_memories(&pairs);
assert_eq!(result.memories_scanned, 0);
assert!(summaries.is_empty());
}
#[test]
fn test_compress_two_similar_low_importance() {
let base = vec![1.0f32, 0.0, 0.0, 0.0];
let pairs = vec![
(
mk_mem("a", "The API latency is high", 0.4),
near_vec(&base, 0.01),
),
(
mk_mem("b", "API response times are slow", 0.3),
near_vec(&base, 0.02),
),
];
let (summaries, deprecated, result) = compress_memories(&pairs);
assert_eq!(result.clusters_found, 1);
assert_eq!(result.summaries_created, 1);
assert_eq!(result.originals_deprecated, 2);
assert_eq!(summaries.len(), 1);
assert_eq!(deprecated.len(), 2);
assert!(summaries[0].0.content.contains("API latency"));
assert!(summaries[0].0.content.contains("API response"));
for (m, _) in &deprecated {
assert!(m.expires_at.is_some());
}
}
#[test]
fn test_compress_orthogonal_no_cluster() {
let pairs = vec![
(
mk_mem("a", "vector search", 0.4),
vec![1.0f32, 0.0, 0.0, 0.0],
),
(
mk_mem("b", "graph traversal", 0.3),
vec![0.0f32, 1.0, 0.0, 0.0],
),
];
let (summaries, _, result) = compress_memories(&pairs);
assert_eq!(result.clusters_found, 0);
assert!(summaries.is_empty());
}
#[test]
fn test_detect_near_duplicate() {
use crate::consolidate::detect_near_duplicate;
let candidates = vec![
("mem_1".to_string(), vec![1.0f32, 0.0, 0.0]),
("mem_2".to_string(), vec![0.0f32, 1.0, 0.0]),
];
let dup = vec![0.9999f32, 0.01, 0.0];
let norm: f32 = dup.iter().map(|x| x * x).sum::<f32>().sqrt();
let dup: Vec<f32> = dup.into_iter().map(|x| x / norm).collect();
let found = detect_near_duplicate(&candidates, &dup, 0.95);
assert_eq!(found, Some("mem_1".to_string()));
let ortho = vec![0.0f32, 0.0, 1.0];
assert!(detect_near_duplicate(&candidates, &ortho, 0.95).is_none());
}
}