use common::Memory;
use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
const DEFAULT_EPSILON: f32 = 0.92;
const DEFAULT_MIN_SAMPLES: usize = 2;
const DEFAULT_SOFT_DEPRECATION_DAYS: u64 = 30;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ConsolidationConfig {
pub enabled: bool,
pub epsilon: f32,
pub min_samples: usize,
pub soft_deprecation_days: u64,
}
impl Default for ConsolidationConfig {
fn default() -> Self {
Self {
enabled: true,
epsilon: DEFAULT_EPSILON,
min_samples: DEFAULT_MIN_SAMPLES,
soft_deprecation_days: DEFAULT_SOFT_DEPRECATION_DAYS,
}
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ConsolidationLogEntry {
pub run_at: u64,
pub memories_scanned: usize,
pub clusters_found: usize,
pub memories_deprecated: usize,
pub anchor_ids: Vec<String>,
pub deprecated_ids: Vec<String>,
}
#[derive(Debug, Default)]
pub struct ConsolidateResult {
pub memories_scanned: usize,
pub clusters_found: usize,
pub memories_deprecated: usize,
pub anchor_ids: Vec<String>,
pub deprecated_ids: Vec<String>,
}
impl ConsolidateResult {
pub fn to_log_entry(&self) -> ConsolidationLogEntry {
let run_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
ConsolidationLogEntry {
run_at,
memories_scanned: self.memories_scanned,
clusters_found: self.clusters_found,
memories_deprecated: self.memories_deprecated,
anchor_ids: self.anchor_ids.clone(),
deprecated_ids: self.deprecated_ids.clone(),
}
}
}
pub fn run_dbscan(
memories: &[(Memory, Vec<f32>)],
config: &ConsolidationConfig,
) -> (ConsolidateResult, Vec<(Memory, Vec<f32>)>) {
let n = memories.len();
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let active: Vec<usize> = (0..n)
.filter(|&i| memories[i].0.expires_at.is_none())
.collect();
let mut result = ConsolidateResult {
memories_scanned: active.len(),
..Default::default()
};
if active.len() < config.min_samples {
return (result, Vec::new());
}
let mut neighbors: HashMap<usize, Vec<usize>> = HashMap::new();
for p in 0..active.len() {
for q in (p + 1)..active.len() {
let sim = cosine_sim(&memories[active[p]].1, &memories[active[q]].1);
if sim >= config.epsilon {
neighbors.entry(p).or_default().push(q);
neighbors.entry(q).or_default().push(p);
}
}
}
let min_nb = config.min_samples.saturating_sub(1).max(1);
let core: HashSet<usize> = (0..active.len())
.filter(|p| neighbors.get(p).map_or(0, |v| v.len()) >= min_nb)
.collect();
let mut visited: HashSet<usize> = HashSet::new();
let mut clusters: Vec<Vec<usize>> = Vec::new();
for &cp in &core {
if visited.contains(&cp) {
continue;
}
let mut cluster = Vec::new();
let mut stack = vec![cp];
while let Some(node) = stack.pop() {
if visited.insert(node) {
cluster.push(node);
if let Some(nbrs) = neighbors.get(&node) {
for &nb in nbrs {
if core.contains(&nb) && !visited.contains(&nb) {
stack.push(nb);
}
}
}
}
}
if cluster.len() >= config.min_samples {
clusters.push(cluster);
}
}
result.clusters_found = clusters.len();
if clusters.is_empty() {
return (result, Vec::new());
}
let expires_at = now_secs + config.soft_deprecation_days * 86400;
let mut updated: Vec<(Memory, Vec<f32>)> = Vec::new();
for cluster in &clusters {
let anchor_p = cluster
.iter()
.copied()
.max_by(|&a, &b| {
let ma = &memories[active[a]].0;
let mb = &memories[active[b]].0;
ma.importance
.partial_cmp(&mb.importance)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| ma.created_at.cmp(&mb.created_at))
})
.unwrap();
result
.anchor_ids
.push(memories[active[anchor_p]].0.id.clone());
for &p in cluster {
if p == anchor_p {
continue;
}
let (mem, emb) = &memories[active[p]];
let deprecated = Memory {
expires_at: Some(expires_at),
..mem.clone()
};
result.deprecated_ids.push(deprecated.id.clone());
updated.push((deprecated, emb.clone()));
}
}
result.memories_deprecated = result.deprecated_ids.len();
(result, updated)
}
pub(crate) fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || b.is_empty() || a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 {
return 0.0;
}
(dot / (na * nb)).clamp(-1.0, 1.0)
}
pub fn detect_near_duplicate(
candidates: &[(String, Vec<f32>)],
new_embedding: &[f32],
threshold: f32,
) -> Option<String> {
for (id, embedding) in candidates {
if embedding.len() != new_embedding.len() {
continue;
}
let sim = cosine_sim(new_embedding, embedding);
if sim >= threshold {
return Some(id.clone());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use common::MemoryType;
fn mk(id: &str, imp: f32) -> Memory {
Memory {
id: id.to_string(),
memory_type: MemoryType::Episodic,
content: id.to_string(),
agent_id: "a".to_string(),
session_id: None,
importance: imp,
tags: vec![],
metadata: None,
created_at: 1000000,
last_accessed_at: 1000000,
access_count: 0,
ttl_seconds: None,
expires_at: None,
}
}
fn unit(dim: usize, i: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[i] = 1.0;
v
}
fn near(base: &[f32], n: f32) -> Vec<f32> {
let mut v: Vec<f32> = base
.iter()
.enumerate()
.map(|(i, x)| x + if i == 0 { n } else { 0.0 })
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut v {
*x /= norm;
}
v
}
#[test]
fn identical_sim() {
let v = vec![1.0f32, 0.0, 0.0];
assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-5);
}
#[test]
fn orthogonal_sim() {
assert!(cosine_sim(&unit(3, 0), &unit(3, 1)).abs() < 1e-5);
}
#[test]
fn no_cluster_single() {
let (r, u) = run_dbscan(
&[(mk("a", 0.5), unit(4, 0))],
&ConsolidationConfig::default(),
);
assert_eq!(r.clusters_found, 0);
assert!(u.is_empty());
}
#[test]
fn two_similar_cluster() {
let b = vec![1.0f32, 0.0, 0.0, 0.0];
let (r, u) = run_dbscan(
&[
(mk("a", 0.8), near(&b, 0.01)),
(mk("b", 0.3), near(&b, 0.02)),
],
&ConsolidationConfig::default(),
);
assert_eq!(r.clusters_found, 1);
assert_eq!(r.anchor_ids, vec!["a"]);
assert_eq!(r.deprecated_ids, vec!["b"]);
assert!(u[0].0.expires_at.is_some());
}
#[test]
fn orthogonal_no_cluster() {
let (r, _) = run_dbscan(
&[(mk("a", 0.5), unit(4, 0)), (mk("b", 0.5), unit(4, 1))],
&ConsolidationConfig::default(),
);
assert_eq!(r.clusters_found, 0);
}
#[test]
fn deprecated_excluded() {
let b = vec![1.0f32, 0.0, 0.0, 0.0];
let mut m = mk("b", 0.3);
m.expires_at = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 2 * 86400,
);
let (r, _) = run_dbscan(
&[(mk("a", 0.8), near(&b, 0.01)), (m, near(&b, 0.02))],
&ConsolidationConfig::default(),
);
assert_eq!(r.clusters_found, 0);
}
#[test]
fn idempotent() {
let b = vec![1.0f32, 0.0, 0.0, 0.0];
let ea = near(&b, 0.01);
let eb = near(&b, 0.02);
let (_, u) = run_dbscan(
&[(mk("a", 0.8), ea.clone()), (mk("b", 0.3), eb.clone())],
&ConsolidationConfig::default(),
);
let dep = u[0].0.clone();
let (r2, _) = run_dbscan(
&[(mk("a", 0.8), ea), (dep, eb)],
&ConsolidationConfig::default(),
);
assert_eq!(r2.clusters_found, 0);
}
}