1use serde::{Deserialize, Serialize};
7
8use crate::graph::traversal::{self, Direction, TraversalOptions};
9use crate::graph::CodeGraph;
10use crate::types::{CodeUnitType, EdgeType};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RegressionOracle {
17 pub changed_unit: u64,
19 pub likely_failures: Vec<TestPrediction>,
21 pub recommended_tests: Vec<TestPrediction>,
23 pub safe_to_skip: Vec<TestId>,
25 pub minimum_test_set: Vec<TestId>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TestPrediction {
32 pub test: TestId,
34 pub failure_probability: f64,
36 pub reason: String,
38 pub dependency_path: Vec<u64>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TestId {
45 pub file: String,
46 pub function: String,
47 pub line: u32,
48 pub unit_id: u64,
49}
50
51pub struct RegressionPredictor<'g> {
55 graph: &'g CodeGraph,
56}
57
58impl<'g> RegressionPredictor<'g> {
59 pub fn new(graph: &'g CodeGraph) -> Self {
60 Self { graph }
61 }
62
63 pub fn predict(&self, changed_unit: u64, max_depth: u32) -> RegressionOracle {
65 let all_tests: Vec<TestId> = self
67 .graph
68 .units()
69 .iter()
70 .filter(|u| u.unit_type == CodeUnitType::Test)
71 .map(|u| TestId {
72 file: u.file_path.display().to_string(),
73 function: u.name.clone(),
74 line: u.span.start_line,
75 unit_id: u.id,
76 })
77 .collect();
78
79 let options = TraversalOptions {
81 max_depth: max_depth as i32,
82 edge_types: vec![
83 EdgeType::Calls,
84 EdgeType::Imports,
85 EdgeType::Tests,
86 EdgeType::References,
87 EdgeType::UsesType,
88 ],
89 direction: Direction::Backward,
90 };
91 let reachable = traversal::bfs(self.graph, changed_unit, &options);
92 let reachable_ids: std::collections::HashSet<u64> =
93 reachable.iter().map(|(id, _)| *id).collect();
94
95 let mut likely_failures = Vec::new();
96 let mut recommended_tests = Vec::new();
97 let mut safe_to_skip = Vec::new();
98
99 for test in &all_tests {
100 let directly_tests = self
102 .graph
103 .edges_from(test.unit_id)
104 .iter()
105 .any(|e| e.edge_type == EdgeType::Tests && e.target_id == changed_unit);
106
107 if directly_tests {
108 likely_failures.push(TestPrediction {
110 test: test.clone(),
111 failure_probability: 0.85,
112 reason: "Directly tests the changed unit".to_string(),
113 dependency_path: vec![test.unit_id, changed_unit],
114 });
115 } else if reachable_ids.contains(&test.unit_id) {
116 let depth = reachable
118 .iter()
119 .find(|(id, _)| *id == test.unit_id)
120 .map(|(_, d)| *d)
121 .unwrap_or(0);
122
123 let probability = 0.6 / (1.0 + depth as f64 * 0.3);
124
125 if probability > 0.3 {
126 recommended_tests.push(TestPrediction {
127 test: test.clone(),
128 failure_probability: probability,
129 reason: format!("Transitively depends on changed unit (depth {})", depth),
130 dependency_path: vec![test.unit_id, changed_unit],
131 });
132 } else {
133 safe_to_skip.push(test.clone());
134 }
135 } else {
136 safe_to_skip.push(test.clone());
137 }
138 }
139
140 likely_failures.sort_by(|a, b| {
142 b.failure_probability
143 .partial_cmp(&a.failure_probability)
144 .unwrap_or(std::cmp::Ordering::Equal)
145 });
146 recommended_tests.sort_by(|a, b| {
147 b.failure_probability
148 .partial_cmp(&a.failure_probability)
149 .unwrap_or(std::cmp::Ordering::Equal)
150 });
151
152 let minimum_test_set: Vec<TestId> = likely_failures
154 .iter()
155 .map(|p| p.test.clone())
156 .chain(
157 recommended_tests
158 .iter()
159 .filter(|p| p.failure_probability > 0.4)
160 .map(|p| p.test.clone()),
161 )
162 .collect();
163
164 RegressionOracle {
165 changed_unit,
166 likely_failures,
167 recommended_tests,
168 safe_to_skip,
169 minimum_test_set,
170 }
171 }
172
173 pub fn minimal_test_set(&self, changed_unit: u64) -> Vec<TestId> {
175 self.predict(changed_unit, 5).minimum_test_set
176 }
177}
178
179#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::types::{CodeUnit, CodeUnitType, Edge, Language, Span};
185 use std::path::PathBuf;
186
187 fn test_graph() -> CodeGraph {
188 let mut graph = CodeGraph::with_default_dimension();
189 let func = graph.add_unit(CodeUnit::new(
190 CodeUnitType::Function,
191 Language::Rust,
192 "process".to_string(),
193 "mod::process".to_string(),
194 PathBuf::from("src/lib.rs"),
195 Span::new(10, 0, 30, 0),
196 ));
197 let test = graph.add_unit(CodeUnit::new(
198 CodeUnitType::Test,
199 Language::Rust,
200 "test_process".to_string(),
201 "mod::test_process".to_string(),
202 PathBuf::from("tests/test_lib.rs"),
203 Span::new(1, 0, 10, 0),
204 ));
205 let _ = graph.add_edge(Edge::new(test, func, EdgeType::Tests));
206 graph
207 }
208
209 #[test]
210 fn predict_finds_direct_test() {
211 let graph = test_graph();
212 let predictor = RegressionPredictor::new(&graph);
213 let oracle = predictor.predict(0, 5);
214 assert!(!oracle.likely_failures.is_empty());
215 assert!(oracle.likely_failures[0].failure_probability > 0.5);
216 }
217
218 #[test]
219 fn minimal_test_set_not_empty() {
220 let graph = test_graph();
221 let predictor = RegressionPredictor::new(&graph);
222 let minimal = predictor.minimal_test_set(0);
223 assert!(!minimal.is_empty());
224 }
225}