Skip to main content

codemem_engine/persistence/
compaction.rs

1use crate::CodememEngine;
2use codemem_core::{GraphBackend, GraphNode, NodeKind, RelationshipType};
3use std::collections::{HashMap, HashSet};
4
5/// Check whether a graph node has any edge linking it to a memory node
6/// (i.e. an edge whose other endpoint is not a code-structural ID).
7fn has_memory_link_edge(graph: &dyn GraphBackend, node_id: &str) -> bool {
8    graph
9        .get_edges(node_id)
10        .map(|edges| {
11            edges.iter().any(|e| {
12                let other = if e.src == node_id { &e.dst } else { &e.src };
13                !other.starts_with("sym:")
14                    && !other.starts_with("file:")
15                    && !other.starts_with("chunk:")
16                    && !other.starts_with("pkg:")
17                    && !other.starts_with("contains:")
18                    && !other.starts_with("ref:")
19            })
20        })
21        .unwrap_or(false)
22}
23
24impl CodememEngine {
25    // ── Graph Compaction ────────────────────────────────────────────────────
26
27    /// Compact chunk and symbol graph-nodes after indexing.
28    /// Returns (chunks_pruned, symbols_pruned).
29    pub fn compact_graph(&self, seen_files: &HashSet<String>) -> (usize, usize) {
30        let mut graph = match self.lock_graph() {
31            Ok(g) => g,
32            Err(_) => return (0, 0),
33        };
34
35        // Fetch all nodes once and share between both compaction passes.
36        let all_nodes = graph.get_all_nodes();
37        let chunks_pruned = self.compact_chunks(seen_files, &mut graph, &all_nodes);
38        let symbols_pruned = self.compact_symbols(seen_files, &mut graph, &all_nodes);
39
40        if chunks_pruned > 0 || symbols_pruned > 0 {
41            // compute_centrality: updates node.centrality with degree centrality.
42            // recompute_centrality: caches PageRank + betweenness for hybrid scoring.
43            // Both are needed — they populate different data used by different scoring paths.
44            graph.compute_centrality();
45            graph.recompute_centrality();
46        }
47
48        (chunks_pruned, symbols_pruned)
49    }
50
51    /// Pass 1: Score and prune low-value chunks, transferring line ranges to parent symbols.
52    ///
53    /// Scoring weights adjust on cold start: when no memories exist yet, the
54    /// `memory_link_score` weight (normally 0.3) is redistributed to the other
55    /// factors so compaction still produces meaningful rankings.
56    fn compact_chunks(
57        &self,
58        seen_files: &HashSet<String>,
59        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
60        all_nodes: &[GraphNode],
61    ) -> usize {
62        let max_chunks_per_file = self.config.chunking.max_retained_chunks_per_file;
63        let chunk_score_threshold = self.config.chunking.min_chunk_score_threshold;
64
65        // Cold-start detection: if no memories exist, memory_link_score is always 0
66        // and its weight should be redistributed to other factors.
67        let has_memories = self
68            .storage
69            .list_memory_ids()
70            .map(|ids| !ids.is_empty())
71            .unwrap_or(false);
72        let (w_centrality, w_parent, w_memory, w_size) = if has_memories {
73            (0.3, 0.2, 0.3, 0.2)
74        } else {
75            // Redistribute memory_link weight: centrality 0.4, parent 0.3, memory 0.0, size 0.3
76            (0.4, 0.3, 0.0, 0.3)
77        };
78
79        let mut chunks_by_file: HashMap<String, Vec<(String, f64)>> = HashMap::new();
80
81        let mut max_degree: f64 = 1.0;
82        let mut max_non_ws: f64 = 1.0;
83
84        let chunk_nodes: Vec<&GraphNode> = all_nodes
85            .iter()
86            .filter(|n| n.kind == NodeKind::Chunk)
87            .collect();
88
89        for node in &chunk_nodes {
90            let degree = graph
91                .get_edges(&node.id)
92                .map(|edges| edges.len() as f64)
93                .unwrap_or(0.0);
94            max_degree = max_degree.max(degree);
95
96            let non_ws = node
97                .payload
98                .get("non_ws_chars")
99                .and_then(|v| v.as_f64())
100                .unwrap_or(0.0);
101            max_non_ws = max_non_ws.max(non_ws);
102        }
103
104        for node in &chunk_nodes {
105            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
106                Some(fp) => fp.to_string(),
107                None => continue,
108            };
109            if !seen_files.contains(&file_path) {
110                continue;
111            }
112
113            let degree = graph
114                .get_edges(&node.id)
115                .map(|edges| edges.len() as f64)
116                .unwrap_or(0.0);
117            let centrality_rank = degree / max_degree;
118
119            let has_symbol_parent = if node.payload.contains_key("parent_symbol") {
120                1.0
121            } else {
122                0.0
123            };
124
125            let memory_link_score = if has_memories && has_memory_link_edge(&**graph, &node.id) {
126                1.0
127            } else {
128                0.0
129            };
130
131            let non_ws = node
132                .payload
133                .get("non_ws_chars")
134                .and_then(|v| v.as_f64())
135                .unwrap_or(0.0);
136            let non_ws_rank = non_ws / max_non_ws;
137
138            let chunk_score = centrality_rank * w_centrality
139                + has_symbol_parent * w_parent
140                + memory_link_score * w_memory
141                + non_ws_rank * w_size;
142
143            chunks_by_file
144                .entry(file_path)
145                .or_default()
146                .push((node.id.clone(), chunk_score));
147        }
148
149        let mut symbol_count_by_file: HashMap<String, usize> = HashMap::new();
150        for node in all_nodes {
151            if matches!(
152                node.kind,
153                NodeKind::Function
154                    | NodeKind::Method
155                    | NodeKind::Class
156                    | NodeKind::Interface
157                    | NodeKind::Type
158                    | NodeKind::Constant
159                    | NodeKind::Test
160            ) {
161                if let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) {
162                    *symbol_count_by_file.entry(fp.to_string()).or_default() += 1;
163                }
164            }
165        }
166
167        let mut chunks_pruned = 0usize;
168        for (file_path, mut chunks) in chunks_by_file {
169            chunks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
170
171            let sym_count = symbol_count_by_file.get(&file_path).copied().unwrap_or(0);
172            let k = max_chunks_per_file.min(chunks.len()).max(3.max(sym_count));
173
174            for (i, (chunk_id, score)) in chunks.iter().enumerate() {
175                // Prune aggressively: remove if beyond the top-k slots OR below the
176                // quality threshold. Using || ensures both caps are enforced independently
177                // (keep at most k chunks, and never keep any chunk below threshold).
178                if i >= k || *score < chunk_score_threshold {
179                    self.transfer_chunk_ranges_to_parent(graph, chunk_id);
180
181                    if let Err(e) = self.storage.delete_graph_edges_for_node(chunk_id) {
182                        tracing::warn!("Failed to delete graph edges for chunk {chunk_id}: {e}");
183                    }
184                    if let Err(e) = self.storage.delete_graph_node(chunk_id) {
185                        tracing::warn!("Failed to delete graph node for chunk {chunk_id}: {e}");
186                    }
187                    if let Err(e) = graph.remove_node(chunk_id) {
188                        tracing::warn!("Failed to remove chunk {chunk_id} from graph: {e}");
189                    }
190                    chunks_pruned += 1;
191                }
192            }
193        }
194
195        chunks_pruned
196    }
197
198    /// When pruning a chunk, transfer its line range to the parent symbol node.
199    fn transfer_chunk_ranges_to_parent(
200        &self,
201        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
202        chunk_id: &str,
203    ) {
204        if let Ok(Some(chunk_node)) = graph.get_node(chunk_id) {
205            if let Some(parent_sym) = chunk_node
206                .payload
207                .get("parent_symbol")
208                .and_then(|v| v.as_str())
209            {
210                let parent_id = format!("sym:{parent_sym}");
211                if let Ok(Some(mut parent_node)) = graph.get_node(&parent_id) {
212                    let line_start = chunk_node
213                        .payload
214                        .get("line_start")
215                        .and_then(|v| v.as_u64())
216                        .unwrap_or(0);
217                    let line_end = chunk_node
218                        .payload
219                        .get("line_end")
220                        .and_then(|v| v.as_u64())
221                        .unwrap_or(0);
222                    let ranges = parent_node
223                        .payload
224                        .entry("covered_ranges".to_string())
225                        .or_insert_with(|| serde_json::json!([]));
226                    if let Some(arr) = ranges.as_array_mut() {
227                        arr.push(serde_json::json!([line_start, line_end]));
228                    }
229                    let count = parent_node
230                        .payload
231                        .entry("pruned_chunk_count".to_string())
232                        .or_insert_with(|| serde_json::json!(0));
233                    if let Some(n) = count.as_u64() {
234                        *count = serde_json::json!(n + 1);
235                    }
236                    let _ = self.storage.insert_graph_node(&parent_node);
237                    let _ = graph.add_node(parent_node);
238                }
239            }
240        }
241    }
242
243    /// Pass 2: Score and prune low-value symbol nodes, transferring ranges to parent files.
244    ///
245    /// Like chunk compaction, scoring weights adjust on cold start: when no memories
246    /// exist yet, the `memory_link_val` weight (normally 0.20) is redistributed to
247    /// call connectivity and code size factors.
248    fn compact_symbols(
249        &self,
250        seen_files: &HashSet<String>,
251        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
252        all_nodes: &[GraphNode],
253    ) -> usize {
254        let max_syms_per_file = self.config.chunking.max_retained_symbols_per_file;
255        let sym_score_threshold = self.config.chunking.min_symbol_score_threshold;
256
257        // Cold-start: redistribute memory_link weight when no memories exist
258        let has_memories = self
259            .storage
260            .list_memory_ids()
261            .map(|ids| !ids.is_empty())
262            .unwrap_or(false);
263        let (w_calls, w_vis, w_kind, w_mem, w_size) = if has_memories {
264            (0.30, 0.20, 0.15, 0.20, 0.15)
265        } else {
266            // Redistribute memory weight to calls and code size
267            (0.40, 0.20, 0.15, 0.0, 0.25)
268        };
269
270        let sym_nodes: Vec<&GraphNode> = all_nodes
271            .iter()
272            .filter(|n| n.id.starts_with("sym:"))
273            .collect();
274
275        let mut max_calls_degree: f64 = 1.0;
276        let mut max_code_size: f64 = 1.0;
277
278        for node in &sym_nodes {
279            let calls_degree = graph
280                .get_edges(&node.id)
281                .map(|edges| {
282                    edges
283                        .iter()
284                        .filter(|e| e.relationship == RelationshipType::Calls)
285                        .count() as f64
286                })
287                .unwrap_or(0.0);
288            max_calls_degree = max_calls_degree.max(calls_degree);
289
290            let line_start = node
291                .payload
292                .get("line_start")
293                .and_then(|v| v.as_f64())
294                .unwrap_or(0.0);
295            let line_end = node
296                .payload
297                .get("line_end")
298                .and_then(|v| v.as_f64())
299                .unwrap_or(0.0);
300            let code_size = (line_end - line_start).max(0.0);
301            max_code_size = max_code_size.max(code_size);
302        }
303
304        let mut syms_by_file: HashMap<String, Vec<(String, f64, bool, bool)>> = HashMap::new();
305
306        for node in &sym_nodes {
307            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
308                Some(fp) => fp.to_string(),
309                None => continue,
310            };
311            if !seen_files.contains(&file_path) {
312                continue;
313            }
314
315            let calls_degree = graph
316                .get_edges(&node.id)
317                .map(|edges| {
318                    edges
319                        .iter()
320                        .filter(|e| e.relationship == RelationshipType::Calls)
321                        .count() as f64
322                })
323                .unwrap_or(0.0);
324            let call_connectivity = calls_degree / max_calls_degree;
325
326            let visibility_score = match node
327                .payload
328                .get("visibility")
329                .and_then(|v| v.as_str())
330                .unwrap_or("private")
331            {
332                "public" => 1.0,
333                "crate" => 0.5,
334                _ => 0.0,
335            };
336
337            let kind_score = match node.kind {
338                NodeKind::Class | NodeKind::Interface => 1.0,
339                NodeKind::Module => 1.0,
340                NodeKind::Function | NodeKind::Method => 0.6,
341                NodeKind::Test => 0.3,
342                NodeKind::Constant => 0.1,
343                _ => 0.5,
344            };
345
346            let mem_linked = has_memories && has_memory_link_edge(&**graph, &node.id);
347            let memory_link_val = if mem_linked { 1.0 } else { 0.0 };
348
349            let line_start = node
350                .payload
351                .get("line_start")
352                .and_then(|v| v.as_f64())
353                .unwrap_or(0.0);
354            let line_end = node
355                .payload
356                .get("line_end")
357                .and_then(|v| v.as_f64())
358                .unwrap_or(0.0);
359            let code_size = (line_end - line_start).max(0.0);
360            let code_size_rank = code_size / max_code_size;
361
362            let symbol_score = call_connectivity * w_calls
363                + visibility_score * w_vis
364                + kind_score * w_kind
365                + memory_link_val * w_mem
366                + code_size_rank * w_size;
367
368            let is_structural = matches!(
369                node.kind,
370                NodeKind::Class | NodeKind::Interface | NodeKind::Module
371            );
372
373            syms_by_file.entry(file_path).or_default().push((
374                node.id.clone(),
375                symbol_score,
376                is_structural,
377                mem_linked,
378            ));
379        }
380
381        let mut symbols_pruned = 0usize;
382        for (_file_path, mut syms) in syms_by_file {
383            syms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
384
385            let public_count = syms
386                .iter()
387                .filter(|(id, ..)| {
388                    graph
389                        .get_node(id)
390                        .ok()
391                        .flatten()
392                        .and_then(|n| {
393                            n.payload
394                                .get("visibility")
395                                .and_then(|v| v.as_str())
396                                .map(|v| v == "public")
397                        })
398                        .unwrap_or(false)
399                })
400                .count();
401            let k = max_syms_per_file.max(public_count);
402
403            for (i, (sym_id, score, is_structural, mem_linked)) in syms.iter().enumerate() {
404                if *is_structural || *mem_linked {
405                    continue;
406                }
407                if i < k && *score >= sym_score_threshold {
408                    continue;
409                }
410
411                self.transfer_symbol_ranges_to_file(graph, sym_id);
412
413                if let Err(e) = self.storage.delete_graph_edges_for_node(sym_id) {
414                    tracing::warn!("Failed to delete graph edges for symbol {sym_id}: {e}");
415                }
416                if let Err(e) = self.storage.delete_graph_node(sym_id) {
417                    tracing::warn!("Failed to delete graph node for symbol {sym_id}: {e}");
418                }
419                if let Err(e) = graph.remove_node(sym_id) {
420                    tracing::warn!("Failed to remove symbol {sym_id} from graph: {e}");
421                }
422                symbols_pruned += 1;
423            }
424        }
425
426        symbols_pruned
427    }
428
429    /// When pruning a symbol, transfer its line range to the parent file node.
430    fn transfer_symbol_ranges_to_file(
431        &self,
432        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
433        sym_id: &str,
434    ) {
435        if let Ok(Some(sym_node)) = graph.get_node(sym_id) {
436            if let Some(fp) = sym_node.payload.get("file_path").and_then(|v| v.as_str()) {
437                let file_id = format!("file:{fp}");
438                if let Ok(Some(mut file_node)) = graph.get_node(&file_id) {
439                    let line_start = sym_node
440                        .payload
441                        .get("line_start")
442                        .and_then(|v| v.as_u64())
443                        .unwrap_or(0);
444                    let line_end = sym_node
445                        .payload
446                        .get("line_end")
447                        .and_then(|v| v.as_u64())
448                        .unwrap_or(0);
449                    if line_end > line_start {
450                        let ranges = file_node
451                            .payload
452                            .entry("pruned_symbol_ranges".to_string())
453                            .or_insert_with(|| serde_json::json!([]));
454                        if let Some(arr) = ranges.as_array_mut() {
455                            arr.push(serde_json::json!([line_start, line_end]));
456                        }
457                        let _ = self.storage.insert_graph_node(&file_node);
458                        let _ = graph.add_node(file_node);
459                    }
460                }
461            }
462        }
463    }
464}