1use super::types::{
7 CallChain, ChangeClassification, ChangeType, ContextFile, ContextSymbol, DiffChange,
8 ExpandedContext, ImpactLevel, ImpactSummary,
9};
10use crate::index::types::{DepGraph, FileEntry, IndexSymbol, IndexSymbolKind, SymbolIndex};
11use std::collections::{HashSet, VecDeque};
12
13use super::types::ContextDepth;
14
15pub struct ContextExpander<'a> {
17 index: &'a SymbolIndex,
18 graph: &'a DepGraph,
19}
20
21impl<'a> ContextExpander<'a> {
22 pub fn new(index: &'a SymbolIndex, graph: &'a DepGraph) -> Self {
24 Self { index, graph }
25 }
26
27 pub fn classify_change(
32 &self,
33 change: &DiffChange,
34 symbol: Option<&IndexSymbol>,
35 ) -> ChangeClassification {
36 if change.change_type == ChangeType::Deleted {
38 return ChangeClassification::Deletion;
39 }
40 if change.change_type == ChangeType::Renamed {
41 return ChangeClassification::FileRename;
42 }
43 if change.change_type == ChangeType::Added {
44 return ChangeClassification::NewCode;
45 }
46
47 if let Some(diff) = &change.diff_content {
49 let signature_indicators = [
51 "fn ",
52 "def ",
53 "function ",
54 "func ",
55 "pub fn ",
56 "async fn ",
57 "class ",
58 "struct ",
59 "enum ",
60 "interface ",
61 "type ",
62 "trait ",
63 ];
64 let has_signature_change = diff.lines().any(|line| {
65 let trimmed = line.trim_start_matches(['+', '-', ' ']);
66 signature_indicators
67 .iter()
68 .any(|ind| trimmed.starts_with(ind))
69 });
70
71 if has_signature_change {
72 let type_indicators =
74 ["class ", "struct ", "enum ", "interface ", "type ", "trait "];
75 if diff.lines().any(|line| {
76 let trimmed = line.trim_start_matches(['+', '-', ' ']);
77 type_indicators.iter().any(|ind| trimmed.starts_with(ind))
78 }) {
79 return ChangeClassification::TypeDefinitionChange;
80 }
81 return ChangeClassification::SignatureChange;
82 }
83
84 let import_indicators = ["import ", "from ", "require(", "use ", "#include"];
86 if diff.lines().any(|line| {
87 let trimmed = line.trim_start_matches(['+', '-', ' ']);
88 import_indicators.iter().any(|ind| trimmed.starts_with(ind))
89 }) {
90 return ChangeClassification::ImportChange;
91 }
92
93 let doc_indicators = ["///", "//!", "/**", "/*", "#", "\"\"\"", "'''"];
95 let all_doc_changes = diff
96 .lines()
97 .filter(|l| l.starts_with('+') || l.starts_with('-'))
98 .filter(|l| l.len() > 1) .all(|line| {
100 let trimmed = line[1..].trim();
101 trimmed.is_empty()
102 || doc_indicators.iter().any(|ind| trimmed.starts_with(ind))
103 });
104 if all_doc_changes {
105 return ChangeClassification::DocumentationOnly;
106 }
107 }
108
109 if let Some(sym) = symbol {
111 match sym.kind {
112 IndexSymbolKind::Class
113 | IndexSymbolKind::Struct
114 | IndexSymbolKind::Enum
115 | IndexSymbolKind::Interface
116 | IndexSymbolKind::Trait
117 | IndexSymbolKind::TypeAlias => {
118 return ChangeClassification::TypeDefinitionChange;
119 }
120 IndexSymbolKind::Function | IndexSymbolKind::Method => {
121 return ChangeClassification::ImplementationChange;
123 }
124 _ => {}
125 }
126 }
127
128 ChangeClassification::ImplementationChange
130 }
131
132 pub(crate) fn classification_score_multiplier(&self, classification: ChangeClassification) -> f32 {
134 match classification {
135 ChangeClassification::Deletion => 1.5, ChangeClassification::SignatureChange => 1.3, ChangeClassification::TypeDefinitionChange => 1.2, ChangeClassification::FileRename => 1.1, ChangeClassification::ImportChange => 0.9, ChangeClassification::NewCode => 0.8, ChangeClassification::ImplementationChange => 0.7, ChangeClassification::DocumentationOnly => 0.3, }
144 }
145
146 fn get_caller_count(&self, symbol_id: u32) -> usize {
148 self.graph.get_callers(symbol_id).len() + self.graph.get_referencers(symbol_id).len()
149 }
150
151 pub fn expand(
153 &self,
154 changes: &[DiffChange],
155 depth: ContextDepth,
156 token_budget: u32,
157 ) -> ExpandedContext {
158 let mut changed_symbols = Vec::new();
159 let mut changed_files = Vec::new();
160 let mut dependent_symbols = Vec::new();
161 let mut dependent_files = Vec::new();
162 let mut related_tests = Vec::new();
163 let mut call_chains = Vec::new();
164
165 let mut seen_files: HashSet<u32> = HashSet::new();
166 let mut seen_symbols: HashSet<u32> = HashSet::new();
167 let mut change_classifications: Vec<ChangeClassification> = Vec::new();
168 let mut high_impact_symbols: HashSet<u32> = HashSet::new(); let mut path_overrides: std::collections::HashMap<u32, String> =
172 std::collections::HashMap::new();
173
174 for change in changes {
175 let (file, output_path) = if let Some(file) = self.index.get_file(&change.file_path) {
176 (file, change.file_path.clone())
177 } else if let Some(old_path) = &change.old_path {
178 if let Some(file) = self.index.get_file(old_path) {
179 path_overrides.insert(file.id.as_u32(), change.file_path.clone());
180 (file, change.file_path.clone())
181 } else {
182 continue;
183 }
184 } else {
185 continue;
186 };
187
188 if !seen_files.contains(&file.id.as_u32()) {
189 seen_files.insert(file.id.as_u32());
190 }
191
192 for (start, end) in &change.line_ranges {
194 for line in *start..=*end {
195 if let Some(symbol) = self.index.find_symbol_at_line(file.id, line) {
196 if !seen_symbols.contains(&symbol.id.as_u32()) {
197 seen_symbols.insert(symbol.id.as_u32());
198
199 let classification = self.classify_change(change, Some(symbol));
201 change_classifications.push(classification);
202
203 let caller_count = self.get_caller_count(symbol.id.as_u32());
208 let caller_bonus = (caller_count as f32 * 0.05).min(0.3); let base_score = 1.0 + caller_bonus;
210
211 if matches!(
213 classification,
214 ChangeClassification::SignatureChange
215 | ChangeClassification::TypeDefinitionChange
216 | ChangeClassification::Deletion
217 ) || caller_count > 5
218 {
219 high_impact_symbols.insert(symbol.id.as_u32());
220 }
221
222 let reason = match classification {
223 ChangeClassification::SignatureChange => {
224 format!("signature changed ({} callers)", caller_count)
225 }
226 ChangeClassification::TypeDefinitionChange => {
227 format!("type definition changed ({} usages)", caller_count)
228 }
229 ChangeClassification::Deletion => {
230 format!("deleted ({} callers will break)", caller_count)
231 }
232 _ => "directly modified".to_owned(),
233 };
234
235 changed_symbols.push(self.to_context_symbol(
236 symbol,
237 file,
238 &reason,
239 base_score,
240 path_overrides.get(&file.id.as_u32()).map(String::as_str),
241 ));
242 }
243 }
244 }
245 }
246
247 let file_classification = self.classify_change(change, None);
249 let file_multiplier = self.classification_score_multiplier(file_classification);
250
251 changed_files.push(ContextFile {
252 id: file.id.as_u32(),
253 path: output_path,
254 language: file.language.name().to_owned(),
255 relevance_reason: format!("{:?} ({:?})", change.change_type, file_classification),
256 relevance_score: file_multiplier,
257 tokens: file.tokens,
258 relevant_sections: change.line_ranges.clone(),
259 diff_content: change.diff_content.clone(),
260 snippets: Vec::new(),
261 });
262 }
263
264 let has_high_impact_change = change_classifications.iter().any(|c| {
266 matches!(
267 c,
268 ChangeClassification::SignatureChange
269 | ChangeClassification::TypeDefinitionChange
270 | ChangeClassification::Deletion
271 )
272 });
273
274 if depth >= ContextDepth::L2 {
276 let l2_files = self.expand_l2(&seen_files);
277 for file_id in &l2_files {
278 if !seen_files.contains(file_id) {
279 if let Some(file) = self.index.get_file_by_id(*file_id) {
280 seen_files.insert(*file_id);
281 let score = if has_high_impact_change { 0.9 } else { 0.8 };
283 let reason = if has_high_impact_change {
284 "imports changed file (breaking change detected)".to_owned()
285 } else {
286 "imports changed file".to_owned()
287 };
288 dependent_files.push(ContextFile {
289 id: file.id.as_u32(),
290 path: file.path.clone(),
291 language: file.language.name().to_owned(),
292 relevance_reason: reason,
293 relevance_score: score,
294 tokens: file.tokens,
295 relevant_sections: vec![],
296 diff_content: None,
297 snippets: Vec::new(),
298 });
299 }
300 }
301 }
302
303 let l2_symbols = self.expand_symbol_refs(&seen_symbols);
305 for symbol_id in &l2_symbols {
306 if !seen_symbols.contains(symbol_id) {
307 if let Some(symbol) = self.index.get_symbol(*symbol_id) {
308 if let Some(file) = self.index.get_file_by_id(symbol.file_id.as_u32()) {
309 seen_symbols.insert(*symbol_id);
310 let is_caller_of_high_impact = high_impact_symbols
312 .iter()
313 .any(|&hi_sym| self.graph.get_callers(hi_sym).contains(symbol_id));
314 let (reason, score) = if is_caller_of_high_impact {
315 ("calls changed symbol (may break)", 0.85)
316 } else {
317 ("references changed symbol", 0.7)
318 };
319 dependent_symbols.push(self.to_context_symbol(
320 symbol,
321 file,
322 reason,
323 score,
324 path_overrides.get(&file.id.as_u32()).map(String::as_str),
325 ));
326 }
327 }
328 }
329 }
330
331 if has_high_impact_change {
333 for &hi_sym_id in &high_impact_symbols {
334 let all_callers = self.graph.get_callers(hi_sym_id);
335 for caller_id in all_callers {
336 if !seen_symbols.contains(&caller_id) {
337 if let Some(caller) = self.index.get_symbol(caller_id) {
338 if let Some(file) =
339 self.index.get_file_by_id(caller.file_id.as_u32())
340 {
341 seen_symbols.insert(caller_id);
342 dependent_symbols.push(self.to_context_symbol(
343 caller,
344 file,
345 "calls modified symbol (potential breakage)",
346 0.9, path_overrides.get(&file.id.as_u32()).map(String::as_str),
348 ));
349 }
350 }
351 }
352 }
353 }
354 }
355 }
356
357 if depth >= ContextDepth::L3 {
358 let l3_files = self.expand_l3(&seen_files);
359 for file_id in &l3_files {
360 if !seen_files.contains(file_id) {
361 if let Some(file) = self.index.get_file_by_id(*file_id) {
362 seen_files.insert(*file_id);
363 dependent_files.push(ContextFile {
364 id: file.id.as_u32(),
365 path: file.path.clone(),
366 language: file.language.name().to_owned(),
367 relevance_reason: "transitively depends on changed file".to_owned(),
368 relevance_score: 0.5,
369 tokens: file.tokens,
370 relevant_sections: vec![],
371 diff_content: None,
372 snippets: Vec::new(),
373 });
374 }
375 }
376 }
377 }
378
379 let mut seen_test_ids: HashSet<u32> = HashSet::new();
381
382 for file in &self.index.files {
384 if self.is_test_file(&file.path) {
385 let imports = self.graph.get_imports(file.id.as_u32());
386 for &imported in &imports {
387 if seen_files.contains(&imported) && !seen_test_ids.contains(&file.id.as_u32())
388 {
389 seen_test_ids.insert(file.id.as_u32());
390 related_tests.push(ContextFile {
391 id: file.id.as_u32(),
392 path: file.path.clone(),
393 language: file.language.name().to_owned(),
394 relevance_reason: "imports changed file".to_owned(),
395 relevance_score: 0.95,
396 tokens: file.tokens,
397 relevant_sections: vec![],
398 diff_content: None,
399 snippets: Vec::new(),
400 });
401 break;
402 }
403 }
404 }
405 }
406
407 for cf in &changed_files {
409 for test_id in self.find_tests_by_naming(&cf.path) {
410 if !seen_test_ids.contains(&test_id) {
411 if let Some(file) = self.index.get_file_by_id(test_id) {
412 seen_test_ids.insert(test_id);
413 related_tests.push(ContextFile {
414 id: file.id.as_u32(),
415 path: file.path.clone(),
416 language: file.language.name().to_owned(),
417 relevance_reason: "test for changed file (naming convention)".to_owned(),
418 relevance_score: 0.85,
419 tokens: file.tokens,
420 relevant_sections: vec![],
421 diff_content: None,
422 snippets: Vec::new(),
423 });
424 }
425 }
426 }
427 }
428
429 for sym in &changed_symbols {
431 let chains = self.build_call_chains(sym.id, 3);
432 call_chains.extend(chains);
433 }
434
435 let impact_summary = self.compute_impact_summary(
437 &changed_files,
438 &dependent_files,
439 &changed_symbols,
440 &dependent_symbols,
441 &related_tests,
442 );
443
444 dependent_files.sort_by(|a, b| {
447 b.relevance_score
448 .partial_cmp(&a.relevance_score)
449 .unwrap_or(std::cmp::Ordering::Equal)
450 });
451 dependent_symbols.sort_by(|a, b| {
452 b.relevance_score
453 .partial_cmp(&a.relevance_score)
454 .unwrap_or(std::cmp::Ordering::Equal)
455 });
456 related_tests.sort_by(|a, b| {
457 b.relevance_score
458 .partial_cmp(&a.relevance_score)
459 .unwrap_or(std::cmp::Ordering::Equal)
460 });
461
462 let mut running_tokens = changed_files.iter().map(|f| f.tokens).sum::<u32>();
464
465 dependent_files.retain(|f| {
467 if running_tokens + f.tokens <= token_budget {
468 running_tokens += f.tokens;
469 true
470 } else {
471 false
472 }
473 });
474
475 related_tests.retain(|f| {
477 if running_tokens + f.tokens <= token_budget {
478 running_tokens += f.tokens;
479 true
480 } else {
481 false
482 }
483 });
484
485 ExpandedContext {
486 changed_symbols,
487 changed_files,
488 dependent_symbols,
489 dependent_files,
490 related_tests,
491 call_chains,
492 impact_summary,
493 total_tokens: running_tokens,
494 }
495 }
496
497 fn expand_l2(&self, file_ids: &HashSet<u32>) -> Vec<u32> {
499 let mut result = Vec::new();
500 for &file_id in file_ids {
501 result.extend(self.graph.get_importers(file_id));
502 }
503 result
504 }
505
506 fn expand_l3(&self, file_ids: &HashSet<u32>) -> Vec<u32> {
508 let mut result = Vec::new();
509 let mut visited: HashSet<u32> = file_ids.iter().copied().collect();
510 let mut queue: VecDeque<u32> = VecDeque::new();
511
512 for &file_id in file_ids {
513 for importer in self.graph.get_importers(file_id) {
514 if visited.insert(importer) {
515 result.push(importer);
516 queue.push_back(importer);
517 }
518 }
519 }
520
521 while let Some(current) = queue.pop_front() {
522 for importer in self.graph.get_importers(current) {
523 if visited.insert(importer) {
524 result.push(importer);
525 queue.push_back(importer);
526 }
527 }
528 }
529
530 result
531 }
532
533 fn expand_symbol_refs(&self, symbol_ids: &HashSet<u32>) -> Vec<u32> {
535 let mut result = Vec::new();
536 for &symbol_id in symbol_ids {
537 result.extend(self.graph.get_referencers(symbol_id));
538 result.extend(self.graph.get_callers(symbol_id));
539 }
540 result
541 }
542
543 pub(crate) fn is_test_file(&self, path: &str) -> bool {
545 let path_lower = path.to_lowercase();
546 path_lower.contains("test")
547 || path_lower.contains("spec")
548 || path_lower.contains("__tests__")
549 || path_lower.ends_with("_test.rs")
550 || path_lower.ends_with("_test.go")
551 || path_lower.ends_with("_test.py")
552 || path_lower.ends_with(".test.ts")
553 || path_lower.ends_with(".test.js")
554 || path_lower.ends_with(".spec.ts")
555 || path_lower.ends_with(".spec.js")
556 }
557
558 fn find_tests_by_naming(&self, source_path: &str) -> Vec<u32> {
564 let path_lower = source_path.to_lowercase();
565 let base_name = std::path::Path::new(&path_lower)
566 .file_stem()
567 .and_then(|s| s.to_str())
568 .unwrap_or("");
569
570 let mut test_ids = Vec::new();
571
572 if base_name.is_empty() {
573 return test_ids;
574 }
575
576 let test_patterns = [
578 format!("{}_test.", base_name),
579 format!("test_{}", base_name),
580 format!("{}.test.", base_name),
581 format!("{}.spec.", base_name),
582 format!("test/{}", base_name),
583 format!("tests/{}", base_name),
584 format!("__tests__/{}", base_name),
585 ];
586
587 for file in &self.index.files {
588 let file_lower = file.path.to_lowercase();
589 if self.is_test_file(&file.path) {
590 for pattern in &test_patterns {
591 if file_lower.contains(pattern) {
592 test_ids.push(file.id.as_u32());
593 break;
594 }
595 }
596 }
597 }
598
599 test_ids
600 }
601
602 fn to_context_symbol(
604 &self,
605 symbol: &IndexSymbol,
606 file: &FileEntry,
607 reason: &str,
608 score: f32,
609 path_override: Option<&str>,
610 ) -> ContextSymbol {
611 ContextSymbol {
612 id: symbol.id.as_u32(),
613 name: symbol.name.clone(),
614 kind: symbol.kind.name().to_owned(),
615 file_path: path_override.unwrap_or(&file.path).to_owned(),
616 start_line: symbol.span.start_line,
617 end_line: symbol.span.end_line,
618 signature: symbol.signature.clone(),
619 relevance_reason: reason.to_owned(),
620 relevance_score: score,
621 }
622 }
623
624 fn build_call_chains(&self, symbol_id: u32, max_depth: usize) -> Vec<CallChain> {
626 let mut chains = Vec::new();
627
628 let mut upstream = Vec::new();
630 self.collect_callers(symbol_id, &mut upstream, max_depth, &mut HashSet::new());
631 if !upstream.is_empty() {
632 upstream.reverse();
633 if let Some(sym) = self.index.get_symbol(symbol_id) {
634 upstream.push(sym.name.clone());
635 }
636 chains.push(CallChain {
637 symbols: upstream.clone(),
638 files: self.get_files_for_symbols(&upstream),
639 });
640 }
641
642 let mut downstream = Vec::new();
644 if let Some(sym) = self.index.get_symbol(symbol_id) {
645 downstream.push(sym.name.clone());
646 }
647 self.collect_callees(symbol_id, &mut downstream, max_depth, &mut HashSet::new());
648 if downstream.len() > 1 {
649 chains.push(CallChain {
650 symbols: downstream.clone(),
651 files: self.get_files_for_symbols(&downstream),
652 });
653 }
654
655 chains
656 }
657
658 fn collect_callers(
659 &self,
660 symbol_id: u32,
661 chain: &mut Vec<String>,
662 depth: usize,
663 visited: &mut HashSet<u32>,
664 ) {
665 if depth == 0 || visited.contains(&symbol_id) {
666 return;
667 }
668 visited.insert(symbol_id);
669
670 let callers = self.graph.get_callers(symbol_id);
671 if let Some(&caller_id) = callers.first() {
672 if let Some(sym) = self.index.get_symbol(caller_id) {
673 chain.push(sym.name.clone());
674 self.collect_callers(caller_id, chain, depth - 1, visited);
675 }
676 }
677 }
678
679 fn collect_callees(
680 &self,
681 symbol_id: u32,
682 chain: &mut Vec<String>,
683 depth: usize,
684 visited: &mut HashSet<u32>,
685 ) {
686 if depth == 0 || visited.contains(&symbol_id) {
687 return;
688 }
689 visited.insert(symbol_id);
690
691 let callees = self.graph.get_callees(symbol_id);
692 if let Some(&callee_id) = callees.first() {
693 if let Some(sym) = self.index.get_symbol(callee_id) {
694 chain.push(sym.name.clone());
695 self.collect_callees(callee_id, chain, depth - 1, visited);
696 }
697 }
698 }
699
700 fn get_files_for_symbols(&self, symbol_names: &[String]) -> Vec<String> {
701 let mut files = Vec::new();
702 let mut seen = HashSet::new();
703 for name in symbol_names {
704 for sym in self.index.find_symbols(name) {
705 if let Some(file) = self.index.get_file_by_id(sym.file_id.as_u32()) {
706 if seen.insert(file.id) {
707 files.push(file.path.clone());
708 }
709 }
710 }
711 }
712 files
713 }
714
715 fn compute_impact_summary(
717 &self,
718 changed_files: &[ContextFile],
719 dependent_files: &[ContextFile],
720 changed_symbols: &[ContextSymbol],
721 dependent_symbols: &[ContextSymbol],
722 related_tests: &[ContextFile],
723 ) -> ImpactSummary {
724 let direct_files = changed_files.len();
725 let transitive_files = dependent_files.len();
726 let affected_symbols = changed_symbols.len() + dependent_symbols.len();
727 let affected_tests = related_tests.len();
728
729 let level = if transitive_files > 20 || affected_symbols > 50 {
731 ImpactLevel::Critical
732 } else if transitive_files > 10 || affected_symbols > 20 {
733 ImpactLevel::High
734 } else if transitive_files > 3 || affected_symbols > 5 {
735 ImpactLevel::Medium
736 } else {
737 ImpactLevel::Low
738 };
739
740 let breaking_changes = changed_symbols
744 .iter()
745 .filter(|s| s.kind == "function" || s.kind == "method")
746 .filter(|s| s.signature.is_some())
747 .filter(|s| {
749 s.signature
750 .as_ref()
751 .is_some_and(|sig| sig.starts_with("pub ") || sig.starts_with("export "))
752 })
753 .map(|s| format!("{} public API signature may have changed", s.name))
754 .collect();
755
756 let description = format!(
757 "Changed {} files affecting {} dependent files and {} symbols. {} tests may need updating.",
758 direct_files, transitive_files, affected_symbols, affected_tests
759 );
760
761 ImpactSummary {
762 level,
763 direct_files,
764 transitive_files,
765 affected_symbols,
766 affected_tests,
767 breaking_changes,
768 description,
769 }
770 }
771}