1use std::collections::HashMap;
22
23use crate::ml_features::{AllocationKind, InferredOwnership, OwnershipFeatures};
24use crate::retraining_pipeline::TrainingSample;
25use crate::training_data::TrainingDataset;
26
27#[derive(Debug, Clone)]
29pub struct ClassifierPrediction {
30 pub prediction: InferredOwnership,
32 pub confidence: f64,
34 pub alternatives: Vec<(InferredOwnership, f64)>,
36}
37
38impl ClassifierPrediction {
39 pub fn new(prediction: InferredOwnership, confidence: f64) -> Self {
41 Self {
42 prediction,
43 confidence,
44 alternatives: Vec::new(),
45 }
46 }
47
48 pub fn with_alternative(mut self, kind: InferredOwnership, confidence: f64) -> Self {
50 self.alternatives.push((kind, confidence));
51 self
52 }
53
54 pub fn is_confident(&self, threshold: f64) -> bool {
56 self.confidence >= threshold
57 }
58}
59
60pub trait OwnershipClassifier: Send + Sync {
62 fn classify(&self, features: &OwnershipFeatures) -> ClassifierPrediction;
64
65 fn classify_batch(&self, features: &[OwnershipFeatures]) -> Vec<ClassifierPrediction> {
67 features.iter().map(|f| self.classify(f)).collect()
68 }
69
70 fn name(&self) -> &str;
72
73 fn is_trained(&self) -> bool;
75}
76
77#[derive(Debug, Clone, Default)]
86pub struct RuleBasedClassifier {
87 weights: RuleWeights,
89}
90
91#[derive(Debug, Clone)]
93pub struct RuleWeights {
94 pub malloc_free: f64,
96 pub array_alloc: f64,
98 pub const_qual: f64,
100 pub write_ops: f64,
102 pub size_param: f64,
104}
105
106impl Default for RuleWeights {
107 fn default() -> Self {
108 Self {
109 malloc_free: 0.95,
110 array_alloc: 0.90,
111 const_qual: 0.85,
112 write_ops: 0.80,
113 size_param: 0.75,
114 }
115 }
116}
117
118impl RuleBasedClassifier {
119 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn with_weights(weights: RuleWeights) -> Self {
126 Self { weights }
127 }
128}
129
130impl OwnershipClassifier for RuleBasedClassifier {
131 fn classify(&self, features: &OwnershipFeatures) -> ClassifierPrediction {
132 if matches!(
134 features.allocation_site,
135 AllocationKind::Malloc | AllocationKind::Calloc
136 ) && features.deallocation_count > 0
137 && !features.has_size_param
138 {
139 return ClassifierPrediction::new(InferredOwnership::Owned, self.weights.malloc_free);
140 }
141
142 if matches!(
144 features.allocation_site,
145 AllocationKind::Malloc | AllocationKind::Calloc
146 ) && (features.has_size_param || features.is_array_decay)
147 && features.deallocation_count > 0
148 {
149 return ClassifierPrediction::new(InferredOwnership::Vec, self.weights.array_alloc);
150 }
151
152 if features.is_const && features.deallocation_count == 0 {
154 if features.is_array_decay && features.has_size_param {
156 return ClassifierPrediction::new(
157 InferredOwnership::Slice,
158 self.weights.size_param,
159 );
160 }
161 return ClassifierPrediction::new(InferredOwnership::Borrowed, self.weights.const_qual);
162 }
163
164 if !features.is_const
166 && features.write_count > 0
167 && features.deallocation_count == 0
168 && !matches!(
169 features.allocation_site,
170 AllocationKind::Malloc | AllocationKind::Calloc
171 )
172 {
173 if features.is_array_decay && features.has_size_param {
175 return ClassifierPrediction::new(
176 InferredOwnership::SliceMut,
177 self.weights.size_param,
178 );
179 }
180 return ClassifierPrediction::new(
181 InferredOwnership::BorrowedMut,
182 self.weights.write_ops,
183 );
184 }
185
186 if features.is_array_decay && features.has_size_param {
188 let ownership = if features.is_const {
189 InferredOwnership::Slice
190 } else {
191 InferredOwnership::SliceMut
192 };
193 return ClassifierPrediction::new(ownership, self.weights.size_param);
194 }
195
196 ClassifierPrediction::new(InferredOwnership::RawPointer, 0.3)
198 }
199
200 fn name(&self) -> &str {
201 "RuleBasedClassifier"
202 }
203
204 fn is_trained(&self) -> bool {
205 true }
207}
208
209#[derive(Debug, Clone, Default)]
211pub struct EvaluationMetrics {
212 pub true_positives: HashMap<String, usize>,
214 pub false_positives: HashMap<String, usize>,
216 pub false_negatives: HashMap<String, usize>,
218 pub total_samples: usize,
220 pub correct: usize,
222}
223
224impl EvaluationMetrics {
225 pub fn accuracy(&self) -> f64 {
227 if self.total_samples == 0 {
228 return 0.0;
229 }
230 self.correct as f64 / self.total_samples as f64
231 }
232
233 pub fn precision(&self, class: &str) -> f64 {
235 let tp = *self.true_positives.get(class).unwrap_or(&0) as f64;
236 let fp = *self.false_positives.get(class).unwrap_or(&0) as f64;
237
238 if tp + fp == 0.0 {
239 return 0.0;
240 }
241 tp / (tp + fp)
242 }
243
244 pub fn recall(&self, class: &str) -> f64 {
246 let tp = *self.true_positives.get(class).unwrap_or(&0) as f64;
247 let fn_ = *self.false_negatives.get(class).unwrap_or(&0) as f64;
248
249 if tp + fn_ == 0.0 {
250 return 0.0;
251 }
252 tp / (tp + fn_)
253 }
254
255 pub fn f1_score(&self, class: &str) -> f64 {
257 let p = self.precision(class);
258 let r = self.recall(class);
259
260 if p + r == 0.0 {
261 return 0.0;
262 }
263 2.0 * p * r / (p + r)
264 }
265
266 pub fn macro_f1(&self) -> f64 {
268 let classes: Vec<_> = self.true_positives.keys().collect();
269 if classes.is_empty() {
270 return 0.0;
271 }
272
273 let sum: f64 = classes.iter().map(|c| self.f1_score(c)).sum();
274 sum / classes.len() as f64
275 }
276}
277
278pub struct ClassifierEvaluator {
280 samples: Vec<TrainingSample>,
282}
283
284impl ClassifierEvaluator {
285 pub fn new(samples: Vec<TrainingSample>) -> Self {
287 Self { samples }
288 }
289
290 pub fn from_dataset(dataset: &TrainingDataset) -> Self {
292 Self {
293 samples: dataset.to_training_samples(),
294 }
295 }
296
297 pub fn evaluate(&self, classifier: &dyn OwnershipClassifier) -> EvaluationMetrics {
299 let mut metrics = EvaluationMetrics {
300 total_samples: self.samples.len(),
301 ..Default::default()
302 };
303
304 for sample in &self.samples {
305 let prediction = classifier.classify(&sample.features);
306 let predicted_class = format!("{:?}", prediction.prediction);
307 let actual_class = format!("{:?}", sample.label);
308
309 if prediction.prediction == sample.label {
310 metrics.correct += 1;
311 *metrics.true_positives.entry(actual_class).or_insert(0) += 1;
312 } else {
313 *metrics
314 .false_positives
315 .entry(predicted_class.clone())
316 .or_insert(0) += 1;
317 *metrics.false_negatives.entry(actual_class).or_insert(0) += 1;
318 }
319 }
320
321 metrics
322 }
323
324 pub fn sample_count(&self) -> usize {
326 self.samples.len()
327 }
328}
329
330#[derive(Debug, Clone)]
332pub struct TrainingConfig {
333 pub validation_split: f64,
335 pub random_seed: u64,
337 pub max_iterations: usize,
339 pub early_stopping_patience: usize,
341 pub min_improvement: f64,
343}
344
345impl Default for TrainingConfig {
346 fn default() -> Self {
347 Self {
348 validation_split: 0.2,
349 random_seed: 42,
350 max_iterations: 100,
351 early_stopping_patience: 10,
352 min_improvement: 0.001,
353 }
354 }
355}
356
357#[derive(Debug)]
359pub struct TrainingResult {
360 pub success: bool,
362 pub train_metrics: EvaluationMetrics,
364 pub validation_metrics: EvaluationMetrics,
366 pub iterations: usize,
368 pub duration_secs: f64,
370}
371
372impl TrainingResult {
373 pub fn success(
375 train_metrics: EvaluationMetrics,
376 validation_metrics: EvaluationMetrics,
377 iterations: usize,
378 duration_secs: f64,
379 ) -> Self {
380 Self {
381 success: true,
382 train_metrics,
383 validation_metrics,
384 iterations,
385 duration_secs,
386 }
387 }
388
389 pub fn failure() -> Self {
391 Self {
392 success: false,
393 train_metrics: EvaluationMetrics::default(),
394 validation_metrics: EvaluationMetrics::default(),
395 iterations: 0,
396 duration_secs: 0.0,
397 }
398 }
399}
400
401pub struct ClassifierTrainer {
403 config: TrainingConfig,
404}
405
406impl ClassifierTrainer {
407 pub fn new(config: TrainingConfig) -> Self {
409 Self { config }
410 }
411
412 pub fn with_defaults() -> Self {
414 Self::new(TrainingConfig::default())
415 }
416
417 pub fn train_rule_based(
421 &self,
422 _dataset: &TrainingDataset,
423 ) -> (RuleBasedClassifier, TrainingResult) {
424 let start = std::time::Instant::now();
425 let classifier = RuleBasedClassifier::new();
426 let duration = start.elapsed().as_secs_f64();
427
428 let result = TrainingResult::success(
429 EvaluationMetrics::default(),
430 EvaluationMetrics::default(),
431 1,
432 duration,
433 );
434
435 (classifier, result)
436 }
437
438 pub fn config(&self) -> &TrainingConfig {
440 &self.config
441 }
442}
443
444pub struct EnsembleClassifier {
446 classifiers: Vec<(Box<dyn OwnershipClassifier>, f64)>,
448 name: String,
450}
451
452impl EnsembleClassifier {
453 pub fn new(name: &str) -> Self {
455 Self {
456 classifiers: Vec::new(),
457 name: name.to_string(),
458 }
459 }
460
461 pub fn add_classifier<C: OwnershipClassifier + 'static>(&mut self, classifier: C, weight: f64) {
463 self.classifiers.push((Box::new(classifier), weight));
464 }
465
466 pub fn classifier_count(&self) -> usize {
468 self.classifiers.len()
469 }
470}
471
472impl OwnershipClassifier for EnsembleClassifier {
473 fn classify(&self, features: &OwnershipFeatures) -> ClassifierPrediction {
474 if self.classifiers.is_empty() {
475 return ClassifierPrediction::new(InferredOwnership::RawPointer, 0.0);
476 }
477
478 let mut votes: HashMap<String, f64> = HashMap::new();
480
481 for (classifier, weight) in &self.classifiers {
482 let prediction = classifier.classify(features);
483 let key = format!("{:?}", prediction.prediction);
484 *votes.entry(key.clone()).or_insert(0.0) += weight * prediction.confidence;
485 }
486
487 let total_weight: f64 = self.classifiers.iter().map(|(_, w)| w).sum();
489 let (best_class, best_score) = votes
490 .iter()
491 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
492 .map(|(k, v)| (k.clone(), *v))
493 .unwrap_or_else(|| ("RawPointer".to_string(), 0.0));
494
495 let prediction = match best_class.as_str() {
497 "Owned" => InferredOwnership::Owned,
498 "Borrowed" => InferredOwnership::Borrowed,
499 "BorrowedMut" => InferredOwnership::BorrowedMut,
500 "Vec" => InferredOwnership::Vec,
501 "Slice" => InferredOwnership::Slice,
502 "SliceMut" => InferredOwnership::SliceMut,
503 "Shared" => InferredOwnership::Shared,
504 _ => InferredOwnership::RawPointer,
505 };
506
507 let confidence = if total_weight > 0.0 {
508 best_score / total_weight
509 } else {
510 0.0
511 };
512
513 ClassifierPrediction::new(prediction, confidence)
514 }
515
516 fn name(&self) -> &str {
517 &self.name
518 }
519
520 fn is_trained(&self) -> bool {
521 self.classifiers.iter().all(|(c, _)| c.is_trained())
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use crate::ml_features::OwnershipFeaturesBuilder;
529
530 #[test]
535 fn prediction_new() {
536 let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9);
537
538 assert!(matches!(pred.prediction, InferredOwnership::Owned));
539 assert!((pred.confidence - 0.9).abs() < 0.001);
540 assert!(pred.alternatives.is_empty());
541 }
542
543 #[test]
544 fn prediction_with_alternative() {
545 let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9)
546 .with_alternative(InferredOwnership::Borrowed, 0.1);
547
548 assert_eq!(pred.alternatives.len(), 1);
549 }
550
551 #[test]
552 fn prediction_is_confident() {
553 let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9);
554
555 assert!(pred.is_confident(0.8));
556 assert!(!pred.is_confident(0.95));
557 }
558
559 #[test]
564 fn rule_based_malloc_owned() {
565 let classifier = RuleBasedClassifier::new();
566 let features = OwnershipFeaturesBuilder::default()
567 .allocation_site(AllocationKind::Malloc)
568 .deallocation_count(1)
569 .pointer_depth(1)
570 .build();
571
572 let pred = classifier.classify(&features);
573
574 assert!(matches!(pred.prediction, InferredOwnership::Owned));
575 assert!(pred.confidence > 0.9);
576 }
577
578 #[test]
579 fn rule_based_array_vec() {
580 let classifier = RuleBasedClassifier::new();
581 let features = OwnershipFeaturesBuilder::default()
582 .allocation_site(AllocationKind::Malloc)
583 .has_size_param(true)
584 .deallocation_count(1)
585 .pointer_depth(1)
586 .build();
587
588 let pred = classifier.classify(&features);
589
590 assert!(matches!(pred.prediction, InferredOwnership::Vec));
591 }
592
593 #[test]
594 fn rule_based_const_borrowed() {
595 let classifier = RuleBasedClassifier::new();
596 let features = OwnershipFeaturesBuilder::default()
597 .const_qualified(true)
598 .pointer_depth(1)
599 .build();
600
601 let pred = classifier.classify(&features);
602
603 assert!(matches!(pred.prediction, InferredOwnership::Borrowed));
604 }
605
606 #[test]
607 fn rule_based_mut_borrowed() {
608 let classifier = RuleBasedClassifier::new();
609 let features = OwnershipFeaturesBuilder::default()
610 .const_qualified(false)
611 .write_count(1)
612 .pointer_depth(1)
613 .build();
614
615 let pred = classifier.classify(&features);
616
617 assert!(matches!(pred.prediction, InferredOwnership::BorrowedMut));
618 }
619
620 #[test]
621 fn rule_based_slice() {
622 let classifier = RuleBasedClassifier::new();
623 let features = OwnershipFeaturesBuilder::default()
624 .const_qualified(true)
625 .array_decay(true)
626 .has_size_param(true)
627 .pointer_depth(1)
628 .build();
629
630 let pred = classifier.classify(&features);
631
632 assert!(matches!(pred.prediction, InferredOwnership::Slice));
633 }
634
635 #[test]
636 fn rule_based_unknown() {
637 let classifier = RuleBasedClassifier::new();
638 let features = OwnershipFeaturesBuilder::default().pointer_depth(1).build();
639
640 let pred = classifier.classify(&features);
641
642 assert!(matches!(pred.prediction, InferredOwnership::RawPointer));
644 assert!(pred.confidence < 0.5);
645 }
646
647 #[test]
648 fn rule_based_name() {
649 let classifier = RuleBasedClassifier::new();
650 assert_eq!(classifier.name(), "RuleBasedClassifier");
651 }
652
653 #[test]
654 fn rule_based_is_trained() {
655 let classifier = RuleBasedClassifier::new();
656 assert!(classifier.is_trained());
657 }
658
659 #[test]
660 fn rule_based_batch_classify() {
661 let classifier = RuleBasedClassifier::new();
662 let features = vec![
663 OwnershipFeaturesBuilder::default()
664 .allocation_site(AllocationKind::Malloc)
665 .deallocation_count(1)
666 .build(),
667 OwnershipFeaturesBuilder::default()
668 .const_qualified(true)
669 .build(),
670 ];
671
672 let predictions = classifier.classify_batch(&features);
673
674 assert_eq!(predictions.len(), 2);
675 assert!(matches!(
676 predictions[0].prediction,
677 InferredOwnership::Owned
678 ));
679 assert!(matches!(
680 predictions[1].prediction,
681 InferredOwnership::Borrowed
682 ));
683 }
684
685 #[test]
690 fn metrics_accuracy() {
691 let metrics = EvaluationMetrics {
692 total_samples: 100,
693 correct: 80,
694 ..Default::default()
695 };
696
697 assert!((metrics.accuracy() - 0.8).abs() < 0.001);
698 }
699
700 #[test]
701 fn metrics_accuracy_empty() {
702 let metrics = EvaluationMetrics::default();
703 assert!((metrics.accuracy() - 0.0).abs() < 0.001);
704 }
705
706 #[test]
707 fn metrics_precision() {
708 let mut metrics = EvaluationMetrics::default();
709 metrics.true_positives.insert("Owned".to_string(), 80);
710 metrics.false_positives.insert("Owned".to_string(), 20);
711
712 assert!((metrics.precision("Owned") - 0.8).abs() < 0.001);
713 }
714
715 #[test]
716 fn metrics_recall() {
717 let mut metrics = EvaluationMetrics::default();
718 metrics.true_positives.insert("Owned".to_string(), 80);
719 metrics.false_negatives.insert("Owned".to_string(), 20);
720
721 assert!((metrics.recall("Owned") - 0.8).abs() < 0.001);
722 }
723
724 #[test]
725 fn metrics_f1_score() {
726 let mut metrics = EvaluationMetrics::default();
727 metrics.true_positives.insert("Owned".to_string(), 80);
728 metrics.false_positives.insert("Owned".to_string(), 20);
729 metrics.false_negatives.insert("Owned".to_string(), 20);
730
731 assert!((metrics.f1_score("Owned") - 0.8).abs() < 0.001);
734 }
735
736 #[test]
741 fn evaluator_new() {
742 let samples = vec![TrainingSample::new(
743 OwnershipFeaturesBuilder::default().build(),
744 InferredOwnership::Owned,
745 "test.c",
746 1,
747 )];
748
749 let evaluator = ClassifierEvaluator::new(samples);
750 assert_eq!(evaluator.sample_count(), 1);
751 }
752
753 #[test]
754 fn evaluator_evaluate() {
755 let samples = vec![
756 TrainingSample::new(
757 OwnershipFeaturesBuilder::default()
758 .allocation_site(AllocationKind::Malloc)
759 .deallocation_count(1)
760 .build(),
761 InferredOwnership::Owned,
762 "test.c",
763 1,
764 ),
765 TrainingSample::new(
766 OwnershipFeaturesBuilder::default()
767 .const_qualified(true)
768 .build(),
769 InferredOwnership::Borrowed,
770 "test.c",
771 2,
772 ),
773 ];
774
775 let evaluator = ClassifierEvaluator::new(samples);
776 let classifier = RuleBasedClassifier::new();
777 let metrics = evaluator.evaluate(&classifier);
778
779 assert_eq!(metrics.total_samples, 2);
780 assert_eq!(metrics.correct, 2);
781 assert!((metrics.accuracy() - 1.0).abs() < 0.001);
782 }
783
784 #[test]
789 fn training_config_default() {
790 let config = TrainingConfig::default();
791
792 assert!((config.validation_split - 0.2).abs() < 0.001);
793 assert_eq!(config.random_seed, 42);
794 assert_eq!(config.max_iterations, 100);
795 }
796
797 #[test]
802 fn training_result_success() {
803 let result = TrainingResult::success(
804 EvaluationMetrics::default(),
805 EvaluationMetrics::default(),
806 10,
807 1.5,
808 );
809
810 assert!(result.success);
811 assert_eq!(result.iterations, 10);
812 assert!((result.duration_secs - 1.5).abs() < 0.001);
813 }
814
815 #[test]
816 fn training_result_failure() {
817 let result = TrainingResult::failure();
818
819 assert!(!result.success);
820 assert_eq!(result.iterations, 0);
821 }
822
823 #[test]
828 fn trainer_with_defaults() {
829 let trainer = ClassifierTrainer::with_defaults();
830 assert!((trainer.config().validation_split - 0.2).abs() < 0.001);
831 }
832
833 #[test]
834 fn trainer_train_rule_based() {
835 let trainer = ClassifierTrainer::with_defaults();
836 let dataset = crate::training_data::TrainingDataset::new("test", "1.0.0");
837
838 let (classifier, result) = trainer.train_rule_based(&dataset);
839
840 assert!(result.success);
841 assert!(classifier.is_trained());
842 }
843
844 #[test]
849 fn ensemble_new() {
850 let ensemble = EnsembleClassifier::new("test_ensemble");
851
852 assert_eq!(ensemble.name(), "test_ensemble");
853 assert_eq!(ensemble.classifier_count(), 0);
854 }
855
856 #[test]
857 fn ensemble_add_classifier() {
858 let mut ensemble = EnsembleClassifier::new("test");
859 ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
860
861 assert_eq!(ensemble.classifier_count(), 1);
862 }
863
864 #[test]
865 fn ensemble_classify_empty() {
866 let ensemble = EnsembleClassifier::new("empty");
867 let features = OwnershipFeaturesBuilder::default().build();
868
869 let pred = ensemble.classify(&features);
870
871 assert!(matches!(pred.prediction, InferredOwnership::RawPointer));
872 assert!((pred.confidence - 0.0).abs() < 0.001);
873 }
874
875 #[test]
876 fn ensemble_classify_single() {
877 let mut ensemble = EnsembleClassifier::new("single");
878 ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
879
880 let features = OwnershipFeaturesBuilder::default()
881 .allocation_site(AllocationKind::Malloc)
882 .deallocation_count(1)
883 .build();
884
885 let pred = ensemble.classify(&features);
886
887 assert!(matches!(pred.prediction, InferredOwnership::Owned));
888 }
889
890 #[test]
891 fn ensemble_is_trained() {
892 let mut ensemble = EnsembleClassifier::new("test");
893 assert!(ensemble.is_trained());
895
896 ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
897 assert!(ensemble.is_trained()); }
899
900 #[test]
909 fn rule_based_slice_mut() {
910 let classifier = RuleBasedClassifier::new();
912 let features = OwnershipFeaturesBuilder::default()
913 .const_qualified(false)
914 .write_count(3)
915 .array_decay(true)
916 .has_size_param(true)
917 .pointer_depth(1)
918 .build();
919
920 let pred = classifier.classify(&features);
921 assert!(matches!(pred.prediction, InferredOwnership::SliceMut));
922 }
923
924 #[test]
925 fn rule_based_rule5_non_const_slice_mut() {
926 let classifier = RuleBasedClassifier::new();
929 let features = OwnershipFeaturesBuilder::default()
930 .const_qualified(false)
931 .write_count(0)
932 .array_decay(true)
933 .has_size_param(true)
934 .pointer_depth(1)
935 .build();
936
937 let pred = classifier.classify(&features);
938 assert!(matches!(pred.prediction, InferredOwnership::SliceMut));
939 }
940
941 #[test]
942 fn rule_based_stack_allocation_falls_through() {
943 let classifier = RuleBasedClassifier::new();
945 let features = OwnershipFeaturesBuilder::default()
946 .allocation_site(AllocationKind::Stack)
947 .pointer_depth(1)
948 .build();
949
950 let pred = classifier.classify(&features);
951 assert!(matches!(pred.prediction, InferredOwnership::RawPointer));
952 }
953
954 #[test]
955 fn rule_based_with_custom_weights() {
956 let weights = RuleWeights {
957 malloc_free: 0.99,
958 array_alloc: 0.98,
959 const_qual: 0.97,
960 write_ops: 0.96,
961 size_param: 0.95,
962 };
963 let classifier = RuleBasedClassifier::with_weights(weights);
964 let features = OwnershipFeaturesBuilder::default()
965 .allocation_site(AllocationKind::Malloc)
966 .deallocation_count(1)
967 .pointer_depth(1)
968 .build();
969
970 let pred = classifier.classify(&features);
971 assert!(matches!(pred.prediction, InferredOwnership::Owned));
972 assert!((pred.confidence - 0.99).abs() < 0.001);
973 }
974
975 #[test]
976 fn metrics_macro_f1() {
977 let mut metrics = EvaluationMetrics::default();
978 metrics.true_positives.insert("Owned".to_string(), 80);
979 metrics.false_positives.insert("Owned".to_string(), 20);
980 metrics.false_negatives.insert("Owned".to_string(), 20);
981
982 metrics.true_positives.insert("Borrowed".to_string(), 90);
983 metrics.false_positives.insert("Borrowed".to_string(), 10);
984 metrics.false_negatives.insert("Borrowed".to_string(), 10);
985
986 let f1 = metrics.macro_f1();
987 assert!(f1 > 0.0);
988 assert!(f1 <= 1.0);
989 }
990
991 #[test]
992 fn metrics_macro_f1_empty() {
993 let metrics = EvaluationMetrics::default();
994 assert!((metrics.macro_f1() - 0.0).abs() < 0.001);
995 }
996
997 #[test]
998 fn metrics_f1_zero_precision_recall() {
999 let metrics = EvaluationMetrics::default();
1000 assert!((metrics.f1_score("Unknown") - 0.0).abs() < 0.001);
1002 }
1003
1004 #[test]
1005 fn metrics_precision_no_predictions() {
1006 let metrics = EvaluationMetrics::default();
1007 assert!((metrics.precision("Unknown") - 0.0).abs() < 0.001);
1009 }
1010
1011 #[test]
1012 fn metrics_recall_no_positives() {
1013 let metrics = EvaluationMetrics::default();
1014 assert!((metrics.recall("Unknown") - 0.0).abs() < 0.001);
1016 }
1017
1018 #[test]
1019 fn ensemble_classify_multiple() {
1020 let mut ensemble = EnsembleClassifier::new("dual");
1022 ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
1023 ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
1024
1025 let features = OwnershipFeaturesBuilder::default()
1026 .const_qualified(true)
1027 .pointer_depth(1)
1028 .build();
1029
1030 let pred = ensemble.classify(&features);
1031 assert!(matches!(pred.prediction, InferredOwnership::Borrowed));
1032 assert!(pred.confidence > 0.0);
1033 }
1034
1035 #[test]
1036 fn ensemble_zero_weight_classifiers() {
1037 let mut ensemble = EnsembleClassifier::new("zero_weight");
1038 ensemble.add_classifier(RuleBasedClassifier::new(), 0.0);
1039
1040 let features = OwnershipFeaturesBuilder::default()
1041 .allocation_site(AllocationKind::Malloc)
1042 .deallocation_count(1)
1043 .build();
1044
1045 let pred = ensemble.classify(&features);
1046 assert!((pred.confidence - 0.0).abs() < 0.001);
1048 }
1049
1050 #[test]
1051 fn full_training_pipeline() {
1052 let config = crate::training_data::SyntheticConfig {
1054 samples_per_pattern: 20,
1055 ..Default::default()
1056 };
1057 let generator = crate::training_data::SyntheticDataGenerator::new(config);
1058 let dataset = generator.generate_full_dataset();
1059
1060 let trainer = ClassifierTrainer::with_defaults();
1062 let (classifier, result) = trainer.train_rule_based(&dataset);
1063
1064 assert!(result.success);
1065
1066 let evaluator = ClassifierEvaluator::from_dataset(&dataset);
1068 let metrics = evaluator.evaluate(&classifier);
1069
1070 assert!(metrics.accuracy() > 0.8);
1072 }
1073
1074 #[test]
1079 fn evaluator_with_mismatched_predictions() {
1080 let owned_features = OwnershipFeatures {
1082 allocation_site: AllocationKind::Malloc,
1083 deallocation_count: 1,
1084 ..Default::default()
1085 };
1086 let samples = vec![TrainingSample::new(
1088 owned_features,
1089 InferredOwnership::Borrowed,
1090 "test.c",
1091 1,
1092 )];
1093
1094 let evaluator = ClassifierEvaluator::new(samples);
1095 let classifier = RuleBasedClassifier::new();
1096 let metrics = evaluator.evaluate(&classifier);
1097
1098 assert_eq!(metrics.total_samples, 1);
1100 assert!(!metrics.false_positives.is_empty() || !metrics.false_negatives.is_empty());
1102 }
1103
1104 #[test]
1105 fn metrics_precision_both_tp_fp_zero() {
1106 let metrics = EvaluationMetrics::default();
1107 assert!((metrics.precision("Owned") - 0.0).abs() < f64::EPSILON);
1109 }
1110
1111 #[test]
1112 fn metrics_recall_both_tp_fn_zero() {
1113 let metrics = EvaluationMetrics::default();
1114 assert!((metrics.recall("Owned") - 0.0).abs() < f64::EPSILON);
1115 }
1116
1117 #[test]
1118 fn metrics_f1_both_precision_recall_zero() {
1119 let metrics = EvaluationMetrics::default();
1120 assert!((metrics.f1_score("Owned") - 0.0).abs() < f64::EPSILON);
1121 }
1122
1123 #[test]
1124 fn metrics_macro_f1_empty_classes() {
1125 let metrics = EvaluationMetrics::default();
1126 assert!((metrics.macro_f1() - 0.0).abs() < f64::EPSILON);
1128 }
1129
1130 #[test]
1131 fn metrics_macro_f1_with_classes() {
1132 let mut metrics = EvaluationMetrics::default();
1133 metrics.true_positives.insert("Owned".to_string(), 10);
1134 metrics.false_positives.insert("Owned".to_string(), 2);
1135 metrics.false_negatives.insert("Owned".to_string(), 1);
1136 metrics.true_positives.insert("Borrowed".to_string(), 5);
1137 metrics.false_positives.insert("Borrowed".to_string(), 1);
1138 metrics.false_negatives.insert("Borrowed".to_string(), 3);
1139
1140 let f1 = metrics.macro_f1();
1141 assert!(f1 > 0.0 && f1 <= 1.0);
1142 }
1143
1144 #[test]
1145 fn prediction_with_multiple_alternatives() {
1146 let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9)
1147 .with_alternative(InferredOwnership::Borrowed, 0.5)
1148 .with_alternative(InferredOwnership::Vec, 0.3);
1149 assert_eq!(pred.alternatives.len(), 2);
1150 assert_eq!(pred.alternatives[0].0, InferredOwnership::Borrowed);
1151 assert_eq!(pred.alternatives[1].0, InferredOwnership::Vec);
1152 }
1153
1154 #[test]
1155 fn evaluator_multiple_mismatches() {
1156 let owned_features = OwnershipFeatures {
1158 allocation_site: AllocationKind::Malloc,
1159 deallocation_count: 1,
1160 ..Default::default()
1161 };
1162 let borrow_features = OwnershipFeatures {
1163 is_const: true,
1164 ..Default::default()
1165 };
1166
1167 let samples = vec![
1168 TrainingSample::new(owned_features.clone(), InferredOwnership::Owned, "a.c", 1),
1170 TrainingSample::new(borrow_features, InferredOwnership::Owned, "b.c", 2),
1172 ];
1173
1174 let evaluator = ClassifierEvaluator::new(samples);
1175 let classifier = RuleBasedClassifier::new();
1176 let metrics = evaluator.evaluate(&classifier);
1177
1178 assert_eq!(metrics.total_samples, 2);
1179 assert!(metrics.correct >= 1);
1181 assert!(metrics.accuracy() >= 0.5);
1182 }
1183}