Skip to main content

engine/
consolidate.rs

1//! CE-6: DBSCAN-based adaptive memory consolidation with soft-deprecation.
2use common::Memory;
3use std::collections::{HashMap, HashSet};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6const DEFAULT_EPSILON: f32 = 0.92;
7const DEFAULT_MIN_SAMPLES: usize = 2;
8const DEFAULT_SOFT_DEPRECATION_DAYS: u64 = 30;
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConsolidationConfig {
12    pub enabled: bool,
13    pub epsilon: f32,
14    pub min_samples: usize,
15    pub soft_deprecation_days: u64,
16}
17impl Default for ConsolidationConfig {
18    fn default() -> Self {
19        Self {
20            enabled: true,
21            epsilon: DEFAULT_EPSILON,
22            min_samples: DEFAULT_MIN_SAMPLES,
23            soft_deprecation_days: DEFAULT_SOFT_DEPRECATION_DAYS,
24        }
25    }
26}
27
28#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
29pub struct ConsolidationLogEntry {
30    pub run_at: u64,
31    pub memories_scanned: usize,
32    pub clusters_found: usize,
33    pub memories_deprecated: usize,
34    pub anchor_ids: Vec<String>,
35    pub deprecated_ids: Vec<String>,
36}
37
38#[derive(Debug, Default)]
39pub struct ConsolidateResult {
40    pub memories_scanned: usize,
41    pub clusters_found: usize,
42    pub memories_deprecated: usize,
43    pub anchor_ids: Vec<String>,
44    pub deprecated_ids: Vec<String>,
45}
46impl ConsolidateResult {
47    pub fn to_log_entry(&self) -> ConsolidationLogEntry {
48        let run_at = SystemTime::now()
49            .duration_since(UNIX_EPOCH)
50            .unwrap_or_default()
51            .as_secs();
52        ConsolidationLogEntry {
53            run_at,
54            memories_scanned: self.memories_scanned,
55            clusters_found: self.clusters_found,
56            memories_deprecated: self.memories_deprecated,
57            anchor_ids: self.anchor_ids.clone(),
58            deprecated_ids: self.deprecated_ids.clone(),
59        }
60    }
61}
62
63pub fn run_dbscan(
64    memories: &[(Memory, Vec<f32>)],
65    config: &ConsolidationConfig,
66) -> (ConsolidateResult, Vec<(Memory, Vec<f32>)>) {
67    let n = memories.len();
68    let now_secs = SystemTime::now()
69        .duration_since(UNIX_EPOCH)
70        .unwrap_or_default()
71        .as_secs();
72    let active: Vec<usize> = (0..n)
73        .filter(|&i| memories[i].0.expires_at.is_none())
74        .collect();
75    let mut result = ConsolidateResult {
76        memories_scanned: active.len(),
77        ..Default::default()
78    };
79    if active.len() < config.min_samples {
80        return (result, Vec::new());
81    }
82
83    let mut neighbors: HashMap<usize, Vec<usize>> = HashMap::new();
84    for p in 0..active.len() {
85        for q in (p + 1)..active.len() {
86            let sim = cosine_sim(&memories[active[p]].1, &memories[active[q]].1);
87            if sim >= config.epsilon {
88                neighbors.entry(p).or_default().push(q);
89                neighbors.entry(q).or_default().push(p);
90            }
91        }
92    }
93
94    let min_nb = config.min_samples.saturating_sub(1).max(1);
95    let core: HashSet<usize> = (0..active.len())
96        .filter(|p| neighbors.get(p).map_or(0, |v| v.len()) >= min_nb)
97        .collect();
98
99    let mut visited: HashSet<usize> = HashSet::new();
100    let mut clusters: Vec<Vec<usize>> = Vec::new();
101    for &cp in &core {
102        if visited.contains(&cp) {
103            continue;
104        }
105        let mut cluster = Vec::new();
106        let mut stack = vec![cp];
107        while let Some(node) = stack.pop() {
108            if visited.insert(node) {
109                cluster.push(node);
110                if let Some(nbrs) = neighbors.get(&node) {
111                    for &nb in nbrs {
112                        if core.contains(&nb) && !visited.contains(&nb) {
113                            stack.push(nb);
114                        }
115                    }
116                }
117            }
118        }
119        if cluster.len() >= config.min_samples {
120            clusters.push(cluster);
121        }
122    }
123
124    result.clusters_found = clusters.len();
125    if clusters.is_empty() {
126        return (result, Vec::new());
127    }
128
129    let expires_at = now_secs + config.soft_deprecation_days * 86400;
130    let mut updated: Vec<(Memory, Vec<f32>)> = Vec::new();
131    for cluster in &clusters {
132        let anchor_p = cluster
133            .iter()
134            .copied()
135            .max_by(|&a, &b| {
136                let ma = &memories[active[a]].0;
137                let mb = &memories[active[b]].0;
138                ma.importance
139                    .partial_cmp(&mb.importance)
140                    .unwrap_or(std::cmp::Ordering::Equal)
141                    .then_with(|| ma.created_at.cmp(&mb.created_at))
142            })
143            .unwrap();
144        result
145            .anchor_ids
146            .push(memories[active[anchor_p]].0.id.clone());
147        for &p in cluster {
148            if p == anchor_p {
149                continue;
150            }
151            let (mem, emb) = &memories[active[p]];
152            let deprecated = Memory {
153                expires_at: Some(expires_at),
154                ..mem.clone()
155            };
156            result.deprecated_ids.push(deprecated.id.clone());
157            updated.push((deprecated, emb.clone()));
158        }
159    }
160    result.memories_deprecated = result.deprecated_ids.len();
161    (result, updated)
162}
163
164pub(crate) fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
165    if a.is_empty() || b.is_empty() || a.len() != b.len() {
166        return 0.0;
167    }
168    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
169    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
170    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
171    if na == 0.0 || nb == 0.0 {
172        return 0.0;
173    }
174    (dot / (na * nb)).clamp(-1.0, 1.0)
175}
176
177/// CE-10a: Check whether `new_embedding` is a near-duplicate of any memory in
178/// `candidates` (a slice of `(memory_id, embedding)` pairs).
179///
180/// Returns the ID of the **first** candidate whose cosine similarity with
181/// `new_embedding` is ≥ `threshold`.  Uses the same `cosine_sim` as DBSCAN so
182/// results are consistent.
183///
184/// Complexity: O(N × D) where N = candidates.len() and D = embedding dimension.
185/// For large namespaces, callers should use ANN search instead (see the store
186/// handler which uses `engine.search(top_k=1)` for O(log N) behaviour).
187pub fn detect_near_duplicate(
188    candidates: &[(String, Vec<f32>)],
189    new_embedding: &[f32],
190    threshold: f32,
191) -> Option<String> {
192    for (id, embedding) in candidates {
193        if embedding.len() != new_embedding.len() {
194            continue;
195        }
196        let sim = cosine_sim(new_embedding, embedding);
197        if sim >= threshold {
198            return Some(id.clone());
199        }
200    }
201    None
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use common::MemoryType;
208    fn mk(id: &str, imp: f32) -> Memory {
209        Memory {
210            id: id.to_string(),
211            memory_type: MemoryType::Episodic,
212            content: id.to_string(),
213            agent_id: "a".to_string(),
214            session_id: None,
215            importance: imp,
216            tags: vec![],
217            metadata: None,
218            created_at: 1000000,
219            last_accessed_at: 1000000,
220            access_count: 0,
221            ttl_seconds: None,
222            expires_at: None,
223        }
224    }
225    fn unit(dim: usize, i: usize) -> Vec<f32> {
226        let mut v = vec![0.0f32; dim];
227        v[i] = 1.0;
228        v
229    }
230    fn near(base: &[f32], n: f32) -> Vec<f32> {
231        let mut v: Vec<f32> = base
232            .iter()
233            .enumerate()
234            .map(|(i, x)| x + if i == 0 { n } else { 0.0 })
235            .collect();
236        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
237        for x in &mut v {
238            *x /= norm;
239        }
240        v
241    }
242
243    #[test]
244    fn identical_sim() {
245        let v = vec![1.0f32, 0.0, 0.0];
246        assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-5);
247    }
248    #[test]
249    fn orthogonal_sim() {
250        assert!(cosine_sim(&unit(3, 0), &unit(3, 1)).abs() < 1e-5);
251    }
252    #[test]
253    fn no_cluster_single() {
254        let (r, u) = run_dbscan(
255            &[(mk("a", 0.5), unit(4, 0))],
256            &ConsolidationConfig::default(),
257        );
258        assert_eq!(r.clusters_found, 0);
259        assert!(u.is_empty());
260    }
261    #[test]
262    fn two_similar_cluster() {
263        let b = vec![1.0f32, 0.0, 0.0, 0.0];
264        let (r, u) = run_dbscan(
265            &[
266                (mk("a", 0.8), near(&b, 0.01)),
267                (mk("b", 0.3), near(&b, 0.02)),
268            ],
269            &ConsolidationConfig::default(),
270        );
271        assert_eq!(r.clusters_found, 1);
272        assert_eq!(r.anchor_ids, vec!["a"]);
273        assert_eq!(r.deprecated_ids, vec!["b"]);
274        assert!(u[0].0.expires_at.is_some());
275    }
276    #[test]
277    fn orthogonal_no_cluster() {
278        let (r, _) = run_dbscan(
279            &[(mk("a", 0.5), unit(4, 0)), (mk("b", 0.5), unit(4, 1))],
280            &ConsolidationConfig::default(),
281        );
282        assert_eq!(r.clusters_found, 0);
283    }
284    #[test]
285    fn deprecated_excluded() {
286        let b = vec![1.0f32, 0.0, 0.0, 0.0];
287        let mut m = mk("b", 0.3);
288        m.expires_at = Some(
289            SystemTime::now()
290                .duration_since(UNIX_EPOCH)
291                .unwrap()
292                .as_secs()
293                + 2 * 86400,
294        );
295        let (r, _) = run_dbscan(
296            &[(mk("a", 0.8), near(&b, 0.01)), (m, near(&b, 0.02))],
297            &ConsolidationConfig::default(),
298        );
299        assert_eq!(r.clusters_found, 0);
300    }
301    #[test]
302    fn idempotent() {
303        let b = vec![1.0f32, 0.0, 0.0, 0.0];
304        let ea = near(&b, 0.01);
305        let eb = near(&b, 0.02);
306        let (_, u) = run_dbscan(
307            &[(mk("a", 0.8), ea.clone()), (mk("b", 0.3), eb.clone())],
308            &ConsolidationConfig::default(),
309        );
310        let dep = u[0].0.clone();
311        let (r2, _) = run_dbscan(
312            &[(mk("a", 0.8), ea), (dep, eb)],
313            &ConsolidationConfig::default(),
314        );
315        assert_eq!(r2.clusters_found, 0);
316    }
317}