1use crate::code_graph::{CodeGraph, CodeNode, NodeKind};
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, Default)]
13pub struct GidContext {
14 pub nodes_touched: Vec<NodeInfo>,
16 pub max_callers: usize,
18 pub total_blast_radius: usize,
20 pub hub_nodes: Vec<NodeInfo>,
22}
23
24#[derive(Debug, Clone)]
26pub struct NodeInfo {
27 pub id: String,
28 pub name: String,
29 pub file: String,
30 pub kind: String,
31 pub callers: usize,
32 pub callees: usize,
33 pub line: Option<usize>,
34}
35
36impl NodeInfo {
37 pub fn from_code_node(node: &CodeNode, callers: usize, callees: usize) -> Self {
38 Self {
39 id: node.id.clone(),
40 name: node.name.clone(),
41 file: node.file_path.clone(),
42 kind: match node.kind {
43 NodeKind::File => "file",
44 NodeKind::Class => "class",
45 NodeKind::Function => "function",
46 NodeKind::Module => "module",
47 }.to_string(),
48 callers,
49 callees,
50 line: node.line,
51 }
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
57pub enum ErrorType {
58 Syntax,
59 Import,
60 Attribute,
61 Assertion,
62 Type,
63 Name,
64 Runtime,
65 Timeout,
66 Unknown,
67}
68
69impl std::fmt::Display for ErrorType {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 ErrorType::Syntax => write!(f, "SyntaxError"),
73 ErrorType::Import => write!(f, "ImportError"),
74 ErrorType::Attribute => write!(f, "AttributeError"),
75 ErrorType::Assertion => write!(f, "AssertionError"),
76 ErrorType::Type => write!(f, "TypeError"),
77 ErrorType::Name => write!(f, "NameError"),
78 ErrorType::Runtime => write!(f, "RuntimeError"),
79 ErrorType::Timeout => write!(f, "Timeout"),
80 ErrorType::Unknown => write!(f, "Unknown"),
81 }
82 }
83}
84
85pub fn query_gid_context(files_changed: &[String], graph: &CodeGraph) -> GidContext {
90 let mut nodes = Vec::new();
91 let mut max_callers = 0;
92 let mut total_blast = 0;
93
94 for file in files_changed {
95 let file_nodes: Vec<&CodeNode> = graph.nodes.iter()
97 .filter(|n| {
98 n.file_path == *file
99 && !n.is_test
100 && matches!(n.kind, NodeKind::Function | NodeKind::Class)
101 })
102 .collect();
103
104 for node in file_nodes {
105 let callers = graph.get_callers(&node.id).len();
106 let callees = graph.get_callees(&node.id).len();
107
108 max_callers = max_callers.max(callers);
109 total_blast += callers;
110
111 nodes.push(NodeInfo::from_code_node(node, callers, callees));
112 }
113 }
114
115 nodes.sort_by(|a, b| b.callers.cmp(&a.callers));
117 nodes.truncate(10);
118
119 let hub_threshold = 10;
121 let hub_nodes: Vec<NodeInfo> = nodes.iter()
122 .filter(|n| n.callers >= hub_threshold)
123 .cloned()
124 .collect();
125
126 GidContext {
127 nodes_touched: nodes,
128 max_callers,
129 total_blast_radius: total_blast,
130 hub_nodes,
131 }
132}
133
134pub fn find_low_risk_alternatives(
137 graph: &CodeGraph,
138 failed_files: &[String],
139 max_callers: usize,
140) -> Vec<NodeInfo> {
141 let mut alternatives = Vec::new();
142
143 let packages: HashSet<String> = failed_files.iter()
145 .filter_map(|f| {
146 f.rsplitn(2, '/').nth(1).map(|s| s.to_string())
147 })
148 .collect();
149
150 for node in &graph.nodes {
151 if node.is_test {
152 continue;
153 }
154 if !matches!(node.kind, NodeKind::Function) {
155 continue;
156 }
157
158 let in_package = packages.iter().any(|pkg| node.file_path.starts_with(pkg));
160 if !in_package {
161 continue;
162 }
163
164 if failed_files.contains(&node.file_path) {
166 continue;
167 }
168
169 let callers = graph.get_callers(&node.id).len();
170 if callers <= max_callers {
171 let callees = graph.get_callees(&node.id).len();
172 alternatives.push(NodeInfo::from_code_node(node, callers, callees));
173 }
174 }
175
176 alternatives.sort_by_key(|n| n.callers);
178 alternatives.truncate(5);
179 alternatives
180}
181
182pub fn classify_error(raw_output: &str) -> ErrorType {
184 let checks: &[(ErrorType, &[&str])] = &[
185 (ErrorType::Syntax, &["SyntaxError:", "SyntaxError("]),
186 (ErrorType::Import, &["ImportError:", "ModuleNotFoundError:"]),
187 (ErrorType::Attribute, &["AttributeError:"]),
188 (ErrorType::Assertion, &["AssertionError:", "AssertionError(", "assert "]),
189 (ErrorType::Type, &["TypeError:"]),
190 (ErrorType::Name, &["NameError:"]),
191 (ErrorType::Timeout, &["TimeoutError", "timed out", "TIMEOUT"]),
192 ];
193
194 let mut best = ErrorType::Unknown;
195 let mut best_count = 0;
196
197 for (etype, patterns) in checks {
198 let count: usize = patterns.iter()
199 .map(|p| raw_output.matches(p).count())
200 .sum();
201 if count > best_count {
202 best_count = count;
203 best = etype.clone();
204 }
205 }
206
207 if best != ErrorType::Syntax && raw_output.contains("SyntaxError:") {
209 return ErrorType::Syntax;
210 }
211
212 best
213}
214
215pub fn extract_key_traceback(raw_output: &str, max_chars: usize) -> String {
217 let traceback_marker = "Traceback (most recent call last)";
218
219 if let Some(pos) = raw_output.find(traceback_marker) {
220 let chunk = &raw_output[pos..];
221 let end = chunk.find("\n\n")
222 .or_else(|| chunk.find("\n====="))
223 .or_else(|| chunk.find("\nFAILED"))
224 .unwrap_or(chunk.len());
225 return chunk[..end.min(max_chars)].to_string();
226 }
227
228 for marker in &["FAIL:", "ERROR:", "FAILED "] {
230 if let Some(pos) = raw_output.find(marker) {
231 let start = pos.saturating_sub(200);
232 let end = (pos + max_chars).min(raw_output.len());
233 return raw_output[start..end].to_string();
234 }
235 }
236
237 let start = raw_output.len().saturating_sub(max_chars);
239 raw_output[start..].to_string()
240}
241
242#[derive(Debug, Clone)]
246pub struct ImpactAnalysis {
247 pub affected_source: Vec<NodeInfo>,
249 pub affected_tests: Vec<NodeInfo>,
251 pub risk_level: RiskLevel,
253 pub summary: String,
255}
256
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub enum RiskLevel {
259 Low, Medium, High, Critical, }
264
265impl std::fmt::Display for RiskLevel {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 match self {
268 RiskLevel::Low => write!(f, "low"),
269 RiskLevel::Medium => write!(f, "medium"),
270 RiskLevel::High => write!(f, "high"),
271 RiskLevel::Critical => write!(f, "critical"),
272 }
273 }
274}
275
276pub fn analyze_impact(files_changed: &[String], graph: &CodeGraph) -> ImpactAnalysis {
278 let gid_ctx = query_gid_context(files_changed, graph);
279
280 let mut affected_source = Vec::new();
281 let mut affected_tests = Vec::new();
282 let mut seen = HashSet::new();
283
284 let changed_node_ids: Vec<String> = graph.nodes.iter()
286 .filter(|n| files_changed.contains(&n.file_path))
287 .map(|n| n.id.clone())
288 .collect();
289
290 for node_id in &changed_node_ids {
292 for impacted in graph.get_impact(node_id) {
293 if seen.insert(impacted.id.clone()) {
294 let callers = graph.get_callers(&impacted.id).len();
295 let callees = graph.get_callees(&impacted.id).len();
296 let info = NodeInfo::from_code_node(impacted, callers, callees);
297
298 if impacted.is_test {
299 affected_tests.push(info);
300 } else {
301 affected_source.push(info);
302 }
303 }
304 }
305 }
306
307 let risk_level = match gid_ctx.max_callers {
309 0..=5 => RiskLevel::Low,
310 6..=20 => RiskLevel::Medium,
311 21..=50 => RiskLevel::High,
312 _ => RiskLevel::Critical,
313 };
314
315 let summary = format!(
317 "Changing {} file(s) affects {} source nodes and {} test nodes. Risk: {} (max {} callers, blast radius {}).",
318 files_changed.len(),
319 affected_source.len(),
320 affected_tests.len(),
321 risk_level,
322 gid_ctx.max_callers,
323 gid_ctx.total_blast_radius,
324 );
325
326 ImpactAnalysis {
327 affected_source,
328 affected_tests,
329 risk_level,
330 summary,
331 }
332}
333
334pub fn format_impact_for_llm(analysis: &ImpactAnalysis) -> String {
336 let mut result = String::new();
337
338 result.push_str(&format!("## Impact Analysis\n\n{}\n\n", analysis.summary));
339
340 if !analysis.affected_source.is_empty() {
341 result.push_str("**Affected source code:**\n");
342 for node in analysis.affected_source.iter().take(10) {
343 result.push_str(&format!(
344 "- {} `{}` ({} callers)\n",
345 node.kind, node.name, node.callers
346 ));
347 }
348 if analysis.affected_source.len() > 10 {
349 result.push_str(&format!(" ...and {} more\n", analysis.affected_source.len() - 10));
350 }
351 result.push('\n');
352 }
353
354 if !analysis.affected_tests.is_empty() {
355 result.push_str("**Related tests:**\n");
356 for node in analysis.affected_tests.iter().take(10) {
357 result.push_str(&format!("- `{}` in {}\n", node.name, node.file));
358 }
359 if analysis.affected_tests.len() > 10 {
360 result.push_str(&format!(" ...and {} more\n", analysis.affected_tests.len() - 10));
361 }
362 result.push('\n');
363 }
364
365 if analysis.risk_level == RiskLevel::High || analysis.risk_level == RiskLevel::Critical {
366 result.push_str("⚠️ **High-risk change!** Consider:\n");
367 result.push_str("- Breaking the change into smaller pieces\n");
368 result.push_str("- Adding backward compatibility\n");
369 result.push_str("- Running full test suite before committing\n\n");
370 }
371
372 result
373}
374
375#[derive(Debug, Clone)]
379pub enum Action {
380 Edit { files: Vec<String>, applied: usize, total: usize },
381 Revert,
382 Read { file: String },
383 Search { pattern: String },
384 Query { kind: String, target: String },
385 Test,
386 Other(String),
387}
388
389impl std::fmt::Display for Action {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 match self {
392 Action::Edit { files, applied, total } => {
393 let names: Vec<&str> = files.iter().map(|f| {
394 f.rsplit('/').next().unwrap_or(f.as_str())
395 }).collect();
396 write!(f, "EDIT {} ({}/{})", names.join(", "), applied, total)
397 }
398 Action::Revert => write!(f, "REVERT"),
399 Action::Read { file } => write!(f, "READ {}", file.rsplit('/').next().unwrap_or(file)),
400 Action::Search { pattern } => {
401 let display = if pattern.len() > 30 {
402 let mut end = 30;
403 while end > 0 && !pattern.is_char_boundary(end) { end -= 1; }
404 &pattern[..end]
405 } else {
406 pattern.as_str()
407 };
408 write!(f, "SEARCH '{}'", display)
409 }
410 Action::Query { kind, target } => write!(f, "GID {} {}", kind, target),
411 Action::Test => write!(f, "TEST"),
412 Action::Other(s) => {
413 let display = if s.len() > 30 {
414 let mut end = 30;
415 while end > 0 && !s.is_char_boundary(end) { end -= 1; }
416 &s[..end]
417 } else {
418 s.as_str()
419 };
420 write!(f, "{}", display)
421 }
422 }
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct TestOutcome {
429 pub error_type: ErrorType,
431 pub primary: (usize, usize),
433 pub secondary: (usize, usize),
435 pub key_error_trace: String,
437 pub failed_secondary_names: Vec<String>,
439}
440
441impl TestOutcome {
442 pub fn new(
443 error_type: ErrorType,
444 primary_passed: usize,
445 primary_total: usize,
446 secondary_passed: usize,
447 secondary_total: usize,
448 ) -> Self {
449 Self {
450 error_type,
451 primary: (primary_passed, primary_total),
452 secondary: (secondary_passed, secondary_total),
453 key_error_trace: String::new(),
454 failed_secondary_names: Vec::new(),
455 }
456 }
457
458 pub fn with_trace(mut self, trace: String) -> Self {
459 self.key_error_trace = trace;
460 self
461 }
462
463 pub fn with_failed_names(mut self, names: Vec<String>) -> Self {
464 self.failed_secondary_names = names;
465 self
466 }
467
468 pub fn score(&self) -> i32 {
471 let secondary_clean = if self.secondary.1 == 0 || self.secondary.0 == self.secondary.1 { 1 } else { 0 };
472 (self.primary.0 as i32) * 1000 * secondary_clean + self.secondary.0 as i32
473 }
474}
475
476#[derive(Debug, Clone)]
478pub struct AttemptRecord {
479 pub round: usize,
480 pub action: Action,
481 pub gid_context: Option<GidContext>,
482 pub test_outcome: Option<TestOutcome>,
483 pub feedback: String,
485}
486
487#[derive(Debug, Clone)]
489pub struct NodeRisk {
490 pub callers: usize,
491 pub times_tried: usize,
492 pub times_failed: usize,
493}
494
495pub struct WorkingMemory {
498 pub attempts: Vec<AttemptRecord>,
499 pub node_risk_map: std::collections::HashMap<String, NodeRisk>,
500 pub best_score: i32,
501 pub best_attempt: Option<usize>,
502 pub low_risk_alternatives: Vec<NodeInfo>,
504}
505
506impl Default for WorkingMemory {
507 fn default() -> Self {
508 Self::new()
509 }
510}
511
512impl WorkingMemory {
513 pub fn new() -> Self {
514 Self {
515 attempts: Vec::new(),
516 node_risk_map: std::collections::HashMap::new(),
517 best_score: -1,
518 best_attempt: None,
519 low_risk_alternatives: Vec::new(),
520 }
521 }
522
523 pub fn record_edit(
525 &mut self,
526 round: usize,
527 files: Vec<String>,
528 applied: usize,
529 total: usize,
530 gid_ctx: GidContext,
531 feedback: String,
532 ) {
533 self.attempts.push(AttemptRecord {
534 round,
535 action: Action::Edit { files, applied, total },
536 gid_context: Some(gid_ctx),
537 test_outcome: None,
538 feedback,
539 });
540 }
541
542 pub fn record_test(&mut self, round: usize, outcome: TestOutcome, raw_feedback: String) {
544 let score = outcome.score();
545
546 if score > self.best_score {
547 self.best_score = score;
548 self.best_attempt = Some(round);
549 }
550
551 if let Some(last_edit) = self.attempts.iter().rev().find(|a| matches!(a.action, Action::Edit { .. })) {
553 if let Some(ref gid) = last_edit.gid_context {
554 for node in &gid.nodes_touched {
555 let entry = self.node_risk_map.entry(node.name.clone()).or_insert(NodeRisk {
556 callers: node.callers,
557 times_tried: 0,
558 times_failed: 0,
559 });
560 entry.times_tried += 1;
561 if outcome.secondary.0 < outcome.secondary.1 || outcome.primary.0 < outcome.primary.1 {
562 entry.times_failed += 1;
563 }
564 }
565 }
566 }
567
568 self.attempts.push(AttemptRecord {
569 round,
570 action: Action::Test,
571 gid_context: None,
572 test_outcome: Some(outcome),
573 feedback: raw_feedback,
574 });
575 }
576
577 pub fn record_action(&mut self, round: usize, action: Action, feedback: String) {
579 self.attempts.push(AttemptRecord {
580 round,
581 action,
582 gid_context: None,
583 test_outcome: None,
584 feedback,
585 });
586 }
587
588 pub fn project_to_prompt(&self) -> String {
591 let mut out = String::new();
592
593 let test_attempts: Vec<&AttemptRecord> = self.attempts.iter()
595 .filter(|a| a.test_outcome.is_some())
596 .collect();
597
598 if !test_attempts.is_empty() {
599 out.push_str("## Attempt History\n\n");
600 out.push_str("| # | Target | Callers | Error | Primary | Secondary |\n");
601 out.push_str("|---|--------|---------|-------|---------|------------|\n");
602
603 for test_a in &test_attempts {
604 let t = test_a.test_outcome.as_ref().unwrap();
605
606 let edit_info = self.attempts.iter()
608 .filter(|a| a.round < test_a.round && matches!(a.action, Action::Edit { .. }))
609 .last();
610
611 let (target, callers) = if let Some(edit) = edit_info {
612 let target_str = match &edit.action {
613 Action::Edit { files, .. } => {
614 files.iter()
615 .map(|f| f.rsplit('/').next().unwrap_or(f))
616 .collect::<Vec<_>>()
617 .join(", ")
618 }
619 _ => "-".into(),
620 };
621 let callers_str = edit.gid_context.as_ref()
622 .map(|g| g.max_callers.to_string())
623 .unwrap_or("-".into());
624 (target_str, callers_str)
625 } else {
626 ("-".into(), "-".into())
627 };
628
629 out.push_str(&format!(
630 "| {} | {} | {} | {} | {}/{} | {}/{} |\n",
631 test_a.round,
632 target,
633 callers,
634 t.error_type,
635 t.primary.0, t.primary.1,
636 t.secondary.0, t.secondary.1,
637 ));
638 }
639 out.push('\n');
640 }
641
642 let mut risky: Vec<(&String, &NodeRisk)> = self.node_risk_map.iter()
644 .filter(|(_, r)| r.times_failed > 0)
645 .collect();
646 risky.sort_by(|a, b| b.1.callers.cmp(&a.1.callers));
647
648 if !risky.is_empty() {
649 out.push_str("## Node History\n");
650 for (name, risk) in risky.iter().take(10) {
651 out.push_str(&format!(
652 "- {} — {} callers, tried {}, failed {}\n",
653 name, risk.callers, risk.times_tried, risk.times_failed
654 ));
655 }
656 out.push('\n');
657 }
658
659 if !self.low_risk_alternatives.is_empty() {
661 out.push_str("## Low-Coupling Alternatives\n");
662 for alt in &self.low_risk_alternatives {
663 out.push_str(&format!(
664 "- {} ({}) — {} callers\n",
665 alt.name, alt.file.rsplit('/').next().unwrap_or(&alt.file), alt.callers
666 ));
667 }
668 out.push('\n');
669 }
670
671 if let Some(last_test) = self.attempts.iter().rev().find(|a| a.test_outcome.is_some()) {
673 let t = last_test.test_outcome.as_ref().unwrap();
674 out.push_str(&format!("## Latest Error (Round {})\n", last_test.round));
675 out.push_str(&format!("Type: {}\n", t.error_type));
676 out.push_str(&format!("Primary: {}/{}, Secondary: {}/{}\n",
677 t.primary.0, t.primary.1, t.secondary.0, t.secondary.1));
678
679 if !t.key_error_trace.is_empty() {
680 out.push_str(&format!("\n```\n{}\n```\n", t.key_error_trace));
681 }
682
683 if !t.failed_secondary_names.is_empty() {
685 let show: Vec<&str> = t.failed_secondary_names.iter().take(10).map(|s| s.as_str()).collect();
686 let remaining = t.failed_secondary_names.len().saturating_sub(10);
687 out.push_str(&format!("\nFailed: {}", show.join(", ")));
688 if remaining > 0 {
689 out.push_str(&format!(" (+{} more)", remaining));
690 }
691 out.push('\n');
692 }
693 }
694
695 if let Some(best_round) = self.best_attempt {
697 out.push_str(&format!(
698 "\n## Best Result: Round {} (score {})\n",
699 best_round, self.best_score
700 ));
701 }
702
703 out
704 }
705
706 pub fn last_feedback(&self) -> &str {
708 self.attempts.last()
709 .map(|a| a.feedback.as_str())
710 .unwrap_or("")
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use crate::code_graph::{CodeEdge, EdgeRelation};
718
719 #[test]
720 fn test_classify_error() {
721 assert_eq!(classify_error("SyntaxError: invalid syntax"), ErrorType::Syntax);
722 assert_eq!(classify_error("ImportError: No module named 'foo'"), ErrorType::Import);
723 assert_eq!(classify_error("AssertionError: 1 != 2"), ErrorType::Assertion);
724 }
725
726 #[test]
727 fn test_classify_syntax_overrides() {
728 let output = "ImportError: ...\nSyntaxError: invalid syntax\nImportError: ...";
729 assert_eq!(classify_error(output), ErrorType::Syntax);
730 }
731
732 #[test]
733 fn test_risk_level() {
734 let mut graph = CodeGraph::default();
735
736 graph.nodes.push(CodeNode {
738 id: "func:core.py:hot_func".into(),
739 kind: NodeKind::Function,
740 name: "hot_func".into(),
741 file_path: "core.py".into(),
742 line: Some(10),
743 decorators: vec![],
744 signature: None,
745 docstring: None,
746 line_count: 20,
747 is_test: false,
748 });
749
750 for i in 0..30 {
752 let caller_id = format!("func:caller{}.py:caller_{}", i, i);
753 graph.nodes.push(CodeNode {
754 id: caller_id.clone(),
755 kind: NodeKind::Function,
756 name: format!("caller_{}", i),
757 file_path: format!("caller{}.py", i),
758 line: Some(1),
759 decorators: vec![],
760 signature: None,
761 docstring: None,
762 line_count: 5,
763 is_test: false,
764 });
765 graph.edges.push(CodeEdge::new(&caller_id, "func:core.py:hot_func", EdgeRelation::Calls));
766 }
767
768 graph.build_indexes();
769
770 let analysis = analyze_impact(&["core.py".into()], &graph);
771 assert_eq!(analysis.risk_level, RiskLevel::High);
772 }
773
774 #[test]
775 fn test_extract_traceback() {
776 let output = r#"
777FAILED tests/test_foo.py::test_bar
778Traceback (most recent call last):
779 File "tests/test_foo.py", line 10, in test_bar
780 assert result == expected
781AssertionError: 1 != 2
782
783FAILED tests/test_other.py::test_baz
784"#;
785 let tb = extract_key_traceback(output, 500);
786 assert!(tb.contains("Traceback (most recent call last)"));
787 assert!(tb.contains("AssertionError: 1 != 2"));
788 }
789}