Skip to main content

engine/
consolidate.rs

1//! CE-6: DBSCAN-based adaptive memory consolidation with soft-deprecation.
2use common::Memory;
3use std::collections::{HashMap, HashSet};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6const DEFAULT_EPSILON: f32 = 0.92;
7const DEFAULT_MIN_SAMPLES: usize = 2;
8const DEFAULT_SOFT_DEPRECATION_DAYS: u64 = 30;
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConsolidationConfig {
12    pub enabled: bool,
13    pub epsilon: f32,
14    pub min_samples: usize,
15    pub soft_deprecation_days: u64,
16}
17impl Default for ConsolidationConfig {
18    fn default() -> Self {
19        Self {
20            enabled: true,
21            epsilon: DEFAULT_EPSILON,
22            min_samples: DEFAULT_MIN_SAMPLES,
23            soft_deprecation_days: DEFAULT_SOFT_DEPRECATION_DAYS,
24        }
25    }
26}
27
28#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
29pub struct ConsolidationLogEntry {
30    pub run_at: u64,
31    pub memories_scanned: usize,
32    pub clusters_found: usize,
33    pub memories_deprecated: usize,
34    pub anchor_ids: Vec<String>,
35    pub deprecated_ids: Vec<String>,
36}
37
38#[derive(Debug, Default)]
39pub struct ConsolidateResult {
40    pub memories_scanned: usize,
41    pub clusters_found: usize,
42    pub memories_deprecated: usize,
43    pub anchor_ids: Vec<String>,
44    pub deprecated_ids: Vec<String>,
45}
46impl ConsolidateResult {
47    pub fn to_log_entry(&self) -> ConsolidationLogEntry {
48        let run_at = SystemTime::now()
49            .duration_since(UNIX_EPOCH)
50            .unwrap_or_default()
51            .as_secs();
52        ConsolidationLogEntry {
53            run_at,
54            memories_scanned: self.memories_scanned,
55            clusters_found: self.clusters_found,
56            memories_deprecated: self.memories_deprecated,
57            anchor_ids: self.anchor_ids.clone(),
58            deprecated_ids: self.deprecated_ids.clone(),
59        }
60    }
61}
62
63pub fn run_dbscan(
64    memories: &[(Memory, Vec<f32>)],
65    config: &ConsolidationConfig,
66) -> (ConsolidateResult, Vec<(Memory, Vec<f32>)>) {
67    let n = memories.len();
68    let now_secs = SystemTime::now()
69        .duration_since(UNIX_EPOCH)
70        .unwrap_or_default()
71        .as_secs();
72    let active: Vec<usize> = (0..n)
73        .filter(|&i| memories[i].0.expires_at.is_none())
74        .collect();
75    let mut result = ConsolidateResult {
76        memories_scanned: active.len(),
77        ..Default::default()
78    };
79    if active.len() < config.min_samples {
80        return (result, Vec::new());
81    }
82
83    let mut neighbors: HashMap<usize, Vec<usize>> = HashMap::new();
84    for p in 0..active.len() {
85        for q in (p + 1)..active.len() {
86            let sim = cosine_sim(&memories[active[p]].1, &memories[active[q]].1);
87            if sim >= config.epsilon {
88                neighbors.entry(p).or_default().push(q);
89                neighbors.entry(q).or_default().push(p);
90            }
91        }
92    }
93
94    let min_nb = config.min_samples.saturating_sub(1).max(1);
95    let core: HashSet<usize> = (0..active.len())
96        .filter(|p| neighbors.get(p).map_or(0, |v| v.len()) >= min_nb)
97        .collect();
98
99    let mut visited: HashSet<usize> = HashSet::new();
100    let mut clusters: Vec<Vec<usize>> = Vec::new();
101    for &cp in &core {
102        if visited.contains(&cp) {
103            continue;
104        }
105        let mut cluster = Vec::new();
106        let mut stack = vec![cp];
107        while let Some(node) = stack.pop() {
108            if visited.insert(node) {
109                cluster.push(node);
110                if let Some(nbrs) = neighbors.get(&node) {
111                    for &nb in nbrs {
112                        if core.contains(&nb) && !visited.contains(&nb) {
113                            stack.push(nb);
114                        }
115                    }
116                }
117            }
118        }
119        if cluster.len() >= config.min_samples {
120            clusters.push(cluster);
121        }
122    }
123
124    result.clusters_found = clusters.len();
125    if clusters.is_empty() {
126        return (result, Vec::new());
127    }
128
129    let expires_at = now_secs + config.soft_deprecation_days * 86400;
130    let mut updated: Vec<(Memory, Vec<f32>)> = Vec::new();
131    for cluster in &clusters {
132        let anchor_p = cluster
133            .iter()
134            .copied()
135            .max_by(|&a, &b| {
136                let ma = &memories[active[a]].0;
137                let mb = &memories[active[b]].0;
138                ma.importance
139                    .partial_cmp(&mb.importance)
140                    .unwrap_or(std::cmp::Ordering::Equal)
141                    .then_with(|| ma.created_at.cmp(&mb.created_at))
142            })
143            .unwrap();
144        result
145            .anchor_ids
146            .push(memories[active[anchor_p]].0.id.clone());
147        for &p in cluster {
148            if p == anchor_p {
149                continue;
150            }
151            let (mem, emb) = &memories[active[p]];
152            let deprecated = Memory {
153                expires_at: Some(expires_at),
154                ..mem.clone()
155            };
156            result.deprecated_ids.push(deprecated.id.clone());
157            updated.push((deprecated, emb.clone()));
158        }
159    }
160    result.memories_deprecated = result.deprecated_ids.len();
161    (result, updated)
162}
163
164fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
165    if a.is_empty() || b.is_empty() || a.len() != b.len() {
166        return 0.0;
167    }
168    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
169    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
170    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
171    if na == 0.0 || nb == 0.0 {
172        return 0.0;
173    }
174    (dot / (na * nb)).clamp(-1.0, 1.0)
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use common::MemoryType;
181    fn mk(id: &str, imp: f32) -> Memory {
182        Memory {
183            id: id.to_string(),
184            memory_type: MemoryType::Episodic,
185            content: id.to_string(),
186            agent_id: "a".to_string(),
187            session_id: None,
188            importance: imp,
189            tags: vec![],
190            metadata: None,
191            created_at: 1000000,
192            last_accessed_at: 1000000,
193            access_count: 0,
194            ttl_seconds: None,
195            expires_at: None,
196        }
197    }
198    fn unit(dim: usize, i: usize) -> Vec<f32> {
199        let mut v = vec![0.0f32; dim];
200        v[i] = 1.0;
201        v
202    }
203    fn near(base: &[f32], n: f32) -> Vec<f32> {
204        let mut v: Vec<f32> = base
205            .iter()
206            .enumerate()
207            .map(|(i, x)| x + if i == 0 { n } else { 0.0 })
208            .collect();
209        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
210        for x in &mut v {
211            *x /= norm;
212        }
213        v
214    }
215
216    #[test]
217    fn identical_sim() {
218        let v = vec![1.0f32, 0.0, 0.0];
219        assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-5);
220    }
221    #[test]
222    fn orthogonal_sim() {
223        assert!(cosine_sim(&unit(3, 0), &unit(3, 1)).abs() < 1e-5);
224    }
225    #[test]
226    fn no_cluster_single() {
227        let (r, u) = run_dbscan(
228            &[(mk("a", 0.5), unit(4, 0))],
229            &ConsolidationConfig::default(),
230        );
231        assert_eq!(r.clusters_found, 0);
232        assert!(u.is_empty());
233    }
234    #[test]
235    fn two_similar_cluster() {
236        let b = vec![1.0f32, 0.0, 0.0, 0.0];
237        let (r, u) = run_dbscan(
238            &[
239                (mk("a", 0.8), near(&b, 0.01)),
240                (mk("b", 0.3), near(&b, 0.02)),
241            ],
242            &ConsolidationConfig::default(),
243        );
244        assert_eq!(r.clusters_found, 1);
245        assert_eq!(r.anchor_ids, vec!["a"]);
246        assert_eq!(r.deprecated_ids, vec!["b"]);
247        assert!(u[0].0.expires_at.is_some());
248    }
249    #[test]
250    fn orthogonal_no_cluster() {
251        let (r, _) = run_dbscan(
252            &[(mk("a", 0.5), unit(4, 0)), (mk("b", 0.5), unit(4, 1))],
253            &ConsolidationConfig::default(),
254        );
255        assert_eq!(r.clusters_found, 0);
256    }
257    #[test]
258    fn deprecated_excluded() {
259        let b = vec![1.0f32, 0.0, 0.0, 0.0];
260        let mut m = mk("b", 0.3);
261        m.expires_at = Some(
262            SystemTime::now()
263                .duration_since(UNIX_EPOCH)
264                .unwrap()
265                .as_secs()
266                + 2 * 86400,
267        );
268        let (r, _) = run_dbscan(
269            &[(mk("a", 0.8), near(&b, 0.01)), (m, near(&b, 0.02))],
270            &ConsolidationConfig::default(),
271        );
272        assert_eq!(r.clusters_found, 0);
273    }
274    #[test]
275    fn idempotent() {
276        let b = vec![1.0f32, 0.0, 0.0, 0.0];
277        let ea = near(&b, 0.01);
278        let eb = near(&b, 0.02);
279        let (_, u) = run_dbscan(
280            &[(mk("a", 0.8), ea.clone()), (mk("b", 0.3), eb.clone())],
281            &ConsolidationConfig::default(),
282        );
283        let dep = u[0].0.clone();
284        let (r2, _) = run_dbscan(
285            &[(mk("a", 0.8), ea), (dep, eb)],
286            &ConsolidationConfig::default(),
287        );
288        assert_eq!(r2.clusters_found, 0);
289    }
290}