Skip to main content

codemem_engine/persistence/
compaction.rs

1use crate::CodememEngine;
2use codemem_core::{GraphBackend, GraphNode, NodeKind, RelationshipType};
3use std::collections::{HashMap, HashSet};
4
5// ── Chunk compaction weights ────────────────────────────────────────────────
6// Normal weights (when memories exist):
7const CHUNK_W_CENTRALITY: f64 = 0.3;
8const CHUNK_W_PARENT: f64 = 0.2;
9const CHUNK_W_MEMORY: f64 = 0.3;
10const CHUNK_W_SIZE: f64 = 0.2;
11// Cold-start weights (no memories → redistribute memory weight):
12const CHUNK_COLD_W_CENTRALITY: f64 = 0.4;
13const CHUNK_COLD_W_PARENT: f64 = 0.3;
14const CHUNK_COLD_W_SIZE: f64 = 0.3;
15
16// ── Symbol compaction weights ───────────────────────────────────────────────
17// Normal weights:
18const SYM_W_CALLS: f64 = 0.30;
19const SYM_W_VISIBILITY: f64 = 0.20;
20const SYM_W_KIND: f64 = 0.15;
21const SYM_W_MEMORY: f64 = 0.20;
22const SYM_W_SIZE: f64 = 0.15;
23// Cold-start weights:
24const SYM_COLD_W_CALLS: f64 = 0.40;
25const SYM_COLD_W_VISIBILITY: f64 = 0.20;
26const SYM_COLD_W_KIND: f64 = 0.15;
27const SYM_COLD_W_SIZE: f64 = 0.25;
28
29/// Check whether a graph node has any edge linking it to a memory node
30/// (i.e. an edge whose other endpoint is not a code-structural ID).
31fn has_memory_link_edge(graph: &dyn GraphBackend, node_id: &str) -> bool {
32    graph
33        .get_edges(node_id)
34        .map(|edges| {
35            edges.iter().any(|e| {
36                let other = if e.src == node_id { &e.dst } else { &e.src };
37                !other.starts_with("sym:")
38                    && !other.starts_with("file:")
39                    && !other.starts_with("chunk:")
40                    && !other.starts_with("pkg:")
41                    && !other.starts_with("contains:")
42                    && !other.starts_with("ref:")
43            })
44        })
45        .unwrap_or(false)
46}
47
48impl CodememEngine {
49    // ── Graph Compaction ────────────────────────────────────────────────────
50
51    /// Compact chunk and symbol graph-nodes after indexing.
52    /// Returns (chunks_pruned, symbols_pruned).
53    ///
54    /// # Panics
55    /// Panics if namespace is None. PageRank must be scoped to a namespace
56    /// to prevent cross-project score pollution in shared databases.
57    pub fn compact_graph(
58        &self,
59        seen_files: &HashSet<String>,
60        namespace: Option<&str>,
61    ) -> (usize, usize) {
62        let namespace = namespace.expect(
63            "compact_graph requires explicit namespace to prevent cross-project PageRank pollution",
64        );
65
66        let mut graph = match self.lock_graph() {
67            Ok(g) => g,
68            Err(e) => {
69                tracing::warn!(
70                    error = %e,
71                    namespace = %namespace,
72                    "compact_graph: failed to acquire graph lock (poisoned)"
73                );
74                return (0, 0);
75            }
76        };
77
78        // Fetch all nodes once and share between both compaction passes.
79        let all_nodes = graph.get_all_nodes();
80        let chunks_pruned = self.compact_chunks(seen_files, &mut **graph, &all_nodes);
81        let symbols_pruned = self.compact_symbols(seen_files, &mut **graph, &all_nodes);
82
83        if chunks_pruned > 0 || symbols_pruned > 0 {
84            // compute_centrality: updates node.centrality with degree centrality.
85            // recompute_centrality: caches PageRank for hybrid scoring, scoped to
86            // the current namespace so cross-project scores don't bleed in.
87            // Both are needed — they populate different data used by different scoring paths.
88            graph.compute_centrality();
89            graph.recompute_centrality_for_namespace(namespace);
90        }
91
92        (chunks_pruned, symbols_pruned)
93    }
94
95    /// Pass 1: Score and prune low-value chunks, transferring line ranges to parent symbols.
96    ///
97    /// Scoring weights adjust on cold start: when no memories exist yet, the
98    /// `memory_link_score` weight (normally 0.3) is redistributed to the other
99    /// factors so compaction still produces meaningful rankings.
100    fn compact_chunks(
101        &self,
102        seen_files: &HashSet<String>,
103        graph: &mut dyn codemem_core::GraphBackend,
104        all_nodes: &[GraphNode],
105    ) -> usize {
106        let max_chunks_per_file = self.config.chunking.max_retained_chunks_per_file;
107        let chunk_score_threshold = self.config.chunking.min_chunk_score_threshold;
108
109        // Cold-start detection: if no memories exist, memory_link_score is always 0
110        // and its weight should be redistributed to other factors.
111        let has_memories = self
112            .storage
113            .list_memory_ids()
114            .map(|ids| !ids.is_empty())
115            .unwrap_or(false);
116        let (w_centrality, w_parent, w_memory, w_size) = if has_memories {
117            (
118                CHUNK_W_CENTRALITY,
119                CHUNK_W_PARENT,
120                CHUNK_W_MEMORY,
121                CHUNK_W_SIZE,
122            )
123        } else {
124            // Redistribute memory_link weight when no memories exist
125            (
126                CHUNK_COLD_W_CENTRALITY,
127                CHUNK_COLD_W_PARENT,
128                0.0,
129                CHUNK_COLD_W_SIZE,
130            )
131        };
132
133        let mut chunks_by_file: HashMap<String, Vec<(String, f64)>> = HashMap::new();
134
135        let mut max_degree: f64 = 1.0;
136        let mut max_non_ws: f64 = 1.0;
137
138        let chunk_nodes: Vec<&GraphNode> = all_nodes
139            .iter()
140            .filter(|n| n.kind == NodeKind::Chunk)
141            .collect();
142
143        for node in &chunk_nodes {
144            let degree = graph
145                .get_edges(&node.id)
146                .map(|edges| edges.len() as f64)
147                .unwrap_or(0.0);
148            max_degree = max_degree.max(degree);
149
150            let non_ws = node
151                .payload
152                .get("non_ws_chars")
153                .and_then(|v| v.as_f64())
154                .unwrap_or(0.0);
155            max_non_ws = max_non_ws.max(non_ws);
156        }
157
158        for node in &chunk_nodes {
159            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
160                Some(fp) => fp.to_string(),
161                None => continue,
162            };
163            if !seen_files.contains(&file_path) {
164                continue;
165            }
166
167            let degree = graph
168                .get_edges(&node.id)
169                .map(|edges| edges.len() as f64)
170                .unwrap_or(0.0);
171            let centrality_rank = degree / max_degree;
172
173            let has_symbol_parent = if node.payload.contains_key("parent_symbol") {
174                1.0
175            } else {
176                0.0
177            };
178
179            let memory_link_score = if has_memories && has_memory_link_edge(graph, &node.id) {
180                1.0
181            } else {
182                0.0
183            };
184
185            let non_ws = node
186                .payload
187                .get("non_ws_chars")
188                .and_then(|v| v.as_f64())
189                .unwrap_or(0.0);
190            let non_ws_rank = non_ws / max_non_ws;
191
192            let chunk_score = centrality_rank * w_centrality
193                + has_symbol_parent * w_parent
194                + memory_link_score * w_memory
195                + non_ws_rank * w_size;
196
197            chunks_by_file
198                .entry(file_path)
199                .or_default()
200                .push((node.id.clone(), chunk_score));
201        }
202
203        let mut symbol_count_by_file: HashMap<String, usize> = HashMap::new();
204        for node in all_nodes {
205            if matches!(
206                node.kind,
207                NodeKind::Function
208                    | NodeKind::Method
209                    | NodeKind::Class
210                    | NodeKind::Interface
211                    | NodeKind::Type
212                    | NodeKind::Constant
213                    | NodeKind::Test
214            ) {
215                if let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) {
216                    *symbol_count_by_file.entry(fp.to_string()).or_default() += 1;
217                }
218            }
219        }
220
221        let mut chunks_pruned = 0usize;
222        for (file_path, mut chunks) in chunks_by_file {
223            chunks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
224
225            let sym_count = symbol_count_by_file.get(&file_path).copied().unwrap_or(0);
226            let k = max_chunks_per_file.min(chunks.len()).max(3.max(sym_count));
227
228            for (i, (chunk_id, score)) in chunks.iter().enumerate() {
229                // Prune aggressively: remove if beyond the top-k slots OR below the
230                // quality threshold. Using || ensures both caps are enforced independently
231                // (keep at most k chunks, and never keep any chunk below threshold).
232                if i >= k || *score < chunk_score_threshold {
233                    self.transfer_chunk_ranges_to_parent(graph, chunk_id);
234
235                    if let Err(e) = self.storage.delete_graph_edges_for_node(chunk_id) {
236                        tracing::warn!("Failed to delete graph edges for chunk {chunk_id}: {e}");
237                    }
238                    if let Err(e) = self.storage.delete_graph_node(chunk_id) {
239                        tracing::warn!("Failed to delete graph node for chunk {chunk_id}: {e}");
240                    }
241                    if let Err(e) = graph.remove_node(chunk_id) {
242                        tracing::warn!("Failed to remove chunk {chunk_id} from graph: {e}");
243                    }
244                    chunks_pruned += 1;
245                }
246            }
247        }
248
249        chunks_pruned
250    }
251
252    /// When pruning a chunk, transfer its line range to the parent symbol node.
253    fn transfer_chunk_ranges_to_parent(
254        &self,
255        graph: &mut dyn codemem_core::GraphBackend,
256        chunk_id: &str,
257    ) {
258        if let Ok(Some(chunk_node)) = graph.get_node(chunk_id) {
259            if let Some(parent_sym) = chunk_node
260                .payload
261                .get("parent_symbol")
262                .and_then(|v| v.as_str())
263            {
264                let parent_id = format!("sym:{parent_sym}");
265                if let Ok(Some(mut parent_node)) = graph.get_node(&parent_id) {
266                    let line_start = chunk_node
267                        .payload
268                        .get("line_start")
269                        .and_then(|v| v.as_u64())
270                        .unwrap_or(0);
271                    let line_end = chunk_node
272                        .payload
273                        .get("line_end")
274                        .and_then(|v| v.as_u64())
275                        .unwrap_or(0);
276                    let ranges = parent_node
277                        .payload
278                        .entry("covered_ranges".to_string())
279                        .or_insert_with(|| serde_json::json!([]));
280                    if let Some(arr) = ranges.as_array_mut() {
281                        arr.push(serde_json::json!([line_start, line_end]));
282                    }
283                    let count = parent_node
284                        .payload
285                        .entry("pruned_chunk_count".to_string())
286                        .or_insert_with(|| serde_json::json!(0));
287                    if let Some(n) = count.as_u64() {
288                        *count = serde_json::json!(n + 1);
289                    }
290                    let _ = self.storage.insert_graph_node(&parent_node);
291                    let _ = graph.add_node(parent_node);
292                }
293            }
294        }
295    }
296
297    /// Pass 2: Score and prune low-value symbol nodes, transferring ranges to parent files.
298    ///
299    /// Like chunk compaction, scoring weights adjust on cold start: when no memories
300    /// exist yet, the `memory_link_val` weight (normally 0.20) is redistributed to
301    /// call connectivity and code size factors.
302    fn compact_symbols(
303        &self,
304        seen_files: &HashSet<String>,
305        graph: &mut dyn codemem_core::GraphBackend,
306        all_nodes: &[GraphNode],
307    ) -> usize {
308        let max_syms_per_file = self.config.chunking.max_retained_symbols_per_file;
309        let sym_score_threshold = self.config.chunking.min_symbol_score_threshold;
310
311        // Cold-start: redistribute memory_link weight when no memories exist
312        let has_memories = self
313            .storage
314            .list_memory_ids()
315            .map(|ids| !ids.is_empty())
316            .unwrap_or(false);
317        let (w_calls, w_vis, w_kind, w_mem, w_size) = if has_memories {
318            (
319                SYM_W_CALLS,
320                SYM_W_VISIBILITY,
321                SYM_W_KIND,
322                SYM_W_MEMORY,
323                SYM_W_SIZE,
324            )
325        } else {
326            // Redistribute memory weight to calls and code size
327            (
328                SYM_COLD_W_CALLS,
329                SYM_COLD_W_VISIBILITY,
330                SYM_COLD_W_KIND,
331                0.0,
332                SYM_COLD_W_SIZE,
333            )
334        };
335
336        // Only compact ast-grep symbols. SCIP-sourced symbols (both explicit and
337        // synthetic containment nodes) should not be pruned by heuristic scoring.
338        let sym_nodes: Vec<&GraphNode> = all_nodes
339            .iter()
340            .filter(|n| {
341                n.id.starts_with("sym:")
342                    && !matches!(
343                        n.payload.get("source").and_then(|v| v.as_str()),
344                        Some("scip" | "scip-synthetic")
345                    )
346            })
347            .collect();
348
349        let mut max_calls_degree: f64 = 1.0;
350        let mut max_code_size: f64 = 1.0;
351
352        for node in &sym_nodes {
353            let calls_degree = graph
354                .get_edges(&node.id)
355                .map(|edges| {
356                    edges
357                        .iter()
358                        .filter(|e| e.relationship == RelationshipType::Calls)
359                        .count() as f64
360                })
361                .unwrap_or(0.0);
362            max_calls_degree = max_calls_degree.max(calls_degree);
363
364            let line_start = node
365                .payload
366                .get("line_start")
367                .and_then(|v| v.as_f64())
368                .unwrap_or(0.0);
369            let line_end = node
370                .payload
371                .get("line_end")
372                .and_then(|v| v.as_f64())
373                .unwrap_or(0.0);
374            let code_size = (line_end - line_start).max(0.0);
375            max_code_size = max_code_size.max(code_size);
376        }
377
378        let mut syms_by_file: HashMap<String, Vec<(String, f64, bool, bool)>> = HashMap::new();
379
380        for node in &sym_nodes {
381            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
382                Some(fp) => fp.to_string(),
383                None => continue,
384            };
385            if !seen_files.contains(&file_path) {
386                continue;
387            }
388
389            let calls_degree = graph
390                .get_edges(&node.id)
391                .map(|edges| {
392                    edges
393                        .iter()
394                        .filter(|e| e.relationship == RelationshipType::Calls)
395                        .count() as f64
396                })
397                .unwrap_or(0.0);
398            let call_connectivity = calls_degree / max_calls_degree;
399
400            let visibility_score = match node
401                .payload
402                .get("visibility")
403                .and_then(|v| v.as_str())
404                .unwrap_or("private")
405            {
406                "public" => 1.0,
407                "crate" => 0.5,
408                _ => 0.0,
409            };
410
411            let kind_score = match node.kind {
412                NodeKind::Class | NodeKind::Interface => 1.0,
413                NodeKind::Module => 1.0,
414                NodeKind::Function | NodeKind::Method => 0.6,
415                NodeKind::Test => 0.3,
416                NodeKind::Constant => 0.1,
417                _ => 0.5,
418            };
419
420            let mem_linked = has_memories && has_memory_link_edge(graph, &node.id);
421            let memory_link_val = if mem_linked { 1.0 } else { 0.0 };
422
423            let line_start = node
424                .payload
425                .get("line_start")
426                .and_then(|v| v.as_f64())
427                .unwrap_or(0.0);
428            let line_end = node
429                .payload
430                .get("line_end")
431                .and_then(|v| v.as_f64())
432                .unwrap_or(0.0);
433            let code_size = (line_end - line_start).max(0.0);
434            let code_size_rank = code_size / max_code_size;
435
436            let symbol_score = call_connectivity * w_calls
437                + visibility_score * w_vis
438                + kind_score * w_kind
439                + memory_link_val * w_mem
440                + code_size_rank * w_size;
441
442            let is_structural = matches!(
443                node.kind,
444                NodeKind::Class | NodeKind::Interface | NodeKind::Module
445            );
446
447            syms_by_file.entry(file_path).or_default().push((
448                node.id.clone(),
449                symbol_score,
450                is_structural,
451                mem_linked,
452            ));
453        }
454
455        let mut symbols_pruned = 0usize;
456        for (_file_path, mut syms) in syms_by_file {
457            syms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
458
459            let public_count = syms
460                .iter()
461                .filter(|(id, ..)| {
462                    graph
463                        .get_node(id)
464                        .ok()
465                        .flatten()
466                        .and_then(|n| {
467                            n.payload
468                                .get("visibility")
469                                .and_then(|v| v.as_str())
470                                .map(|v| v == "public")
471                        })
472                        .unwrap_or(false)
473                })
474                .count();
475            let k = max_syms_per_file.max(public_count);
476
477            for (i, (sym_id, score, is_structural, mem_linked)) in syms.iter().enumerate() {
478                if *is_structural || *mem_linked {
479                    continue;
480                }
481                if i < k && *score >= sym_score_threshold {
482                    continue;
483                }
484
485                self.transfer_symbol_ranges_to_file(graph, sym_id);
486
487                if let Err(e) = self.storage.delete_graph_edges_for_node(sym_id) {
488                    tracing::warn!("Failed to delete graph edges for symbol {sym_id}: {e}");
489                }
490                if let Err(e) = self.storage.delete_graph_node(sym_id) {
491                    tracing::warn!("Failed to delete graph node for symbol {sym_id}: {e}");
492                }
493                if let Err(e) = graph.remove_node(sym_id) {
494                    tracing::warn!("Failed to remove symbol {sym_id} from graph: {e}");
495                }
496                symbols_pruned += 1;
497            }
498        }
499
500        symbols_pruned
501    }
502
503    /// When pruning a symbol, transfer its line range to the parent file node.
504    fn transfer_symbol_ranges_to_file(
505        &self,
506        graph: &mut dyn codemem_core::GraphBackend,
507        sym_id: &str,
508    ) {
509        if let Ok(Some(sym_node)) = graph.get_node(sym_id) {
510            if let Some(fp) = sym_node.payload.get("file_path").and_then(|v| v.as_str()) {
511                let file_id = format!("file:{fp}");
512                if let Ok(Some(mut file_node)) = graph.get_node(&file_id) {
513                    let line_start = sym_node
514                        .payload
515                        .get("line_start")
516                        .and_then(|v| v.as_u64())
517                        .unwrap_or(0);
518                    let line_end = sym_node
519                        .payload
520                        .get("line_end")
521                        .and_then(|v| v.as_u64())
522                        .unwrap_or(0);
523                    if line_end > line_start {
524                        let ranges = file_node
525                            .payload
526                            .entry("pruned_symbol_ranges".to_string())
527                            .or_insert_with(|| serde_json::json!([]));
528                        if let Some(arr) = ranges.as_array_mut() {
529                            arr.push(serde_json::json!([line_start, line_end]));
530                        }
531                        let _ = self.storage.insert_graph_node(&file_node);
532                        let _ = graph.add_node(file_node);
533                    }
534                }
535            }
536        }
537    }
538}