1use crate::CodememEngine;
5use codemem_core::{CodememError, GraphBackend, MemoryNode, RelationshipType};
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone)]
12pub struct DiffHunk {
13 pub file_path: String,
14 pub added_lines: Vec<u32>,
15 pub removed_lines: Vec<u32>,
16}
17
18#[derive(Debug, Clone, Default)]
20pub struct DiffSymbolMapping {
21 pub changed_symbols: Vec<String>,
23 pub containing_symbols: Vec<String>,
25 pub changed_files: Vec<String>,
27}
28
29#[derive(Debug, Clone, serde::Serialize)]
31pub struct SymbolInfo {
32 pub id: String,
33 pub label: String,
34 pub kind: String,
35 pub file_path: Option<String>,
36 pub line_start: Option<u32>,
37 pub pagerank: f64,
38}
39
40#[derive(Debug, Clone, serde::Serialize)]
42pub struct MissingChange {
43 pub symbol: String,
44 pub reason: String,
45}
46
47#[derive(Debug, Clone, serde::Serialize)]
49pub struct BlastRadiusReport {
50 pub changed_symbols: Vec<SymbolInfo>,
51 pub direct_dependents: Vec<SymbolInfo>,
52 pub transitive_dependents: Vec<SymbolInfo>,
53 pub affected_files: Vec<String>,
54 pub affected_modules: Vec<String>,
55 pub risk_score: f64,
56 pub missing_changes: Vec<MissingChange>,
57 pub relevant_memories: Vec<MemorySnippet>,
58}
59
60#[derive(Debug, Clone, serde::Serialize)]
62pub struct MemorySnippet {
63 pub id: String,
64 pub content: String,
65 pub memory_type: String,
66 pub importance: f64,
67}
68
69impl From<&MemoryNode> for MemorySnippet {
70 fn from(m: &MemoryNode) -> Self {
71 Self {
72 id: m.id.clone(),
73 content: m.content.clone(),
74 memory_type: m.memory_type.to_string(),
75 importance: m.importance,
76 }
77 }
78}
79
80pub fn parse_diff(diff: &str) -> Vec<DiffHunk> {
84 let mut hunks = Vec::new();
85 let mut current_file: Option<String> = None;
86 let mut added_lines: Vec<u32> = Vec::new();
87 let mut removed_lines: Vec<u32> = Vec::new();
88 let mut new_line: u32 = 0;
89 let mut old_line: u32 = 0;
90
91 for line in diff.lines() {
92 if line.starts_with("+++ b/") {
93 if let Some(ref file) = current_file {
95 if !added_lines.is_empty() || !removed_lines.is_empty() {
96 hunks.push(DiffHunk {
97 file_path: file.clone(),
98 added_lines: std::mem::take(&mut added_lines),
99 removed_lines: std::mem::take(&mut removed_lines),
100 });
101 }
102 }
103 current_file = line.strip_prefix("+++ b/").map(|s| s.to_string());
104 } else if line.starts_with("@@ ") {
105 if let Some((new_start, old_start)) = parse_hunk_header(line) {
107 new_line = new_start;
108 old_line = old_start;
109 }
110 } else if current_file.is_some() {
111 if line.starts_with('+') && !line.starts_with("+++") {
112 added_lines.push(new_line);
113 new_line += 1;
114 } else if line.starts_with('-') && !line.starts_with("---") {
115 removed_lines.push(old_line);
116 old_line += 1;
117 } else {
118 new_line += 1;
120 old_line += 1;
121 }
122 }
123 }
124
125 if let Some(file) = current_file {
127 if !added_lines.is_empty() || !removed_lines.is_empty() {
128 hunks.push(DiffHunk {
129 file_path: file,
130 added_lines,
131 removed_lines,
132 });
133 }
134 }
135
136 hunks
137}
138
139fn parse_hunk_header(line: &str) -> Option<(u32, u32)> {
141 let parts: Vec<&str> = line.split_whitespace().collect();
143 if parts.len() < 3 {
144 return None;
145 }
146 let old_part = parts[1].strip_prefix('-')?;
147 let new_part = parts[2].strip_prefix('+')?;
148
149 let old_start: u32 = old_part.split(',').next()?.parse().ok()?;
150 let new_start: u32 = new_part.split(',').next()?.parse().ok()?;
151 Some((new_start, old_start))
152}
153
154impl CodememEngine {
157 pub fn diff_to_symbols(&self, diff: &str) -> Result<DiffSymbolMapping, CodememError> {
159 let hunks = parse_diff(diff);
160 let graph = self.lock_graph()?;
161 let all_nodes = graph.get_all_nodes();
162
163 let mut mapping = DiffSymbolMapping::default();
164 let mut seen_symbols: HashSet<String> = HashSet::new();
165 let mut seen_files: HashSet<String> = HashSet::new();
166
167 let mut file_symbols: HashMap<&str, Vec<&codemem_core::GraphNode>> = HashMap::new();
169 for node in &all_nodes {
170 if !node.id.starts_with("sym:") {
171 continue;
172 }
173 if let Some(fp) = node.payload.get("file_path").and_then(|v| v.as_str()) {
174 file_symbols.entry(fp).or_default().push(node);
175 }
176 }
177
178 for hunk in &hunks {
179 let file_id = format!("file:{}", hunk.file_path);
180 if seen_files.insert(file_id.clone()) {
181 mapping.changed_files.push(file_id);
182 }
183
184 let changed_lines: HashSet<u32> = hunk
185 .added_lines
186 .iter()
187 .chain(hunk.removed_lines.iter())
188 .copied()
189 .collect();
190
191 if let Some(nodes) = file_symbols.get(hunk.file_path.as_str()) {
193 for node in nodes {
194 let line_start = node
195 .payload
196 .get("line_start")
197 .and_then(|v| v.as_u64())
198 .unwrap_or(0) as u32;
199 let line_end = node
200 .payload
201 .get("line_end")
202 .and_then(|v| v.as_u64())
203 .unwrap_or(line_start as u64) as u32;
204
205 let overlaps = changed_lines
206 .iter()
207 .any(|&l| l >= line_start && l <= line_end);
208 if overlaps && seen_symbols.insert(node.id.clone()) {
209 mapping.changed_symbols.push(node.id.clone());
210 }
211 }
212 }
213 }
214
215 let changed_set: HashSet<&str> =
217 mapping.changed_symbols.iter().map(|s| s.as_str()).collect();
218 for node in &all_nodes {
219 if !node.id.starts_with("sym:") || changed_set.contains(node.id.as_str()) {
220 continue;
221 }
222 let edges = graph.get_edges(&node.id).unwrap_or_default();
224 let contains_changed = edges.iter().any(|e| {
225 e.relationship == RelationshipType::Contains && changed_set.contains(e.dst.as_str())
226 });
227 if contains_changed && seen_symbols.insert(node.id.clone()) {
228 mapping.containing_symbols.push(node.id.clone());
229 }
230 }
231
232 Ok(mapping)
233 }
234
235 pub fn blast_radius(
238 &self,
239 diff: &str,
240 depth: usize,
241 ) -> Result<BlastRadiusReport, CodememError> {
242 let mapping = self.diff_to_symbols(diff)?;
243 let graph = self.lock_graph()?;
244
245 let mut changed_infos = Vec::new();
246 let mut direct_deps = Vec::new();
247 let mut transitive_deps = Vec::new();
248 let mut affected_files: HashSet<String> = HashSet::new();
249 let mut affected_modules: HashSet<String> = HashSet::new();
250 let mut seen: HashSet<String> = HashSet::new();
251 let mut risk_score: f64 = 0.0;
252
253 for sym_id in &mapping.changed_symbols {
255 if let Some(info) = node_to_symbol_info(&**graph, sym_id) {
256 risk_score += info.pagerank;
257 if let Some(ref fp) = info.file_path {
258 affected_files.insert(fp.clone());
259 }
260 seen.insert(sym_id.clone());
261 changed_infos.push(info);
262 }
263 }
264 for sym_id in &mapping.containing_symbols {
265 if let Some(info) = node_to_symbol_info(&**graph, sym_id) {
266 if let Some(ref fp) = info.file_path {
267 affected_files.insert(fp.clone());
268 }
269 seen.insert(sym_id.clone());
270 changed_infos.push(info);
271 }
272 }
273
274 let all_changed: Vec<&str> = mapping
276 .changed_symbols
277 .iter()
278 .chain(mapping.containing_symbols.iter())
279 .map(|s| s.as_str())
280 .collect();
281
282 for &start_id in &all_changed {
283 let edges = graph.get_edges(start_id).unwrap_or_default();
285 for edge in &edges {
286 let dependent_id = if edge.dst == start_id {
288 &edge.src
289 } else {
290 continue; };
292 if !dependent_id.starts_with("sym:") || !seen.insert(dependent_id.clone()) {
293 continue;
294 }
295 if matches!(
296 edge.relationship,
297 RelationshipType::Calls
298 | RelationshipType::Imports
299 | RelationshipType::Inherits
300 | RelationshipType::Implements
301 | RelationshipType::Overrides
302 ) {
303 if let Some(info) = node_to_symbol_info(&**graph, dependent_id) {
304 if let Some(ref fp) = info.file_path {
305 affected_files.insert(fp.clone());
306 }
307 direct_deps.push(info);
308 }
309 }
310 }
311 }
312
313 if depth > 1 {
317 let mut frontier: Vec<String> = direct_deps.iter().map(|d| d.id.clone()).collect();
318 for _ in 1..depth {
319 let mut next_frontier = Vec::new();
320 for node_id in &frontier {
321 let edges = graph.get_edges(node_id).unwrap_or_default();
322 for edge in &edges {
323 if edge.dst != *node_id {
325 continue;
326 }
327 if !matches!(
328 edge.relationship,
329 RelationshipType::Calls
330 | RelationshipType::Imports
331 | RelationshipType::Inherits
332 | RelationshipType::Implements
333 | RelationshipType::Overrides
334 ) {
335 continue;
336 }
337 let dep_id = &edge.src;
338 if !dep_id.starts_with("sym:") || !seen.insert(dep_id.clone()) {
339 continue;
340 }
341 if let Some(info) = node_to_symbol_info(&**graph, dep_id) {
342 if let Some(ref fp) = info.file_path {
343 affected_files.insert(fp.clone());
344 }
345 if info.kind == "Module" {
346 affected_modules.insert(info.id.clone());
347 }
348 next_frontier.push(dep_id.clone());
349 transitive_deps.push(info);
350 }
351 }
352 }
353 if next_frontier.is_empty() {
354 break;
355 }
356 frontier = next_frontier;
357 }
358 }
359
360 for info in changed_infos.iter().chain(direct_deps.iter()) {
362 if info.kind == "Module" {
363 affected_modules.insert(info.id.clone());
364 }
365 }
366
367 let transitive_count = direct_deps.len() + transitive_deps.len();
372 risk_score += (transitive_count as f64 + 1.0).ln();
373
374 drop(graph);
376
377 let mut relevant_memories = Vec::new();
379 for sym_id in mapping
380 .changed_symbols
381 .iter()
382 .chain(mapping.containing_symbols.iter())
383 .take(20)
384 {
385 if let Ok(results) = self.get_node_memories(sym_id, 1, None) {
386 for r in &results {
387 relevant_memories.push(MemorySnippet::from(&r.memory));
388 }
389 }
390 }
391 let mut seen_mem_ids: HashSet<String> = HashSet::new();
393 relevant_memories.retain(|m| seen_mem_ids.insert(m.id.clone()));
394
395 let graph = self.lock_graph()?;
397 let missing_changes = detect_missing_changes(&**graph, &mapping.changed_symbols, &seen);
398
399 let affected_files: Vec<String> = affected_files.into_iter().collect();
400 let affected_modules: Vec<String> = affected_modules.into_iter().collect();
401
402 Ok(BlastRadiusReport {
403 changed_symbols: changed_infos,
404 direct_dependents: direct_deps,
405 transitive_dependents: transitive_deps,
406 affected_files,
407 affected_modules,
408 risk_score,
409 missing_changes,
410 relevant_memories,
411 })
412 }
413}
414
415fn node_to_symbol_info(graph: &dyn GraphBackend, node_id: &str) -> Option<SymbolInfo> {
418 let node = graph.get_node(node_id).ok()??;
419 Some(SymbolInfo {
420 id: node.id.clone(),
421 label: node.label.clone(),
422 kind: node.kind.to_string(),
423 file_path: node
424 .payload
425 .get("file_path")
426 .and_then(|v| v.as_str())
427 .map(String::from),
428 line_start: node
429 .payload
430 .get("line_start")
431 .and_then(|v| v.as_u64())
432 .map(|v| v as u32),
433 pagerank: graph.get_pagerank(&node.id),
434 })
435}
436
437fn detect_missing_changes(
440 graph: &dyn GraphBackend,
441 changed_symbols: &[String],
442 already_in_diff: &HashSet<String>,
443) -> Vec<MissingChange> {
444 let mut missing = Vec::new();
445
446 let mut caller_sets: HashMap<String, HashSet<String>> = HashMap::new();
449
450 for sym_id in changed_symbols {
451 let edges = graph.get_edges(sym_id).unwrap_or_default();
452 let callers: HashSet<String> = edges
453 .iter()
454 .filter(|e| e.dst == *sym_id && e.relationship == RelationshipType::Calls)
455 .map(|e| e.src.clone())
456 .collect();
457 if !callers.is_empty() {
458 caller_sets.insert(sym_id.clone(), callers);
459 }
460 }
461
462 let mut sibling_counts: HashMap<String, usize> = HashMap::new();
464 for callers in caller_sets.values() {
465 for caller_id in callers {
466 let edges = graph.get_edges(caller_id).unwrap_or_default();
467 for edge in &edges {
468 if edge.src == *caller_id
469 && edge.relationship == RelationshipType::Calls
470 && edge.dst.starts_with("sym:")
471 && !already_in_diff.contains(&edge.dst)
472 {
473 *sibling_counts.entry(edge.dst.clone()).or_insert(0) += 1;
474 }
475 }
476 }
477 }
478
479 let threshold = (changed_symbols.len() / 2).max(2);
481 for (sibling, count) in &sibling_counts {
482 if *count >= threshold {
483 missing.push(MissingChange {
484 symbol: sibling.clone(),
485 reason: format!(
486 "shares {} callers with {} changed symbols",
487 count,
488 changed_symbols.len()
489 ),
490 });
491 }
492 }
493
494 missing
495}
496
497#[cfg(test)]
498#[path = "tests/review_tests.rs"]
499mod tests;