Skip to main content

agentic_codebase/engine/
regression.rs

1//! Regression Oracle — Invention 3.
2//!
3//! Predict which tests are likely to fail based on a change, before running them.
4//! Uses the code graph to trace from changed units to test units.
5
6use serde::{Deserialize, Serialize};
7
8use crate::graph::traversal::{self, Direction, TraversalOptions};
9use crate::graph::CodeGraph;
10use crate::types::{CodeUnitType, EdgeType};
11
12// ── Types ────────────────────────────────────────────────────────────────────
13
14/// Prediction of test outcomes.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RegressionOracle {
17    /// The unit being changed.
18    pub changed_unit: u64,
19    /// Tests predicted to fail.
20    pub likely_failures: Vec<TestPrediction>,
21    /// Tests that should pass but are worth running.
22    pub recommended_tests: Vec<TestPrediction>,
23    /// Tests that are definitely unaffected.
24    pub safe_to_skip: Vec<TestId>,
25    /// Minimum test set for confidence.
26    pub minimum_test_set: Vec<TestId>,
27}
28
29/// A single test prediction.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TestPrediction {
32    /// The test.
33    pub test: TestId,
34    /// Probability of failure.
35    pub failure_probability: f64,
36    /// Why we think it might fail.
37    pub reason: String,
38    /// Path from change to test.
39    pub dependency_path: Vec<u64>,
40}
41
42/// Identifier for a test.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TestId {
45    pub file: String,
46    pub function: String,
47    pub line: u32,
48    pub unit_id: u64,
49}
50
51// ── RegressionPredictor ──────────────────────────────────────────────────────
52
53/// Predicts test outcomes based on code changes.
54pub struct RegressionPredictor<'g> {
55    graph: &'g CodeGraph,
56}
57
58impl<'g> RegressionPredictor<'g> {
59    pub fn new(graph: &'g CodeGraph) -> Self {
60        Self { graph }
61    }
62
63    /// Predict test outcomes for a given change.
64    pub fn predict(&self, changed_unit: u64, max_depth: u32) -> RegressionOracle {
65        // Find all test units in the graph
66        let all_tests: Vec<TestId> = self
67            .graph
68            .units()
69            .iter()
70            .filter(|u| u.unit_type == CodeUnitType::Test)
71            .map(|u| TestId {
72                file: u.file_path.display().to_string(),
73                function: u.name.clone(),
74                line: u.span.start_line,
75                unit_id: u.id,
76            })
77            .collect();
78
79        // BFS backward from changed unit to find dependents
80        let options = TraversalOptions {
81            max_depth: max_depth as i32,
82            edge_types: vec![
83                EdgeType::Calls,
84                EdgeType::Imports,
85                EdgeType::Tests,
86                EdgeType::References,
87                EdgeType::UsesType,
88            ],
89            direction: Direction::Backward,
90        };
91        let reachable = traversal::bfs(self.graph, changed_unit, &options);
92        let reachable_ids: std::collections::HashSet<u64> =
93            reachable.iter().map(|(id, _)| *id).collect();
94
95        let mut likely_failures = Vec::new();
96        let mut recommended_tests = Vec::new();
97        let mut safe_to_skip = Vec::new();
98
99        for test in &all_tests {
100            // Check if this test is directly connected
101            let directly_tests = self
102                .graph
103                .edges_from(test.unit_id)
104                .iter()
105                .any(|e| e.edge_type == EdgeType::Tests && e.target_id == changed_unit);
106
107            if directly_tests {
108                // Direct test of the changed unit — high failure probability
109                likely_failures.push(TestPrediction {
110                    test: test.clone(),
111                    failure_probability: 0.85,
112                    reason: "Directly tests the changed unit".to_string(),
113                    dependency_path: vec![test.unit_id, changed_unit],
114                });
115            } else if reachable_ids.contains(&test.unit_id) {
116                // Transitively connected
117                let depth = reachable
118                    .iter()
119                    .find(|(id, _)| *id == test.unit_id)
120                    .map(|(_, d)| *d)
121                    .unwrap_or(0);
122
123                let probability = 0.6 / (1.0 + depth as f64 * 0.3);
124
125                if probability > 0.3 {
126                    recommended_tests.push(TestPrediction {
127                        test: test.clone(),
128                        failure_probability: probability,
129                        reason: format!("Transitively depends on changed unit (depth {})", depth),
130                        dependency_path: vec![test.unit_id, changed_unit],
131                    });
132                } else {
133                    safe_to_skip.push(test.clone());
134                }
135            } else {
136                safe_to_skip.push(test.clone());
137            }
138        }
139
140        // Sort by failure probability descending
141        likely_failures.sort_by(|a, b| {
142            b.failure_probability
143                .partial_cmp(&a.failure_probability)
144                .unwrap_or(std::cmp::Ordering::Equal)
145        });
146        recommended_tests.sort_by(|a, b| {
147            b.failure_probability
148                .partial_cmp(&a.failure_probability)
149                .unwrap_or(std::cmp::Ordering::Equal)
150        });
151
152        // Minimum test set = likely failures + high-probability recommended
153        let minimum_test_set: Vec<TestId> = likely_failures
154            .iter()
155            .map(|p| p.test.clone())
156            .chain(
157                recommended_tests
158                    .iter()
159                    .filter(|p| p.failure_probability > 0.4)
160                    .map(|p| p.test.clone()),
161            )
162            .collect();
163
164        RegressionOracle {
165            changed_unit,
166            likely_failures,
167            recommended_tests,
168            safe_to_skip,
169            minimum_test_set,
170        }
171    }
172
173    /// Get the minimum test set needed for confidence.
174    pub fn minimal_test_set(&self, changed_unit: u64) -> Vec<TestId> {
175        self.predict(changed_unit, 5).minimum_test_set
176    }
177}
178
179// ── Tests ────────────────────────────────────────────────────────────────────
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::types::{CodeUnit, CodeUnitType, Edge, Language, Span};
185    use std::path::PathBuf;
186
187    fn test_graph() -> CodeGraph {
188        let mut graph = CodeGraph::with_default_dimension();
189        let func = graph.add_unit(CodeUnit::new(
190            CodeUnitType::Function,
191            Language::Rust,
192            "process".to_string(),
193            "mod::process".to_string(),
194            PathBuf::from("src/lib.rs"),
195            Span::new(10, 0, 30, 0),
196        ));
197        let test = graph.add_unit(CodeUnit::new(
198            CodeUnitType::Test,
199            Language::Rust,
200            "test_process".to_string(),
201            "mod::test_process".to_string(),
202            PathBuf::from("tests/test_lib.rs"),
203            Span::new(1, 0, 10, 0),
204        ));
205        let _ = graph.add_edge(Edge::new(test, func, EdgeType::Tests));
206        graph
207    }
208
209    #[test]
210    fn predict_finds_direct_test() {
211        let graph = test_graph();
212        let predictor = RegressionPredictor::new(&graph);
213        let oracle = predictor.predict(0, 5);
214        assert!(!oracle.likely_failures.is_empty());
215        assert!(oracle.likely_failures[0].failure_probability > 0.5);
216    }
217
218    #[test]
219    fn minimal_test_set_not_empty() {
220        let graph = test_graph();
221        let predictor = RegressionPredictor::new(&graph);
222        let minimal = predictor.minimal_test_set(0);
223        assert!(!minimal.is_empty());
224    }
225}