Skip to main content

verificar/transpiler/
ml_oracle.rs

1//! ML-based enhancements for TranspilerOracle
2//!
3//! Integrates aprender ML capabilities for smarter test prioritization
4//! and bug prediction. See GH-2.
5
6use crate::generator::GeneratedCode;
7
8/// Features extracted from source code for ML prediction
9#[derive(Debug, Clone, Default)]
10pub struct CodeFeatures {
11    /// AST depth
12    pub ast_depth: usize,
13    /// Number of AST nodes
14    pub node_count: usize,
15    /// Cyclomatic complexity estimate
16    pub cyclomatic_complexity: usize,
17    /// Number of unique identifiers
18    pub identifier_count: usize,
19    /// Number of function calls
20    pub call_count: usize,
21    /// Has loops
22    pub has_loops: bool,
23    /// Has conditionals
24    pub has_conditionals: bool,
25    /// Has exception handling
26    pub has_exceptions: bool,
27}
28
29impl CodeFeatures {
30    /// Extract features from a generated program
31    #[must_use]
32    pub fn from_program(program: &GeneratedCode) -> Self {
33        Self {
34            ast_depth: program.ast_depth,
35            node_count: program.code.lines().count(),
36            cyclomatic_complexity: estimate_complexity(&program.code),
37            identifier_count: count_identifiers(&program.code),
38            call_count: count_calls(&program.code),
39            has_loops: program.code.contains("for") || program.code.contains("while"),
40            has_conditionals: program.code.contains("if"),
41            has_exceptions: program.code.contains("try") || program.code.contains("except"),
42        }
43    }
44
45    /// Convert to feature vector for ML
46    #[must_use]
47    pub fn to_vec(&self) -> Vec<f64> {
48        vec![
49            self.ast_depth as f64,
50            self.node_count as f64,
51            self.cyclomatic_complexity as f64,
52            self.identifier_count as f64,
53            self.call_count as f64,
54            if self.has_loops { 1.0 } else { 0.0 },
55            if self.has_conditionals { 1.0 } else { 0.0 },
56            if self.has_exceptions { 1.0 } else { 0.0 },
57        ]
58    }
59}
60
61/// Estimate cyclomatic complexity from code
62fn estimate_complexity(code: &str) -> usize {
63    let mut complexity = 1;
64    for keyword in ["if", "elif", "for", "while", "and", "or", "except"] {
65        complexity += code.matches(keyword).count();
66    }
67    complexity
68}
69
70/// Count unique identifiers in code
71fn count_identifiers(code: &str) -> usize {
72    code.split(|c: char| !c.is_alphanumeric() && c != '_')
73        .filter(|s| !s.is_empty() && s.chars().next().is_some_and(char::is_alphabetic))
74        .collect::<std::collections::HashSet<_>>()
75        .len()
76}
77
78/// Count function calls in code
79fn count_calls(code: &str) -> usize {
80    code.matches('(').count()
81}
82
83/// Bug predictor using ML model
84pub trait BugPredictor: Send + Sync {
85    /// Predict probability that a test case will expose a bug
86    fn predict_bug_probability(&self, features: &CodeFeatures) -> f64;
87
88    /// Batch prediction for efficiency
89    fn predict_batch(&self, features: &[CodeFeatures]) -> Vec<f64> {
90        features
91            .iter()
92            .map(|f| self.predict_bug_probability(f))
93            .collect()
94    }
95}
96
97/// Test prioritizer using similarity
98pub trait TestPrioritizer: Send + Sync {
99    /// Prioritize test cases by similarity to known failing tests
100    fn prioritize(&self, tests: &[GeneratedCode], k: usize) -> Vec<usize>;
101
102    /// Add a failing test to the index
103    fn add_failing_test(&mut self, test: &GeneratedCode);
104
105    /// Number of failing tests in index
106    fn failing_count(&self) -> usize;
107}
108
109/// Simple baseline predictor (always returns 0.5)
110#[derive(Debug, Clone, Default)]
111pub struct BaselinePredictor;
112
113impl BugPredictor for BaselinePredictor {
114    fn predict_bug_probability(&self, _features: &CodeFeatures) -> f64 {
115        0.5
116    }
117}
118
119/// Complexity-based predictor (higher complexity = higher bug probability)
120#[derive(Debug, Clone)]
121pub struct ComplexityPredictor {
122    /// Complexity threshold for high bug probability
123    pub threshold: usize,
124}
125
126impl Default for ComplexityPredictor {
127    fn default() -> Self {
128        Self { threshold: 5 }
129    }
130}
131
132impl BugPredictor for ComplexityPredictor {
133    fn predict_bug_probability(&self, features: &CodeFeatures) -> f64 {
134        let score = features.cyclomatic_complexity as f64 / self.threshold as f64;
135        score.min(1.0)
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::Language;
143
144    fn sample_program() -> GeneratedCode {
145        GeneratedCode {
146            code: "def foo(x):\n    if x > 0:\n        return x\n    return 0".to_string(),
147            language: Language::Python,
148            ast_depth: 3,
149            features: vec!["function".to_string(), "conditional".to_string()],
150        }
151    }
152
153    #[test]
154    fn test_code_features_from_program() {
155        let program = sample_program();
156        let features = CodeFeatures::from_program(&program);
157
158        assert_eq!(features.ast_depth, 3);
159        assert!(features.has_conditionals);
160        assert!(!features.has_loops);
161        assert!(features.node_count > 0);
162    }
163
164    #[test]
165    fn test_code_features_to_vec() {
166        let features = CodeFeatures {
167            ast_depth: 3,
168            node_count: 10,
169            cyclomatic_complexity: 5,
170            identifier_count: 8,
171            call_count: 2,
172            has_loops: true,
173            has_conditionals: true,
174            has_exceptions: false,
175        };
176
177        let vec = features.to_vec();
178        assert_eq!(vec.len(), 8);
179        assert_eq!(vec[0], 3.0); // ast_depth
180        assert_eq!(vec[5], 1.0); // has_loops
181        assert_eq!(vec[7], 0.0); // has_exceptions
182    }
183
184    #[test]
185    fn test_estimate_complexity() {
186        let simple = "x = 1";
187        let complex = "if x and y or z:\n    for i in range(10):\n        pass";
188
189        assert!(estimate_complexity(complex) > estimate_complexity(simple));
190    }
191
192    #[test]
193    fn test_count_identifiers() {
194        let code = "def foo(x, y):\n    return x + y";
195        let count = count_identifiers(code);
196        assert!(count >= 4); // def, foo, x, y, return
197    }
198
199    #[test]
200    fn test_baseline_predictor() {
201        let predictor = BaselinePredictor;
202        let features = CodeFeatures::default();
203
204        assert_eq!(predictor.predict_bug_probability(&features), 0.5);
205    }
206
207    #[test]
208    fn test_complexity_predictor() {
209        let predictor = ComplexityPredictor { threshold: 5 };
210
211        let low = CodeFeatures {
212            cyclomatic_complexity: 2,
213            ..Default::default()
214        };
215        let high = CodeFeatures {
216            cyclomatic_complexity: 10,
217            ..Default::default()
218        };
219
220        assert!(predictor.predict_bug_probability(&low) < 0.5);
221        assert_eq!(predictor.predict_bug_probability(&high), 1.0);
222    }
223
224    #[test]
225    fn test_batch_prediction() {
226        let predictor = BaselinePredictor;
227        let features = vec![CodeFeatures::default(), CodeFeatures::default()];
228
229        let predictions = predictor.predict_batch(&features);
230        assert_eq!(predictions.len(), 2);
231        assert!(predictions.iter().all(|&p| p == 0.5));
232    }
233
234    // RED PHASE: These tests should fail until we implement aprender integration
235
236    #[test]
237    #[ignore = "requires aprender ml feature"]
238    fn test_random_forest_predictor() {
239        // TODO: Implement RandomForestPredictor using aprender
240        // let predictor = RandomForestPredictor::train(&training_data);
241        // assert!(predictor.predict_bug_probability(&features) >= 0.0);
242        // assert!(predictor.predict_bug_probability(&features) <= 1.0);
243        unimplemented!("RandomForestPredictor not yet implemented")
244    }
245
246    #[test]
247    #[ignore = "requires aprender ml feature"]
248    fn test_hnsw_prioritizer() {
249        // TODO: Implement HNSWPrioritizer using aprender
250        // let mut prioritizer = HNSWPrioritizer::new();
251        // prioritizer.add_failing_test(&failing_test);
252        // let priorities = prioritizer.prioritize(&tests, 5);
253        // assert_eq!(priorities.len(), 5);
254        unimplemented!("HNSWPrioritizer not yet implemented")
255    }
256
257    #[test]
258    #[ignore = "requires aprender ml feature"]
259    fn test_incremental_learning() {
260        // TODO: Implement incremental model updates
261        // let mut predictor = IncrementalPredictor::new();
262        // predictor.update(&new_examples);
263        // assert!(predictor.version() > 0);
264        unimplemented!("IncrementalPredictor not yet implemented")
265    }
266}