1use common::Memory;
3use std::collections::{HashMap, HashSet};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6const DEFAULT_EPSILON: f32 = 0.92;
7const DEFAULT_MIN_SAMPLES: usize = 2;
8const DEFAULT_SOFT_DEPRECATION_DAYS: u64 = 30;
9
10#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
11pub struct ConsolidationConfig {
12 pub enabled: bool,
13 pub epsilon: f32,
14 pub min_samples: usize,
15 pub soft_deprecation_days: u64,
16}
17impl Default for ConsolidationConfig {
18 fn default() -> Self {
19 Self {
20 enabled: true,
21 epsilon: DEFAULT_EPSILON,
22 min_samples: DEFAULT_MIN_SAMPLES,
23 soft_deprecation_days: DEFAULT_SOFT_DEPRECATION_DAYS,
24 }
25 }
26}
27
28#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
29pub struct ConsolidationLogEntry {
30 pub run_at: u64,
31 pub memories_scanned: usize,
32 pub clusters_found: usize,
33 pub memories_deprecated: usize,
34 pub anchor_ids: Vec<String>,
35 pub deprecated_ids: Vec<String>,
36}
37
38#[derive(Debug, Default)]
39pub struct ConsolidateResult {
40 pub memories_scanned: usize,
41 pub clusters_found: usize,
42 pub memories_deprecated: usize,
43 pub anchor_ids: Vec<String>,
44 pub deprecated_ids: Vec<String>,
45}
46impl ConsolidateResult {
47 pub fn to_log_entry(&self) -> ConsolidationLogEntry {
48 let run_at = SystemTime::now()
49 .duration_since(UNIX_EPOCH)
50 .unwrap_or_default()
51 .as_secs();
52 ConsolidationLogEntry {
53 run_at,
54 memories_scanned: self.memories_scanned,
55 clusters_found: self.clusters_found,
56 memories_deprecated: self.memories_deprecated,
57 anchor_ids: self.anchor_ids.clone(),
58 deprecated_ids: self.deprecated_ids.clone(),
59 }
60 }
61}
62
63pub fn run_dbscan(
64 memories: &[(Memory, Vec<f32>)],
65 config: &ConsolidationConfig,
66) -> (ConsolidateResult, Vec<(Memory, Vec<f32>)>) {
67 let n = memories.len();
68 let now_secs = SystemTime::now()
69 .duration_since(UNIX_EPOCH)
70 .unwrap_or_default()
71 .as_secs();
72 let active: Vec<usize> = (0..n)
73 .filter(|&i| memories[i].0.expires_at.is_none())
74 .collect();
75 let mut result = ConsolidateResult {
76 memories_scanned: active.len(),
77 ..Default::default()
78 };
79 if active.len() < config.min_samples {
80 return (result, Vec::new());
81 }
82
83 let mut neighbors: HashMap<usize, Vec<usize>> = HashMap::new();
84 for p in 0..active.len() {
85 for q in (p + 1)..active.len() {
86 let sim = cosine_sim(&memories[active[p]].1, &memories[active[q]].1);
87 if sim >= config.epsilon {
88 neighbors.entry(p).or_default().push(q);
89 neighbors.entry(q).or_default().push(p);
90 }
91 }
92 }
93
94 let min_nb = config.min_samples.saturating_sub(1).max(1);
95 let core: HashSet<usize> = (0..active.len())
96 .filter(|p| neighbors.get(p).map_or(0, |v| v.len()) >= min_nb)
97 .collect();
98
99 let mut visited: HashSet<usize> = HashSet::new();
100 let mut clusters: Vec<Vec<usize>> = Vec::new();
101 for &cp in &core {
102 if visited.contains(&cp) {
103 continue;
104 }
105 let mut cluster = Vec::new();
106 let mut stack = vec![cp];
107 while let Some(node) = stack.pop() {
108 if visited.insert(node) {
109 cluster.push(node);
110 if let Some(nbrs) = neighbors.get(&node) {
111 for &nb in nbrs {
112 if core.contains(&nb) && !visited.contains(&nb) {
113 stack.push(nb);
114 }
115 }
116 }
117 }
118 }
119 if cluster.len() >= config.min_samples {
120 clusters.push(cluster);
121 }
122 }
123
124 result.clusters_found = clusters.len();
125 if clusters.is_empty() {
126 return (result, Vec::new());
127 }
128
129 let expires_at = now_secs + config.soft_deprecation_days * 86400;
130 let mut updated: Vec<(Memory, Vec<f32>)> = Vec::new();
131 for cluster in &clusters {
132 let anchor_p = cluster
133 .iter()
134 .copied()
135 .max_by(|&a, &b| {
136 let ma = &memories[active[a]].0;
137 let mb = &memories[active[b]].0;
138 ma.importance
139 .partial_cmp(&mb.importance)
140 .unwrap_or(std::cmp::Ordering::Equal)
141 .then_with(|| ma.created_at.cmp(&mb.created_at))
142 })
143 .unwrap();
144 result
145 .anchor_ids
146 .push(memories[active[anchor_p]].0.id.clone());
147 for &p in cluster {
148 if p == anchor_p {
149 continue;
150 }
151 let (mem, emb) = &memories[active[p]];
152 let deprecated = Memory {
153 expires_at: Some(expires_at),
154 ..mem.clone()
155 };
156 result.deprecated_ids.push(deprecated.id.clone());
157 updated.push((deprecated, emb.clone()));
158 }
159 }
160 result.memories_deprecated = result.deprecated_ids.len();
161 (result, updated)
162}
163
164fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
165 if a.is_empty() || b.is_empty() || a.len() != b.len() {
166 return 0.0;
167 }
168 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
169 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
170 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
171 if na == 0.0 || nb == 0.0 {
172 return 0.0;
173 }
174 (dot / (na * nb)).clamp(-1.0, 1.0)
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use common::MemoryType;
181 fn mk(id: &str, imp: f32) -> Memory {
182 Memory {
183 id: id.to_string(),
184 memory_type: MemoryType::Episodic,
185 content: id.to_string(),
186 agent_id: "a".to_string(),
187 session_id: None,
188 importance: imp,
189 tags: vec![],
190 metadata: None,
191 created_at: 1000000,
192 last_accessed_at: 1000000,
193 access_count: 0,
194 ttl_seconds: None,
195 expires_at: None,
196 }
197 }
198 fn unit(dim: usize, i: usize) -> Vec<f32> {
199 let mut v = vec![0.0f32; dim];
200 v[i] = 1.0;
201 v
202 }
203 fn near(base: &[f32], n: f32) -> Vec<f32> {
204 let mut v: Vec<f32> = base
205 .iter()
206 .enumerate()
207 .map(|(i, x)| x + if i == 0 { n } else { 0.0 })
208 .collect();
209 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
210 for x in &mut v {
211 *x /= norm;
212 }
213 v
214 }
215
216 #[test]
217 fn identical_sim() {
218 let v = vec![1.0f32, 0.0, 0.0];
219 assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-5);
220 }
221 #[test]
222 fn orthogonal_sim() {
223 assert!(cosine_sim(&unit(3, 0), &unit(3, 1)).abs() < 1e-5);
224 }
225 #[test]
226 fn no_cluster_single() {
227 let (r, u) = run_dbscan(
228 &[(mk("a", 0.5), unit(4, 0))],
229 &ConsolidationConfig::default(),
230 );
231 assert_eq!(r.clusters_found, 0);
232 assert!(u.is_empty());
233 }
234 #[test]
235 fn two_similar_cluster() {
236 let b = vec![1.0f32, 0.0, 0.0, 0.0];
237 let (r, u) = run_dbscan(
238 &[
239 (mk("a", 0.8), near(&b, 0.01)),
240 (mk("b", 0.3), near(&b, 0.02)),
241 ],
242 &ConsolidationConfig::default(),
243 );
244 assert_eq!(r.clusters_found, 1);
245 assert_eq!(r.anchor_ids, vec!["a"]);
246 assert_eq!(r.deprecated_ids, vec!["b"]);
247 assert!(u[0].0.expires_at.is_some());
248 }
249 #[test]
250 fn orthogonal_no_cluster() {
251 let (r, _) = run_dbscan(
252 &[(mk("a", 0.5), unit(4, 0)), (mk("b", 0.5), unit(4, 1))],
253 &ConsolidationConfig::default(),
254 );
255 assert_eq!(r.clusters_found, 0);
256 }
257 #[test]
258 fn deprecated_excluded() {
259 let b = vec![1.0f32, 0.0, 0.0, 0.0];
260 let mut m = mk("b", 0.3);
261 m.expires_at = Some(
262 SystemTime::now()
263 .duration_since(UNIX_EPOCH)
264 .unwrap()
265 .as_secs()
266 + 2 * 86400,
267 );
268 let (r, _) = run_dbscan(
269 &[(mk("a", 0.8), near(&b, 0.01)), (m, near(&b, 0.02))],
270 &ConsolidationConfig::default(),
271 );
272 assert_eq!(r.clusters_found, 0);
273 }
274 #[test]
275 fn idempotent() {
276 let b = vec![1.0f32, 0.0, 0.0, 0.0];
277 let ea = near(&b, 0.01);
278 let eb = near(&b, 0.02);
279 let (_, u) = run_dbscan(
280 &[(mk("a", 0.8), ea.clone()), (mk("b", 0.3), eb.clone())],
281 &ConsolidationConfig::default(),
282 );
283 let dep = u[0].0.clone();
284 let (r2, _) = run_dbscan(
285 &[(mk("a", 0.8), ea), (dep, eb)],
286 &ConsolidationConfig::default(),
287 );
288 assert_eq!(r2.clusters_found, 0);
289 }
290}