1#[cfg(feature = "ml")]
7use aprender::tree::RandomForestClassifier;
8#[cfg(feature = "ml")]
9use aprender::Matrix;
10
11use crate::data::CodeFeatures;
12use crate::Result;
13
14#[derive(Debug)]
16pub struct AprenderBugPredictor {
17 #[cfg(feature = "ml")]
18 model: RandomForestClassifier,
19 #[cfg(not(feature = "ml"))]
20 _phantom: std::marker::PhantomData<()>,
21}
22
23impl AprenderBugPredictor {
24 #[cfg(feature = "ml")]
35 pub fn train(features: &[CodeFeatures], labels: &[bool]) -> Result<Self> {
36 let n_samples = features.len();
38 let n_features = 5; let mut data = Vec::with_capacity(n_samples * n_features);
41 for f in features {
42 data.push(f.ast_depth as f32);
43 data.push(f.num_operators as f32);
44 data.push(f.num_control_flow as f32);
45 data.push(f.cyclomatic_complexity);
46 data.push(if f.uses_edge_values { 1.0 } else { 0.0 });
47 }
48
49 let x = Matrix::from_vec(n_samples, n_features, data)
50 .map_err(|e| crate::Error::Data(format!("Failed to create matrix: {e}")))?;
51
52 let y: Vec<usize> = labels.iter().map(|&b| usize::from(b)).collect();
53
54 let mut model = RandomForestClassifier::new(100)
55 .with_max_depth(10)
56 .with_random_state(42);
57
58 model
59 .fit(&x, &y)
60 .map_err(|e| crate::Error::Data(format!("Training failed: {e}")))?;
61
62 Ok(Self { model })
63 }
64
65 #[cfg(not(feature = "ml"))]
71 pub fn train(_features: &[CodeFeatures], _labels: &[bool]) -> Result<Self> {
72 Err(crate::Error::Data(
73 "aprender integration requires 'ml' feature".to_string(),
74 ))
75 }
76
77 #[cfg(feature = "ml")]
84 pub fn predict(&self, features: &CodeFeatures) -> f32 {
85 let data = vec![
86 features.ast_depth as f32,
87 features.num_operators as f32,
88 features.num_control_flow as f32,
89 features.cyclomatic_complexity,
90 if features.uses_edge_values { 1.0 } else { 0.0 },
91 ];
92
93 let Ok(x) = Matrix::from_vec(1, 5, data) else {
94 return 0.5; };
96
97 let predictions = self.model.predict(&x);
98 if predictions.is_empty() {
100 0.5
101 } else {
102 predictions[0] as f32
103 }
104 }
105
106 #[cfg(not(feature = "ml"))]
108 pub fn predict(&self, _features: &CodeFeatures) -> f32 {
109 0.5 }
111
112 pub fn save(&self, _path: &str) -> Result<()> {
121 Err(crate::Error::Data(
122 "Model serialization not yet implemented".to_string(),
123 ))
124 }
125
126 pub fn load(_path: &str) -> Result<Self> {
135 Err(crate::Error::Data(
136 "Model deserialization not yet implemented".to_string(),
137 ))
138 }
139}
140
141#[cfg(all(test, feature = "ml"))]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_train_and_predict() {
147 let features = vec![
148 CodeFeatures {
149 ast_depth: 5,
150 num_operators: 10,
151 num_control_flow: 2,
152 cyclomatic_complexity: 3.0,
153 uses_edge_values: false,
154 ..Default::default()
155 },
156 CodeFeatures {
157 ast_depth: 10,
158 num_operators: 50,
159 num_control_flow: 10,
160 cyclomatic_complexity: 15.0,
161 uses_edge_values: true,
162 ..Default::default()
163 },
164 ];
165
166 let labels = vec![false, true]; let predictor = AprenderBugPredictor::train(&features, &labels).unwrap();
169
170 let test_simple = CodeFeatures {
172 ast_depth: 3,
173 num_operators: 5,
174 num_control_flow: 1,
175 cyclomatic_complexity: 2.0,
176 uses_edge_values: false,
177 ..Default::default()
178 };
179
180 let prob = predictor.predict(&test_simple);
181 assert!((0.0..=1.0).contains(&prob));
182 }
183}