Skip to main content

codemem_engine/persistence/
compaction.rs

1use crate::CodememEngine;
2use codemem_core::{GraphBackend, GraphNode, NodeKind, RelationshipType};
3use std::collections::{HashMap, HashSet};
4
5// ── Chunk compaction weights ────────────────────────────────────────────────
6// Normal weights (when memories exist):
7const CHUNK_W_CENTRALITY: f64 = 0.3;
8const CHUNK_W_PARENT: f64 = 0.2;
9const CHUNK_W_MEMORY: f64 = 0.3;
10const CHUNK_W_SIZE: f64 = 0.2;
11// Cold-start weights (no memories → redistribute memory weight):
12const CHUNK_COLD_W_CENTRALITY: f64 = 0.4;
13const CHUNK_COLD_W_PARENT: f64 = 0.3;
14const CHUNK_COLD_W_SIZE: f64 = 0.3;
15
16// ── Symbol compaction weights ───────────────────────────────────────────────
17// Normal weights:
18const SYM_W_CALLS: f64 = 0.30;
19const SYM_W_VISIBILITY: f64 = 0.20;
20const SYM_W_KIND: f64 = 0.15;
21const SYM_W_MEMORY: f64 = 0.20;
22const SYM_W_SIZE: f64 = 0.15;
23// Cold-start weights:
24const SYM_COLD_W_CALLS: f64 = 0.40;
25const SYM_COLD_W_VISIBILITY: f64 = 0.20;
26const SYM_COLD_W_KIND: f64 = 0.15;
27const SYM_COLD_W_SIZE: f64 = 0.25;
28
29/// Check whether a graph node has any edge linking it to a memory node
30/// (i.e. an edge whose other endpoint is not a code-structural ID).
31fn has_memory_link_edge(graph: &dyn GraphBackend, node_id: &str) -> bool {
32    graph
33        .get_edges(node_id)
34        .map(|edges| {
35            edges.iter().any(|e| {
36                let other = if e.src == node_id { &e.dst } else { &e.src };
37                !other.starts_with("sym:")
38                    && !other.starts_with("file:")
39                    && !other.starts_with("chunk:")
40                    && !other.starts_with("pkg:")
41                    && !other.starts_with("contains:")
42                    && !other.starts_with("ref:")
43            })
44        })
45        .unwrap_or(false)
46}
47
48impl CodememEngine {
49    // ── Graph Compaction ────────────────────────────────────────────────────
50
51    /// Compact chunk and symbol graph-nodes after indexing.
52    /// Returns (chunks_pruned, symbols_pruned).
53    pub fn compact_graph(&self, seen_files: &HashSet<String>) -> (usize, usize) {
54        let mut graph = match self.lock_graph() {
55            Ok(g) => g,
56            Err(_) => return (0, 0),
57        };
58
59        // Fetch all nodes once and share between both compaction passes.
60        let all_nodes = graph.get_all_nodes();
61        let chunks_pruned = self.compact_chunks(seen_files, &mut graph, &all_nodes);
62        let symbols_pruned = self.compact_symbols(seen_files, &mut graph, &all_nodes);
63
64        if chunks_pruned > 0 || symbols_pruned > 0 {
65            // compute_centrality: updates node.centrality with degree centrality.
66            // recompute_centrality: caches PageRank + betweenness for hybrid scoring.
67            // Both are needed — they populate different data used by different scoring paths.
68            graph.compute_centrality();
69            graph.recompute_centrality();
70        }
71
72        (chunks_pruned, symbols_pruned)
73    }
74
75    /// Pass 1: Score and prune low-value chunks, transferring line ranges to parent symbols.
76    ///
77    /// Scoring weights adjust on cold start: when no memories exist yet, the
78    /// `memory_link_score` weight (normally 0.3) is redistributed to the other
79    /// factors so compaction still produces meaningful rankings.
80    fn compact_chunks(
81        &self,
82        seen_files: &HashSet<String>,
83        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
84        all_nodes: &[GraphNode],
85    ) -> usize {
86        let max_chunks_per_file = self.config.chunking.max_retained_chunks_per_file;
87        let chunk_score_threshold = self.config.chunking.min_chunk_score_threshold;
88
89        // Cold-start detection: if no memories exist, memory_link_score is always 0
90        // and its weight should be redistributed to other factors.
91        let has_memories = self
92            .storage
93            .list_memory_ids()
94            .map(|ids| !ids.is_empty())
95            .unwrap_or(false);
96        let (w_centrality, w_parent, w_memory, w_size) = if has_memories {
97            (
98                CHUNK_W_CENTRALITY,
99                CHUNK_W_PARENT,
100                CHUNK_W_MEMORY,
101                CHUNK_W_SIZE,
102            )
103        } else {
104            // Redistribute memory_link weight when no memories exist
105            (
106                CHUNK_COLD_W_CENTRALITY,
107                CHUNK_COLD_W_PARENT,
108                0.0,
109                CHUNK_COLD_W_SIZE,
110            )
111        };
112
113        let mut chunks_by_file: HashMap<String, Vec<(String, f64)>> = HashMap::new();
114
115        let mut max_degree: f64 = 1.0;
116        let mut max_non_ws: f64 = 1.0;
117
118        let chunk_nodes: Vec<&GraphNode> = all_nodes
119            .iter()
120            .filter(|n| n.kind == NodeKind::Chunk)
121            .collect();
122
123        for node in &chunk_nodes {
124            let degree = graph
125                .get_edges(&node.id)
126                .map(|edges| edges.len() as f64)
127                .unwrap_or(0.0);
128            max_degree = max_degree.max(degree);
129
130            let non_ws = node
131                .payload
132                .get("non_ws_chars")
133                .and_then(|v| v.as_f64())
134                .unwrap_or(0.0);
135            max_non_ws = max_non_ws.max(non_ws);
136        }
137
138        for node in &chunk_nodes {
139            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
140                Some(fp) => fp.to_string(),
141                None => continue,
142            };
143            if !seen_files.contains(&file_path) {
144                continue;
145            }
146
147            let degree = graph
148                .get_edges(&node.id)
149                .map(|edges| edges.len() as f64)
150                .unwrap_or(0.0);
151            let centrality_rank = degree / max_degree;
152
153            let has_symbol_parent = if node.payload.contains_key("parent_symbol") {
154                1.0
155            } else {
156                0.0
157            };
158
159            let memory_link_score = if has_memories && has_memory_link_edge(&**graph, &node.id) {
160                1.0
161            } else {
162                0.0
163            };
164
165            let non_ws = node
166                .payload
167                .get("non_ws_chars")
168                .and_then(|v| v.as_f64())
169                .unwrap_or(0.0);
170            let non_ws_rank = non_ws / max_non_ws;
171
172            let chunk_score = centrality_rank * w_centrality
173                + has_symbol_parent * w_parent
174                + memory_link_score * w_memory
175                + non_ws_rank * w_size;
176
177            chunks_by_file
178                .entry(file_path)
179                .or_default()
180                .push((node.id.clone(), chunk_score));
181        }
182
183        let mut symbol_count_by_file: HashMap<String, usize> = HashMap::new();
184        for node in all_nodes {
185            if matches!(
186                node.kind,
187                NodeKind::Function
188                    | NodeKind::Method
189                    | NodeKind::Class
190                    | NodeKind::Interface
191                    | NodeKind::Type
192                    | NodeKind::Constant
193                    | NodeKind::Test
194            ) {
195                if let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) {
196                    *symbol_count_by_file.entry(fp.to_string()).or_default() += 1;
197                }
198            }
199        }
200
201        let mut chunks_pruned = 0usize;
202        for (file_path, mut chunks) in chunks_by_file {
203            chunks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
204
205            let sym_count = symbol_count_by_file.get(&file_path).copied().unwrap_or(0);
206            let k = max_chunks_per_file.min(chunks.len()).max(3.max(sym_count));
207
208            for (i, (chunk_id, score)) in chunks.iter().enumerate() {
209                // Prune aggressively: remove if beyond the top-k slots OR below the
210                // quality threshold. Using || ensures both caps are enforced independently
211                // (keep at most k chunks, and never keep any chunk below threshold).
212                if i >= k || *score < chunk_score_threshold {
213                    self.transfer_chunk_ranges_to_parent(graph, chunk_id);
214
215                    if let Err(e) = self.storage.delete_graph_edges_for_node(chunk_id) {
216                        tracing::warn!("Failed to delete graph edges for chunk {chunk_id}: {e}");
217                    }
218                    if let Err(e) = self.storage.delete_graph_node(chunk_id) {
219                        tracing::warn!("Failed to delete graph node for chunk {chunk_id}: {e}");
220                    }
221                    if let Err(e) = graph.remove_node(chunk_id) {
222                        tracing::warn!("Failed to remove chunk {chunk_id} from graph: {e}");
223                    }
224                    chunks_pruned += 1;
225                }
226            }
227        }
228
229        chunks_pruned
230    }
231
232    /// When pruning a chunk, transfer its line range to the parent symbol node.
233    fn transfer_chunk_ranges_to_parent(
234        &self,
235        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
236        chunk_id: &str,
237    ) {
238        if let Ok(Some(chunk_node)) = graph.get_node(chunk_id) {
239            if let Some(parent_sym) = chunk_node
240                .payload
241                .get("parent_symbol")
242                .and_then(|v| v.as_str())
243            {
244                let parent_id = format!("sym:{parent_sym}");
245                if let Ok(Some(mut parent_node)) = graph.get_node(&parent_id) {
246                    let line_start = chunk_node
247                        .payload
248                        .get("line_start")
249                        .and_then(|v| v.as_u64())
250                        .unwrap_or(0);
251                    let line_end = chunk_node
252                        .payload
253                        .get("line_end")
254                        .and_then(|v| v.as_u64())
255                        .unwrap_or(0);
256                    let ranges = parent_node
257                        .payload
258                        .entry("covered_ranges".to_string())
259                        .or_insert_with(|| serde_json::json!([]));
260                    if let Some(arr) = ranges.as_array_mut() {
261                        arr.push(serde_json::json!([line_start, line_end]));
262                    }
263                    let count = parent_node
264                        .payload
265                        .entry("pruned_chunk_count".to_string())
266                        .or_insert_with(|| serde_json::json!(0));
267                    if let Some(n) = count.as_u64() {
268                        *count = serde_json::json!(n + 1);
269                    }
270                    let _ = self.storage.insert_graph_node(&parent_node);
271                    let _ = graph.add_node(parent_node);
272                }
273            }
274        }
275    }
276
277    /// Pass 2: Score and prune low-value symbol nodes, transferring ranges to parent files.
278    ///
279    /// Like chunk compaction, scoring weights adjust on cold start: when no memories
280    /// exist yet, the `memory_link_val` weight (normally 0.20) is redistributed to
281    /// call connectivity and code size factors.
282    fn compact_symbols(
283        &self,
284        seen_files: &HashSet<String>,
285        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
286        all_nodes: &[GraphNode],
287    ) -> usize {
288        let max_syms_per_file = self.config.chunking.max_retained_symbols_per_file;
289        let sym_score_threshold = self.config.chunking.min_symbol_score_threshold;
290
291        // Cold-start: redistribute memory_link weight when no memories exist
292        let has_memories = self
293            .storage
294            .list_memory_ids()
295            .map(|ids| !ids.is_empty())
296            .unwrap_or(false);
297        let (w_calls, w_vis, w_kind, w_mem, w_size) = if has_memories {
298            (
299                SYM_W_CALLS,
300                SYM_W_VISIBILITY,
301                SYM_W_KIND,
302                SYM_W_MEMORY,
303                SYM_W_SIZE,
304            )
305        } else {
306            // Redistribute memory weight to calls and code size
307            (
308                SYM_COLD_W_CALLS,
309                SYM_COLD_W_VISIBILITY,
310                SYM_COLD_W_KIND,
311                0.0,
312                SYM_COLD_W_SIZE,
313            )
314        };
315
316        let sym_nodes: Vec<&GraphNode> = all_nodes
317            .iter()
318            .filter(|n| n.id.starts_with("sym:"))
319            .collect();
320
321        let mut max_calls_degree: f64 = 1.0;
322        let mut max_code_size: f64 = 1.0;
323
324        for node in &sym_nodes {
325            let calls_degree = graph
326                .get_edges(&node.id)
327                .map(|edges| {
328                    edges
329                        .iter()
330                        .filter(|e| e.relationship == RelationshipType::Calls)
331                        .count() as f64
332                })
333                .unwrap_or(0.0);
334            max_calls_degree = max_calls_degree.max(calls_degree);
335
336            let line_start = node
337                .payload
338                .get("line_start")
339                .and_then(|v| v.as_f64())
340                .unwrap_or(0.0);
341            let line_end = node
342                .payload
343                .get("line_end")
344                .and_then(|v| v.as_f64())
345                .unwrap_or(0.0);
346            let code_size = (line_end - line_start).max(0.0);
347            max_code_size = max_code_size.max(code_size);
348        }
349
350        let mut syms_by_file: HashMap<String, Vec<(String, f64, bool, bool)>> = HashMap::new();
351
352        for node in &sym_nodes {
353            let file_path = match node.payload.get("file_path").and_then(|v| v.as_str()) {
354                Some(fp) => fp.to_string(),
355                None => continue,
356            };
357            if !seen_files.contains(&file_path) {
358                continue;
359            }
360
361            let calls_degree = graph
362                .get_edges(&node.id)
363                .map(|edges| {
364                    edges
365                        .iter()
366                        .filter(|e| e.relationship == RelationshipType::Calls)
367                        .count() as f64
368                })
369                .unwrap_or(0.0);
370            let call_connectivity = calls_degree / max_calls_degree;
371
372            let visibility_score = match node
373                .payload
374                .get("visibility")
375                .and_then(|v| v.as_str())
376                .unwrap_or("private")
377            {
378                "public" => 1.0,
379                "crate" => 0.5,
380                _ => 0.0,
381            };
382
383            let kind_score = match node.kind {
384                NodeKind::Class | NodeKind::Interface => 1.0,
385                NodeKind::Module => 1.0,
386                NodeKind::Function | NodeKind::Method => 0.6,
387                NodeKind::Test => 0.3,
388                NodeKind::Constant => 0.1,
389                _ => 0.5,
390            };
391
392            let mem_linked = has_memories && has_memory_link_edge(&**graph, &node.id);
393            let memory_link_val = if mem_linked { 1.0 } else { 0.0 };
394
395            let line_start = node
396                .payload
397                .get("line_start")
398                .and_then(|v| v.as_f64())
399                .unwrap_or(0.0);
400            let line_end = node
401                .payload
402                .get("line_end")
403                .and_then(|v| v.as_f64())
404                .unwrap_or(0.0);
405            let code_size = (line_end - line_start).max(0.0);
406            let code_size_rank = code_size / max_code_size;
407
408            let symbol_score = call_connectivity * w_calls
409                + visibility_score * w_vis
410                + kind_score * w_kind
411                + memory_link_val * w_mem
412                + code_size_rank * w_size;
413
414            let is_structural = matches!(
415                node.kind,
416                NodeKind::Class | NodeKind::Interface | NodeKind::Module
417            );
418
419            syms_by_file.entry(file_path).or_default().push((
420                node.id.clone(),
421                symbol_score,
422                is_structural,
423                mem_linked,
424            ));
425        }
426
427        let mut symbols_pruned = 0usize;
428        for (_file_path, mut syms) in syms_by_file {
429            syms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
430
431            let public_count = syms
432                .iter()
433                .filter(|(id, ..)| {
434                    graph
435                        .get_node(id)
436                        .ok()
437                        .flatten()
438                        .and_then(|n| {
439                            n.payload
440                                .get("visibility")
441                                .and_then(|v| v.as_str())
442                                .map(|v| v == "public")
443                        })
444                        .unwrap_or(false)
445                })
446                .count();
447            let k = max_syms_per_file.max(public_count);
448
449            for (i, (sym_id, score, is_structural, mem_linked)) in syms.iter().enumerate() {
450                if *is_structural || *mem_linked {
451                    continue;
452                }
453                if i < k && *score >= sym_score_threshold {
454                    continue;
455                }
456
457                self.transfer_symbol_ranges_to_file(graph, sym_id);
458
459                if let Err(e) = self.storage.delete_graph_edges_for_node(sym_id) {
460                    tracing::warn!("Failed to delete graph edges for symbol {sym_id}: {e}");
461                }
462                if let Err(e) = self.storage.delete_graph_node(sym_id) {
463                    tracing::warn!("Failed to delete graph node for symbol {sym_id}: {e}");
464                }
465                if let Err(e) = graph.remove_node(sym_id) {
466                    tracing::warn!("Failed to remove symbol {sym_id} from graph: {e}");
467                }
468                symbols_pruned += 1;
469            }
470        }
471
472        symbols_pruned
473    }
474
475    /// When pruning a symbol, transfer its line range to the parent file node.
476    fn transfer_symbol_ranges_to_file(
477        &self,
478        graph: &mut std::sync::MutexGuard<'_, codemem_storage::graph::GraphEngine>,
479        sym_id: &str,
480    ) {
481        if let Ok(Some(sym_node)) = graph.get_node(sym_id) {
482            if let Some(fp) = sym_node.payload.get("file_path").and_then(|v| v.as_str()) {
483                let file_id = format!("file:{fp}");
484                if let Ok(Some(mut file_node)) = graph.get_node(&file_id) {
485                    let line_start = sym_node
486                        .payload
487                        .get("line_start")
488                        .and_then(|v| v.as_u64())
489                        .unwrap_or(0);
490                    let line_end = sym_node
491                        .payload
492                        .get("line_end")
493                        .and_then(|v| v.as_u64())
494                        .unwrap_or(0);
495                    if line_end > line_start {
496                        let ranges = file_node
497                            .payload
498                            .entry("pruned_symbol_ranges".to_string())
499                            .or_insert_with(|| serde_json::json!([]));
500                        if let Some(arr) = ranges.as_array_mut() {
501                            arr.push(serde_json::json!([line_start, line_end]));
502                        }
503                        let _ = self.storage.insert_graph_node(&file_node);
504                        let _ = graph.add_node(file_node);
505                    }
506                }
507            }
508        }
509    }
510}