Skip to main content

codemem_engine/persistence/
compaction.rs

1use crate::CodememEngine;
2use codemem_core::{GraphBackend, GraphNode, NodeKind, RelationshipType};
3use std::collections::{HashMap, HashSet};
4
5// ── Chunk compaction weights ────────────────────────────────────────────────
6// Normal weights (when memories exist):
7const CHUNK_W_CENTRALITY: f64 = 0.3;
8const CHUNK_W_PARENT: f64 = 0.2;
9const CHUNK_W_MEMORY: f64 = 0.3;
10const CHUNK_W_SIZE: f64 = 0.2;
11// Cold-start weights (no memories → redistribute memory weight):
12const CHUNK_COLD_W_CENTRALITY: f64 = 0.4;
13const CHUNK_COLD_W_PARENT: f64 = 0.3;
14const CHUNK_COLD_W_SIZE: f64 = 0.3;
15
16// ── Symbol compaction weights ───────────────────────────────────────────────
17// Normal weights:
18const SYM_W_CALLS: f64 = 0.30;
19const SYM_W_VISIBILITY: f64 = 0.20;
20const SYM_W_KIND: f64 = 0.15;
21const SYM_W_MEMORY: f64 = 0.20;
22const SYM_W_SIZE: f64 = 0.15;
23// Cold-start weights:
24const SYM_COLD_W_CALLS: f64 = 0.40;
25const SYM_COLD_W_VISIBILITY: f64 = 0.20;
26const SYM_COLD_W_KIND: f64 = 0.15;
27const SYM_COLD_W_SIZE: f64 = 0.25;
28
29/// Check whether a graph node has any edge linking it to a memory node
30/// (i.e. an edge whose other endpoint is not a code-structural ID).
31fn has_memory_link_edge(graph: &dyn GraphBackend, node_id: &str) -> bool {
32    graph
33        .get_edges(node_id)
34        .map(|edges| {
35            edges.iter().any(|e| {
36                let other = if e.src == node_id { &e.dst } else { &e.src };
37                !other.starts_with("sym:")
38                    && !other.starts_with("file:")
39                    && !other.starts_with("chunk:")
40                    && !other.starts_with("pkg:")
41                    && !other.starts_with("contains:")
42                    && !other.starts_with("ref:")
43            })
44        })
45        .unwrap_or(false)
46}
47
48impl CodememEngine {
49    // ── Graph Compaction ────────────────────────────────────────────────────
50
51    /// Compact chunk and symbol graph-nodes after indexing.
52    /// Returns (chunks_pruned, symbols_pruned).
53    pub fn compact_graph(&self, seen_files: &HashSet<String>) -> (usize, usize) {
54        let mut graph = match self.lock_graph() {
55            Ok(g) => g,
56            Err(_) => return (0, 0),
57        };
58
59        // Fetch all nodes once and share between both compaction passes.
60        let all_nodes = graph.get_all_nodes();
61        let chunks_pruned = self.compact_chunks(seen_files, &mut graph, &all_nodes);
62        let symbols_pruned = self.compact_symbols(seen_files, &mut graph, &all_nodes);
63
64        if chunks_pruned > 0 || symbols_pruned > 0 {
65            // compute_centrality: updates node.centrality with degree centrality.
66            // recompute_centrality: caches PageRank + betweenness for hybrid scoring.
67            // Both are needed — they populate different data used by different scoring paths.
68            graph.compute_centrality();
69            graph.recompute_centrality();
70        }
71
72        (chunks_pruned, symbols_pruned)
73    }
74
75    /// Pass 1: Score and prune low-value chunks, transferring line ranges to parent symbols.
76    ///
77    /// Scoring weights adjust on cold start: when no memories exist yet, the
78    /// `memory_link_score` weight (normally 0.3) is redistributed to the other
79    /// factors so compaction still produces meaningful rankings.
80    fn compact_chunks(
81        &self,
82        seen_files: &HashSet<String>,
83        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
84        all_nodes: &[GraphNode],
85    ) -> usize {
86        let max_chunks_per_file = self.config.chunking.max_retained_chunks_per_file;
87        let chunk_score_threshold = self.config.chunking.min_chunk_score_threshold;
88
89        // Cold-start detection: if no memories exist, memory_link_score is always 0
90        // and its weight should be redistributed to other factors.
91        let has_memories = self
92            .storage
93            .list_memory_ids()
94            .map(|ids| !ids.is_empty())
95            .unwrap_or(false);
96        let (w_centrality, w_parent, w_memory, w_size) = if has_memories {
97            (
98                CHUNK_W_CENTRALITY,
99                CHUNK_W_PARENT,
100                CHUNK_W_MEMORY,
101                CHUNK_W_SIZE,
102            )
103        } else {
104            // Redistribute memory_link weight when no memories exist
105            (
106                CHUNK_COLD_W_CENTRALITY,
107                CHUNK_COLD_W_PARENT,
108                0.0,
109                CHUNK_COLD_W_SIZE,
110            )
111        };
112
113        let mut chunks_by_file: HashMap<String, Vec<(String, f64)>> = HashMap::new();
114
115        let mut max_degree: f64 = 1.0;
116        let mut max_non_ws: f64 = 1.0;
117
118        let chunk_nodes: Vec<&GraphNode> = all_nodes
119            .iter()
120            .filter(|n| n.kind == NodeKind::Chunk)
121            .collect();
122
123        for node in &chunk_nodes {
124            let degree = graph
125                .get_edges(&node.id)
126                .map(|edges| edges.len() as f64)
127                .unwrap_or(0.0);
128            max_degree = max_degree.max(degree);
129
130            let non_ws = node
131                .payload
132                .get("non_ws_chars")
133                .and_then(|v| v.as_f64())
134                .unwrap_or(0.0);
135            max_non_ws = max_non_ws.max(non_ws);
136        }
137
138        for node in &chunk_nodes {
139            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
140                Some(fp) => fp.to_string(),
141                None => continue,
142            };
143            if !seen_files.contains(&file_path) {
144                continue;
145            }
146
147            let degree = graph
148                .get_edges(&node.id)
149                .map(|edges| edges.len() as f64)
150                .unwrap_or(0.0);
151            let centrality_rank = degree / max_degree;
152
153            let has_symbol_parent = if node.payload.contains_key("parent_symbol") {
154                1.0
155            } else {
156                0.0
157            };
158
159            let memory_link_score = if has_memories && has_memory_link_edge(&**graph, &node.id) {
160                1.0
161            } else {
162                0.0
163            };
164
165            let non_ws = node
166                .payload
167                .get("non_ws_chars")
168                .and_then(|v| v.as_f64())
169                .unwrap_or(0.0);
170            let non_ws_rank = non_ws / max_non_ws;
171
172            let chunk_score = centrality_rank * w_centrality
173                + has_symbol_parent * w_parent
174                + memory_link_score * w_memory
175                + non_ws_rank * w_size;
176
177            chunks_by_file
178                .entry(file_path)
179                .or_default()
180                .push((node.id.clone(), chunk_score));
181        }
182
183        let mut symbol_count_by_file: HashMap<String, usize> = HashMap::new();
184        for node in all_nodes {
185            if matches!(
186                node.kind,
187                NodeKind::Function
188                    | NodeKind::Method
189                    | NodeKind::Class
190                    | NodeKind::Interface
191                    | NodeKind::Type
192                    | NodeKind::Constant
193                    | NodeKind::Test
194            ) {
195                if let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) {
196                    *symbol_count_by_file.entry(fp.to_string()).or_default() += 1;
197                }
198            }
199        }
200
201        let mut chunks_pruned = 0usize;
202        for (file_path, mut chunks) in chunks_by_file {
203            chunks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
204
205            let sym_count = symbol_count_by_file.get(&file_path).copied().unwrap_or(0);
206            let k = max_chunks_per_file.min(chunks.len()).max(3.max(sym_count));
207
208            for (i, (chunk_id, score)) in chunks.iter().enumerate() {
209                // Prune aggressively: remove if beyond the top-k slots OR below the
210                // quality threshold. Using || ensures both caps are enforced independently
211                // (keep at most k chunks, and never keep any chunk below threshold).
212                if i >= k || *score < chunk_score_threshold {
213                    self.transfer_chunk_ranges_to_parent(graph, chunk_id);
214
215                    if let Err(e) = self.storage.delete_graph_edges_for_node(chunk_id) {
216                        tracing::warn!("Failed to delete graph edges for chunk {chunk_id}: {e}");
217                    }
218                    if let Err(e) = self.storage.delete_graph_node(chunk_id) {
219                        tracing::warn!("Failed to delete graph node for chunk {chunk_id}: {e}");
220                    }
221                    if let Err(e) = graph.remove_node(chunk_id) {
222                        tracing::warn!("Failed to remove chunk {chunk_id} from graph: {e}");
223                    }
224                    chunks_pruned += 1;
225                }
226            }
227        }
228
229        chunks_pruned
230    }
231
232    /// When pruning a chunk, transfer its line range to the parent symbol node.
233    fn transfer_chunk_ranges_to_parent(
234        &self,
235        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
236        chunk_id: &str,
237    ) {
238        if let Ok(Some(chunk_node)) = graph.get_node(chunk_id) {
239            if let Some(parent_sym) = chunk_node
240                .payload
241                .get("parent_symbol")
242                .and_then(|v| v.as_str())
243            {
244                let parent_id = format!("sym:{parent_sym}");
245                if let Ok(Some(mut parent_node)) = graph.get_node(&parent_id) {
246                    let line_start = chunk_node
247                        .payload
248                        .get("line_start")
249                        .and_then(|v| v.as_u64())
250                        .unwrap_or(0);
251                    let line_end = chunk_node
252                        .payload
253                        .get("line_end")
254                        .and_then(|v| v.as_u64())
255                        .unwrap_or(0);
256                    let ranges = parent_node
257                        .payload
258                        .entry("covered_ranges".to_string())
259                        .or_insert_with(|| serde_json::json!([]));
260                    if let Some(arr) = ranges.as_array_mut() {
261                        arr.push(serde_json::json!([line_start, line_end]));
262                    }
263                    let count = parent_node
264                        .payload
265                        .entry("pruned_chunk_count".to_string())
266                        .or_insert_with(|| serde_json::json!(0));
267                    if let Some(n) = count.as_u64() {
268                        *count = serde_json::json!(n + 1);
269                    }
270                    let _ = self.storage.insert_graph_node(&parent_node);
271                    let _ = graph.add_node(parent_node);
272                }
273            }
274        }
275    }
276
277    /// Pass 2: Score and prune low-value symbol nodes, transferring ranges to parent files.
278    ///
279    /// Like chunk compaction, scoring weights adjust on cold start: when no memories
280    /// exist yet, the `memory_link_val` weight (normally 0.20) is redistributed to
281    /// call connectivity and code size factors.
282    fn compact_symbols(
283        &self,
284        seen_files: &HashSet<String>,
285        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
286        all_nodes: &[GraphNode],
287    ) -> usize {
288        let max_syms_per_file = self.config.chunking.max_retained_symbols_per_file;
289        let sym_score_threshold = self.config.chunking.min_symbol_score_threshold;
290
291        // Cold-start: redistribute memory_link weight when no memories exist
292        let has_memories = self
293            .storage
294            .list_memory_ids()
295            .map(|ids| !ids.is_empty())
296            .unwrap_or(false);
297        let (w_calls, w_vis, w_kind, w_mem, w_size) = if has_memories {
298            (
299                SYM_W_CALLS,
300                SYM_W_VISIBILITY,
301                SYM_W_KIND,
302                SYM_W_MEMORY,
303                SYM_W_SIZE,
304            )
305        } else {
306            // Redistribute memory weight to calls and code size
307            (
308                SYM_COLD_W_CALLS,
309                SYM_COLD_W_VISIBILITY,
310                SYM_COLD_W_KIND,
311                0.0,
312                SYM_COLD_W_SIZE,
313            )
314        };
315
316        // Only compact ast-grep symbols. SCIP-sourced symbols (both explicit and
317        // synthetic containment nodes) should not be pruned by heuristic scoring.
318        let sym_nodes: Vec<&GraphNode> = all_nodes
319            .iter()
320            .filter(|n| {
321                n.id.starts_with("sym:")
322                    && !matches!(
323                        n.payload.get("source").and_then(|v| v.as_str()),
324                        Some("scip" | "scip-synthetic")
325                    )
326            })
327            .collect();
328
329        let mut max_calls_degree: f64 = 1.0;
330        let mut max_code_size: f64 = 1.0;
331
332        for node in &sym_nodes {
333            let calls_degree = graph
334                .get_edges(&node.id)
335                .map(|edges| {
336                    edges
337                        .iter()
338                        .filter(|e| e.relationship == RelationshipType::Calls)
339                        .count() as f64
340                })
341                .unwrap_or(0.0);
342            max_calls_degree = max_calls_degree.max(calls_degree);
343
344            let line_start = node
345                .payload
346                .get("line_start")
347                .and_then(|v| v.as_f64())
348                .unwrap_or(0.0);
349            let line_end = node
350                .payload
351                .get("line_end")
352                .and_then(|v| v.as_f64())
353                .unwrap_or(0.0);
354            let code_size = (line_end - line_start).max(0.0);
355            max_code_size = max_code_size.max(code_size);
356        }
357
358        let mut syms_by_file: HashMap<String, Vec<(String, f64, bool, bool)>> = HashMap::new();
359
360        for node in &sym_nodes {
361            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
362                Some(fp) => fp.to_string(),
363                None => continue,
364            };
365            if !seen_files.contains(&file_path) {
366                continue;
367            }
368
369            let calls_degree = graph
370                .get_edges(&node.id)
371                .map(|edges| {
372                    edges
373                        .iter()
374                        .filter(|e| e.relationship == RelationshipType::Calls)
375                        .count() as f64
376                })
377                .unwrap_or(0.0);
378            let call_connectivity = calls_degree / max_calls_degree;
379
380            let visibility_score = match node
381                .payload
382                .get("visibility")
383                .and_then(|v| v.as_str())
384                .unwrap_or("private")
385            {
386                "public" => 1.0,
387                "crate" => 0.5,
388                _ => 0.0,
389            };
390
391            let kind_score = match node.kind {
392                NodeKind::Class | NodeKind::Interface => 1.0,
393                NodeKind::Module => 1.0,
394                NodeKind::Function | NodeKind::Method => 0.6,
395                NodeKind::Test => 0.3,
396                NodeKind::Constant => 0.1,
397                _ => 0.5,
398            };
399
400            let mem_linked = has_memories && has_memory_link_edge(&**graph, &node.id);
401            let memory_link_val = if mem_linked { 1.0 } else { 0.0 };
402
403            let line_start = node
404                .payload
405                .get("line_start")
406                .and_then(|v| v.as_f64())
407                .unwrap_or(0.0);
408            let line_end = node
409                .payload
410                .get("line_end")
411                .and_then(|v| v.as_f64())
412                .unwrap_or(0.0);
413            let code_size = (line_end - line_start).max(0.0);
414            let code_size_rank = code_size / max_code_size;
415
416            let symbol_score = call_connectivity * w_calls
417                + visibility_score * w_vis
418                + kind_score * w_kind
419                + memory_link_val * w_mem
420                + code_size_rank * w_size;
421
422            let is_structural = matches!(
423                node.kind,
424                NodeKind::Class | NodeKind::Interface | NodeKind::Module
425            );
426
427            syms_by_file.entry(file_path).or_default().push((
428                node.id.clone(),
429                symbol_score,
430                is_structural,
431                mem_linked,
432            ));
433        }
434
435        let mut symbols_pruned = 0usize;
436        for (_file_path, mut syms) in syms_by_file {
437            syms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
438
439            let public_count = syms
440                .iter()
441                .filter(|(id, ..)| {
442                    graph
443                        .get_node(id)
444                        .ok()
445                        .flatten()
446                        .and_then(|n| {
447                            n.payload
448                                .get("visibility")
449                                .and_then(|v| v.as_str())
450                                .map(|v| v == "public")
451                        })
452                        .unwrap_or(false)
453                })
454                .count();
455            let k = max_syms_per_file.max(public_count);
456
457            for (i, (sym_id, score, is_structural, mem_linked)) in syms.iter().enumerate() {
458                if *is_structural || *mem_linked {
459                    continue;
460                }
461                if i < k && *score >= sym_score_threshold {
462                    continue;
463                }
464
465                self.transfer_symbol_ranges_to_file(graph, sym_id);
466
467                if let Err(e) = self.storage.delete_graph_edges_for_node(sym_id) {
468                    tracing::warn!("Failed to delete graph edges for symbol {sym_id}: {e}");
469                }
470                if let Err(e) = self.storage.delete_graph_node(sym_id) {
471                    tracing::warn!("Failed to delete graph node for symbol {sym_id}: {e}");
472                }
473                if let Err(e) = graph.remove_node(sym_id) {
474                    tracing::warn!("Failed to remove symbol {sym_id} from graph: {e}");
475                }
476                symbols_pruned += 1;
477            }
478        }
479
480        symbols_pruned
481    }
482
483    /// When pruning a symbol, transfer its line range to the parent file node.
484    fn transfer_symbol_ranges_to_file(
485        &self,
486        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
487        sym_id: &str,
488    ) {
489        if let Ok(Some(sym_node)) = graph.get_node(sym_id) {
490            if let Some(fp) = sym_node.payload.get("file_path").and_then(|v| v.as_str()) {
491                let file_id = format!("file:{fp}");
492                if let Ok(Some(mut file_node)) = graph.get_node(&file_id) {
493                    let line_start = sym_node
494                        .payload
495                        .get("line_start")
496                        .and_then(|v| v.as_u64())
497                        .unwrap_or(0);
498                    let line_end = sym_node
499                        .payload
500                        .get("line_end")
501                        .and_then(|v| v.as_u64())
502                        .unwrap_or(0);
503                    if line_end > line_start {
504                        let ranges = file_node
505                            .payload
506                            .entry("pruned_symbol_ranges".to_string())
507                            .or_insert_with(|| serde_json::json!([]));
508                        if let Some(arr) = ranges.as_array_mut() {
509                            arr.push(serde_json::json!([line_start, line_end]));
510                        }
511                        let _ = self.storage.insert_graph_node(&file_node);
512                        let _ = graph.add_node(file_node);
513                    }
514                }
515            }
516        }
517    }
518}