Skip to main content

verificar/ml/
training.rs

1//! Bug prediction model training pipeline
2//!
3//! Trains ML models on verification data. See VERIFICAR-051.
4
5use crate::transpiler::{CodeFeatures, TranspilerVerdict};
6use std::path::Path;
7
8/// Training example: features + label (bug or not)
9#[derive(Debug, Clone)]
10pub struct TrainingExample {
11    /// Extracted code features
12    pub features: CodeFeatures,
13    /// True if this example exposed a bug (non-Pass verdict)
14    pub is_bug: bool,
15}
16
17/// Training configuration
18#[derive(Debug, Clone)]
19pub struct TrainingConfig {
20    /// Train/test split ratio (0.0 to 1.0)
21    pub train_ratio: f64,
22    /// Number of cross-validation folds
23    pub cv_folds: usize,
24    /// Random seed for reproducibility
25    pub seed: u64,
26    /// Minimum examples required for training
27    pub min_examples: usize,
28}
29
30impl Default for TrainingConfig {
31    fn default() -> Self {
32        Self {
33            train_ratio: 0.8,
34            cv_folds: 5,
35            seed: 42,
36            min_examples: 100,
37        }
38    }
39}
40
41/// Training metrics
42#[derive(Debug, Clone, Default)]
43pub struct TrainingMetrics {
44    /// Accuracy (correct / total)
45    pub accuracy: f64,
46    /// Precision (TP / (TP + FP))
47    pub precision: f64,
48    /// Recall (TP / (TP + FN))
49    pub recall: f64,
50    /// F1 score (harmonic mean of precision and recall)
51    pub f1_score: f64,
52    /// Area under ROC curve
53    pub auc_roc: f64,
54    /// Number of training examples
55    pub train_size: usize,
56    /// Number of test examples
57    pub test_size: usize,
58}
59
60impl TrainingMetrics {
61    /// Calculate F1 from precision and recall
62    #[must_use]
63    pub fn calculate_f1(precision: f64, recall: f64) -> f64 {
64        if precision + recall == 0.0 {
65            0.0
66        } else {
67            2.0 * precision * recall / (precision + recall)
68        }
69    }
70}
71
72/// Cross-validation results
73#[derive(Debug, Clone, Default)]
74pub struct CrossValidationResults {
75    /// Metrics for each fold
76    pub fold_metrics: Vec<TrainingMetrics>,
77    /// Mean accuracy across folds
78    pub mean_accuracy: f64,
79    /// Std deviation of accuracy
80    pub std_accuracy: f64,
81    /// Mean F1 across folds
82    pub mean_f1: f64,
83}
84
85impl CrossValidationResults {
86    /// Calculate summary statistics from fold metrics
87    #[must_use]
88    pub fn summarize(fold_metrics: Vec<TrainingMetrics>) -> Self {
89        if fold_metrics.is_empty() {
90            return Self::default();
91        }
92
93        let n = fold_metrics.len() as f64;
94        let mean_accuracy = fold_metrics.iter().map(|m| m.accuracy).sum::<f64>() / n;
95        let mean_f1 = fold_metrics.iter().map(|m| m.f1_score).sum::<f64>() / n;
96
97        let variance = fold_metrics
98            .iter()
99            .map(|m| (m.accuracy - mean_accuracy).powi(2))
100            .sum::<f64>()
101            / n;
102        let std_accuracy = variance.sqrt();
103
104        Self {
105            fold_metrics,
106            mean_accuracy,
107            std_accuracy,
108            mean_f1,
109        }
110    }
111}
112
113/// Trained model that can be saved/loaded
114pub trait TrainedModel: Send + Sync {
115    /// Predict bug probability for features
116    fn predict(&self, features: &CodeFeatures) -> f64;
117
118    /// Save model to file
119    ///
120    /// # Errors
121    ///
122    /// Returns IO error if save fails
123    fn save(&self, path: &Path) -> std::io::Result<()>;
124
125    /// Model version/metadata
126    fn metadata(&self) -> ModelMetadata;
127}
128
129/// Model metadata for serialization
130#[derive(Debug, Clone)]
131pub struct ModelMetadata {
132    /// Model type name
133    pub model_type: String,
134    /// Training timestamp
135    pub trained_at: String,
136    /// Number of training examples
137    pub train_examples: usize,
138    /// Training metrics
139    pub metrics: TrainingMetrics,
140}
141
142/// Model trainer trait
143pub trait ModelTrainer {
144    /// Train model on examples
145    ///
146    /// # Errors
147    ///
148    /// Returns error if training fails or insufficient data
149    fn train(
150        &self,
151        examples: &[TrainingExample],
152        config: &TrainingConfig,
153    ) -> Result<Box<dyn TrainedModel>, TrainingError>;
154
155    /// Run cross-validation
156    ///
157    /// # Errors
158    ///
159    /// Returns error if cross-validation fails
160    fn cross_validate(
161        &self,
162        examples: &[TrainingExample],
163        config: &TrainingConfig,
164    ) -> Result<CrossValidationResults, TrainingError>;
165}
166
167/// Training errors
168#[derive(Debug, Clone)]
169pub enum TrainingError {
170    /// Not enough training examples
171    InsufficientData {
172        /// Minimum required examples
173        required: usize,
174        /// Actually provided examples
175        provided: usize,
176    },
177    /// Invalid configuration
178    InvalidConfig(String),
179    /// Model training failed
180    TrainingFailed(String),
181    /// IO error during save/load
182    IoError(String),
183}
184
185impl std::fmt::Display for TrainingError {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        match self {
188            Self::InsufficientData { required, provided } => {
189                write!(f, "Insufficient data: need {required}, got {provided}")
190            }
191            Self::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
192            Self::TrainingFailed(msg) => write!(f, "Training failed: {msg}"),
193            Self::IoError(msg) => write!(f, "IO error: {msg}"),
194        }
195    }
196}
197
198impl std::error::Error for TrainingError {}
199
200/// Convert verdict to bug label
201#[must_use]
202pub fn verdict_to_label(verdict: &TranspilerVerdict) -> bool {
203    !matches!(verdict, TranspilerVerdict::Pass)
204}
205
206/// Split examples into train/test sets
207#[must_use]
208pub fn train_test_split(
209    examples: &[TrainingExample],
210    train_ratio: f64,
211    seed: u64,
212) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
213    use std::collections::hash_map::DefaultHasher;
214    use std::hash::{Hash, Hasher};
215
216    let mut train = Vec::new();
217    let mut test = Vec::new();
218
219    for (i, example) in examples.iter().enumerate() {
220        let mut hasher = DefaultHasher::new();
221        (seed, i).hash(&mut hasher);
222        let hash = hasher.finish();
223        #[allow(clippy::cast_sign_loss)]
224        let threshold = (train_ratio * u64::MAX as f64) as u64;
225
226        if hash < threshold {
227            train.push(example.clone());
228        } else {
229            test.push(example.clone());
230        }
231    }
232
233    (train, test)
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    fn sample_examples(n: usize) -> Vec<TrainingExample> {
241        (0..n)
242            .map(|i| TrainingExample {
243                features: CodeFeatures {
244                    ast_depth: i % 5,
245                    cyclomatic_complexity: i % 10,
246                    ..Default::default()
247                },
248                is_bug: i % 3 == 0,
249            })
250            .collect()
251    }
252
253    #[test]
254    fn test_training_config_default() {
255        let config = TrainingConfig::default();
256        assert_eq!(config.train_ratio, 0.8);
257        assert_eq!(config.cv_folds, 5);
258        assert_eq!(config.min_examples, 100);
259    }
260
261    #[test]
262    fn test_training_metrics_f1() {
263        assert_eq!(TrainingMetrics::calculate_f1(0.8, 0.6), 0.6857142857142857);
264        assert_eq!(TrainingMetrics::calculate_f1(0.0, 0.0), 0.0);
265        assert_eq!(TrainingMetrics::calculate_f1(1.0, 1.0), 1.0);
266    }
267
268    #[test]
269    fn test_cross_validation_summarize() {
270        let folds = vec![
271            TrainingMetrics {
272                accuracy: 0.8,
273                f1_score: 0.75,
274                ..Default::default()
275            },
276            TrainingMetrics {
277                accuracy: 0.85,
278                f1_score: 0.80,
279                ..Default::default()
280            },
281            TrainingMetrics {
282                accuracy: 0.9,
283                f1_score: 0.85,
284                ..Default::default()
285            },
286        ];
287
288        let cv = CrossValidationResults::summarize(folds);
289        assert!((cv.mean_accuracy - 0.85).abs() < 0.001);
290        assert!((cv.mean_f1 - 0.8).abs() < 0.001);
291        assert!(cv.std_accuracy > 0.0);
292    }
293
294    #[test]
295    fn test_cross_validation_empty() {
296        let cv = CrossValidationResults::summarize(vec![]);
297        assert_eq!(cv.mean_accuracy, 0.0);
298        assert_eq!(cv.fold_metrics.len(), 0);
299    }
300
301    #[test]
302    fn test_verdict_to_label() {
303        assert!(!verdict_to_label(&TranspilerVerdict::Pass));
304        assert!(verdict_to_label(&TranspilerVerdict::OutputMismatch));
305        assert!(verdict_to_label(&TranspilerVerdict::TranspileError(
306            "err".into()
307        )));
308        assert!(verdict_to_label(&TranspilerVerdict::Timeout));
309    }
310
311    #[test]
312    fn test_train_test_split_ratio() {
313        let examples = sample_examples(1000);
314        let (train, test) = train_test_split(&examples, 0.8, 42);
315
316        // Should be approximately 80/20 split
317        let train_ratio = train.len() as f64 / examples.len() as f64;
318        assert!(train_ratio > 0.7 && train_ratio < 0.9);
319        assert_eq!(train.len() + test.len(), examples.len());
320    }
321
322    #[test]
323    fn test_train_test_split_deterministic() {
324        let examples = sample_examples(100);
325        let (train1, _) = train_test_split(&examples, 0.8, 42);
326        let (train2, _) = train_test_split(&examples, 0.8, 42);
327
328        assert_eq!(train1.len(), train2.len());
329    }
330
331    #[test]
332    fn test_training_error_display() {
333        let err = TrainingError::InsufficientData {
334            required: 100,
335            provided: 50,
336        };
337        assert!(err.to_string().contains("100"));
338        assert!(err.to_string().contains("50"));
339    }
340
341    #[test]
342    fn test_model_metadata_clone() {
343        let meta = ModelMetadata {
344            model_type: "RandomForest".into(),
345            trained_at: "2025-01-01".into(),
346            train_examples: 1000,
347            metrics: TrainingMetrics::default(),
348        };
349        let cloned = meta.clone();
350        assert_eq!(cloned.model_type, meta.model_type);
351    }
352
353    // RED PHASE: Tests that require aprender integration
354
355    #[test]
356    #[ignore = "requires aprender ml feature"]
357    fn test_random_forest_training() {
358        // TODO: Implement with aprender RandomForestClassifier
359        // let trainer = RandomForestTrainer::new();
360        // let examples = sample_examples(1000);
361        // let model = trainer.train(&examples, &TrainingConfig::default()).unwrap();
362        // assert!(model.predict(&CodeFeatures::default()) >= 0.0);
363        unimplemented!("RandomForest training not yet implemented")
364    }
365
366    #[test]
367    #[ignore = "requires aprender ml feature"]
368    fn test_cross_validation_with_model() {
369        // TODO: Implement CV with actual model
370        // let trainer = RandomForestTrainer::new();
371        // let examples = sample_examples(500);
372        // let cv = trainer.cross_validate(&examples, &TrainingConfig::default()).unwrap();
373        // assert_eq!(cv.fold_metrics.len(), 5);
374        // assert!(cv.mean_accuracy > 0.5);
375        unimplemented!("Cross-validation not yet implemented")
376    }
377
378    #[test]
379    #[ignore = "requires aprender ml feature"]
380    fn test_model_save_load() {
381        // TODO: Implement model persistence
382        // let trainer = RandomForestTrainer::new();
383        // let model = trainer.train(&examples, &config).unwrap();
384        // model.save(Path::new("/tmp/model.bin")).unwrap();
385        // let loaded = RandomForestTrainer::load(Path::new("/tmp/model.bin")).unwrap();
386        // assert_eq!(model.predict(&features), loaded.predict(&features));
387        unimplemented!("Model save/load not yet implemented")
388    }
389
390    #[test]
391    #[ignore = "requires aprender ml feature"]
392    fn test_stratified_split() {
393        // TODO: Implement stratified sampling
394        // let examples = sample_examples(1000);
395        // let (train, test) = stratified_split(&examples, 0.8, 42);
396        // let train_bug_ratio = train.iter().filter(|e| e.is_bug).count() as f64 / train.len() as f64;
397        // let test_bug_ratio = test.iter().filter(|e| e.is_bug).count() as f64 / test.len() as f64;
398        // assert!((train_bug_ratio - test_bug_ratio).abs() < 0.1);
399        unimplemented!("Stratified split not yet implemented")
400    }
401}