codemem_engine/consolidation/
cluster.rs1use super::union_find::UnionFind;
2use super::ConsolidationResult;
3use crate::CodememEngine;
4use codemem_core::{CodememError, GraphBackend, VectorBackend};
5use codemem_storage::vector::cosine_similarity;
6use serde_json::json;
7use std::collections::{HashMap, HashSet};
8
9impl CodememEngine {
10 pub fn consolidate_cluster(
20 &self,
21 similarity_threshold: Option<f64>,
22 ) -> Result<ConsolidationResult, CodememError> {
23 let similarity_threshold = similarity_threshold.unwrap_or(0.92);
24
25 let ids = self.storage.list_memory_ids()?;
26 let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
27 let memories = self.storage.get_memories_batch(&id_refs)?;
28
29 let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
33 for (idx, m) in memories.iter().enumerate() {
34 let key = format!(
35 "{}:{}",
36 m.namespace.as_deref().unwrap_or("default"),
37 m.memory_type
38 );
39 groups.entry(key).or_default().push(idx);
40 }
41
42 let n = memories.len();
44 let mut uf = UnionFind::new(n);
45
46 let vector = self.lock_vector()?;
50
51 let id_to_idx: HashMap<&str, usize> = memories
53 .iter()
54 .enumerate()
55 .map(|(i, m)| (m.id.as_str(), i))
56 .collect();
57
58 for member_indices in groups.values() {
59 if member_indices.len() <= 1 {
60 continue;
61 }
62
63 if member_indices.len() <= 50 {
64 for i in 0..member_indices.len() {
67 for j in (i + 1)..member_indices.len() {
68 let idx_a = member_indices[i];
69 let idx_b = member_indices[j];
70
71 let id_a = &memories[idx_a].id;
72 let id_b = &memories[idx_b].id;
73
74 let sim = match (
75 self.storage.get_embedding(id_a).ok().flatten(),
76 self.storage.get_embedding(id_b).ok().flatten(),
77 ) {
78 (Some(emb_a), Some(emb_b)) => cosine_similarity(&emb_a, &emb_b),
79 _ => {
80 if memories[idx_a].content_hash == memories[idx_b].content_hash {
81 1.0
82 } else {
83 0.0
84 }
85 }
86 };
87
88 if sim >= similarity_threshold {
89 uf.union(idx_a, idx_b);
90 }
91 }
92 }
93 } else {
94 let k_neighbors = 10.min(member_indices.len());
97
98 let group_ids: HashSet<&str> = member_indices
100 .iter()
101 .map(|&idx| memories[idx].id.as_str())
102 .collect();
103
104 for &idx_a in member_indices {
105 let id_a = &memories[idx_a].id;
106 let embedding = match self.storage.get_embedding(id_a).ok().flatten() {
107 Some(e) => e,
108 None => continue,
109 };
110
111 let neighbors = vector
113 .search(&embedding, k_neighbors + 1)
114 .unwrap_or_default();
115
116 for (neighbor_id, _) in &neighbors {
117 if neighbor_id == id_a {
118 continue;
119 }
120 if !group_ids.contains(neighbor_id.as_str()) {
122 continue;
123 }
124
125 let idx_b = match id_to_idx.get(neighbor_id.as_str()) {
126 Some(&idx) => idx,
127 None => continue,
128 };
129
130 let sim = match self.storage.get_embedding(neighbor_id).ok().flatten() {
132 Some(emb_b) => cosine_similarity(&embedding, &emb_b),
133 None => {
134 if memories[idx_a].content_hash == memories[idx_b].content_hash {
135 1.0
136 } else {
137 0.0
138 }
139 }
140 };
141
142 if sim >= similarity_threshold {
143 uf.union(idx_a, idx_b);
144 }
145 }
146 }
147 }
148 }
149 drop(vector);
150
151 let clusters = uf.groups(n);
152
153 let mut merged_count = 0usize;
154 let mut kept_count = 0usize;
155 let mut ids_to_delete: Vec<String> = Vec::new();
156
157 for cluster in &clusters {
158 if cluster.len() <= 1 {
159 kept_count += 1;
160 continue;
161 }
162
163 let mut members: Vec<(usize, f64)> = cluster
164 .iter()
165 .map(|&idx| (idx, memories[idx].importance))
166 .collect();
167 members.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
168 kept_count += 1;
169
170 for &(idx, _) in members.iter().skip(1) {
171 ids_to_delete.push(memories[idx].id.clone());
172 merged_count += 1;
173 }
174 }
175
176 for batch in ids_to_delete.chunks(100) {
180 let batch_refs: Vec<&str> = batch.iter().map(|s| s.as_str()).collect();
181 if let Err(e) = self.storage.delete_memories_batch_cascade(&batch_refs) {
182 tracing::warn!(
183 "Failed to batch-delete {} memories during cluster consolidation: {e}",
184 batch.len()
185 );
186 }
187
188 let mut graph = self.lock_graph()?;
190 let mut vector = self.lock_vector()?;
191 let mut bm25 = self.lock_bm25()?;
192 for id in batch {
193 if let Err(e) = vector.remove(id) {
194 tracing::warn!(
195 "Failed to remove {id} from vector index during cluster consolidation: {e}"
196 );
197 }
198 if let Err(e) = graph.remove_node(id) {
199 tracing::warn!(
200 "Failed to remove {id} from graph during cluster consolidation: {e}"
201 );
202 }
203 bm25.remove_document(id);
204 }
205 drop(bm25);
206 drop(vector);
207 drop(graph);
208 }
209
210 if merged_count > 0 {
212 let mut vector = self.lock_vector()?;
213 self.rebuild_vector_index_internal(&mut vector);
214 drop(vector);
215 }
216
217 self.save_index();
218
219 if let Err(e) = self
220 .storage
221 .insert_consolidation_log("cluster", merged_count)
222 {
223 tracing::warn!("Failed to log cluster consolidation: {e}");
224 }
225
226 Ok(ConsolidationResult {
227 cycle: "cluster".to_string(),
228 affected: merged_count,
229 details: json!({
230 "merged": merged_count,
231 "kept": kept_count,
232 "similarity_threshold": similarity_threshold,
233 "algorithm": "semantic_cosine",
234 }),
235 })
236 }
237}