1use crate::data::CodeFeatures;
10use crate::Result;
11use rand::seq::SliceRandom;
12use rand::SeedableRng;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct ModelMetrics {
18 pub true_positives: usize,
20 pub true_negatives: usize,
22 pub false_positives: usize,
24 pub false_negatives: usize,
26 pub precision: f64,
28 pub recall: f64,
30 pub f1_score: f64,
32 pub accuracy: f64,
34 pub auc: f64,
36}
37
38impl ModelMetrics {
39 #[must_use]
41 pub fn compute(predictions: &[bool], ground_truth: &[bool]) -> Self {
42 let mut tp = 0;
43 let mut tn = 0;
44 let mut fp = 0;
45 let mut r#fn = 0;
46
47 for (pred, truth) in predictions.iter().zip(ground_truth.iter()) {
48 match (pred, truth) {
49 (true, true) => tp += 1,
50 (false, false) => tn += 1,
51 (true, false) => fp += 1,
52 (false, true) => r#fn += 1,
53 }
54 }
55
56 let precision = if tp + fp > 0 {
57 tp as f64 / (tp + fp) as f64
58 } else {
59 0.0
60 };
61
62 let recall = if tp + r#fn > 0 {
63 tp as f64 / (tp + r#fn) as f64
64 } else {
65 0.0
66 };
67
68 let f1_score = if precision + recall > 0.0 {
69 2.0 * (precision * recall) / (precision + recall)
70 } else {
71 0.0
72 };
73
74 let total = tp + tn + fp + r#fn;
75 let accuracy = if total > 0 {
76 (tp + tn) as f64 / total as f64
77 } else {
78 0.0
79 };
80
81 let tpr = recall;
83 let tnr = if tn + fp > 0 {
84 tn as f64 / (tn + fp) as f64
85 } else {
86 0.0
87 };
88 let auc = (tpr + tnr) / 2.0;
89
90 Self {
91 true_positives: tp,
92 true_negatives: tn,
93 false_positives: fp,
94 false_negatives: r#fn,
95 precision,
96 recall,
97 f1_score,
98 accuracy,
99 auc,
100 }
101 }
102
103 #[must_use]
105 pub fn average(metrics: &[ModelMetrics]) -> Self {
106 if metrics.is_empty() {
107 return Self::default();
108 }
109
110 let n = metrics.len() as f64;
111 Self {
112 true_positives: metrics.iter().map(|m| m.true_positives).sum::<usize>() / metrics.len(),
113 true_negatives: metrics.iter().map(|m| m.true_negatives).sum::<usize>() / metrics.len(),
114 false_positives: metrics.iter().map(|m| m.false_positives).sum::<usize>()
115 / metrics.len(),
116 false_negatives: metrics.iter().map(|m| m.false_negatives).sum::<usize>()
117 / metrics.len(),
118 precision: metrics.iter().map(|m| m.precision).sum::<f64>() / n,
119 recall: metrics.iter().map(|m| m.recall).sum::<f64>() / n,
120 f1_score: metrics.iter().map(|m| m.f1_score).sum::<f64>() / n,
121 accuracy: metrics.iter().map(|m| m.accuracy).sum::<f64>() / n,
122 auc: metrics.iter().map(|m| m.auc).sum::<f64>() / n,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct TrainingConfig {
130 pub train_ratio: f64,
132 pub cv_folds: usize,
134 pub seed: u64,
136 pub n_trees: usize,
138 pub max_depth: usize,
140}
141
142impl Default for TrainingConfig {
143 fn default() -> Self {
144 Self {
145 train_ratio: 0.8,
146 cv_folds: 5,
147 seed: 42,
148 n_trees: 100,
149 max_depth: 10,
150 }
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TrainingResult {
157 pub train_metrics: ModelMetrics,
159 pub test_metrics: ModelMetrics,
161 pub cv_metrics: Vec<ModelMetrics>,
163 pub cv_average: ModelMetrics,
165 pub train_samples: usize,
167 pub test_samples: usize,
169}
170
171#[derive(Debug)]
173pub struct ModelTrainer {
174 config: TrainingConfig,
175}
176
177impl ModelTrainer {
178 #[must_use]
180 pub fn new() -> Self {
181 Self {
182 config: TrainingConfig::default(),
183 }
184 }
185
186 #[must_use]
188 pub fn with_config(config: TrainingConfig) -> Self {
189 Self { config }
190 }
191
192 #[must_use]
194 pub fn train_ratio(mut self, ratio: f64) -> Self {
195 self.config.train_ratio = ratio.clamp(0.1, 0.99);
196 self
197 }
198
199 #[must_use]
201 pub fn cv_folds(mut self, folds: usize) -> Self {
202 self.config.cv_folds = folds.max(2);
203 self
204 }
205
206 #[must_use]
208 pub fn seed(mut self, seed: u64) -> Self {
209 self.config.seed = seed;
210 self
211 }
212
213 pub fn train_test_split(
215 &self,
216 features: &[CodeFeatures],
217 labels: &[bool],
218 ) -> (Vec<CodeFeatures>, Vec<bool>, Vec<CodeFeatures>, Vec<bool>) {
219 let mut rng = rand::rngs::StdRng::seed_from_u64(self.config.seed);
220
221 let positives: Vec<usize> = labels
223 .iter()
224 .enumerate()
225 .filter(|(_, &l)| l)
226 .map(|(i, _)| i)
227 .collect();
228 let negatives: Vec<usize> = labels
229 .iter()
230 .enumerate()
231 .filter(|(_, &l)| !l)
232 .map(|(i, _)| i)
233 .collect();
234
235 let mut pos_shuffled = positives.clone();
237 let mut neg_shuffled = negatives.clone();
238 pos_shuffled.shuffle(&mut rng);
239 neg_shuffled.shuffle(&mut rng);
240
241 #[allow(clippy::cast_sign_loss)]
243 let pos_split = (pos_shuffled.len() as f64 * self.config.train_ratio) as usize;
244 #[allow(clippy::cast_sign_loss)]
245 let neg_split = (neg_shuffled.len() as f64 * self.config.train_ratio) as usize;
246
247 let train_indices: Vec<usize> = pos_shuffled[..pos_split]
248 .iter()
249 .chain(neg_shuffled[..neg_split].iter())
250 .copied()
251 .collect();
252
253 let test_indices: Vec<usize> = pos_shuffled[pos_split..]
254 .iter()
255 .chain(neg_shuffled[neg_split..].iter())
256 .copied()
257 .collect();
258
259 let train_features: Vec<CodeFeatures> =
260 train_indices.iter().map(|&i| features[i].clone()).collect();
261 let train_labels: Vec<bool> = train_indices.iter().map(|&i| labels[i]).collect();
262 let test_features: Vec<CodeFeatures> =
263 test_indices.iter().map(|&i| features[i].clone()).collect();
264 let test_labels: Vec<bool> = test_indices.iter().map(|&i| labels[i]).collect();
265
266 (train_features, train_labels, test_features, test_labels)
267 }
268
269 pub fn cross_validate(
275 &self,
276 features: &[CodeFeatures],
277 labels: &[bool],
278 ) -> Result<Vec<ModelMetrics>> {
279 let mut rng = rand::rngs::StdRng::seed_from_u64(self.config.seed);
280 let n = features.len();
281 let fold_size = n / self.config.cv_folds;
282
283 let mut indices: Vec<usize> = (0..n).collect();
285 indices.shuffle(&mut rng);
286
287 let mut metrics = Vec::with_capacity(self.config.cv_folds);
288
289 for fold in 0..self.config.cv_folds {
290 let start = fold * fold_size;
291 let end = if fold == self.config.cv_folds - 1 {
292 n
293 } else {
294 start + fold_size
295 };
296
297 let test_indices: Vec<usize> = indices[start..end].to_vec();
299
300 let train_indices: Vec<usize> = indices[..start]
302 .iter()
303 .chain(indices[end..].iter())
304 .copied()
305 .collect();
306
307 let train_features: Vec<CodeFeatures> =
308 train_indices.iter().map(|&i| features[i].clone()).collect();
309 let train_labels: Vec<bool> = train_indices.iter().map(|&i| labels[i]).collect();
310 let test_features: Vec<CodeFeatures> =
311 test_indices.iter().map(|&i| features[i].clone()).collect();
312 let test_labels: Vec<bool> = test_indices.iter().map(|&i| labels[i]).collect();
313
314 let fold_metrics = self.train_and_evaluate(
316 &train_features,
317 &train_labels,
318 &test_features,
319 &test_labels,
320 )?;
321 metrics.push(fold_metrics);
322 }
323
324 Ok(metrics)
325 }
326
327 fn train_and_evaluate(
329 &self,
330 _train_features: &[CodeFeatures],
331 _train_labels: &[bool],
332 test_features: &[CodeFeatures],
333 test_labels: &[bool],
334 ) -> Result<ModelMetrics> {
335 let predictor = super::BugPredictor::new();
338
339 let predictions: Vec<bool> = test_features
340 .iter()
341 .map(|f| predictor.predict(f) > 0.5)
342 .collect();
343
344 Ok(ModelMetrics::compute(&predictions, test_labels))
345 }
346
347 pub fn train(&self, features: &[CodeFeatures], labels: &[bool]) -> Result<TrainingResult> {
353 let (train_features, train_labels, test_features, test_labels) =
355 self.train_test_split(features, labels);
356
357 let cv_metrics = self.cross_validate(&train_features, &train_labels)?;
359 let cv_average = ModelMetrics::average(&cv_metrics);
360
361 let train_metrics = self.train_and_evaluate(
363 &train_features,
364 &train_labels,
365 &train_features,
366 &train_labels,
367 )?;
368 let test_metrics =
369 self.train_and_evaluate(&train_features, &train_labels, &test_features, &test_labels)?;
370
371 Ok(TrainingResult {
372 train_metrics,
373 test_metrics,
374 cv_metrics,
375 cv_average,
376 train_samples: train_features.len(),
377 test_samples: test_features.len(),
378 })
379 }
380}
381
382impl Default for ModelTrainer {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct SerializedModel {
391 pub version: String,
393 pub config: TrainingConfig,
395 pub metrics: ModelMetrics,
397 pub weights: Vec<f64>,
399}
400
401impl SerializedModel {
402 pub fn save(&self, path: &str) -> Result<()> {
408 let json = serde_json::to_string_pretty(self)
409 .map_err(|e| crate::Error::Data(format!("Serialization failed: {e}")))?;
410 std::fs::write(path, json)
411 .map_err(|e| crate::Error::Data(format!("Failed to write file: {e}")))?;
412 Ok(())
413 }
414
415 pub fn load(path: &str) -> Result<Self> {
421 let json = std::fs::read_to_string(path)
422 .map_err(|e| crate::Error::Data(format!("Failed to read file: {e}")))?;
423 let model: Self = serde_json::from_str(&json)
424 .map_err(|e| crate::Error::Data(format!("Deserialization failed: {e}")))?;
425 Ok(model)
426 }
427}
428
429impl Serialize for TrainingConfig {
431 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
432 where
433 S: serde::Serializer,
434 {
435 use serde::ser::SerializeStruct;
436 let mut state = serializer.serialize_struct("TrainingConfig", 5)?;
437 state.serialize_field("train_ratio", &self.train_ratio)?;
438 state.serialize_field("cv_folds", &self.cv_folds)?;
439 state.serialize_field("seed", &self.seed)?;
440 state.serialize_field("n_trees", &self.n_trees)?;
441 state.serialize_field("max_depth", &self.max_depth)?;
442 state.end()
443 }
444}
445
446impl<'de> Deserialize<'de> for TrainingConfig {
447 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
448 where
449 D: serde::Deserializer<'de>,
450 {
451 #[derive(Deserialize)]
452 struct Helper {
453 train_ratio: f64,
454 cv_folds: usize,
455 seed: u64,
456 n_trees: usize,
457 max_depth: usize,
458 }
459
460 let helper = Helper::deserialize(deserializer)?;
461 Ok(Self {
462 train_ratio: helper.train_ratio,
463 cv_folds: helper.cv_folds,
464 seed: helper.seed,
465 n_trees: helper.n_trees,
466 max_depth: helper.max_depth,
467 })
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 fn sample_data() -> (Vec<CodeFeatures>, Vec<bool>) {
476 let features: Vec<CodeFeatures> = (0..100)
477 .map(|i| CodeFeatures {
478 ast_depth: (i % 10) as u32,
479 num_operators: (i % 20) as u32,
480 num_control_flow: (i % 5) as u32,
481 cyclomatic_complexity: (i % 15) as f32,
482 uses_edge_values: i % 3 == 0,
483 ..Default::default()
484 })
485 .collect();
486 let labels: Vec<bool> = (0..100).map(|i| i % 4 == 0).collect();
487 (features, labels)
488 }
489
490 #[test]
491 fn test_model_metrics_compute() {
492 let predictions = vec![true, true, false, false, true];
493 let ground_truth = vec![true, false, false, true, true];
494
495 let metrics = ModelMetrics::compute(&predictions, &ground_truth);
496
497 assert_eq!(metrics.true_positives, 2);
498 assert_eq!(metrics.true_negatives, 1);
499 assert_eq!(metrics.false_positives, 1);
500 assert_eq!(metrics.false_negatives, 1);
501 assert!((metrics.precision - 0.666).abs() < 0.01);
502 assert!((metrics.recall - 0.666).abs() < 0.01);
503 }
504
505 #[test]
506 fn test_model_metrics_perfect() {
507 let predictions = vec![true, false, true, false];
508 let ground_truth = vec![true, false, true, false];
509
510 let metrics = ModelMetrics::compute(&predictions, &ground_truth);
511
512 assert!((metrics.precision - 1.0).abs() < f64::EPSILON);
513 assert!((metrics.recall - 1.0).abs() < f64::EPSILON);
514 assert!((metrics.f1_score - 1.0).abs() < f64::EPSILON);
515 assert!((metrics.accuracy - 1.0).abs() < f64::EPSILON);
516 }
517
518 #[test]
519 fn test_model_metrics_average() {
520 let metrics = vec![
521 ModelMetrics {
522 precision: 0.8,
523 recall: 0.7,
524 f1_score: 0.75,
525 accuracy: 0.85,
526 auc: 0.9,
527 ..Default::default()
528 },
529 ModelMetrics {
530 precision: 0.6,
531 recall: 0.9,
532 f1_score: 0.72,
533 accuracy: 0.75,
534 auc: 0.8,
535 ..Default::default()
536 },
537 ];
538
539 let avg = ModelMetrics::average(&metrics);
540
541 assert!((avg.precision - 0.7).abs() < f64::EPSILON);
542 assert!((avg.recall - 0.8).abs() < f64::EPSILON);
543 }
544
545 #[test]
546 fn test_training_config_default() {
547 let config = TrainingConfig::default();
548 assert!((config.train_ratio - 0.8).abs() < f64::EPSILON);
549 assert_eq!(config.cv_folds, 5);
550 assert_eq!(config.seed, 42);
551 }
552
553 #[test]
554 fn test_trainer_new() {
555 let trainer = ModelTrainer::new();
556 assert!((trainer.config.train_ratio - 0.8).abs() < f64::EPSILON);
557 }
558
559 #[test]
560 fn test_trainer_builder() {
561 let trainer = ModelTrainer::new().train_ratio(0.7).cv_folds(10).seed(123);
562
563 assert!((trainer.config.train_ratio - 0.7).abs() < f64::EPSILON);
564 assert_eq!(trainer.config.cv_folds, 10);
565 assert_eq!(trainer.config.seed, 123);
566 }
567
568 #[test]
569 fn test_train_test_split() {
570 let (features, labels) = sample_data();
571 let trainer = ModelTrainer::new();
572
573 let (train_f, train_l, test_f, test_l) = trainer.train_test_split(&features, &labels);
574
575 let total = features.len();
577 let train_expected = (total as f64 * 0.8) as usize;
578 assert!(train_f.len() >= train_expected - 5 && train_f.len() <= train_expected + 5);
579 assert_eq!(train_f.len(), train_l.len());
580 assert_eq!(test_f.len(), test_l.len());
581 }
582
583 #[test]
584 fn test_cross_validate() {
585 let (features, labels) = sample_data();
586 let trainer = ModelTrainer::new().cv_folds(5);
587
588 let cv_metrics = trainer.cross_validate(&features, &labels).unwrap();
589
590 assert_eq!(cv_metrics.len(), 5);
591 for m in &cv_metrics {
592 assert!((0.0..=1.0).contains(&m.accuracy));
593 }
594 }
595
596 #[test]
597 fn test_train_full_pipeline() {
598 let (features, labels) = sample_data();
599 let trainer = ModelTrainer::new();
600
601 let result = trainer.train(&features, &labels).unwrap();
602
603 assert!(result.train_samples > 0);
604 assert!(result.test_samples > 0);
605 assert_eq!(result.cv_metrics.len(), 5);
606 assert!((0.0..=1.0).contains(&result.test_metrics.accuracy));
607 }
608
609 #[test]
610 fn test_serialized_model() {
611 let model = SerializedModel {
612 version: "0.1.0".to_string(),
613 config: TrainingConfig::default(),
614 metrics: ModelMetrics::default(),
615 weights: vec![0.1, 0.2, 0.3],
616 };
617
618 let json = serde_json::to_string(&model).unwrap();
619 let loaded: SerializedModel = serde_json::from_str(&json).unwrap();
620
621 assert_eq!(loaded.version, "0.1.0");
622 assert_eq!(loaded.weights.len(), 3);
623 }
624
625 #[test]
626 fn test_training_result_serialize() {
627 let result = TrainingResult {
628 train_metrics: ModelMetrics::default(),
629 test_metrics: ModelMetrics::default(),
630 cv_metrics: vec![ModelMetrics::default()],
631 cv_average: ModelMetrics::default(),
632 train_samples: 80,
633 test_samples: 20,
634 };
635
636 let json = serde_json::to_string(&result).unwrap();
637 assert!(json.contains("train_samples"));
638 }
639
640 #[test]
641 fn test_model_metrics_empty() {
642 let metrics = ModelMetrics::compute(&[], &[]);
643 assert_eq!(metrics.true_positives, 0);
644 assert!((metrics.accuracy - 0.0).abs() < f64::EPSILON);
645 }
646
647 #[test]
648 fn test_model_metrics_all_negative() {
649 let predictions = vec![false, false, false];
650 let ground_truth = vec![false, false, false];
651
652 let metrics = ModelMetrics::compute(&predictions, &ground_truth);
653
654 assert_eq!(metrics.true_negatives, 3);
655 assert!((metrics.accuracy - 1.0).abs() < f64::EPSILON);
656 }
657
658 #[test]
659 fn test_trainer_ratio_clamp() {
660 let trainer = ModelTrainer::new().train_ratio(0.05);
661 assert!((trainer.config.train_ratio - 0.1).abs() < f64::EPSILON);
662
663 let trainer = ModelTrainer::new().train_ratio(1.5);
664 assert!((trainer.config.train_ratio - 0.99).abs() < f64::EPSILON);
665 }
666
667 #[test]
668 fn test_trainer_cv_folds_min() {
669 let trainer = ModelTrainer::new().cv_folds(1);
670 assert_eq!(trainer.config.cv_folds, 2);
671 }
672
673 #[test]
674 fn test_model_metrics_auc_calculation() {
675 let mut predictions = Vec::new();
680 let mut ground_truth = Vec::new();
681
682 for _ in 0..10 {
684 predictions.push(true);
685 ground_truth.push(true);
686 }
687 for _ in 0..20 {
689 predictions.push(false);
690 ground_truth.push(false);
691 }
692 for _ in 0..5 {
694 predictions.push(true);
695 ground_truth.push(false);
696 }
697 for _ in 0..3 {
699 predictions.push(false);
700 ground_truth.push(true);
701 }
702
703 let metrics = ModelMetrics::compute(&predictions, &ground_truth);
704
705 assert!(
709 metrics.auc > 0.7,
710 "AUC should be > 0.7, got {}",
711 metrics.auc
712 );
713 assert!(
714 metrics.auc < 0.85,
715 "AUC should be < 0.85, got {}",
716 metrics.auc
717 );
718
719 }
722
723 #[test]
724 fn test_model_metrics_average_fp_fn() {
725 let metrics = vec![
728 ModelMetrics {
729 false_positives: 10,
730 false_negatives: 20,
731 ..Default::default()
732 },
733 ModelMetrics {
734 false_positives: 30,
735 false_negatives: 40,
736 ..Default::default()
737 },
738 ];
739
740 let avg = ModelMetrics::average(&metrics);
741
742 assert_eq!(avg.false_positives, 20);
744 assert_eq!(avg.false_negatives, 30);
746 }
747
748 #[test]
749 fn test_trainer_with_config() {
750 let config = TrainingConfig {
753 train_ratio: 0.6,
754 cv_folds: 10,
755 seed: 12345,
756 n_trees: 50,
757 max_depth: 5,
758 };
759 let trainer = ModelTrainer::with_config(config);
760
761 assert!((trainer.config.train_ratio - 0.6).abs() < f64::EPSILON);
763 assert_eq!(trainer.config.cv_folds, 10);
764 assert_eq!(trainer.config.seed, 12345);
765 }
766}