1use serde::{Deserialize, Serialize};
20
21use crate::ml::CommitFeatures;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
25pub enum DefectCategory {
26 AstTransform,
28 OwnershipBorrow,
30 StdlibMapping,
32 Other,
34}
35
36impl Default for DefectCategory {
37 fn default() -> Self {
38 Self::Other
39 }
40}
41
42impl DefectCategory {
43 #[must_use]
45 pub fn weight(&self) -> f32 {
46 match self {
47 Self::AstTransform => 2.0,
48 Self::OwnershipBorrow => 1.5,
49 Self::StdlibMapping => 1.2,
50 Self::Other => 1.0,
51 }
52 }
53
54 #[must_use]
56 pub fn all() -> &'static [Self] {
57 &[
58 Self::AstTransform,
59 Self::OwnershipBorrow,
60 Self::StdlibMapping,
61 Self::Other,
62 ]
63 }
64
65 #[must_use]
67 pub fn classify(code: &str) -> Self {
68 if code.contains("ast")
70 || code.contains("node")
71 || code.contains("parse")
72 || code.contains("transform")
73 || code.contains("visitor")
74 {
75 return Self::AstTransform;
76 }
77
78 if code.contains("borrow")
80 || code.contains("lifetime")
81 || code.contains("move")
82 || code.contains("&mut")
83 || code.contains("owned")
84 {
85 return Self::OwnershipBorrow;
86 }
87
88 if code.contains("std::")
90 || code.contains("collections")
91 || code.contains("HashMap")
92 || code.contains("Vec")
93 || code.contains("String")
94 {
95 return Self::StdlibMapping;
96 }
97
98 Self::Other
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct CategoryWeights {
105 pub ast_transform: f32,
107 pub ownership_borrow: f32,
109 pub stdlib_mapping: f32,
111 pub other: f32,
113}
114
115impl Default for CategoryWeights {
116 fn default() -> Self {
117 Self {
118 ast_transform: 2.0,
119 ownership_borrow: 1.5,
120 stdlib_mapping: 1.2,
121 other: 1.0,
122 }
123 }
124}
125
126impl CategoryWeights {
127 #[must_use]
129 pub fn get(&self, category: DefectCategory) -> f32 {
130 match category {
131 DefectCategory::AstTransform => self.ast_transform,
132 DefectCategory::OwnershipBorrow => self.ownership_borrow,
133 DefectCategory::StdlibMapping => self.stdlib_mapping,
134 DefectCategory::Other => self.other,
135 }
136 }
137
138 pub fn set(&mut self, category: DefectCategory, weight: f32) {
140 match category {
141 DefectCategory::AstTransform => self.ast_transform = weight,
142 DefectCategory::OwnershipBorrow => self.ownership_borrow = weight,
143 DefectCategory::StdlibMapping => self.stdlib_mapping = weight,
144 DefectCategory::Other => self.other = weight,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct DefectSample {
152 pub features: CommitFeatures,
154 pub is_defect: bool,
156 pub category: Option<DefectCategory>,
158}
159
160impl DefectSample {
161 #[must_use]
163 pub fn new(features: CommitFeatures, is_defect: bool) -> Self {
164 Self {
165 features,
166 is_defect,
167 category: None,
168 }
169 }
170
171 #[must_use]
173 pub fn with_category(mut self, category: DefectCategory) -> Self {
174 self.category = Some(category);
175 self
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct DefectPrediction {
182 pub base_probability: f32,
184 pub weighted_probability: f32,
186 pub category: DefectCategory,
188 pub confidence: f32,
190}
191
192impl DefectPrediction {
193 #[must_use]
195 pub fn priority_score(&self) -> f32 {
196 self.weighted_probability * self.confidence
197 }
198}
199
200#[derive(Debug)]
205pub struct DefectPredictor {
206 weights: CategoryWeights,
208 feature_weights: [f32; 8],
210 bias: f32,
212 stats: DefectPredictorStats,
214 is_trained: bool,
216}
217
218#[derive(Debug, Clone, Default)]
220pub struct DefectPredictorStats {
221 pub n_samples: usize,
223 pub n_defects: usize,
225 pub accuracy: Option<f32>,
227}
228
229impl Default for DefectPredictor {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235impl DefectPredictor {
236 #[must_use]
238 pub fn new() -> Self {
239 Self {
240 weights: CategoryWeights::default(),
241 feature_weights: [
242 0.15, 0.10, 0.08, 0.12, -0.20, 0.25, -0.15, 0.10, ],
251 bias: 0.1,
252 stats: DefectPredictorStats::default(),
253 is_trained: false,
254 }
255 }
256
257 #[must_use]
259 pub fn with_weights(weights: CategoryWeights) -> Self {
260 Self {
261 weights,
262 ..Self::new()
263 }
264 }
265
266 pub fn train(&mut self, samples: &[DefectSample]) -> crate::Result<()> {
274 if samples.is_empty() {
275 return Err(crate::Error::Data("No training samples".to_string()));
276 }
277
278 self.stats.n_samples = samples.len();
279 self.stats.n_defects = samples.iter().filter(|s| s.is_defect).count();
280
281 let learning_rate = 0.01;
284 let epochs = 100;
285
286 for _ in 0..epochs {
287 let mut gradient = [0.0f32; 8];
288 let mut bias_gradient = 0.0f32;
289
290 for sample in samples {
291 let arr = sample.features.to_array();
292 let pred = self.predict_raw(&sample.features);
293 let target = if sample.is_defect { 1.0 } else { 0.0 };
294 let error = pred - target;
295
296 for (i, &val) in arr.iter().enumerate() {
297 gradient[i] += error * val;
298 }
299 bias_gradient += error;
300 }
301
302 let n = samples.len() as f32;
304 for (i, grad) in gradient.iter().enumerate() {
305 self.feature_weights[i] -= learning_rate * grad / n;
306 }
307 self.bias -= learning_rate * bias_gradient / n;
308 }
309
310 let correct = samples
312 .iter()
313 .filter(|s| {
314 let pred = self.predict_raw(&s.features) >= 0.5;
315 pred == s.is_defect
316 })
317 .count();
318 self.stats.accuracy = Some(correct as f32 / samples.len() as f32);
319 self.is_trained = true;
320
321 Ok(())
322 }
323
324 fn predict_raw(&self, features: &CommitFeatures) -> f32 {
326 let arr = features.to_array();
327 let mut score = self.bias;
328
329 for (i, &val) in arr.iter().enumerate() {
330 let normalized = match i {
332 0 => (val / 100.0).min(1.0), 1 => (val / 50.0).min(1.0), 2 => (val / 10.0).min(1.0), 3 => val.min(1.0), 4 => val, 5 => (val / 10.0).min(1.0), 6 => val, 7 => (val / 30.0).min(1.0), _ => val,
341 };
342 score += self.feature_weights[i] * normalized;
343 }
344
345 1.0 / (1.0 + (-score).exp())
347 }
348
349 #[must_use]
351 pub fn predict(&self, features: &CommitFeatures, code: &str) -> DefectPrediction {
352 let base_probability = self.predict_raw(features);
353 let category = DefectCategory::classify(code);
354 let weight = self.weights.get(category);
355
356 let weighted_probability = (base_probability * weight).min(1.0);
358
359 let confidence = if self.is_trained { 0.8 } else { 0.5 };
361
362 DefectPrediction {
363 base_probability,
364 weighted_probability,
365 category,
366 confidence,
367 }
368 }
369
370 #[must_use]
372 pub fn predict_features(&self, features: &CommitFeatures) -> DefectPrediction {
373 let base_probability = self.predict_raw(features);
374
375 DefectPrediction {
376 base_probability,
377 weighted_probability: base_probability,
378 category: DefectCategory::Other,
379 confidence: if self.is_trained { 0.8 } else { 0.5 },
380 }
381 }
382
383 #[must_use]
385 pub fn stats(&self) -> &DefectPredictorStats {
386 &self.stats
387 }
388
389 #[must_use]
391 pub fn is_trained(&self) -> bool {
392 self.is_trained
393 }
394
395 #[must_use]
397 pub fn category_weights(&self) -> &CategoryWeights {
398 &self.weights
399 }
400
401 pub fn prioritize(&self, samples: &[(CommitFeatures, String)]) -> Vec<usize> {
405 let mut scored: Vec<(usize, f32)> = samples
406 .iter()
407 .enumerate()
408 .map(|(i, (features, code))| {
409 let pred = self.predict(features, code);
410 (i, pred.priority_score())
411 })
412 .collect();
413
414 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
415
416 scored.into_iter().map(|(i, _)| i).collect()
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 fn sample_features() -> CommitFeatures {
425 CommitFeatures {
426 lines_added: 50,
427 lines_deleted: 10,
428 files_changed: 3,
429 churn_ratio: 0.8,
430 has_test_changes: false,
431 complexity_delta: 5.0,
432 author_experience: 0.5,
433 days_since_last_change: 7.0,
434 }
435 }
436
437 fn buggy_features() -> CommitFeatures {
438 CommitFeatures {
439 lines_added: 200,
440 lines_deleted: 5,
441 files_changed: 10,
442 churn_ratio: 0.95,
443 has_test_changes: false,
444 complexity_delta: 15.0,
445 author_experience: 0.1,
446 days_since_last_change: 1.0,
447 }
448 }
449
450 fn safe_features() -> CommitFeatures {
451 CommitFeatures {
452 lines_added: 10,
453 lines_deleted: 5,
454 files_changed: 1,
455 churn_ratio: 0.3,
456 has_test_changes: true,
457 complexity_delta: -2.0,
458 author_experience: 0.9,
459 days_since_last_change: 14.0,
460 }
461 }
462
463 #[test]
466 fn test_defect_category_weights() {
467 assert!((DefectCategory::AstTransform.weight() - 2.0).abs() < f32::EPSILON);
468 assert!((DefectCategory::OwnershipBorrow.weight() - 1.5).abs() < f32::EPSILON);
469 assert!((DefectCategory::StdlibMapping.weight() - 1.2).abs() < f32::EPSILON);
470 assert!((DefectCategory::Other.weight() - 1.0).abs() < f32::EPSILON);
471 }
472
473 #[test]
474 fn test_defect_category_classify_ast() {
475 assert_eq!(
476 DefectCategory::classify("parse_ast_node"),
477 DefectCategory::AstTransform
478 );
479 assert_eq!(
480 DefectCategory::classify("transform_expression"),
481 DefectCategory::AstTransform
482 );
483 }
484
485 #[test]
486 fn test_defect_category_classify_ownership() {
487 assert_eq!(
488 DefectCategory::classify("fix borrow checker"),
489 DefectCategory::OwnershipBorrow
490 );
491 assert_eq!(
492 DefectCategory::classify("lifetime issue"),
493 DefectCategory::OwnershipBorrow
494 );
495 }
496
497 #[test]
498 fn test_defect_category_classify_stdlib() {
499 assert_eq!(
500 DefectCategory::classify("use std::collections::HashMap"),
501 DefectCategory::StdlibMapping
502 );
503 assert_eq!(
504 DefectCategory::classify("Vec::new()"),
505 DefectCategory::StdlibMapping
506 );
507 }
508
509 #[test]
510 fn test_defect_category_classify_other() {
511 assert_eq!(
512 DefectCategory::classify("simple fix"),
513 DefectCategory::Other
514 );
515 }
516
517 #[test]
518 fn test_defect_category_all() {
519 let all = DefectCategory::all();
520 assert_eq!(all.len(), 4);
521 }
522
523 #[test]
524 fn test_defect_category_default() {
525 assert_eq!(DefectCategory::default(), DefectCategory::Other);
526 }
527
528 #[test]
531 fn test_category_weights_default() {
532 let weights = CategoryWeights::default();
533 assert!((weights.ast_transform - 2.0).abs() < f32::EPSILON);
534 assert!((weights.other - 1.0).abs() < f32::EPSILON);
535 }
536
537 #[test]
538 fn test_category_weights_get() {
539 let weights = CategoryWeights::default();
540 assert!((weights.get(DefectCategory::AstTransform) - 2.0).abs() < f32::EPSILON);
541 }
542
543 #[test]
544 fn test_category_weights_set() {
545 let mut weights = CategoryWeights::default();
546 weights.set(DefectCategory::AstTransform, 3.0);
547 assert!((weights.ast_transform - 3.0).abs() < f32::EPSILON);
548 }
549
550 #[test]
553 fn test_defect_sample_new() {
554 let sample = DefectSample::new(sample_features(), true);
555 assert!(sample.is_defect);
556 assert!(sample.category.is_none());
557 }
558
559 #[test]
560 fn test_defect_sample_with_category() {
561 let sample =
562 DefectSample::new(sample_features(), true).with_category(DefectCategory::AstTransform);
563 assert_eq!(sample.category, Some(DefectCategory::AstTransform));
564 }
565
566 #[test]
569 fn test_defect_prediction_priority_score() {
570 let pred = DefectPrediction {
571 base_probability: 0.8,
572 weighted_probability: 0.9,
573 category: DefectCategory::AstTransform,
574 confidence: 0.7,
575 };
576
577 let score = pred.priority_score();
578 assert!((score - 0.63).abs() < 0.01); }
580
581 #[test]
584 fn test_defect_predictor_new() {
585 let predictor = DefectPredictor::new();
586 assert!(!predictor.is_trained());
587 }
588
589 #[test]
590 fn test_defect_predictor_with_weights() {
591 let weights = CategoryWeights {
592 ast_transform: 3.0,
593 ..Default::default()
594 };
595 let predictor = DefectPredictor::with_weights(weights);
596 assert!((predictor.category_weights().ast_transform - 3.0).abs() < f32::EPSILON);
597 }
598
599 #[test]
600 fn test_defect_predictor_predict_features() {
601 let predictor = DefectPredictor::new();
602 let pred = predictor.predict_features(&sample_features());
603
604 assert!(pred.base_probability >= 0.0);
605 assert!(pred.base_probability <= 1.0);
606 assert_eq!(pred.category, DefectCategory::Other);
607 }
608
609 #[test]
610 fn test_defect_predictor_predict_with_code() {
611 let predictor = DefectPredictor::new();
612 let pred = predictor.predict(&sample_features(), "fix ast parser");
613
614 assert_eq!(pred.category, DefectCategory::AstTransform);
615 assert!(pred.weighted_probability >= pred.base_probability);
617 }
618
619 #[test]
620 fn test_defect_predictor_probability_bounded() {
621 let predictor = DefectPredictor::new();
622
623 for features in &[sample_features(), buggy_features(), safe_features()] {
624 let pred = predictor.predict_features(features);
625 assert!(pred.base_probability >= 0.0);
626 assert!(pred.base_probability <= 1.0);
627 assert!(pred.weighted_probability >= 0.0);
628 assert!(pred.weighted_probability <= 1.0);
629 }
630 }
631
632 #[test]
633 fn test_defect_predictor_buggy_higher_than_safe() {
634 let predictor = DefectPredictor::new();
635
636 let buggy_pred = predictor.predict_features(&buggy_features());
637 let safe_pred = predictor.predict_features(&safe_features());
638
639 assert!(buggy_pred.base_probability > safe_pred.base_probability);
641 }
642
643 #[test]
646 fn test_defect_predictor_train_empty_fails() {
647 let mut predictor = DefectPredictor::new();
648 let result = predictor.train(&[]);
649 assert!(result.is_err());
650 }
651
652 #[test]
653 fn test_defect_predictor_train() {
654 let mut predictor = DefectPredictor::new();
655
656 let samples = vec![
657 DefectSample::new(buggy_features(), true),
658 DefectSample::new(buggy_features(), true),
659 DefectSample::new(safe_features(), false),
660 DefectSample::new(safe_features(), false),
661 ];
662
663 let result = predictor.train(&samples);
664 assert!(result.is_ok());
665 assert!(predictor.is_trained());
666 }
667
668 #[test]
669 fn test_defect_predictor_train_stats() {
670 let mut predictor = DefectPredictor::new();
671
672 let samples = vec![
673 DefectSample::new(buggy_features(), true),
674 DefectSample::new(safe_features(), false),
675 DefectSample::new(safe_features(), false),
676 ];
677
678 predictor.train(&samples).unwrap();
679
680 let stats = predictor.stats();
681 assert_eq!(stats.n_samples, 3);
682 assert_eq!(stats.n_defects, 1);
683 assert!(stats.accuracy.is_some());
684 }
685
686 #[test]
687 fn test_defect_predictor_confidence_after_training() {
688 let mut predictor = DefectPredictor::new();
689
690 let pred_before = predictor.predict_features(&sample_features());
692 assert!((pred_before.confidence - 0.5).abs() < f32::EPSILON);
693
694 let samples = vec![
696 DefectSample::new(buggy_features(), true),
697 DefectSample::new(safe_features(), false),
698 ];
699 predictor.train(&samples).unwrap();
700
701 let pred_after = predictor.predict_features(&sample_features());
703 assert!((pred_after.confidence - 0.8).abs() < f32::EPSILON);
704 }
705
706 #[test]
709 fn test_defect_predictor_prioritize() {
710 let predictor = DefectPredictor::new();
711
712 let samples = vec![
713 (safe_features(), "simple code".to_string()),
714 (buggy_features(), "ast transform bug".to_string()),
715 (sample_features(), "normal code".to_string()),
716 ];
717
718 let order = predictor.prioritize(&samples);
719
720 assert_eq!(order[0], 1);
722 }
723
724 #[test]
725 fn test_defect_predictor_prioritize_empty() {
726 let predictor = DefectPredictor::new();
727 let samples: Vec<(CommitFeatures, String)> = vec![];
728
729 let order = predictor.prioritize(&samples);
730 assert!(order.is_empty());
731 }
732
733 #[test]
736 fn test_defect_category_debug() {
737 let debug = format!("{:?}", DefectCategory::AstTransform);
738 assert!(debug.contains("AstTransform"));
739 }
740
741 #[test]
742 fn test_category_weights_debug() {
743 let weights = CategoryWeights::default();
744 let debug = format!("{weights:?}");
745 assert!(debug.contains("CategoryWeights"));
746 }
747
748 #[test]
749 fn test_defect_sample_debug() {
750 let sample = DefectSample::new(sample_features(), true);
751 let debug = format!("{sample:?}");
752 assert!(debug.contains("DefectSample"));
753 }
754
755 #[test]
756 fn test_defect_prediction_debug() {
757 let pred = DefectPrediction {
758 base_probability: 0.5,
759 weighted_probability: 0.6,
760 category: DefectCategory::Other,
761 confidence: 0.7,
762 };
763 let debug = format!("{pred:?}");
764 assert!(debug.contains("DefectPrediction"));
765 }
766
767 #[test]
768 fn test_defect_predictor_debug() {
769 let predictor = DefectPredictor::new();
770 let debug = format!("{predictor:?}");
771 assert!(debug.contains("DefectPredictor"));
772 }
773
774 #[test]
775 fn test_defect_predictor_stats_debug() {
776 let stats = DefectPredictorStats::default();
777 let debug = format!("{stats:?}");
778 assert!(debug.contains("DefectPredictorStats"));
779 }
780
781 #[test]
784 fn test_defect_category_serialize() {
785 let category = DefectCategory::AstTransform;
786 let json = serde_json::to_string(&category).unwrap();
787 let restored: DefectCategory = serde_json::from_str(&json).unwrap();
788 assert_eq!(category, restored);
789 }
790
791 #[test]
792 fn test_category_weights_serialize() {
793 let weights = CategoryWeights::default();
794 let json = serde_json::to_string(&weights).unwrap();
795 let restored: CategoryWeights = serde_json::from_str(&json).unwrap();
796 assert!((weights.ast_transform - restored.ast_transform).abs() < f32::EPSILON);
797 }
798
799 #[test]
800 fn test_defect_sample_serialize() {
801 let sample =
802 DefectSample::new(sample_features(), true).with_category(DefectCategory::AstTransform);
803 let json = serde_json::to_string(&sample).unwrap();
804 let restored: DefectSample = serde_json::from_str(&json).unwrap();
805 assert_eq!(sample.is_defect, restored.is_defect);
806 assert_eq!(sample.category, restored.category);
807 }
808}
809
810#[cfg(test)]
812mod proptests {
813 use super::*;
814 use proptest::prelude::*;
815
816 proptest! {
817 #[test]
819 fn prop_probability_bounded(
820 lines_added in 0u32..1000,
821 lines_deleted in 0u32..500,
822 files_changed in 1u32..50,
823 complexity_delta in -20.0f32..50.0,
824 ) {
825 let features = CommitFeatures {
826 lines_added,
827 lines_deleted,
828 files_changed,
829 churn_ratio: 0.5,
830 has_test_changes: false,
831 complexity_delta,
832 author_experience: 0.5,
833 days_since_last_change: 7.0,
834 };
835
836 let predictor = DefectPredictor::new();
837 let pred = predictor.predict_features(&features);
838
839 prop_assert!(pred.base_probability >= 0.0);
840 prop_assert!(pred.base_probability <= 1.0);
841 }
842
843 #[test]
845 fn prop_category_weight_increases(base_prob in 0.1f32..0.8) {
846 let weights = CategoryWeights::default();
847
848 let weighted = base_prob * weights.get(DefectCategory::AstTransform);
849 let unweighted = base_prob * weights.get(DefectCategory::Other);
850
851 prop_assert!(weighted >= unweighted);
852 }
853
854 #[test]
856 fn prop_complexity_increases_probability(base_complexity in -5.0f32..5.0) {
857 let predictor = DefectPredictor::new();
858
859 let low = CommitFeatures {
860 complexity_delta: base_complexity,
861 ..Default::default()
862 };
863
864 let high = CommitFeatures {
865 complexity_delta: base_complexity + 10.0,
866 ..Default::default()
867 };
868
869 let low_pred = predictor.predict_features(&low);
870 let high_pred = predictor.predict_features(&high);
871
872 prop_assert!(high_pred.base_probability >= low_pred.base_probability);
873 }
874
875 #[test]
877 fn prop_tests_reduce_probability(lines_added in 10u32..100) {
878 let predictor = DefectPredictor::new();
879
880 let without_tests = CommitFeatures {
881 lines_added,
882 has_test_changes: false,
883 ..Default::default()
884 };
885
886 let with_tests = CommitFeatures {
887 lines_added,
888 has_test_changes: true,
889 ..Default::default()
890 };
891
892 let without_pred = predictor.predict_features(&without_tests);
893 let with_pred = predictor.predict_features(&with_tests);
894
895 prop_assert!(with_pred.base_probability <= without_pred.base_probability);
896 }
897
898 #[test]
900 fn prop_experience_reduces_probability(lines_added in 10u32..100) {
901 let predictor = DefectPredictor::new();
902
903 let novice = CommitFeatures {
904 lines_added,
905 author_experience: 0.1,
906 ..Default::default()
907 };
908
909 let expert = CommitFeatures {
910 lines_added,
911 author_experience: 0.9,
912 ..Default::default()
913 };
914
915 let novice_pred = predictor.predict_features(&novice);
916 let expert_pred = predictor.predict_features(&expert);
917
918 prop_assert!(expert_pred.base_probability <= novice_pred.base_probability);
919 }
920 }
921}