1use super::types::{
7 CallChain, ChangeClassification, ChangeType, ContextFile, ContextSymbol, DiffChange,
8 ExpandedContext, ImpactLevel, ImpactSummary,
9};
10use crate::index::types::{DepGraph, FileEntry, IndexSymbol, IndexSymbolKind, SymbolIndex};
11use std::collections::{HashSet, VecDeque};
12
13use super::types::ContextDepth;
14
15pub struct ContextExpander<'a> {
17 index: &'a SymbolIndex,
18 graph: &'a DepGraph,
19}
20
21impl<'a> ContextExpander<'a> {
22 pub fn new(index: &'a SymbolIndex, graph: &'a DepGraph) -> Self {
24 Self { index, graph }
25 }
26
27 pub fn classify_change(
32 &self,
33 change: &DiffChange,
34 symbol: Option<&IndexSymbol>,
35 ) -> ChangeClassification {
36 if change.change_type == ChangeType::Deleted {
38 return ChangeClassification::Deletion;
39 }
40 if change.change_type == ChangeType::Renamed {
41 return ChangeClassification::FileRename;
42 }
43 if change.change_type == ChangeType::Added {
44 return ChangeClassification::NewCode;
45 }
46
47 if let Some(diff) = &change.diff_content {
49 let signature_indicators = [
51 "fn ",
52 "def ",
53 "function ",
54 "func ",
55 "pub fn ",
56 "async fn ",
57 "class ",
58 "struct ",
59 "enum ",
60 "interface ",
61 "type ",
62 "trait ",
63 ];
64 let has_signature_change = diff.lines().any(|line| {
65 let trimmed = line.trim_start_matches(['+', '-', ' ']);
66 signature_indicators
67 .iter()
68 .any(|ind| trimmed.starts_with(ind))
69 });
70
71 if has_signature_change {
72 let type_indicators =
74 ["class ", "struct ", "enum ", "interface ", "type ", "trait "];
75 if diff.lines().any(|line| {
76 let trimmed = line.trim_start_matches(['+', '-', ' ']);
77 type_indicators.iter().any(|ind| trimmed.starts_with(ind))
78 }) {
79 return ChangeClassification::TypeDefinitionChange;
80 }
81 return ChangeClassification::SignatureChange;
82 }
83
84 let import_indicators = ["import ", "from ", "require(", "use ", "#include"];
86 if diff.lines().any(|line| {
87 let trimmed = line.trim_start_matches(['+', '-', ' ']);
88 import_indicators.iter().any(|ind| trimmed.starts_with(ind))
89 }) {
90 return ChangeClassification::ImportChange;
91 }
92
93 let doc_indicators = ["///", "//!", "/**", "/*", "#", "\"\"\"", "'''"];
95 let all_doc_changes = diff
96 .lines()
97 .filter(|l| l.starts_with('+') || l.starts_with('-'))
98 .filter(|l| l.len() > 1) .all(|line| {
100 let trimmed = line[1..].trim();
101 trimmed.is_empty()
102 || doc_indicators.iter().any(|ind| trimmed.starts_with(ind))
103 });
104 if all_doc_changes {
105 return ChangeClassification::DocumentationOnly;
106 }
107 }
108
109 if let Some(sym) = symbol {
111 match sym.kind {
112 IndexSymbolKind::Class
113 | IndexSymbolKind::Struct
114 | IndexSymbolKind::Enum
115 | IndexSymbolKind::Interface
116 | IndexSymbolKind::Trait
117 | IndexSymbolKind::TypeAlias => {
118 return ChangeClassification::TypeDefinitionChange;
119 },
120 IndexSymbolKind::Function | IndexSymbolKind::Method => {
121 return ChangeClassification::ImplementationChange;
123 },
124 _ => {},
125 }
126 }
127
128 ChangeClassification::ImplementationChange
130 }
131
132 pub(crate) fn classification_score_multiplier(
134 &self,
135 classification: ChangeClassification,
136 ) -> f32 {
137 match classification {
138 ChangeClassification::Deletion => 1.5, ChangeClassification::SignatureChange => 1.3, ChangeClassification::TypeDefinitionChange => 1.2, ChangeClassification::FileRename => 1.1, ChangeClassification::ImportChange => 0.9, ChangeClassification::NewCode => 0.8, ChangeClassification::ImplementationChange => 0.7, ChangeClassification::DocumentationOnly => 0.3, }
147 }
148
149 fn get_caller_count(&self, symbol_id: u32) -> usize {
151 self.graph.get_callers(symbol_id).len() + self.graph.get_referencers(symbol_id).len()
152 }
153
154 pub fn expand(
156 &self,
157 changes: &[DiffChange],
158 depth: ContextDepth,
159 token_budget: u32,
160 ) -> ExpandedContext {
161 let mut changed_symbols = Vec::new();
162 let mut changed_files = Vec::new();
163 let mut dependent_symbols = Vec::new();
164 let mut dependent_files = Vec::new();
165 let mut related_tests = Vec::new();
166 let mut call_chains = Vec::new();
167
168 let mut seen_files: HashSet<u32> = HashSet::new();
169 let mut seen_symbols: HashSet<u32> = HashSet::new();
170 let mut change_classifications: Vec<ChangeClassification> = Vec::new();
171 let mut high_impact_symbols: HashSet<u32> = HashSet::new(); let mut path_overrides: std::collections::HashMap<u32, String> =
175 std::collections::HashMap::new();
176
177 for change in changes {
178 let (file, output_path) = if let Some(file) = self.index.get_file(&change.file_path) {
179 (file, change.file_path.clone())
180 } else if let Some(old_path) = &change.old_path {
181 if let Some(file) = self.index.get_file(old_path) {
182 path_overrides.insert(file.id.as_u32(), change.file_path.clone());
183 (file, change.file_path.clone())
184 } else {
185 continue;
186 }
187 } else {
188 continue;
189 };
190
191 if !seen_files.contains(&file.id.as_u32()) {
192 seen_files.insert(file.id.as_u32());
193 }
194
195 for (start, end) in &change.line_ranges {
197 for line in *start..=*end {
198 if let Some(symbol) = self.index.find_symbol_at_line(file.id, line) {
199 if !seen_symbols.contains(&symbol.id.as_u32()) {
200 seen_symbols.insert(symbol.id.as_u32());
201
202 let classification = self.classify_change(change, Some(symbol));
204 change_classifications.push(classification);
205
206 let caller_count = self.get_caller_count(symbol.id.as_u32());
211 let caller_bonus = (caller_count as f32 * 0.05).min(0.3); let base_score = 1.0 + caller_bonus;
213
214 if matches!(
216 classification,
217 ChangeClassification::SignatureChange
218 | ChangeClassification::TypeDefinitionChange
219 | ChangeClassification::Deletion
220 ) || caller_count > 5
221 {
222 high_impact_symbols.insert(symbol.id.as_u32());
223 }
224
225 let reason = match classification {
226 ChangeClassification::SignatureChange => {
227 format!("signature changed ({} callers)", caller_count)
228 },
229 ChangeClassification::TypeDefinitionChange => {
230 format!("type definition changed ({} usages)", caller_count)
231 },
232 ChangeClassification::Deletion => {
233 format!("deleted ({} callers will break)", caller_count)
234 },
235 _ => "directly modified".to_owned(),
236 };
237
238 changed_symbols.push(self.to_context_symbol(
239 symbol,
240 file,
241 &reason,
242 base_score,
243 path_overrides.get(&file.id.as_u32()).map(String::as_str),
244 ));
245 }
246 }
247 }
248 }
249
250 let file_classification = self.classify_change(change, None);
252 let file_multiplier = self.classification_score_multiplier(file_classification);
253
254 changed_files.push(ContextFile {
255 id: file.id.as_u32(),
256 path: output_path,
257 language: file.language.name().to_owned(),
258 relevance_reason: format!("{:?} ({:?})", change.change_type, file_classification),
259 relevance_score: file_multiplier,
260 tokens: file.tokens,
261 relevant_sections: change.line_ranges.clone(),
262 diff_content: change.diff_content.clone(),
263 snippets: Vec::new(),
264 });
265 }
266
267 let has_high_impact_change = change_classifications.iter().any(|c| {
269 matches!(
270 c,
271 ChangeClassification::SignatureChange
272 | ChangeClassification::TypeDefinitionChange
273 | ChangeClassification::Deletion
274 )
275 });
276
277 if depth >= ContextDepth::L2 {
279 let l2_files = self.expand_l2(&seen_files);
280 for file_id in &l2_files {
281 if !seen_files.contains(file_id) {
282 if let Some(file) = self.index.get_file_by_id(*file_id) {
283 seen_files.insert(*file_id);
284 let score = if has_high_impact_change { 0.9 } else { 0.8 };
286 let reason = if has_high_impact_change {
287 "imports changed file (breaking change detected)".to_owned()
288 } else {
289 "imports changed file".to_owned()
290 };
291 dependent_files.push(ContextFile {
292 id: file.id.as_u32(),
293 path: file.path.clone(),
294 language: file.language.name().to_owned(),
295 relevance_reason: reason,
296 relevance_score: score,
297 tokens: file.tokens,
298 relevant_sections: vec![],
299 diff_content: None,
300 snippets: Vec::new(),
301 });
302 }
303 }
304 }
305
306 let l2_symbols = self.expand_symbol_refs(&seen_symbols);
308 for symbol_id in &l2_symbols {
309 if !seen_symbols.contains(symbol_id) {
310 if let Some(symbol) = self.index.get_symbol(*symbol_id) {
311 if let Some(file) = self.index.get_file_by_id(symbol.file_id.as_u32()) {
312 seen_symbols.insert(*symbol_id);
313 let is_caller_of_high_impact = high_impact_symbols
315 .iter()
316 .any(|&hi_sym| self.graph.get_callers(hi_sym).contains(symbol_id));
317 let (reason, score) = if is_caller_of_high_impact {
318 ("calls changed symbol (may break)", 0.85)
319 } else {
320 ("references changed symbol", 0.7)
321 };
322 dependent_symbols.push(self.to_context_symbol(
323 symbol,
324 file,
325 reason,
326 score,
327 path_overrides.get(&file.id.as_u32()).map(String::as_str),
328 ));
329 }
330 }
331 }
332 }
333
334 if has_high_impact_change {
336 for &hi_sym_id in &high_impact_symbols {
337 let all_callers = self.graph.get_callers(hi_sym_id);
338 for caller_id in all_callers {
339 if !seen_symbols.contains(&caller_id) {
340 if let Some(caller) = self.index.get_symbol(caller_id) {
341 if let Some(file) =
342 self.index.get_file_by_id(caller.file_id.as_u32())
343 {
344 seen_symbols.insert(caller_id);
345 dependent_symbols.push(self.to_context_symbol(
346 caller,
347 file,
348 "calls modified symbol (potential breakage)",
349 0.9, path_overrides.get(&file.id.as_u32()).map(String::as_str),
351 ));
352 }
353 }
354 }
355 }
356 }
357 }
358 }
359
360 if depth >= ContextDepth::L3 {
361 let l3_files = self.expand_l3(&seen_files);
362 for file_id in &l3_files {
363 if !seen_files.contains(file_id) {
364 if let Some(file) = self.index.get_file_by_id(*file_id) {
365 seen_files.insert(*file_id);
366 dependent_files.push(ContextFile {
367 id: file.id.as_u32(),
368 path: file.path.clone(),
369 language: file.language.name().to_owned(),
370 relevance_reason: "transitively depends on changed file".to_owned(),
371 relevance_score: 0.5,
372 tokens: file.tokens,
373 relevant_sections: vec![],
374 diff_content: None,
375 snippets: Vec::new(),
376 });
377 }
378 }
379 }
380 }
381
382 let mut seen_test_ids: HashSet<u32> = HashSet::new();
384
385 for file in &self.index.files {
387 if self.is_test_file(&file.path) {
388 let imports = self.graph.get_imports(file.id.as_u32());
389 for &imported in &imports {
390 if seen_files.contains(&imported) && !seen_test_ids.contains(&file.id.as_u32())
391 {
392 seen_test_ids.insert(file.id.as_u32());
393 related_tests.push(ContextFile {
394 id: file.id.as_u32(),
395 path: file.path.clone(),
396 language: file.language.name().to_owned(),
397 relevance_reason: "imports changed file".to_owned(),
398 relevance_score: 0.95,
399 tokens: file.tokens,
400 relevant_sections: vec![],
401 diff_content: None,
402 snippets: Vec::new(),
403 });
404 break;
405 }
406 }
407 }
408 }
409
410 for cf in &changed_files {
412 for test_id in self.find_tests_by_naming(&cf.path) {
413 if !seen_test_ids.contains(&test_id) {
414 if let Some(file) = self.index.get_file_by_id(test_id) {
415 seen_test_ids.insert(test_id);
416 related_tests.push(ContextFile {
417 id: file.id.as_u32(),
418 path: file.path.clone(),
419 language: file.language.name().to_owned(),
420 relevance_reason: "test for changed file (naming convention)"
421 .to_owned(),
422 relevance_score: 0.85,
423 tokens: file.tokens,
424 relevant_sections: vec![],
425 diff_content: None,
426 snippets: Vec::new(),
427 });
428 }
429 }
430 }
431 }
432
433 for sym in &changed_symbols {
435 let chains = self.build_call_chains(sym.id, 3);
436 call_chains.extend(chains);
437 }
438
439 let impact_summary = self.compute_impact_summary(
441 &changed_files,
442 &dependent_files,
443 &changed_symbols,
444 &dependent_symbols,
445 &related_tests,
446 );
447
448 dependent_files.sort_by(|a, b| {
451 b.relevance_score
452 .partial_cmp(&a.relevance_score)
453 .unwrap_or(std::cmp::Ordering::Equal)
454 });
455 dependent_symbols.sort_by(|a, b| {
456 b.relevance_score
457 .partial_cmp(&a.relevance_score)
458 .unwrap_or(std::cmp::Ordering::Equal)
459 });
460 related_tests.sort_by(|a, b| {
461 b.relevance_score
462 .partial_cmp(&a.relevance_score)
463 .unwrap_or(std::cmp::Ordering::Equal)
464 });
465
466 let mut running_tokens = changed_files.iter().map(|f| f.tokens).sum::<u32>();
468
469 dependent_files.retain(|f| {
471 if running_tokens + f.tokens <= token_budget {
472 running_tokens += f.tokens;
473 true
474 } else {
475 false
476 }
477 });
478
479 related_tests.retain(|f| {
481 if running_tokens + f.tokens <= token_budget {
482 running_tokens += f.tokens;
483 true
484 } else {
485 false
486 }
487 });
488
489 ExpandedContext {
490 changed_symbols,
491 changed_files,
492 dependent_symbols,
493 dependent_files,
494 related_tests,
495 call_chains,
496 impact_summary,
497 total_tokens: running_tokens,
498 }
499 }
500
501 fn expand_l2(&self, file_ids: &HashSet<u32>) -> Vec<u32> {
503 let mut result = Vec::new();
504 for &file_id in file_ids {
505 result.extend(self.graph.get_importers(file_id));
506 }
507 result
508 }
509
510 fn expand_l3(&self, file_ids: &HashSet<u32>) -> Vec<u32> {
512 let mut result = Vec::new();
513 let mut visited: HashSet<u32> = file_ids.iter().copied().collect();
514 let mut queue: VecDeque<u32> = VecDeque::new();
515
516 for &file_id in file_ids {
517 for importer in self.graph.get_importers(file_id) {
518 if visited.insert(importer) {
519 result.push(importer);
520 queue.push_back(importer);
521 }
522 }
523 }
524
525 while let Some(current) = queue.pop_front() {
526 for importer in self.graph.get_importers(current) {
527 if visited.insert(importer) {
528 result.push(importer);
529 queue.push_back(importer);
530 }
531 }
532 }
533
534 result
535 }
536
537 fn expand_symbol_refs(&self, symbol_ids: &HashSet<u32>) -> Vec<u32> {
539 let mut result = Vec::new();
540 for &symbol_id in symbol_ids {
541 result.extend(self.graph.get_referencers(symbol_id));
542 result.extend(self.graph.get_callers(symbol_id));
543 }
544 result
545 }
546
547 pub(crate) fn is_test_file(&self, path: &str) -> bool {
549 let path_lower = path.to_lowercase();
550 path_lower.contains("test")
551 || path_lower.contains("spec")
552 || path_lower.contains("__tests__")
553 || path_lower.ends_with("_test.rs")
554 || path_lower.ends_with("_test.go")
555 || path_lower.ends_with("_test.py")
556 || path_lower.ends_with(".test.ts")
557 || path_lower.ends_with(".test.js")
558 || path_lower.ends_with(".spec.ts")
559 || path_lower.ends_with(".spec.js")
560 }
561
562 fn find_tests_by_naming(&self, source_path: &str) -> Vec<u32> {
568 let path_lower = source_path.to_lowercase();
569 let base_name = std::path::Path::new(&path_lower)
570 .file_stem()
571 .and_then(|s| s.to_str())
572 .unwrap_or("");
573
574 let mut test_ids = Vec::new();
575
576 if base_name.is_empty() {
577 return test_ids;
578 }
579
580 let test_patterns = [
582 format!("{}_test.", base_name),
583 format!("test_{}", base_name),
584 format!("{}.test.", base_name),
585 format!("{}.spec.", base_name),
586 format!("test/{}", base_name),
587 format!("tests/{}", base_name),
588 format!("__tests__/{}", base_name),
589 ];
590
591 for file in &self.index.files {
592 let file_lower = file.path.to_lowercase();
593 if self.is_test_file(&file.path) {
594 for pattern in &test_patterns {
595 if file_lower.contains(pattern) {
596 test_ids.push(file.id.as_u32());
597 break;
598 }
599 }
600 }
601 }
602
603 test_ids
604 }
605
606 fn to_context_symbol(
608 &self,
609 symbol: &IndexSymbol,
610 file: &FileEntry,
611 reason: &str,
612 score: f32,
613 path_override: Option<&str>,
614 ) -> ContextSymbol {
615 ContextSymbol {
616 id: symbol.id.as_u32(),
617 name: symbol.name.clone(),
618 kind: symbol.kind.name().to_owned(),
619 file_path: path_override.unwrap_or(&file.path).to_owned(),
620 start_line: symbol.span.start_line,
621 end_line: symbol.span.end_line,
622 signature: symbol.signature.clone(),
623 relevance_reason: reason.to_owned(),
624 relevance_score: score,
625 }
626 }
627
628 fn build_call_chains(&self, symbol_id: u32, max_depth: usize) -> Vec<CallChain> {
630 let mut chains = Vec::new();
631
632 let mut upstream = Vec::new();
634 self.collect_callers(symbol_id, &mut upstream, max_depth, &mut HashSet::new());
635 if !upstream.is_empty() {
636 upstream.reverse();
637 if let Some(sym) = self.index.get_symbol(symbol_id) {
638 upstream.push(sym.name.clone());
639 }
640 chains.push(CallChain {
641 symbols: upstream.clone(),
642 files: self.get_files_for_symbols(&upstream),
643 });
644 }
645
646 let mut downstream = Vec::new();
648 if let Some(sym) = self.index.get_symbol(symbol_id) {
649 downstream.push(sym.name.clone());
650 }
651 self.collect_callees(symbol_id, &mut downstream, max_depth, &mut HashSet::new());
652 if downstream.len() > 1 {
653 chains.push(CallChain {
654 symbols: downstream.clone(),
655 files: self.get_files_for_symbols(&downstream),
656 });
657 }
658
659 chains
660 }
661
662 fn collect_callers(
663 &self,
664 symbol_id: u32,
665 chain: &mut Vec<String>,
666 depth: usize,
667 visited: &mut HashSet<u32>,
668 ) {
669 if depth == 0 || visited.contains(&symbol_id) {
670 return;
671 }
672 visited.insert(symbol_id);
673
674 let callers = self.graph.get_callers(symbol_id);
675 if let Some(&caller_id) = callers.first() {
676 if let Some(sym) = self.index.get_symbol(caller_id) {
677 chain.push(sym.name.clone());
678 self.collect_callers(caller_id, chain, depth - 1, visited);
679 }
680 }
681 }
682
683 fn collect_callees(
684 &self,
685 symbol_id: u32,
686 chain: &mut Vec<String>,
687 depth: usize,
688 visited: &mut HashSet<u32>,
689 ) {
690 if depth == 0 || visited.contains(&symbol_id) {
691 return;
692 }
693 visited.insert(symbol_id);
694
695 let callees = self.graph.get_callees(symbol_id);
696 if let Some(&callee_id) = callees.first() {
697 if let Some(sym) = self.index.get_symbol(callee_id) {
698 chain.push(sym.name.clone());
699 self.collect_callees(callee_id, chain, depth - 1, visited);
700 }
701 }
702 }
703
704 fn get_files_for_symbols(&self, symbol_names: &[String]) -> Vec<String> {
705 let mut files = Vec::new();
706 let mut seen = HashSet::new();
707 for name in symbol_names {
708 for sym in self.index.find_symbols(name) {
709 if let Some(file) = self.index.get_file_by_id(sym.file_id.as_u32()) {
710 if seen.insert(file.id) {
711 files.push(file.path.clone());
712 }
713 }
714 }
715 }
716 files
717 }
718
719 fn compute_impact_summary(
721 &self,
722 changed_files: &[ContextFile],
723 dependent_files: &[ContextFile],
724 changed_symbols: &[ContextSymbol],
725 dependent_symbols: &[ContextSymbol],
726 related_tests: &[ContextFile],
727 ) -> ImpactSummary {
728 let direct_files = changed_files.len();
729 let transitive_files = dependent_files.len();
730 let affected_symbols = changed_symbols.len() + dependent_symbols.len();
731 let affected_tests = related_tests.len();
732
733 let level = if transitive_files > 20 || affected_symbols > 50 {
735 ImpactLevel::Critical
736 } else if transitive_files > 10 || affected_symbols > 20 {
737 ImpactLevel::High
738 } else if transitive_files > 3 || affected_symbols > 5 {
739 ImpactLevel::Medium
740 } else {
741 ImpactLevel::Low
742 };
743
744 let breaking_changes = changed_symbols
748 .iter()
749 .filter(|s| s.kind == "function" || s.kind == "method")
750 .filter(|s| s.signature.is_some())
751 .filter(|s| {
753 s.signature
754 .as_ref()
755 .is_some_and(|sig| sig.starts_with("pub ") || sig.starts_with("export "))
756 })
757 .map(|s| format!("{} public API signature may have changed", s.name))
758 .collect();
759
760 let description = format!(
761 "Changed {} files affecting {} dependent files and {} symbols. {} tests may need updating.",
762 direct_files, transitive_files, affected_symbols, affected_tests
763 );
764
765 ImpactSummary {
766 level,
767 direct_files,
768 transitive_files,
769 affected_symbols,
770 affected_tests,
771 breaking_changes,
772 description,
773 }
774 }
775}