Skip to main content

codemem_engine/
review.rs

1//! Diff-aware review pipeline: parse unified diffs, map changed lines to symbols,
2//! compute blast radius via multi-hop graph traversal.
3
4use crate::CodememEngine;
5use codemem_core::{CodememError, GraphBackend, MemoryNode, RelationshipType};
6use std::collections::{HashMap, HashSet};
7
8// ── Types ────────────────────────────────────────────────────────────────
9
10/// A parsed diff hunk: file path + changed line ranges.
11#[derive(Debug, Clone)]
12pub struct DiffHunk {
13    pub file_path: String,
14    pub added_lines: Vec<u32>,
15    pub removed_lines: Vec<u32>,
16}
17
18/// Mapping from a unified diff to affected symbol IDs in the graph.
19#[derive(Debug, Clone, Default)]
20pub struct DiffSymbolMapping {
21    /// sym:IDs whose definition range overlaps a changed line.
22    pub changed_symbols: Vec<String>,
23    /// sym:IDs whose body contains changes (parent of a changed symbol).
24    pub containing_symbols: Vec<String>,
25    /// file:IDs of changed files.
26    pub changed_files: Vec<String>,
27}
28
29/// Information about a symbol for the blast radius report.
30#[derive(Debug, Clone, serde::Serialize)]
31pub struct SymbolInfo {
32    pub id: String,
33    pub label: String,
34    pub kind: String,
35    pub file_path: Option<String>,
36    pub line_start: Option<u32>,
37    pub pagerank: f64,
38}
39
40/// A potentially missing change detected by pattern analysis.
41#[derive(Debug, Clone, serde::Serialize)]
42pub struct MissingChange {
43    pub symbol: String,
44    pub reason: String,
45}
46
47/// A file historically coupled with changed files but absent from the diff.
48#[derive(Debug, Clone, serde::Serialize)]
49pub struct MissingCoChange {
50    pub file_path: String,
51    pub coupled_with: Vec<String>,
52    pub strength: f64,
53}
54
55/// Full blast radius report for a diff.
56#[derive(Debug, Clone, serde::Serialize)]
57pub struct BlastRadiusReport {
58    pub changed_symbols: Vec<SymbolInfo>,
59    pub direct_dependents: Vec<SymbolInfo>,
60    pub transitive_dependents: Vec<SymbolInfo>,
61    pub affected_files: Vec<String>,
62    pub affected_modules: Vec<String>,
63    pub risk_score: f64,
64    pub missing_changes: Vec<MissingChange>,
65    pub missing_co_changes: Vec<MissingCoChange>,
66    pub relevant_memories: Vec<MemorySnippet>,
67}
68
69/// Lightweight memory reference for the report (avoids serializing full MemoryNode).
70#[derive(Debug, Clone, serde::Serialize)]
71pub struct MemorySnippet {
72    pub id: String,
73    pub content: String,
74    pub memory_type: String,
75    pub importance: f64,
76}
77
78impl From<&MemoryNode> for MemorySnippet {
79    fn from(m: &MemoryNode) -> Self {
80        Self {
81            id: m.id.clone(),
82            content: m.content.clone(),
83            memory_type: m.memory_type.to_string(),
84            importance: m.importance,
85        }
86    }
87}
88
89// ── Diff Parsing ─────────────────────────────────────────────────────────
90
91/// Parse a unified diff into hunks with file paths and changed line numbers.
92pub fn parse_diff(diff: &str) -> Vec<DiffHunk> {
93    let mut hunks = Vec::new();
94    let mut current_file: Option<String> = None;
95    let mut added_lines: Vec<u32> = Vec::new();
96    let mut removed_lines: Vec<u32> = Vec::new();
97    let mut new_line: u32 = 0;
98    let mut old_line: u32 = 0;
99
100    for line in diff.lines() {
101        if line.starts_with("+++ b/") {
102            // Flush previous file
103            if let Some(ref file) = current_file {
104                if !added_lines.is_empty() || !removed_lines.is_empty() {
105                    hunks.push(DiffHunk {
106                        file_path: file.clone(),
107                        added_lines: std::mem::take(&mut added_lines),
108                        removed_lines: std::mem::take(&mut removed_lines),
109                    });
110                }
111            }
112            current_file = line.strip_prefix("+++ b/").map(|s| s.to_string());
113        } else if line.starts_with("@@ ") {
114            // Parse hunk header: @@ -old_start,old_count +new_start,new_count @@
115            if let Some((new_start, old_start)) = parse_hunk_header(line) {
116                new_line = new_start;
117                old_line = old_start;
118            }
119        } else if current_file.is_some() {
120            if line.starts_with('+') && !line.starts_with("+++") {
121                added_lines.push(new_line);
122                new_line += 1;
123            } else if line.starts_with('-') && !line.starts_with("---") {
124                removed_lines.push(old_line);
125                old_line += 1;
126            } else {
127                // Context line
128                new_line += 1;
129                old_line += 1;
130            }
131        }
132    }
133
134    // Flush last file
135    if let Some(file) = current_file {
136        if !added_lines.is_empty() || !removed_lines.is_empty() {
137            hunks.push(DiffHunk {
138                file_path: file,
139                added_lines,
140                removed_lines,
141            });
142        }
143    }
144
145    hunks
146}
147
148/// Parse a @@ hunk header, returning (new_start, old_start).
149fn parse_hunk_header(line: &str) -> Option<(u32, u32)> {
150    // Format: @@ -old_start[,old_count] +new_start[,new_count] @@
151    let parts: Vec<&str> = line.split_whitespace().collect();
152    if parts.len() < 3 {
153        return None;
154    }
155    let old_part = parts[1].strip_prefix('-')?;
156    let new_part = parts[2].strip_prefix('+')?;
157
158    let old_start: u32 = old_part.split(',').next()?.parse().ok()?;
159    let new_start: u32 = new_part.split(',').next()?.parse().ok()?;
160    Some((new_start, old_start))
161}
162
163// ── Diff to Symbols ──────────────────────────────────────────────────────
164
165impl CodememEngine {
166    /// Map a unified diff to affected symbol IDs using the graph's line range data.
167    pub fn diff_to_symbols(&self, diff: &str) -> Result<DiffSymbolMapping, CodememError> {
168        let hunks = parse_diff(diff);
169        let graph = self.lock_graph()?;
170        let all_nodes = graph.get_all_nodes();
171
172        let mut mapping = DiffSymbolMapping::default();
173        let mut seen_symbols: HashSet<String> = HashSet::new();
174        let mut seen_files: HashSet<String> = HashSet::new();
175
176        // Build file→symbols index to avoid O(nodes × hunks) scan
177        let mut file_symbols: HashMap<&str, Vec<&codemem_core::GraphNode>> = HashMap::new();
178        for node in &all_nodes {
179            if !node.id.starts_with("sym:") {
180                continue;
181            }
182            if let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) {
183                file_symbols.entry(fp).or_default().push(node);
184            }
185        }
186
187        for hunk in &hunks {
188            let file_id = format!("file:{}", hunk.file_path);
189            if seen_files.insert(file_id.clone()) {
190                mapping.changed_files.push(file_id);
191            }
192
193            let changed_lines: HashSet<u32> = hunk
194                .added_lines
195                .iter()
196                .chain(hunk.removed_lines.iter())
197                .copied()
198                .collect();
199
200            // Only check symbols in this file (indexed lookup)
201            if let Some(nodes) = file_symbols.get(hunk.file_path.as_str()) {
202                for node in nodes {
203                    let line_start = node
204                        .payload
205                        .get("line_start")
206                        .and_then(|v| v.as_u64())
207                        .unwrap_or(0) as u32;
208                    let line_end = node
209                        .payload
210                        .get("line_end")
211                        .and_then(|v| v.as_u64())
212                        .unwrap_or(line_start as u64) as u32;
213
214                    let overlaps = changed_lines
215                        .iter()
216                        .any(|&l| l >= line_start && l <= line_end);
217                    if overlaps && seen_symbols.insert(node.id.clone()) {
218                        mapping.changed_symbols.push(node.id.clone());
219                    }
220                }
221            }
222        }
223
224        // Find containing symbols (parents of changed symbols via CONTAINS edges)
225        let changed_set: HashSet<&str> =
226            mapping.changed_symbols.iter().map(|s| s.as_str()).collect();
227        for node in &all_nodes {
228            if !node.id.starts_with("sym:") || changed_set.contains(node.id.as_str()) {
229                continue;
230            }
231            // Check if this symbol contains any changed symbol
232            let edges = graph.get_edges(&node.id).unwrap_or_default();
233            let contains_changed = edges.iter().any(|e| {
234                e.relationship == RelationshipType::Contains && changed_set.contains(e.dst.as_str())
235            });
236            if contains_changed && seen_symbols.insert(node.id.clone()) {
237                mapping.containing_symbols.push(node.id.clone());
238            }
239        }
240
241        Ok(mapping)
242    }
243
244    /// Compute the blast radius for a diff: changed symbols, dependents, risk score,
245    /// relevant memories, and missing change detection.
246    pub fn blast_radius(
247        &self,
248        diff: &str,
249        depth: usize,
250    ) -> Result<BlastRadiusReport, CodememError> {
251        let mapping = self.diff_to_symbols(diff)?;
252        let graph = self.lock_graph()?;
253
254        let mut changed_infos = Vec::new();
255        let mut direct_deps = Vec::new();
256        let mut transitive_deps = Vec::new();
257        let mut affected_files: HashSet<String> = HashSet::new();
258        let mut affected_modules: HashSet<String> = HashSet::new();
259        let mut seen: HashSet<String> = HashSet::new();
260        let mut risk_score: f64 = 0.0;
261
262        // Collect changed symbol info + their PageRank for risk scoring
263        for sym_id in &mapping.changed_symbols {
264            if let Some(info) = node_to_symbol_info(&**graph, sym_id) {
265                risk_score += info.pagerank;
266                if let Some(ref fp) = info.file_path {
267                    affected_files.insert(fp.clone());
268                }
269                seen.insert(sym_id.clone());
270                changed_infos.push(info);
271            }
272        }
273        for sym_id in &mapping.containing_symbols {
274            if let Some(info) = node_to_symbol_info(&**graph, sym_id) {
275                if let Some(ref fp) = info.file_path {
276                    affected_files.insert(fp.clone());
277                }
278                seen.insert(sym_id.clone());
279                changed_infos.push(info);
280            }
281        }
282
283        // BFS from changed symbols to find dependents
284        let all_changed: Vec<&str> = mapping
285            .changed_symbols
286            .iter()
287            .chain(mapping.containing_symbols.iter())
288            .map(|s| s.as_str())
289            .collect();
290
291        for &start_id in &all_changed {
292            // Get direct dependents (1-hop incoming edges: who CALLS/IMPORTS this symbol?)
293            let edges = graph.get_edges(start_id).unwrap_or_default();
294            for edge in &edges {
295                // Incoming edges: other symbols that depend on this one
296                let dependent_id = if edge.dst == start_id {
297                    &edge.src
298                } else {
299                    continue; // outgoing edge, skip
300                };
301                if !dependent_id.starts_with("sym:") || !seen.insert(dependent_id.clone()) {
302                    continue;
303                }
304                if matches!(
305                    edge.relationship,
306                    RelationshipType::Calls
307                        | RelationshipType::Imports
308                        | RelationshipType::Inherits
309                        | RelationshipType::Implements
310                        | RelationshipType::Overrides
311                ) {
312                    if let Some(info) = node_to_symbol_info(&**graph, dependent_id) {
313                        if let Some(ref fp) = info.file_path {
314                            affected_files.insert(fp.clone());
315                        }
316                        direct_deps.push(info);
317                    }
318                }
319            }
320        }
321
322        // Transitive dependents (2+ hops) via iterative incoming-edge traversal.
323        // BFS follows outgoing edges (wrong direction for "who depends on me?").
324        // Instead, walk incoming edges layer by layer.
325        if depth > 1 {
326            let mut frontier: Vec<String> = direct_deps.iter().map(|d| d.id.clone()).collect();
327            for _ in 1..depth {
328                let mut next_frontier = Vec::new();
329                for node_id in &frontier {
330                    let edges = graph.get_edges(node_id).unwrap_or_default();
331                    for edge in &edges {
332                        // Only follow incoming dependency edges
333                        if edge.dst != *node_id {
334                            continue;
335                        }
336                        if !matches!(
337                            edge.relationship,
338                            RelationshipType::Calls
339                                | RelationshipType::Imports
340                                | RelationshipType::Inherits
341                                | RelationshipType::Implements
342                                | RelationshipType::Overrides
343                        ) {
344                            continue;
345                        }
346                        let dep_id = &edge.src;
347                        if !dep_id.starts_with("sym:") || !seen.insert(dep_id.clone()) {
348                            continue;
349                        }
350                        if let Some(info) = node_to_symbol_info(&**graph, dep_id) {
351                            if let Some(ref fp) = info.file_path {
352                                affected_files.insert(fp.clone());
353                            }
354                            if info.kind == "Module" {
355                                affected_modules.insert(info.id.clone());
356                            }
357                            next_frontier.push(dep_id.clone());
358                            transitive_deps.push(info);
359                        }
360                    }
361                }
362                if next_frontier.is_empty() {
363                    break;
364                }
365                frontier = next_frontier;
366            }
367        }
368
369        // Detect affected modules from all symbols
370        for info in changed_infos.iter().chain(direct_deps.iter()) {
371            if info.kind == "Module" {
372                affected_modules.insert(info.id.clone());
373            }
374        }
375
376        // Risk score: Σ(pagerank) + log(transitive_count + 1)
377        // Additive so that diffs touching zero-pagerank symbols (common when
378        // centrality hasn't been computed or symbols have no edges) still get
379        // a nonzero risk score from their dependent count.
380        let transitive_count = direct_deps.len() + transitive_deps.len();
381        risk_score += (transitive_count as f64 + 1.0).ln();
382
383        // Drop graph lock before accessing storage
384        drop(graph);
385
386        // Find relevant memories connected to changed symbols
387        let mut relevant_memories = Vec::new();
388        for sym_id in mapping
389            .changed_symbols
390            .iter()
391            .chain(mapping.containing_symbols.iter())
392            .take(20)
393        {
394            if let Ok(results) = self.get_node_memories(sym_id, 1, None) {
395                for r in &results {
396                    relevant_memories.push(MemorySnippet::from(&r.memory));
397                }
398            }
399        }
400        // Dedup memories by ID
401        let mut seen_mem_ids: HashSet<String> = HashSet::new();
402        relevant_memories.retain(|m| seen_mem_ids.insert(m.id.clone()));
403
404        // Missing change detection: find symbols with similar caller patterns
405        let graph = self.lock_graph()?;
406        let missing_changes = detect_missing_changes(&**graph, &mapping.changed_symbols, &seen);
407
408        // Missing co-change detection: files that historically change together
409        // with the diff's changed files but are absent from the diff.
410        let missing_co_changes = detect_missing_co_changes(&**graph, &mapping.changed_files);
411
412        let affected_files: Vec<String> = affected_files.into_iter().collect();
413        let affected_modules: Vec<String> = affected_modules.into_iter().collect();
414
415        Ok(BlastRadiusReport {
416            changed_symbols: changed_infos,
417            direct_dependents: direct_deps,
418            transitive_dependents: transitive_deps,
419            affected_files,
420            affected_modules,
421            risk_score,
422            missing_changes,
423            missing_co_changes,
424            relevant_memories,
425        })
426    }
427}
428
429// ── Helpers ──────────────────────────────────────────────────────────────
430
431fn node_to_symbol_info(graph: &dyn GraphBackend, node_id: &str) -> Option<SymbolInfo> {
432    let node = graph.get_node(node_id).ok()??;
433    Some(SymbolInfo {
434        id: node.id.clone(),
435        label: node.label.clone(),
436        kind: node.kind.to_string(),
437        file_path: node
438            .payload
439            .get("file_path")
440            .and_then(|v| v.as_str())
441            .map(String::from),
442        line_start: node
443            .payload
444            .get("line_start")
445            .and_then(|v| v.as_u64())
446            .map(|v| v as u32),
447        pagerank: graph.get_pagerank(&node.id),
448    })
449}
450
451/// Detect potentially missing changes: symbols that share callers with changed symbols
452/// but aren't in the diff.
453fn detect_missing_changes(
454    graph: &dyn GraphBackend,
455    changed_symbols: &[String],
456    already_in_diff: &HashSet<String>,
457) -> Vec<MissingChange> {
458    let mut missing = Vec::new();
459
460    // For each changed symbol, find its callers. Then find what else those callers call.
461    // If a sibling is called by the same callers but not in the diff, flag it.
462    let mut caller_sets: HashMap<String, HashSet<String>> = HashMap::new();
463
464    for sym_id in changed_symbols {
465        let edges = graph.get_edges(sym_id).unwrap_or_default();
466        let callers: HashSet<String> = edges
467            .iter()
468            .filter(|e| e.dst == *sym_id && e.relationship == RelationshipType::Calls)
469            .map(|e| e.src.clone())
470            .collect();
471        if !callers.is_empty() {
472            caller_sets.insert(sym_id.clone(), callers);
473        }
474    }
475
476    // Find siblings: other symbols called by the same callers
477    let mut sibling_counts: HashMap<String, usize> = HashMap::new();
478    for callers in caller_sets.values() {
479        for caller_id in callers {
480            let edges = graph.get_edges(caller_id).unwrap_or_default();
481            for edge in &edges {
482                if edge.src == *caller_id
483                    && edge.relationship == RelationshipType::Calls
484                    && edge.dst.starts_with("sym:")
485                    && !already_in_diff.contains(&edge.dst)
486                {
487                    *sibling_counts.entry(edge.dst.clone()).or_insert(0) += 1;
488                }
489            }
490        }
491    }
492
493    // Flag siblings that share multiple callers with changed symbols
494    let threshold = (changed_symbols.len() / 2).max(2);
495    for (sibling, count) in &sibling_counts {
496        if *count >= threshold {
497            missing.push(MissingChange {
498                symbol: sibling.clone(),
499                reason: format!(
500                    "shares {} callers with {} changed symbols",
501                    count,
502                    changed_symbols.len()
503                ),
504            });
505        }
506    }
507
508    missing
509}
510
511/// Detect files historically coupled (CoChanged edges) with changed files but absent
512/// from the diff. For each missing file, collects which changed files it's coupled with
513/// and averages the coupling strength. Results are sorted by strength descending.
514fn detect_missing_co_changes(
515    graph: &dyn GraphBackend,
516    changed_files: &[String],
517) -> Vec<MissingCoChange> {
518    let changed_set: HashSet<&str> = changed_files.iter().map(|s| s.as_str()).collect();
519
520    // Map: missing_file_id -> Vec<(coupled_changed_file, weight)>
521    let mut candidates: HashMap<String, Vec<(String, f64)>> = HashMap::new();
522
523    for file_id in changed_files {
524        let edges = graph.get_edges(file_id).unwrap_or_default();
525        for edge in &edges {
526            if edge.relationship != RelationshipType::CoChanged {
527                continue;
528            }
529            // Find the other end of the edge
530            let other = if edge.src == *file_id {
531                &edge.dst
532            } else {
533                &edge.src
534            };
535            // Skip files already in the diff
536            if changed_set.contains(other.as_str()) {
537                continue;
538            }
539            // Strip "file:" prefix for display
540            let file_path = file_id.strip_prefix("file:").unwrap_or(file_id);
541            candidates
542                .entry(other.clone())
543                .or_default()
544                .push((file_path.to_string(), edge.weight));
545        }
546    }
547
548    // Minimum average coupling strength to surface a missing co-change.
549    // Below this threshold, the coupling is too weak to be actionable.
550    const MIN_CO_CHANGE_STRENGTH: f64 = 0.3;
551
552    let mut result: Vec<MissingCoChange> = candidates
553        .into_iter()
554        .filter_map(|(file_id, couplings)| {
555            let strength = couplings.iter().map(|(_, w)| w).sum::<f64>() / couplings.len() as f64;
556            if strength < MIN_CO_CHANGE_STRENGTH {
557                return None;
558            }
559            let coupled_with = couplings.into_iter().map(|(f, _)| f).collect();
560            let file_path = file_id
561                .strip_prefix("file:")
562                .unwrap_or(&file_id)
563                .to_string();
564            Some(MissingCoChange {
565                file_path,
566                coupled_with,
567                strength,
568            })
569        })
570        .collect();
571
572    result.sort_by(|a, b| {
573        b.strength
574            .partial_cmp(&a.strength)
575            .unwrap_or(std::cmp::Ordering::Equal)
576    });
577    result
578}
579
580#[cfg(test)]
581#[path = "tests/review_tests.rs"]
582mod tests;