batuta/bug_hunter/localization/
multi_channel.rs1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9use crate::bug_hunter::types::{ChannelWeights, LocalizationStrategy, SbflFormula};
10
11use super::scoring::{MutationData, ScoredLocation, SpectrumData, TestCoverage};
12
13pub struct MultiChannelLocalizer {
15 pub strategy: LocalizationStrategy,
16 pub weights: ChannelWeights,
17 pub sbfl_formula: SbflFormula,
18 pub spectrum_data: SpectrumData,
19 pub mutation_data: MutationData,
20 pub static_findings: HashMap<(PathBuf, usize), f64>,
22 pub error_message: Option<String>,
24}
25
26impl MultiChannelLocalizer {
27 pub fn new(strategy: LocalizationStrategy, weights: ChannelWeights) -> Self {
28 Self {
29 strategy,
30 weights,
31 sbfl_formula: SbflFormula::Ochiai,
32 spectrum_data: SpectrumData::default(),
33 mutation_data: MutationData::default(),
34 static_findings: HashMap::new(),
35 error_message: None,
36 }
37 }
38
39 pub fn add_coverage(&mut self, coverage: &[TestCoverage]) {
41 for test in coverage {
42 if test.passed {
43 self.spectrum_data.total_passed += 1;
44 for (loc, count) in &test.executed_lines {
45 *self.spectrum_data.passed_coverage.entry(loc.clone()).or_insert(0) += count;
46 }
47 } else {
48 self.spectrum_data.total_failed += 1;
49 for (loc, count) in &test.executed_lines {
50 *self.spectrum_data.failed_coverage.entry(loc.clone()).or_insert(0) += count;
51 }
52 }
53 }
54 }
55
56 pub fn add_static_finding(&mut self, file: &Path, line: usize, score: f64) {
58 self.static_findings.insert((file.to_path_buf(), line), score);
59 }
60
61 pub fn set_error_message(&mut self, msg: &str) {
63 self.error_message = Some(msg.to_string());
64 }
65
66 pub(crate) fn compute_semantic_score(&self, _file: &Path, line: usize, content: &str) -> f64 {
68 let Some(ref error_msg) = self.error_message else {
69 return 0.0;
70 };
71
72 let error_lower = error_msg.to_lowercase();
74 let error_words: Vec<&str> =
75 error_lower.split_whitespace().filter(|w| w.len() > 3).collect();
76
77 let line_content = content.lines().nth(line.saturating_sub(1)).unwrap_or("");
78 let line_lower = line_content.to_lowercase();
79
80 let matches = error_words.iter().filter(|w| line_lower.contains(*w)).count();
81
82 if error_words.is_empty() {
83 0.0
84 } else {
85 (matches as f64 / error_words.len() as f64).min(1.0)
86 }
87 }
88
89 pub fn localize(&self, project_path: &Path) -> Vec<ScoredLocation> {
91 let mut locations: HashMap<(PathBuf, usize), ScoredLocation> = HashMap::new();
92
93 for key in self.spectrum_data.failed_coverage.keys() {
95 locations
96 .entry(key.clone())
97 .or_insert_with(|| ScoredLocation::new(key.0.clone(), key.1));
98 }
99 for key in self.mutation_data.mutants.keys() {
100 locations
101 .entry(key.clone())
102 .or_insert_with(|| ScoredLocation::new(key.0.clone(), key.1));
103 }
104 for key in self.static_findings.keys() {
105 locations
106 .entry(key.clone())
107 .or_insert_with(|| ScoredLocation::new(key.0.clone(), key.1));
108 }
109
110 let mut result: Vec<ScoredLocation> = locations
112 .into_iter()
113 .map(|(key, mut loc)| {
114 loc.spectrum_score =
116 self.spectrum_data.compute_score(&key.0, key.1, self.sbfl_formula);
117
118 loc.mutation_score = self.mutation_data.compute_score(&key.0, key.1);
120
121 loc.static_score = *self.static_findings.get(&key).unwrap_or(&0.0);
123
124 if self.error_message.is_some() {
126 let file_path = project_path.join(&key.0);
127 if let Ok(content) = std::fs::read_to_string(&file_path) {
128 loc.semantic_score = self.compute_semantic_score(&key.0, key.1, &content);
129 }
130 }
131
132 match self.strategy {
134 LocalizationStrategy::Sbfl => {
135 loc.final_score = loc.spectrum_score;
136 }
137 LocalizationStrategy::Mbfl => {
138 loc.final_score = loc.mutation_score;
139 }
140 LocalizationStrategy::Causal => {
141 loc.final_score = loc.spectrum_score;
144 }
145 LocalizationStrategy::MultiChannel | LocalizationStrategy::Hybrid => {
146 loc.compute_final_score(&self.weights);
147 }
148 }
149
150 loc
151 })
152 .collect();
153
154 result.sort_by(|a, b| {
156 b.final_score.partial_cmp(&a.final_score).unwrap_or(std::cmp::Ordering::Equal)
157 });
158
159 result
160 }
161}