Skip to main content

codemem_engine/
review.rs

1//! Diff-aware review pipeline: parse unified diffs, map changed lines to symbols,
2//! compute blast radius via multi-hop graph traversal.
3
4use crate::CodememEngine;
5use codemem_core::{CodememError, GraphBackend, MemoryNode, RelationshipType};
6use std::collections::{HashMap, HashSet};
7
8// ── Types ────────────────────────────────────────────────────────────────
9
10/// A parsed diff hunk: file path + changed line ranges.
11#[derive(Debug, Clone)]
12pub struct DiffHunk {
13    pub file_path: String,
14    pub added_lines: Vec<u32>,
15    pub removed_lines: Vec<u32>,
16}
17
18/// Mapping from a unified diff to affected symbol IDs in the graph.
19#[derive(Debug, Clone, Default)]
20pub struct DiffSymbolMapping {
21    /// sym:IDs whose definition range overlaps a changed line.
22    pub changed_symbols: Vec<String>,
23    /// sym:IDs whose body contains changes (parent of a changed symbol).
24    pub containing_symbols: Vec<String>,
25    /// file:IDs of changed files.
26    pub changed_files: Vec<String>,
27}
28
29/// Information about a symbol for the blast radius report.
30#[derive(Debug, Clone, serde::Serialize)]
31pub struct SymbolInfo {
32    pub id: String,
33    pub label: String,
34    pub kind: String,
35    pub file_path: Option<String>,
36    pub line_start: Option<u32>,
37    pub pagerank: f64,
38}
39
40/// A potentially missing change detected by pattern analysis.
41#[derive(Debug, Clone, serde::Serialize)]
42pub struct MissingChange {
43    pub symbol: String,
44    pub reason: String,
45}
46
47/// Full blast radius report for a diff.
48#[derive(Debug, Clone, serde::Serialize)]
49pub struct BlastRadiusReport {
50    pub changed_symbols: Vec<SymbolInfo>,
51    pub direct_dependents: Vec<SymbolInfo>,
52    pub transitive_dependents: Vec<SymbolInfo>,
53    pub affected_files: Vec<String>,
54    pub affected_modules: Vec<String>,
55    pub risk_score: f64,
56    pub missing_changes: Vec<MissingChange>,
57    pub relevant_memories: Vec<MemorySnippet>,
58}
59
60/// Lightweight memory reference for the report (avoids serializing full MemoryNode).
61#[derive(Debug, Clone, serde::Serialize)]
62pub struct MemorySnippet {
63    pub id: String,
64    pub content: String,
65    pub memory_type: String,
66    pub importance: f64,
67}
68
69impl From<&MemoryNode> for MemorySnippet {
70    fn from(m: &MemoryNode) -> Self {
71        Self {
72            id: m.id.clone(),
73            content: m.content.clone(),
74            memory_type: m.memory_type.to_string(),
75            importance: m.importance,
76        }
77    }
78}
79
80// ── Diff Parsing ─────────────────────────────────────────────────────────
81
82/// Parse a unified diff into hunks with file paths and changed line numbers.
83pub fn parse_diff(diff: &str) -> Vec<DiffHunk> {
84    let mut hunks = Vec::new();
85    let mut current_file: Option<String> = None;
86    let mut added_lines: Vec<u32> = Vec::new();
87    let mut removed_lines: Vec<u32> = Vec::new();
88    let mut new_line: u32 = 0;
89    let mut old_line: u32 = 0;
90
91    for line in diff.lines() {
92        if line.starts_with("+++ b/") {
93            // Flush previous file
94            if let Some(ref file) = current_file {
95                if !added_lines.is_empty() || !removed_lines.is_empty() {
96                    hunks.push(DiffHunk {
97                        file_path: file.clone(),
98                        added_lines: std::mem::take(&mut added_lines),
99                        removed_lines: std::mem::take(&mut removed_lines),
100                    });
101                }
102            }
103            current_file = line.strip_prefix("+++ b/").map(|s| s.to_string());
104        } else if line.starts_with("@@ ") {
105            // Parse hunk header: @@ -old_start,old_count +new_start,new_count @@
106            if let Some((new_start, old_start)) = parse_hunk_header(line) {
107                new_line = new_start;
108                old_line = old_start;
109            }
110        } else if current_file.is_some() {
111            if line.starts_with('+') && !line.starts_with("+++") {
112                added_lines.push(new_line);
113                new_line += 1;
114            } else if line.starts_with('-') && !line.starts_with("---") {
115                removed_lines.push(old_line);
116                old_line += 1;
117            } else {
118                // Context line
119                new_line += 1;
120                old_line += 1;
121            }
122        }
123    }
124
125    // Flush last file
126    if let Some(file) = current_file {
127        if !added_lines.is_empty() || !removed_lines.is_empty() {
128            hunks.push(DiffHunk {
129                file_path: file,
130                added_lines,
131                removed_lines,
132            });
133        }
134    }
135
136    hunks
137}
138
139/// Parse a @@ hunk header, returning (new_start, old_start).
140fn parse_hunk_header(line: &str) -> Option<(u32, u32)> {
141    // Format: @@ -old_start[,old_count] +new_start[,new_count] @@
142    let parts: Vec<&str> = line.split_whitespace().collect();
143    if parts.len() < 3 {
144        return None;
145    }
146    let old_part = parts[1].strip_prefix('-')?;
147    let new_part = parts[2].strip_prefix('+')?;
148
149    let old_start: u32 = old_part.split(',').next()?.parse().ok()?;
150    let new_start: u32 = new_part.split(',').next()?.parse().ok()?;
151    Some((new_start, old_start))
152}
153
154// ── Diff to Symbols ──────────────────────────────────────────────────────
155
156impl CodememEngine {
157    /// Map a unified diff to affected symbol IDs using the graph's line range data.
158    pub fn diff_to_symbols(&self, diff: &str) -> Result<DiffSymbolMapping, CodememError> {
159        let hunks = parse_diff(diff);
160        let graph = self.lock_graph()?;
161        let all_nodes = graph.get_all_nodes();
162
163        let mut mapping = DiffSymbolMapping::default();
164        let mut seen_symbols: HashSet<String> = HashSet::new();
165        let mut seen_files: HashSet<String> = HashSet::new();
166
167        // Build file→symbols index to avoid O(nodes × hunks) scan
168        let mut file_symbols: HashMap<&str, Vec<&codemem_core::GraphNode>> = HashMap::new();
169        for node in &all_nodes {
170            if !node.id.starts_with("sym:") {
171                continue;
172            }
173            if let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) {
174                file_symbols.entry(fp).or_default().push(node);
175            }
176        }
177
178        for hunk in &hunks {
179            let file_id = format!("file:{}", hunk.file_path);
180            if seen_files.insert(file_id.clone()) {
181                mapping.changed_files.push(file_id);
182            }
183
184            let changed_lines: HashSet<u32> = hunk
185                .added_lines
186                .iter()
187                .chain(hunk.removed_lines.iter())
188                .copied()
189                .collect();
190
191            // Only check symbols in this file (indexed lookup)
192            if let Some(nodes) = file_symbols.get(hunk.file_path.as_str()) {
193                for node in nodes {
194                    let line_start = node
195                        .payload
196                        .get("line_start")
197                        .and_then(|v| v.as_u64())
198                        .unwrap_or(0) as u32;
199                    let line_end = node
200                        .payload
201                        .get("line_end")
202                        .and_then(|v| v.as_u64())
203                        .unwrap_or(line_start as u64) as u32;
204
205                    let overlaps = changed_lines
206                        .iter()
207                        .any(|&l| l >= line_start && l <= line_end);
208                    if overlaps && seen_symbols.insert(node.id.clone()) {
209                        mapping.changed_symbols.push(node.id.clone());
210                    }
211                }
212            }
213        }
214
215        // Find containing symbols (parents of changed symbols via CONTAINS edges)
216        let changed_set: HashSet<&str> =
217            mapping.changed_symbols.iter().map(|s| s.as_str()).collect();
218        for node in &all_nodes {
219            if !node.id.starts_with("sym:") || changed_set.contains(node.id.as_str()) {
220                continue;
221            }
222            // Check if this symbol contains any changed symbol
223            let edges = graph.get_edges(&node.id).unwrap_or_default();
224            let contains_changed = edges.iter().any(|e| {
225                e.relationship == RelationshipType::Contains && changed_set.contains(e.dst.as_str())
226            });
227            if contains_changed && seen_symbols.insert(node.id.clone()) {
228                mapping.containing_symbols.push(node.id.clone());
229            }
230        }
231
232        Ok(mapping)
233    }
234
235    /// Compute the blast radius for a diff: changed symbols, dependents, risk score,
236    /// relevant memories, and missing change detection.
237    pub fn blast_radius(
238        &self,
239        diff: &str,
240        depth: usize,
241    ) -> Result<BlastRadiusReport, CodememError> {
242        let mapping = self.diff_to_symbols(diff)?;
243        let graph = self.lock_graph()?;
244
245        let mut changed_infos = Vec::new();
246        let mut direct_deps = Vec::new();
247        let mut transitive_deps = Vec::new();
248        let mut affected_files: HashSet<String> = HashSet::new();
249        let mut affected_modules: HashSet<String> = HashSet::new();
250        let mut seen: HashSet<String> = HashSet::new();
251        let mut risk_score: f64 = 0.0;
252
253        // Collect changed symbol info + their PageRank for risk scoring
254        for sym_id in &mapping.changed_symbols {
255            if let Some(info) = node_to_symbol_info(&**graph, sym_id) {
256                risk_score += info.pagerank;
257                if let Some(ref fp) = info.file_path {
258                    affected_files.insert(fp.clone());
259                }
260                seen.insert(sym_id.clone());
261                changed_infos.push(info);
262            }
263        }
264        for sym_id in &mapping.containing_symbols {
265            if let Some(info) = node_to_symbol_info(&**graph, sym_id) {
266                if let Some(ref fp) = info.file_path {
267                    affected_files.insert(fp.clone());
268                }
269                seen.insert(sym_id.clone());
270                changed_infos.push(info);
271            }
272        }
273
274        // BFS from changed symbols to find dependents
275        let all_changed: Vec<&str> = mapping
276            .changed_symbols
277            .iter()
278            .chain(mapping.containing_symbols.iter())
279            .map(|s| s.as_str())
280            .collect();
281
282        for &start_id in &all_changed {
283            // Get direct dependents (1-hop incoming edges: who CALLS/IMPORTS this symbol?)
284            let edges = graph.get_edges(start_id).unwrap_or_default();
285            for edge in &edges {
286                // Incoming edges: other symbols that depend on this one
287                let dependent_id = if edge.dst == start_id {
288                    &edge.src
289                } else {
290                    continue; // outgoing edge, skip
291                };
292                if !dependent_id.starts_with("sym:") || !seen.insert(dependent_id.clone()) {
293                    continue;
294                }
295                if matches!(
296                    edge.relationship,
297                    RelationshipType::Calls
298                        | RelationshipType::Imports
299                        | RelationshipType::Inherits
300                        | RelationshipType::Implements
301                        | RelationshipType::Overrides
302                ) {
303                    if let Some(info) = node_to_symbol_info(&**graph, dependent_id) {
304                        if let Some(ref fp) = info.file_path {
305                            affected_files.insert(fp.clone());
306                        }
307                        direct_deps.push(info);
308                    }
309                }
310            }
311        }
312
313        // Transitive dependents (2+ hops) via iterative incoming-edge traversal.
314        // BFS follows outgoing edges (wrong direction for "who depends on me?").
315        // Instead, walk incoming edges layer by layer.
316        if depth > 1 {
317            let mut frontier: Vec<String> = direct_deps.iter().map(|d| d.id.clone()).collect();
318            for _ in 1..depth {
319                let mut next_frontier = Vec::new();
320                for node_id in &frontier {
321                    let edges = graph.get_edges(node_id).unwrap_or_default();
322                    for edge in &edges {
323                        // Only follow incoming dependency edges
324                        if edge.dst != *node_id {
325                            continue;
326                        }
327                        if !matches!(
328                            edge.relationship,
329                            RelationshipType::Calls
330                                | RelationshipType::Imports
331                                | RelationshipType::Inherits
332                                | RelationshipType::Implements
333                                | RelationshipType::Overrides
334                        ) {
335                            continue;
336                        }
337                        let dep_id = &edge.src;
338                        if !dep_id.starts_with("sym:") || !seen.insert(dep_id.clone()) {
339                            continue;
340                        }
341                        if let Some(info) = node_to_symbol_info(&**graph, dep_id) {
342                            if let Some(ref fp) = info.file_path {
343                                affected_files.insert(fp.clone());
344                            }
345                            if info.kind == "Module" {
346                                affected_modules.insert(info.id.clone());
347                            }
348                            next_frontier.push(dep_id.clone());
349                            transitive_deps.push(info);
350                        }
351                    }
352                }
353                if next_frontier.is_empty() {
354                    break;
355                }
356                frontier = next_frontier;
357            }
358        }
359
360        // Detect affected modules from all symbols
361        for info in changed_infos.iter().chain(direct_deps.iter()) {
362            if info.kind == "Module" {
363                affected_modules.insert(info.id.clone());
364            }
365        }
366
367        // Risk score: Σ(pagerank) + log(transitive_count + 1)
368        // Additive so that diffs touching zero-pagerank symbols (common when
369        // centrality hasn't been computed or symbols have no edges) still get
370        // a nonzero risk score from their dependent count.
371        let transitive_count = direct_deps.len() + transitive_deps.len();
372        risk_score += (transitive_count as f64 + 1.0).ln();
373
374        // Drop graph lock before accessing storage
375        drop(graph);
376
377        // Find relevant memories connected to changed symbols
378        let mut relevant_memories = Vec::new();
379        for sym_id in mapping
380            .changed_symbols
381            .iter()
382            .chain(mapping.containing_symbols.iter())
383            .take(20)
384        {
385            if let Ok(results) = self.get_node_memories(sym_id, 1, None) {
386                for r in &results {
387                    relevant_memories.push(MemorySnippet::from(&r.memory));
388                }
389            }
390        }
391        // Dedup memories by ID
392        let mut seen_mem_ids: HashSet<String> = HashSet::new();
393        relevant_memories.retain(|m| seen_mem_ids.insert(m.id.clone()));
394
395        // Missing change detection: find symbols with similar caller patterns
396        let graph = self.lock_graph()?;
397        let missing_changes = detect_missing_changes(&**graph, &mapping.changed_symbols, &seen);
398
399        let affected_files: Vec<String> = affected_files.into_iter().collect();
400        let affected_modules: Vec<String> = affected_modules.into_iter().collect();
401
402        Ok(BlastRadiusReport {
403            changed_symbols: changed_infos,
404            direct_dependents: direct_deps,
405            transitive_dependents: transitive_deps,
406            affected_files,
407            affected_modules,
408            risk_score,
409            missing_changes,
410            relevant_memories,
411        })
412    }
413}
414
415// ── Helpers ──────────────────────────────────────────────────────────────
416
417fn node_to_symbol_info(graph: &dyn GraphBackend, node_id: &str) -> Option<SymbolInfo> {
418    let node = graph.get_node(node_id).ok()??;
419    Some(SymbolInfo {
420        id: node.id.clone(),
421        label: node.label.clone(),
422        kind: node.kind.to_string(),
423        file_path: node
424            .payload
425            .get("file_path")
426            .and_then(|v| v.as_str())
427            .map(String::from),
428        line_start: node
429            .payload
430            .get("line_start")
431            .and_then(|v| v.as_u64())
432            .map(|v| v as u32),
433        pagerank: graph.get_pagerank(&node.id),
434    })
435}
436
437/// Detect potentially missing changes: symbols that share callers with changed symbols
438/// but aren't in the diff.
439fn detect_missing_changes(
440    graph: &dyn GraphBackend,
441    changed_symbols: &[String],
442    already_in_diff: &HashSet<String>,
443) -> Vec<MissingChange> {
444    let mut missing = Vec::new();
445
446    // For each changed symbol, find its callers. Then find what else those callers call.
447    // If a sibling is called by the same callers but not in the diff, flag it.
448    let mut caller_sets: HashMap<String, HashSet<String>> = HashMap::new();
449
450    for sym_id in changed_symbols {
451        let edges = graph.get_edges(sym_id).unwrap_or_default();
452        let callers: HashSet<String> = edges
453            .iter()
454            .filter(|e| e.dst == *sym_id && e.relationship == RelationshipType::Calls)
455            .map(|e| e.src.clone())
456            .collect();
457        if !callers.is_empty() {
458            caller_sets.insert(sym_id.clone(), callers);
459        }
460    }
461
462    // Find siblings: other symbols called by the same callers
463    let mut sibling_counts: HashMap<String, usize> = HashMap::new();
464    for callers in caller_sets.values() {
465        for caller_id in callers {
466            let edges = graph.get_edges(caller_id).unwrap_or_default();
467            for edge in &edges {
468                if edge.src == *caller_id
469                    && edge.relationship == RelationshipType::Calls
470                    && edge.dst.starts_with("sym:")
471                    && !already_in_diff.contains(&edge.dst)
472                {
473                    *sibling_counts.entry(edge.dst.clone()).or_insert(0) += 1;
474                }
475            }
476        }
477    }
478
479    // Flag siblings that share multiple callers with changed symbols
480    let threshold = (changed_symbols.len() / 2).max(2);
481    for (sibling, count) in &sibling_counts {
482        if *count >= threshold {
483            missing.push(MissingChange {
484                symbol: sibling.clone(),
485                reason: format!(
486                    "shares {} callers with {} changed symbols",
487                    count,
488                    changed_symbols.len()
489                ),
490            });
491        }
492    }
493
494    missing
495}
496
497#[cfg(test)]
498#[path = "tests/review_tests.rs"]
499mod tests;