1use crate::data::CodeFeatures;
11use serde::{Deserialize, Serialize};
12use std::time::Instant;
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct ConfusionMatrix {
17 pub tp: usize,
19 pub tn: usize,
21 pub fp: usize,
23 pub r#fn: usize,
25}
26
27impl ConfusionMatrix {
28 #[must_use]
30 pub fn from_predictions(predictions: &[bool], ground_truth: &[bool]) -> Self {
31 let mut matrix = Self::default();
32
33 for (pred, truth) in predictions.iter().zip(ground_truth.iter()) {
34 match (pred, truth) {
35 (true, true) => matrix.tp += 1,
36 (false, false) => matrix.tn += 1,
37 (true, false) => matrix.fp += 1,
38 (false, true) => matrix.r#fn += 1,
39 }
40 }
41
42 matrix
43 }
44
45 #[must_use]
47 pub fn total(&self) -> usize {
48 self.tp + self.tn + self.fp + self.r#fn
49 }
50
51 #[must_use]
53 pub fn accuracy(&self) -> f64 {
54 let total = self.total();
55 if total == 0 {
56 return 0.0;
57 }
58 (self.tp + self.tn) as f64 / total as f64
59 }
60
61 #[must_use]
63 pub fn precision(&self) -> f64 {
64 let denom = self.tp + self.fp;
65 if denom == 0 {
66 return 0.0;
67 }
68 self.tp as f64 / denom as f64
69 }
70
71 #[must_use]
73 pub fn recall(&self) -> f64 {
74 let denom = self.tp + self.r#fn;
75 if denom == 0 {
76 return 0.0;
77 }
78 self.tp as f64 / denom as f64
79 }
80
81 #[must_use]
83 pub fn specificity(&self) -> f64 {
84 let denom = self.tn + self.fp;
85 if denom == 0 {
86 return 0.0;
87 }
88 self.tn as f64 / denom as f64
89 }
90
91 #[must_use]
93 pub fn f1_score(&self) -> f64 {
94 let precision = self.precision();
95 let recall = self.recall();
96 if precision + recall == 0.0 {
97 return 0.0;
98 }
99 2.0 * (precision * recall) / (precision + recall)
100 }
101
102 #[must_use]
104 pub fn to_ascii(&self) -> String {
105 format!(
106 r"
107Confusion Matrix
108================
109 Predicted
110 Pos Neg
111Actual Pos {:>5} {:>5} (TP, FN)
112Actual Neg {:>5} {:>5} (FP, TN)
113
114Accuracy: {:.3}
115Precision: {:.3}
116Recall: {:.3}
117F1 Score: {:.3}
118",
119 self.tp,
120 self.r#fn,
121 self.fp,
122 self.tn,
123 self.accuracy(),
124 self.precision(),
125 self.recall(),
126 self.f1_score()
127 )
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct RocPoint {
134 pub threshold: f64,
136 pub tpr: f64,
138 pub fpr: f64,
140}
141
142#[derive(Debug, Clone, Default, Serialize, Deserialize)]
144pub struct RocCurve {
145 pub points: Vec<RocPoint>,
147 pub auc: f64,
149}
150
151impl RocCurve {
152 #[must_use]
154 pub fn from_scores(scores: &[f64], ground_truth: &[bool]) -> Self {
155 if scores.is_empty() || scores.len() != ground_truth.len() {
156 return Self::default();
157 }
158
159 let mut indexed: Vec<(f64, bool)> = scores
161 .iter()
162 .zip(ground_truth.iter())
163 .map(|(&s, &t)| (s, t))
164 .collect();
165 indexed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
166
167 let total_positives = ground_truth.iter().filter(|&&t| t).count();
168 let total_negatives = ground_truth.len() - total_positives;
169
170 if total_positives == 0 || total_negatives == 0 {
171 return Self::default();
172 }
173
174 let mut points = Vec::new();
175 let mut tp = 0;
176 let mut fp = 0;
177
178 points.push(RocPoint {
180 threshold: 1.0,
181 tpr: 0.0,
182 fpr: 0.0,
183 });
184
185 let mut prev_score = f64::INFINITY;
186
187 for (score, is_positive) in &indexed {
188 #[allow(clippy::float_cmp)]
190 if *score != prev_score {
191 let tpr = f64::from(tp) / total_positives as f64;
192 let fpr = f64::from(fp) / total_negatives as f64;
193 points.push(RocPoint {
194 threshold: *score,
195 tpr,
196 fpr,
197 });
198 prev_score = *score;
199 }
200
201 if *is_positive {
202 tp += 1;
203 } else {
204 fp += 1;
205 }
206 }
207
208 points.push(RocPoint {
210 threshold: 0.0,
211 tpr: 1.0,
212 fpr: 1.0,
213 });
214
215 let auc = Self::calculate_auc(&points);
217
218 Self { points, auc }
219 }
220
221 fn calculate_auc(points: &[RocPoint]) -> f64 {
223 let mut auc = 0.0;
224
225 for i in 1..points.len() {
226 let width = points[i].fpr - points[i - 1].fpr;
227 let height = (points[i].tpr + points[i - 1].tpr) / 2.0;
228 auc += width * height;
229 }
230
231 auc.abs()
232 }
233
234 #[must_use]
236 pub fn to_ascii(&self) -> String {
237 use std::fmt::Write;
238
239 let mut output = String::new();
240 output.push_str("ROC Curve\n");
241 output.push_str("=========\n");
242 let _ = writeln!(output, "AUC: {:.4}\n", self.auc);
243
244 let grid_size = 10;
246 let mut grid = vec![vec!['.'; grid_size]; grid_size];
247
248 #[allow(clippy::cast_sign_loss)]
250 for point in &self.points {
251 let x = (point.fpr * (grid_size - 1) as f64).round() as usize;
252 let y = ((1.0 - point.tpr) * (grid_size - 1) as f64).round() as usize;
253 if x < grid_size && y < grid_size {
254 grid[y][x] = '*';
255 }
256 }
257
258 for (i, row) in grid.iter_mut().enumerate() {
260 if row[i] == '.' {
261 row[i] = '-';
262 }
263 }
264
265 output.push_str("TPR\n");
266 output.push_str("1.0 |");
267 for row in &grid {
268 output.push_str(&row.iter().collect::<String>());
269 output.push_str("|\n |");
270 }
271 output.push_str(&"-".repeat(grid_size));
272 output.push_str("| FPR\n 0");
273 output.push_str(&" ".repeat(grid_size - 2));
274 output.push_str("1.0\n");
275
276 output
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct FeatureImportance {
283 pub name: String,
285 pub importance: f64,
287}
288
289pub fn calculate_feature_importance(
291 features: &[CodeFeatures],
292 labels: &[bool],
293 predictor: &dyn Fn(&CodeFeatures) -> f64,
294) -> Vec<FeatureImportance> {
295 let baseline_score = calculate_accuracy(features, labels, predictor);
296 let feature_names = [
297 "ast_depth",
298 "num_operators",
299 "num_control_flow",
300 "cyclomatic_complexity",
301 "uses_edge_values",
302 ];
303
304 let mut importances = Vec::new();
305
306 for (idx, name) in feature_names.iter().enumerate() {
307 let permuted_features: Vec<CodeFeatures> = features
309 .iter()
310 .enumerate()
311 .map(|(i, f)| {
312 let mut permuted = f.clone();
313 let swap_idx = (i + 1) % features.len();
314 match idx {
315 0 => permuted.ast_depth = features[swap_idx].ast_depth,
316 1 => permuted.num_operators = features[swap_idx].num_operators,
317 2 => permuted.num_control_flow = features[swap_idx].num_control_flow,
318 3 => permuted.cyclomatic_complexity = features[swap_idx].cyclomatic_complexity,
319 4 => permuted.uses_edge_values = features[swap_idx].uses_edge_values,
320 _ => {}
321 }
322 permuted
323 })
324 .collect();
325
326 let permuted_score = calculate_accuracy(&permuted_features, labels, predictor);
327 let importance = (baseline_score - permuted_score).max(0.0);
328
329 importances.push(FeatureImportance {
330 name: (*name).to_string(),
331 importance,
332 });
333 }
334
335 let total: f64 = importances.iter().map(|f| f.importance).sum();
337 if total > 0.0 {
338 for f in &mut importances {
339 f.importance /= total;
340 }
341 }
342
343 importances.sort_by(|a, b| {
345 b.importance
346 .partial_cmp(&a.importance)
347 .unwrap_or(std::cmp::Ordering::Equal)
348 });
349
350 importances
351}
352
353fn calculate_accuracy(
354 features: &[CodeFeatures],
355 labels: &[bool],
356 predictor: &dyn Fn(&CodeFeatures) -> f64,
357) -> f64 {
358 let correct: usize = features
359 .iter()
360 .zip(labels.iter())
361 .map(|(f, &l)| {
362 let pred = predictor(f) > 0.5;
363 usize::from(pred == l)
364 })
365 .sum();
366
367 if features.is_empty() {
368 return 0.0;
369 }
370 correct as f64 / features.len() as f64
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct BenchmarkResult {
376 pub num_predictions: usize,
378 pub total_time_ms: f64,
380 pub predictions_per_sec: f64,
382 pub avg_latency_us: f64,
384}
385
386pub fn benchmark_inference<F>(
388 predictor: F,
389 features: &[CodeFeatures],
390 iterations: usize,
391) -> BenchmarkResult
392where
393 F: Fn(&CodeFeatures) -> f64,
394{
395 let start = Instant::now();
396
397 for _ in 0..iterations {
398 for f in features {
399 let _ = predictor(f);
400 }
401 }
402
403 let elapsed = start.elapsed();
404 let total_time_ms = elapsed.as_secs_f64() * 1000.0;
405 let num_predictions = iterations * features.len();
406
407 let predictions_per_sec = if total_time_ms > 0.0 {
408 num_predictions as f64 / (total_time_ms / 1000.0)
409 } else {
410 0.0
411 };
412
413 let avg_latency_us = if num_predictions > 0 {
414 (total_time_ms * 1000.0) / num_predictions as f64
415 } else {
416 0.0
417 };
418
419 BenchmarkResult {
420 num_predictions,
421 total_time_ms,
422 predictions_per_sec,
423 avg_latency_us,
424 }
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
429pub struct ModelComparison {
430 pub baseline: ComparisonMetrics,
432 pub trained: ComparisonMetrics,
434 pub accuracy_improvement: f64,
436 pub f1_improvement: f64,
438 pub speedup: f64,
440}
441
442#[derive(Debug, Clone, Default, Serialize, Deserialize)]
444pub struct ComparisonMetrics {
445 pub name: String,
447 pub accuracy: f64,
449 pub f1_score: f64,
451 pub predictions_per_sec: f64,
453}
454
455impl ModelComparison {
456 #[must_use]
458 pub fn compare(baseline: ComparisonMetrics, trained: ComparisonMetrics) -> Self {
459 let accuracy_improvement = trained.accuracy - baseline.accuracy;
460 let f1_improvement = trained.f1_score - baseline.f1_score;
461 let speedup = if baseline.predictions_per_sec > 0.0 {
462 trained.predictions_per_sec / baseline.predictions_per_sec
463 } else {
464 1.0
465 };
466
467 Self {
468 baseline,
469 trained,
470 accuracy_improvement,
471 f1_improvement,
472 speedup,
473 }
474 }
475
476 #[must_use]
478 pub fn to_ascii(&self) -> String {
479 format!(
480 r"
481Model Comparison
482================
483 Baseline Trained Delta
484Name {:<12} {:<12}
485Accuracy {:<12.4} {:<12.4} {:+.4}
486F1 Score {:<12.4} {:<12.4} {:+.4}
487Pred/sec {:<12.0} {:<12.0} {:.2}x
488",
489 self.baseline.name,
490 self.trained.name,
491 self.baseline.accuracy,
492 self.trained.accuracy,
493 self.accuracy_improvement,
494 self.baseline.f1_score,
495 self.trained.f1_score,
496 self.f1_improvement,
497 self.baseline.predictions_per_sec,
498 self.trained.predictions_per_sec,
499 self.speedup
500 )
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_confusion_matrix_from_predictions() {
510 let predictions = vec![true, true, false, false, true];
511 let ground_truth = vec![true, false, false, true, true];
512
513 let matrix = ConfusionMatrix::from_predictions(&predictions, &ground_truth);
514
515 assert_eq!(matrix.tp, 2);
516 assert_eq!(matrix.tn, 1);
517 assert_eq!(matrix.fp, 1);
518 assert_eq!(matrix.r#fn, 1);
519 }
520
521 #[test]
522 fn test_confusion_matrix_metrics() {
523 let matrix = ConfusionMatrix {
524 tp: 50,
525 tn: 40,
526 fp: 10,
527 r#fn: 0,
528 };
529
530 assert!((matrix.accuracy() - 0.9).abs() < 0.001);
531 assert!((matrix.precision() - 0.833).abs() < 0.01);
532 assert!((matrix.recall() - 1.0).abs() < 0.001);
533 }
534
535 #[test]
536 fn test_confusion_matrix_perfect() {
537 let predictions = vec![true, false, true, false];
538 let ground_truth = vec![true, false, true, false];
539
540 let matrix = ConfusionMatrix::from_predictions(&predictions, &ground_truth);
541
542 assert!((matrix.accuracy() - 1.0).abs() < f64::EPSILON);
543 assert!((matrix.precision() - 1.0).abs() < f64::EPSILON);
544 assert!((matrix.recall() - 1.0).abs() < f64::EPSILON);
545 assert!((matrix.f1_score() - 1.0).abs() < f64::EPSILON);
546 }
547
548 #[test]
549 fn test_confusion_matrix_to_ascii() {
550 let matrix = ConfusionMatrix {
551 tp: 10,
552 tn: 20,
553 fp: 5,
554 r#fn: 3,
555 };
556
557 let ascii = matrix.to_ascii();
558 assert!(ascii.contains("Confusion Matrix"));
559 assert!(ascii.contains("10"));
560 assert!(ascii.contains("Accuracy"));
561 }
562
563 #[test]
564 fn test_roc_curve_from_scores() {
565 let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0];
566 let ground_truth = vec![
567 true, true, true, true, true, false, false, false, false, false,
568 ];
569
570 let roc = RocCurve::from_scores(&scores, &ground_truth);
571
572 assert!(roc.auc > 0.9); assert!(!roc.points.is_empty());
574 }
575
576 #[test]
577 fn test_roc_curve_random() {
578 let scores = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
580 let ground_truth = vec![true, false, true, false, true, false, true, false];
581
582 let roc = RocCurve::from_scores(&scores, &ground_truth);
583
584 assert!(roc.auc >= 0.0 && roc.auc <= 1.0);
586 }
587
588 #[test]
589 fn test_roc_curve_empty() {
590 let roc = RocCurve::from_scores(&[], &[]);
591 assert!((roc.auc - 0.0).abs() < f64::EPSILON);
592 }
593
594 #[test]
595 fn test_roc_curve_to_ascii() {
596 let scores = vec![0.9, 0.8, 0.3, 0.2];
597 let ground_truth = vec![true, true, false, false];
598
599 let roc = RocCurve::from_scores(&scores, &ground_truth);
600 let ascii = roc.to_ascii();
601
602 assert!(ascii.contains("ROC Curve"));
603 assert!(ascii.contains("AUC"));
604 }
605
606 #[test]
607 fn test_feature_importance() {
608 let features: Vec<CodeFeatures> = (0..100)
609 .map(|i| CodeFeatures {
610 ast_depth: (i % 10) as u32,
611 num_operators: (i % 20) as u32,
612 num_control_flow: (i % 5) as u32,
613 cyclomatic_complexity: (i % 15) as f32,
614 uses_edge_values: i % 3 == 0,
615 ..Default::default()
616 })
617 .collect();
618 let labels: Vec<bool> = (0..100).map(|i| i % 4 == 0).collect();
619
620 let predictor = |f: &CodeFeatures| f.ast_depth as f64 * 0.1;
621 let importance = calculate_feature_importance(&features, &labels, &predictor);
622
623 assert_eq!(importance.len(), 5);
624 let total: f64 = importance.iter().map(|f| f.importance).sum();
626 assert!(total <= 1.1);
627 }
628
629 #[test]
630 fn test_benchmark_inference() {
631 let features: Vec<CodeFeatures> = (0..100).map(|_| CodeFeatures::default()).collect();
632
633 let predictor = |_: &CodeFeatures| 0.5;
634 let result = benchmark_inference(predictor, &features, 10);
635
636 assert_eq!(result.num_predictions, 1000);
637 assert!(result.total_time_ms > 0.0);
638 assert!(result.predictions_per_sec > 0.0);
639 }
640
641 #[test]
642 fn test_model_comparison() {
643 let baseline = ComparisonMetrics {
644 name: "Baseline".to_string(),
645 accuracy: 0.7,
646 f1_score: 0.65,
647 predictions_per_sec: 10000.0,
648 };
649
650 let trained = ComparisonMetrics {
651 name: "Trained".to_string(),
652 accuracy: 0.85,
653 f1_score: 0.82,
654 predictions_per_sec: 8000.0,
655 };
656
657 let comparison = ModelComparison::compare(baseline, trained);
658
659 assert!((comparison.accuracy_improvement - 0.15).abs() < 0.001);
660 assert!((comparison.f1_improvement - 0.17).abs() < 0.001);
661 assert!(comparison.speedup < 1.0); }
663
664 #[test]
665 fn test_model_comparison_to_ascii() {
666 let baseline = ComparisonMetrics {
667 name: "Baseline".to_string(),
668 accuracy: 0.7,
669 f1_score: 0.65,
670 predictions_per_sec: 10000.0,
671 };
672
673 let trained = ComparisonMetrics {
674 name: "Trained".to_string(),
675 accuracy: 0.85,
676 f1_score: 0.82,
677 predictions_per_sec: 15000.0,
678 };
679
680 let comparison = ModelComparison::compare(baseline, trained);
681 let ascii = comparison.to_ascii();
682
683 assert!(ascii.contains("Model Comparison"));
684 assert!(ascii.contains("Baseline"));
685 assert!(ascii.contains("Trained"));
686 }
687
688 #[test]
689 fn test_confusion_matrix_empty() {
690 let matrix = ConfusionMatrix::from_predictions(&[], &[]);
691 assert_eq!(matrix.total(), 0);
692 assert!((matrix.accuracy() - 0.0).abs() < f64::EPSILON);
693 }
694
695 #[test]
696 fn test_confusion_matrix_specificity() {
697 let matrix = ConfusionMatrix {
698 tp: 10,
699 tn: 80,
700 fp: 20,
701 r#fn: 10,
702 };
703
704 assert!((matrix.specificity() - 0.8).abs() < 0.001);
705 }
706
707 #[test]
708 fn test_roc_point_debug() {
709 let point = RocPoint {
710 threshold: 0.5,
711 tpr: 0.8,
712 fpr: 0.2,
713 };
714 let debug = format!("{:?}", point);
715 assert!(debug.contains("RocPoint"));
716 }
717
718 #[test]
719 fn test_feature_importance_serialize() {
720 let fi = FeatureImportance {
721 name: "test".to_string(),
722 importance: 0.5,
723 };
724 let json = serde_json::to_string(&fi).unwrap();
725 assert!(json.contains("test"));
726 }
727
728 #[test]
729 fn test_benchmark_result_serialize() {
730 let result = BenchmarkResult {
731 num_predictions: 1000,
732 total_time_ms: 100.0,
733 predictions_per_sec: 10000.0,
734 avg_latency_us: 100.0,
735 };
736 let json = serde_json::to_string(&result).unwrap();
737 assert!(json.contains("num_predictions"));
738 }
739
740 #[test]
741 fn test_benchmark_predictions_per_sec_formula() {
742 let features: Vec<CodeFeatures> = (0..10).map(|_| CodeFeatures::default()).collect();
745 let predictor = |_: &CodeFeatures| 0.5;
746
747 let result = benchmark_inference(predictor, &features, 10);
749
750 assert_eq!(result.num_predictions, 100);
751 assert!(
756 result.predictions_per_sec > 1000.0,
757 "predictions_per_sec should be > 1000, got {}",
758 result.predictions_per_sec
759 );
760 }
761
762 #[test]
763 fn test_roc_curve_grid_positions() {
764 let scores = vec![0.9, 0.1];
767 let ground_truth = vec![true, false];
768
769 let roc = RocCurve::from_scores(&scores, &ground_truth);
770 let ascii = roc.to_ascii();
771
772 assert!(ascii.contains('*'), "ROC plot should contain star markers");
775 assert!(
777 ascii.contains('-'),
778 "ROC plot should contain diagonal markers"
779 );
780 }
781}