1use crate::transpiler::{CodeFeatures, TranspilerVerdict};
6use std::path::Path;
7
8#[derive(Debug, Clone)]
10pub struct TrainingExample {
11 pub features: CodeFeatures,
13 pub is_bug: bool,
15}
16
17#[derive(Debug, Clone)]
19pub struct TrainingConfig {
20 pub train_ratio: f64,
22 pub cv_folds: usize,
24 pub seed: u64,
26 pub min_examples: usize,
28}
29
30impl Default for TrainingConfig {
31 fn default() -> Self {
32 Self {
33 train_ratio: 0.8,
34 cv_folds: 5,
35 seed: 42,
36 min_examples: 100,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct TrainingMetrics {
44 pub accuracy: f64,
46 pub precision: f64,
48 pub recall: f64,
50 pub f1_score: f64,
52 pub auc_roc: f64,
54 pub train_size: usize,
56 pub test_size: usize,
58}
59
60impl TrainingMetrics {
61 #[must_use]
63 pub fn calculate_f1(precision: f64, recall: f64) -> f64 {
64 if precision + recall == 0.0 {
65 0.0
66 } else {
67 2.0 * precision * recall / (precision + recall)
68 }
69 }
70}
71
72#[derive(Debug, Clone, Default)]
74pub struct CrossValidationResults {
75 pub fold_metrics: Vec<TrainingMetrics>,
77 pub mean_accuracy: f64,
79 pub std_accuracy: f64,
81 pub mean_f1: f64,
83}
84
85impl CrossValidationResults {
86 #[must_use]
88 pub fn summarize(fold_metrics: Vec<TrainingMetrics>) -> Self {
89 if fold_metrics.is_empty() {
90 return Self::default();
91 }
92
93 let n = fold_metrics.len() as f64;
94 let mean_accuracy = fold_metrics.iter().map(|m| m.accuracy).sum::<f64>() / n;
95 let mean_f1 = fold_metrics.iter().map(|m| m.f1_score).sum::<f64>() / n;
96
97 let variance = fold_metrics
98 .iter()
99 .map(|m| (m.accuracy - mean_accuracy).powi(2))
100 .sum::<f64>()
101 / n;
102 let std_accuracy = variance.sqrt();
103
104 Self {
105 fold_metrics,
106 mean_accuracy,
107 std_accuracy,
108 mean_f1,
109 }
110 }
111}
112
113pub trait TrainedModel: Send + Sync {
115 fn predict(&self, features: &CodeFeatures) -> f64;
117
118 fn save(&self, path: &Path) -> std::io::Result<()>;
124
125 fn metadata(&self) -> ModelMetadata;
127}
128
129#[derive(Debug, Clone)]
131pub struct ModelMetadata {
132 pub model_type: String,
134 pub trained_at: String,
136 pub train_examples: usize,
138 pub metrics: TrainingMetrics,
140}
141
142pub trait ModelTrainer {
144 fn train(
150 &self,
151 examples: &[TrainingExample],
152 config: &TrainingConfig,
153 ) -> Result<Box<dyn TrainedModel>, TrainingError>;
154
155 fn cross_validate(
161 &self,
162 examples: &[TrainingExample],
163 config: &TrainingConfig,
164 ) -> Result<CrossValidationResults, TrainingError>;
165}
166
167#[derive(Debug, Clone)]
169pub enum TrainingError {
170 InsufficientData {
172 required: usize,
174 provided: usize,
176 },
177 InvalidConfig(String),
179 TrainingFailed(String),
181 IoError(String),
183}
184
185impl std::fmt::Display for TrainingError {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 match self {
188 Self::InsufficientData { required, provided } => {
189 write!(f, "Insufficient data: need {required}, got {provided}")
190 }
191 Self::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
192 Self::TrainingFailed(msg) => write!(f, "Training failed: {msg}"),
193 Self::IoError(msg) => write!(f, "IO error: {msg}"),
194 }
195 }
196}
197
198impl std::error::Error for TrainingError {}
199
200#[must_use]
202pub fn verdict_to_label(verdict: &TranspilerVerdict) -> bool {
203 !matches!(verdict, TranspilerVerdict::Pass)
204}
205
206#[must_use]
208pub fn train_test_split(
209 examples: &[TrainingExample],
210 train_ratio: f64,
211 seed: u64,
212) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
213 use std::collections::hash_map::DefaultHasher;
214 use std::hash::{Hash, Hasher};
215
216 let mut train = Vec::new();
217 let mut test = Vec::new();
218
219 for (i, example) in examples.iter().enumerate() {
220 let mut hasher = DefaultHasher::new();
221 (seed, i).hash(&mut hasher);
222 let hash = hasher.finish();
223 #[allow(clippy::cast_sign_loss)]
224 let threshold = (train_ratio * u64::MAX as f64) as u64;
225
226 if hash < threshold {
227 train.push(example.clone());
228 } else {
229 test.push(example.clone());
230 }
231 }
232
233 (train, test)
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 fn sample_examples(n: usize) -> Vec<TrainingExample> {
241 (0..n)
242 .map(|i| TrainingExample {
243 features: CodeFeatures {
244 ast_depth: i % 5,
245 cyclomatic_complexity: i % 10,
246 ..Default::default()
247 },
248 is_bug: i % 3 == 0,
249 })
250 .collect()
251 }
252
253 #[test]
254 fn test_training_config_default() {
255 let config = TrainingConfig::default();
256 assert_eq!(config.train_ratio, 0.8);
257 assert_eq!(config.cv_folds, 5);
258 assert_eq!(config.min_examples, 100);
259 }
260
261 #[test]
262 fn test_training_metrics_f1() {
263 assert_eq!(TrainingMetrics::calculate_f1(0.8, 0.6), 0.6857142857142857);
264 assert_eq!(TrainingMetrics::calculate_f1(0.0, 0.0), 0.0);
265 assert_eq!(TrainingMetrics::calculate_f1(1.0, 1.0), 1.0);
266 }
267
268 #[test]
269 fn test_cross_validation_summarize() {
270 let folds = vec![
271 TrainingMetrics {
272 accuracy: 0.8,
273 f1_score: 0.75,
274 ..Default::default()
275 },
276 TrainingMetrics {
277 accuracy: 0.85,
278 f1_score: 0.80,
279 ..Default::default()
280 },
281 TrainingMetrics {
282 accuracy: 0.9,
283 f1_score: 0.85,
284 ..Default::default()
285 },
286 ];
287
288 let cv = CrossValidationResults::summarize(folds);
289 assert!((cv.mean_accuracy - 0.85).abs() < 0.001);
290 assert!((cv.mean_f1 - 0.8).abs() < 0.001);
291 assert!(cv.std_accuracy > 0.0);
292 }
293
294 #[test]
295 fn test_cross_validation_empty() {
296 let cv = CrossValidationResults::summarize(vec![]);
297 assert_eq!(cv.mean_accuracy, 0.0);
298 assert_eq!(cv.fold_metrics.len(), 0);
299 }
300
301 #[test]
302 fn test_verdict_to_label() {
303 assert!(!verdict_to_label(&TranspilerVerdict::Pass));
304 assert!(verdict_to_label(&TranspilerVerdict::OutputMismatch));
305 assert!(verdict_to_label(&TranspilerVerdict::TranspileError(
306 "err".into()
307 )));
308 assert!(verdict_to_label(&TranspilerVerdict::Timeout));
309 }
310
311 #[test]
312 fn test_train_test_split_ratio() {
313 let examples = sample_examples(1000);
314 let (train, test) = train_test_split(&examples, 0.8, 42);
315
316 let train_ratio = train.len() as f64 / examples.len() as f64;
318 assert!(train_ratio > 0.7 && train_ratio < 0.9);
319 assert_eq!(train.len() + test.len(), examples.len());
320 }
321
322 #[test]
323 fn test_train_test_split_deterministic() {
324 let examples = sample_examples(100);
325 let (train1, _) = train_test_split(&examples, 0.8, 42);
326 let (train2, _) = train_test_split(&examples, 0.8, 42);
327
328 assert_eq!(train1.len(), train2.len());
329 }
330
331 #[test]
332 fn test_training_error_display() {
333 let err = TrainingError::InsufficientData {
334 required: 100,
335 provided: 50,
336 };
337 assert!(err.to_string().contains("100"));
338 assert!(err.to_string().contains("50"));
339 }
340
341 #[test]
342 fn test_model_metadata_clone() {
343 let meta = ModelMetadata {
344 model_type: "RandomForest".into(),
345 trained_at: "2025-01-01".into(),
346 train_examples: 1000,
347 metrics: TrainingMetrics::default(),
348 };
349 let cloned = meta.clone();
350 assert_eq!(cloned.model_type, meta.model_type);
351 }
352
353 #[test]
356 #[ignore = "requires aprender ml feature"]
357 fn test_random_forest_training() {
358 unimplemented!("RandomForest training not yet implemented")
364 }
365
366 #[test]
367 #[ignore = "requires aprender ml feature"]
368 fn test_cross_validation_with_model() {
369 unimplemented!("Cross-validation not yet implemented")
376 }
377
378 #[test]
379 #[ignore = "requires aprender ml feature"]
380 fn test_model_save_load() {
381 unimplemented!("Model save/load not yet implemented")
388 }
389
390 #[test]
391 #[ignore = "requires aprender ml feature"]
392 fn test_stratified_split() {
393 unimplemented!("Stratified split not yet implemented")
400 }
401}