Skip to main content

verificar/ml/
aprender.rs

1//! aprender integration for bug prediction
2//!
3//! This module provides integration with the aprender ML library
4//! for training and inference of bug prediction models.
5
6#[cfg(feature = "ml")]
7use aprender::tree::RandomForestClassifier;
8#[cfg(feature = "ml")]
9use aprender::Matrix;
10
11use crate::data::CodeFeatures;
12use crate::Result;
13
14/// Trained bug prediction model using aprender RandomForest
15#[derive(Debug)]
16pub struct AprenderBugPredictor {
17    #[cfg(feature = "ml")]
18    model: RandomForestClassifier,
19    #[cfg(not(feature = "ml"))]
20    _phantom: std::marker::PhantomData<()>,
21}
22
23impl AprenderBugPredictor {
24    /// Train a new bug prediction model
25    ///
26    /// # Arguments
27    ///
28    /// * `features` - Training features extracted from code
29    /// * `labels` - Ground truth labels (true if bug, false if correct)
30    ///
31    /// # Errors
32    ///
33    /// Returns error if training fails
34    #[cfg(feature = "ml")]
35    pub fn train(features: &[CodeFeatures], labels: &[bool]) -> Result<Self> {
36        // Convert CodeFeatures to Matrix<f32>
37        let n_samples = features.len();
38        let n_features = 5; // ast_depth, num_operators, num_control_flow, cyclomatic_complexity, uses_edge_values
39
40        let mut data = Vec::with_capacity(n_samples * n_features);
41        for f in features {
42            data.push(f.ast_depth as f32);
43            data.push(f.num_operators as f32);
44            data.push(f.num_control_flow as f32);
45            data.push(f.cyclomatic_complexity);
46            data.push(if f.uses_edge_values { 1.0 } else { 0.0 });
47        }
48
49        let x = Matrix::from_vec(n_samples, n_features, data)
50            .map_err(|e| crate::Error::Data(format!("Failed to create matrix: {e}")))?;
51
52        let y: Vec<usize> = labels.iter().map(|&b| usize::from(b)).collect();
53
54        let mut model = RandomForestClassifier::new(100)
55            .with_max_depth(10)
56            .with_random_state(42);
57
58        model
59            .fit(&x, &y)
60            .map_err(|e| crate::Error::Data(format!("Training failed: {e}")))?;
61
62        Ok(Self { model })
63    }
64
65    /// Train a new bug prediction model (no-op without ml feature)
66    ///
67    /// # Errors
68    ///
69    /// Always returns error without 'ml' feature enabled
70    #[cfg(not(feature = "ml"))]
71    pub fn train(_features: &[CodeFeatures], _labels: &[bool]) -> Result<Self> {
72        Err(crate::Error::Data(
73            "aprender integration requires 'ml' feature".to_string(),
74        ))
75    }
76
77    /// Predict probability of a bug
78    ///
79    /// Returns probability in range [0, 1]
80    ///
81    /// Note: Currently returns hard predictions (0.0 or 1.0) as aprender's
82    /// RandomForestClassifier doesn't expose predict_proba yet.
83    #[cfg(feature = "ml")]
84    pub fn predict(&self, features: &CodeFeatures) -> f32 {
85        let data = vec![
86            features.ast_depth as f32,
87            features.num_operators as f32,
88            features.num_control_flow as f32,
89            features.cyclomatic_complexity,
90            if features.uses_edge_values { 1.0 } else { 0.0 },
91        ];
92
93        let Ok(x) = Matrix::from_vec(1, 5, data) else {
94            return 0.5; // Fallback
95        };
96
97        let predictions = self.model.predict(&x);
98        // Convert class label (0 or 1) to probability
99        if predictions.is_empty() {
100            0.5
101        } else {
102            predictions[0] as f32
103        }
104    }
105
106    /// Predict probability of a bug (fallback without ml feature)
107    #[cfg(not(feature = "ml"))]
108    pub fn predict(&self, _features: &CodeFeatures) -> f32 {
109        0.5 // Neutral probability
110    }
111
112    /// Save model to file
113    ///
114    /// # Errors
115    ///
116    /// Returns error if serialization fails
117    ///
118    /// Note: Model persistence requires serde serialization support in aprender.
119    /// Planned for future release with SafeTensors format.
120    pub fn save(&self, _path: &str) -> Result<()> {
121        Err(crate::Error::Data(
122            "Model serialization not yet implemented".to_string(),
123        ))
124    }
125
126    /// Load model from file
127    ///
128    /// # Errors
129    ///
130    /// Returns error if deserialization fails
131    ///
132    /// Note: Model persistence requires serde serialization support in aprender.
133    /// Planned for future release with SafeTensors format.
134    pub fn load(_path: &str) -> Result<Self> {
135        Err(crate::Error::Data(
136            "Model deserialization not yet implemented".to_string(),
137        ))
138    }
139}
140
141#[cfg(all(test, feature = "ml"))]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_train_and_predict() {
147        let features = vec![
148            CodeFeatures {
149                ast_depth: 5,
150                num_operators: 10,
151                num_control_flow: 2,
152                cyclomatic_complexity: 3.0,
153                uses_edge_values: false,
154                ..Default::default()
155            },
156            CodeFeatures {
157                ast_depth: 10,
158                num_operators: 50,
159                num_control_flow: 10,
160                cyclomatic_complexity: 15.0,
161                uses_edge_values: true,
162                ..Default::default()
163            },
164        ];
165
166        let labels = vec![false, true]; // Second one is buggy
167
168        let predictor = AprenderBugPredictor::train(&features, &labels).unwrap();
169
170        // Predict on new data
171        let test_simple = CodeFeatures {
172            ast_depth: 3,
173            num_operators: 5,
174            num_control_flow: 1,
175            cyclomatic_complexity: 2.0,
176            uses_edge_values: false,
177            ..Default::default()
178        };
179
180        let prob = predictor.predict(&test_simple);
181        assert!((0.0..=1.0).contains(&prob));
182    }
183}