Skip to main content

decy_ownership/
classifier.rs

1//! ML classifier infrastructure for ownership inference (DECY-ML-011).
2//!
3//! Provides infrastructure for training and using ML classifiers
4//! for ownership pattern recognition. Designed to integrate with
5//! external ML libraries (e.g., aprender for RandomForest).
6//!
7//! # Architecture
8//!
9//! The classifier system uses a trait-based design:
10//! - [`OwnershipClassifier`]: Core trait for all classifiers
11//! - [`RuleBasedClassifier`]: Baseline deterministic classifier
12//! - [`EnsembleClassifier`]: Combines multiple classifiers
13//!
14//! # Training Flow
15//!
16//! 1. Collect labeled data via [`TrainingDataset`](crate::training_data::TrainingDataset)
17//! 2. Train classifier using [`ClassifierTrainer`]
18//! 3. Validate using [`ClassifierEvaluator`]
19//! 4. Deploy via [`HybridClassifier`](crate::hybrid_classifier::HybridClassifier)
20
21use std::collections::HashMap;
22
23use crate::ml_features::{AllocationKind, InferredOwnership, OwnershipFeatures};
24use crate::retraining_pipeline::TrainingSample;
25use crate::training_data::TrainingDataset;
26
27/// A classification prediction with confidence.
28#[derive(Debug, Clone)]
29pub struct ClassifierPrediction {
30    /// Predicted ownership kind.
31    pub prediction: InferredOwnership,
32    /// Confidence score (0.0 - 1.0).
33    pub confidence: f64,
34    /// Alternative predictions with confidences.
35    pub alternatives: Vec<(InferredOwnership, f64)>,
36}
37
38impl ClassifierPrediction {
39    /// Create a new prediction.
40    pub fn new(prediction: InferredOwnership, confidence: f64) -> Self {
41        Self {
42            prediction,
43            confidence,
44            alternatives: Vec::new(),
45        }
46    }
47
48    /// Add an alternative prediction.
49    pub fn with_alternative(mut self, kind: InferredOwnership, confidence: f64) -> Self {
50        self.alternatives.push((kind, confidence));
51        self
52    }
53
54    /// Check if prediction is confident (above threshold).
55    pub fn is_confident(&self, threshold: f64) -> bool {
56        self.confidence >= threshold
57    }
58}
59
60/// Core trait for ownership classifiers.
61pub trait OwnershipClassifier: Send + Sync {
62    /// Classify a feature vector.
63    fn classify(&self, features: &OwnershipFeatures) -> ClassifierPrediction;
64
65    /// Classify multiple samples (batch prediction).
66    fn classify_batch(&self, features: &[OwnershipFeatures]) -> Vec<ClassifierPrediction> {
67        features.iter().map(|f| self.classify(f)).collect()
68    }
69
70    /// Get classifier name.
71    fn name(&self) -> &str;
72
73    /// Check if classifier is trained.
74    fn is_trained(&self) -> bool;
75}
76
77/// Rule-based baseline classifier.
78///
79/// Uses deterministic rules based on feature patterns:
80/// - `malloc + free` → `Owned`
81/// - `malloc + size_param + array_decay` → `Vec`
82/// - `const` → `Borrowed`
83/// - `write_count > 0` → `BorrowedMut`
84/// - `array_decay + size_param` → `Slice`
85#[derive(Debug, Clone, Default)]
86pub struct RuleBasedClassifier {
87    /// Rule weights for confidence scoring.
88    weights: RuleWeights,
89}
90
91/// Weights for rule-based confidence scoring.
92#[derive(Debug, Clone)]
93pub struct RuleWeights {
94    /// Weight for malloc/free pattern.
95    pub malloc_free: f64,
96    /// Weight for array allocation pattern.
97    pub array_alloc: f64,
98    /// Weight for const qualifier.
99    pub const_qual: f64,
100    /// Weight for write operations.
101    pub write_ops: f64,
102    /// Weight for size parameter.
103    pub size_param: f64,
104}
105
106impl Default for RuleWeights {
107    fn default() -> Self {
108        Self {
109            malloc_free: 0.95,
110            array_alloc: 0.90,
111            const_qual: 0.85,
112            write_ops: 0.80,
113            size_param: 0.75,
114        }
115    }
116}
117
118impl RuleBasedClassifier {
119    /// Create a new rule-based classifier.
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    /// Create with custom weights.
125    pub fn with_weights(weights: RuleWeights) -> Self {
126        Self { weights }
127    }
128}
129
130impl OwnershipClassifier for RuleBasedClassifier {
131    fn classify(&self, features: &OwnershipFeatures) -> ClassifierPrediction {
132        // Rule 1: malloc + free → Owned (Box)
133        if matches!(
134            features.allocation_site,
135            AllocationKind::Malloc | AllocationKind::Calloc
136        ) && features.deallocation_count > 0
137            && !features.has_size_param
138        {
139            return ClassifierPrediction::new(InferredOwnership::Owned, self.weights.malloc_free);
140        }
141
142        // Rule 2: malloc + size + array_decay → Vec
143        if matches!(
144            features.allocation_site,
145            AllocationKind::Malloc | AllocationKind::Calloc
146        ) && (features.has_size_param || features.is_array_decay)
147            && features.deallocation_count > 0
148        {
149            return ClassifierPrediction::new(InferredOwnership::Vec, self.weights.array_alloc);
150        }
151
152        // Rule 3: const pointer → Borrowed (&T)
153        if features.is_const && features.deallocation_count == 0 {
154            // Check for slice pattern
155            if features.is_array_decay && features.has_size_param {
156                return ClassifierPrediction::new(
157                    InferredOwnership::Slice,
158                    self.weights.size_param,
159                );
160            }
161            return ClassifierPrediction::new(InferredOwnership::Borrowed, self.weights.const_qual);
162        }
163
164        // Rule 4: non-const with writes → BorrowedMut (&mut T)
165        if !features.is_const
166            && features.write_count > 0
167            && features.deallocation_count == 0
168            && !matches!(
169                features.allocation_site,
170                AllocationKind::Malloc | AllocationKind::Calloc
171            )
172        {
173            // Check for mutable slice pattern
174            if features.is_array_decay && features.has_size_param {
175                return ClassifierPrediction::new(
176                    InferredOwnership::SliceMut,
177                    self.weights.size_param,
178                );
179            }
180            return ClassifierPrediction::new(
181                InferredOwnership::BorrowedMut,
182                self.weights.write_ops,
183            );
184        }
185
186        // Rule 5: array decay with size → Slice
187        if features.is_array_decay && features.has_size_param {
188            let ownership = if features.is_const {
189                InferredOwnership::Slice
190            } else {
191                InferredOwnership::SliceMut
192            };
193            return ClassifierPrediction::new(ownership, self.weights.size_param);
194        }
195
196        // Default: low confidence RawPointer
197        ClassifierPrediction::new(InferredOwnership::RawPointer, 0.3)
198    }
199
200    fn name(&self) -> &str {
201        "RuleBasedClassifier"
202    }
203
204    fn is_trained(&self) -> bool {
205        true // Always trained (deterministic)
206    }
207}
208
209/// Evaluation metrics for a classifier.
210#[derive(Debug, Clone, Default)]
211pub struct EvaluationMetrics {
212    /// True positives per class.
213    pub true_positives: HashMap<String, usize>,
214    /// False positives per class.
215    pub false_positives: HashMap<String, usize>,
216    /// False negatives per class.
217    pub false_negatives: HashMap<String, usize>,
218    /// Total samples evaluated.
219    pub total_samples: usize,
220    /// Correct predictions.
221    pub correct: usize,
222}
223
224impl EvaluationMetrics {
225    /// Compute overall accuracy.
226    pub fn accuracy(&self) -> f64 {
227        if self.total_samples == 0 {
228            return 0.0;
229        }
230        self.correct as f64 / self.total_samples as f64
231    }
232
233    /// Compute precision for a class.
234    pub fn precision(&self, class: &str) -> f64 {
235        let tp = *self.true_positives.get(class).unwrap_or(&0) as f64;
236        let fp = *self.false_positives.get(class).unwrap_or(&0) as f64;
237
238        if tp + fp == 0.0 {
239            return 0.0;
240        }
241        tp / (tp + fp)
242    }
243
244    /// Compute recall for a class.
245    pub fn recall(&self, class: &str) -> f64 {
246        let tp = *self.true_positives.get(class).unwrap_or(&0) as f64;
247        let fn_ = *self.false_negatives.get(class).unwrap_or(&0) as f64;
248
249        if tp + fn_ == 0.0 {
250            return 0.0;
251        }
252        tp / (tp + fn_)
253    }
254
255    /// Compute F1 score for a class.
256    pub fn f1_score(&self, class: &str) -> f64 {
257        let p = self.precision(class);
258        let r = self.recall(class);
259
260        if p + r == 0.0 {
261            return 0.0;
262        }
263        2.0 * p * r / (p + r)
264    }
265
266    /// Compute macro-averaged F1 score.
267    pub fn macro_f1(&self) -> f64 {
268        let classes: Vec<_> = self.true_positives.keys().collect();
269        if classes.is_empty() {
270            return 0.0;
271        }
272
273        let sum: f64 = classes.iter().map(|c| self.f1_score(c)).sum();
274        sum / classes.len() as f64
275    }
276}
277
278/// Classifier evaluator.
279pub struct ClassifierEvaluator {
280    /// Test samples.
281    samples: Vec<TrainingSample>,
282}
283
284impl ClassifierEvaluator {
285    /// Create evaluator from test samples.
286    pub fn new(samples: Vec<TrainingSample>) -> Self {
287        Self { samples }
288    }
289
290    /// Create from dataset (uses all samples).
291    pub fn from_dataset(dataset: &TrainingDataset) -> Self {
292        Self {
293            samples: dataset.to_training_samples(),
294        }
295    }
296
297    /// Evaluate a classifier.
298    pub fn evaluate(&self, classifier: &dyn OwnershipClassifier) -> EvaluationMetrics {
299        let mut metrics = EvaluationMetrics {
300            total_samples: self.samples.len(),
301            ..Default::default()
302        };
303
304        for sample in &self.samples {
305            let prediction = classifier.classify(&sample.features);
306            let predicted_class = format!("{:?}", prediction.prediction);
307            let actual_class = format!("{:?}", sample.label);
308
309            if prediction.prediction == sample.label {
310                metrics.correct += 1;
311                *metrics.true_positives.entry(actual_class).or_insert(0) += 1;
312            } else {
313                *metrics
314                    .false_positives
315                    .entry(predicted_class.clone())
316                    .or_insert(0) += 1;
317                *metrics.false_negatives.entry(actual_class).or_insert(0) += 1;
318            }
319        }
320
321        metrics
322    }
323
324    /// Get sample count.
325    pub fn sample_count(&self) -> usize {
326        self.samples.len()
327    }
328}
329
330/// Configuration for classifier training.
331#[derive(Debug, Clone)]
332pub struct TrainingConfig {
333    /// Validation split ratio.
334    pub validation_split: f64,
335    /// Random seed for reproducibility.
336    pub random_seed: u64,
337    /// Maximum training iterations.
338    pub max_iterations: usize,
339    /// Early stopping patience.
340    pub early_stopping_patience: usize,
341    /// Minimum improvement for early stopping.
342    pub min_improvement: f64,
343}
344
345impl Default for TrainingConfig {
346    fn default() -> Self {
347        Self {
348            validation_split: 0.2,
349            random_seed: 42,
350            max_iterations: 100,
351            early_stopping_patience: 10,
352            min_improvement: 0.001,
353        }
354    }
355}
356
357/// Result of classifier training.
358#[derive(Debug)]
359pub struct TrainingResult {
360    /// Training succeeded.
361    pub success: bool,
362    /// Final training metrics.
363    pub train_metrics: EvaluationMetrics,
364    /// Final validation metrics.
365    pub validation_metrics: EvaluationMetrics,
366    /// Number of iterations completed.
367    pub iterations: usize,
368    /// Training duration in seconds.
369    pub duration_secs: f64,
370}
371
372impl TrainingResult {
373    /// Create a successful result.
374    pub fn success(
375        train_metrics: EvaluationMetrics,
376        validation_metrics: EvaluationMetrics,
377        iterations: usize,
378        duration_secs: f64,
379    ) -> Self {
380        Self {
381            success: true,
382            train_metrics,
383            validation_metrics,
384            iterations,
385            duration_secs,
386        }
387    }
388
389    /// Create a failed result.
390    pub fn failure() -> Self {
391        Self {
392            success: false,
393            train_metrics: EvaluationMetrics::default(),
394            validation_metrics: EvaluationMetrics::default(),
395            iterations: 0,
396            duration_secs: 0.0,
397        }
398    }
399}
400
401/// Trainer for classifiers.
402pub struct ClassifierTrainer {
403    config: TrainingConfig,
404}
405
406impl ClassifierTrainer {
407    /// Create a new trainer with configuration.
408    pub fn new(config: TrainingConfig) -> Self {
409        Self { config }
410    }
411
412    /// Create with default configuration.
413    pub fn with_defaults() -> Self {
414        Self::new(TrainingConfig::default())
415    }
416
417    /// Train a rule-based classifier (returns pre-built classifier).
418    ///
419    /// Note: Rule-based classifier doesn't require training.
420    pub fn train_rule_based(
421        &self,
422        _dataset: &TrainingDataset,
423    ) -> (RuleBasedClassifier, TrainingResult) {
424        let start = std::time::Instant::now();
425        let classifier = RuleBasedClassifier::new();
426        let duration = start.elapsed().as_secs_f64();
427
428        let result = TrainingResult::success(
429            EvaluationMetrics::default(),
430            EvaluationMetrics::default(),
431            1,
432            duration,
433        );
434
435        (classifier, result)
436    }
437
438    /// Get training configuration.
439    pub fn config(&self) -> &TrainingConfig {
440        &self.config
441    }
442}
443
444/// Ensemble classifier combining multiple classifiers.
445pub struct EnsembleClassifier {
446    /// Component classifiers with weights.
447    classifiers: Vec<(Box<dyn OwnershipClassifier>, f64)>,
448    /// Ensemble name.
449    name: String,
450}
451
452impl EnsembleClassifier {
453    /// Create a new ensemble.
454    pub fn new(name: &str) -> Self {
455        Self {
456            classifiers: Vec::new(),
457            name: name.to_string(),
458        }
459    }
460
461    /// Add a classifier with weight.
462    pub fn add_classifier<C: OwnershipClassifier + 'static>(&mut self, classifier: C, weight: f64) {
463        self.classifiers.push((Box::new(classifier), weight));
464    }
465
466    /// Get number of classifiers in ensemble.
467    pub fn classifier_count(&self) -> usize {
468        self.classifiers.len()
469    }
470}
471
472impl OwnershipClassifier for EnsembleClassifier {
473    fn classify(&self, features: &OwnershipFeatures) -> ClassifierPrediction {
474        if self.classifiers.is_empty() {
475            return ClassifierPrediction::new(InferredOwnership::RawPointer, 0.0);
476        }
477
478        // Collect weighted votes
479        let mut votes: HashMap<String, f64> = HashMap::new();
480
481        for (classifier, weight) in &self.classifiers {
482            let prediction = classifier.classify(features);
483            let key = format!("{:?}", prediction.prediction);
484            *votes.entry(key.clone()).or_insert(0.0) += weight * prediction.confidence;
485        }
486
487        // Find highest voted class
488        let total_weight: f64 = self.classifiers.iter().map(|(_, w)| w).sum();
489        let (best_class, best_score) = votes
490            .iter()
491            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
492            .map(|(k, v)| (k.clone(), *v))
493            .unwrap_or_else(|| ("RawPointer".to_string(), 0.0));
494
495        // Convert back to InferredOwnership
496        let prediction = match best_class.as_str() {
497            "Owned" => InferredOwnership::Owned,
498            "Borrowed" => InferredOwnership::Borrowed,
499            "BorrowedMut" => InferredOwnership::BorrowedMut,
500            "Vec" => InferredOwnership::Vec,
501            "Slice" => InferredOwnership::Slice,
502            "SliceMut" => InferredOwnership::SliceMut,
503            "Shared" => InferredOwnership::Shared,
504            _ => InferredOwnership::RawPointer,
505        };
506
507        let confidence = if total_weight > 0.0 {
508            best_score / total_weight
509        } else {
510            0.0
511        };
512
513        ClassifierPrediction::new(prediction, confidence)
514    }
515
516    fn name(&self) -> &str {
517        &self.name
518    }
519
520    fn is_trained(&self) -> bool {
521        self.classifiers.iter().all(|(c, _)| c.is_trained())
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use crate::ml_features::OwnershipFeaturesBuilder;
529
530    // ========================================================================
531    // ClassifierPrediction tests
532    // ========================================================================
533
534    #[test]
535    fn prediction_new() {
536        let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9);
537
538        assert!(matches!(pred.prediction, InferredOwnership::Owned));
539        assert!((pred.confidence - 0.9).abs() < 0.001);
540        assert!(pred.alternatives.is_empty());
541    }
542
543    #[test]
544    fn prediction_with_alternative() {
545        let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9)
546            .with_alternative(InferredOwnership::Borrowed, 0.1);
547
548        assert_eq!(pred.alternatives.len(), 1);
549    }
550
551    #[test]
552    fn prediction_is_confident() {
553        let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9);
554
555        assert!(pred.is_confident(0.8));
556        assert!(!pred.is_confident(0.95));
557    }
558
559    // ========================================================================
560    // RuleBasedClassifier tests
561    // ========================================================================
562
563    #[test]
564    fn rule_based_malloc_owned() {
565        let classifier = RuleBasedClassifier::new();
566        let features = OwnershipFeaturesBuilder::default()
567            .allocation_site(AllocationKind::Malloc)
568            .deallocation_count(1)
569            .pointer_depth(1)
570            .build();
571
572        let pred = classifier.classify(&features);
573
574        assert!(matches!(pred.prediction, InferredOwnership::Owned));
575        assert!(pred.confidence > 0.9);
576    }
577
578    #[test]
579    fn rule_based_array_vec() {
580        let classifier = RuleBasedClassifier::new();
581        let features = OwnershipFeaturesBuilder::default()
582            .allocation_site(AllocationKind::Malloc)
583            .has_size_param(true)
584            .deallocation_count(1)
585            .pointer_depth(1)
586            .build();
587
588        let pred = classifier.classify(&features);
589
590        assert!(matches!(pred.prediction, InferredOwnership::Vec));
591    }
592
593    #[test]
594    fn rule_based_const_borrowed() {
595        let classifier = RuleBasedClassifier::new();
596        let features = OwnershipFeaturesBuilder::default()
597            .const_qualified(true)
598            .pointer_depth(1)
599            .build();
600
601        let pred = classifier.classify(&features);
602
603        assert!(matches!(pred.prediction, InferredOwnership::Borrowed));
604    }
605
606    #[test]
607    fn rule_based_mut_borrowed() {
608        let classifier = RuleBasedClassifier::new();
609        let features = OwnershipFeaturesBuilder::default()
610            .const_qualified(false)
611            .write_count(1)
612            .pointer_depth(1)
613            .build();
614
615        let pred = classifier.classify(&features);
616
617        assert!(matches!(pred.prediction, InferredOwnership::BorrowedMut));
618    }
619
620    #[test]
621    fn rule_based_slice() {
622        let classifier = RuleBasedClassifier::new();
623        let features = OwnershipFeaturesBuilder::default()
624            .const_qualified(true)
625            .array_decay(true)
626            .has_size_param(true)
627            .pointer_depth(1)
628            .build();
629
630        let pred = classifier.classify(&features);
631
632        assert!(matches!(pred.prediction, InferredOwnership::Slice));
633    }
634
635    #[test]
636    fn rule_based_unknown() {
637        let classifier = RuleBasedClassifier::new();
638        let features = OwnershipFeaturesBuilder::default().pointer_depth(1).build();
639
640        let pred = classifier.classify(&features);
641
642        // Default to RawPointer with low confidence
643        assert!(matches!(pred.prediction, InferredOwnership::RawPointer));
644        assert!(pred.confidence < 0.5);
645    }
646
647    #[test]
648    fn rule_based_name() {
649        let classifier = RuleBasedClassifier::new();
650        assert_eq!(classifier.name(), "RuleBasedClassifier");
651    }
652
653    #[test]
654    fn rule_based_is_trained() {
655        let classifier = RuleBasedClassifier::new();
656        assert!(classifier.is_trained());
657    }
658
659    #[test]
660    fn rule_based_batch_classify() {
661        let classifier = RuleBasedClassifier::new();
662        let features = vec![
663            OwnershipFeaturesBuilder::default()
664                .allocation_site(AllocationKind::Malloc)
665                .deallocation_count(1)
666                .build(),
667            OwnershipFeaturesBuilder::default()
668                .const_qualified(true)
669                .build(),
670        ];
671
672        let predictions = classifier.classify_batch(&features);
673
674        assert_eq!(predictions.len(), 2);
675        assert!(matches!(
676            predictions[0].prediction,
677            InferredOwnership::Owned
678        ));
679        assert!(matches!(
680            predictions[1].prediction,
681            InferredOwnership::Borrowed
682        ));
683    }
684
685    // ========================================================================
686    // EvaluationMetrics tests
687    // ========================================================================
688
689    #[test]
690    fn metrics_accuracy() {
691        let metrics = EvaluationMetrics {
692            total_samples: 100,
693            correct: 80,
694            ..Default::default()
695        };
696
697        assert!((metrics.accuracy() - 0.8).abs() < 0.001);
698    }
699
700    #[test]
701    fn metrics_accuracy_empty() {
702        let metrics = EvaluationMetrics::default();
703        assert!((metrics.accuracy() - 0.0).abs() < 0.001);
704    }
705
706    #[test]
707    fn metrics_precision() {
708        let mut metrics = EvaluationMetrics::default();
709        metrics.true_positives.insert("Owned".to_string(), 80);
710        metrics.false_positives.insert("Owned".to_string(), 20);
711
712        assert!((metrics.precision("Owned") - 0.8).abs() < 0.001);
713    }
714
715    #[test]
716    fn metrics_recall() {
717        let mut metrics = EvaluationMetrics::default();
718        metrics.true_positives.insert("Owned".to_string(), 80);
719        metrics.false_negatives.insert("Owned".to_string(), 20);
720
721        assert!((metrics.recall("Owned") - 0.8).abs() < 0.001);
722    }
723
724    #[test]
725    fn metrics_f1_score() {
726        let mut metrics = EvaluationMetrics::default();
727        metrics.true_positives.insert("Owned".to_string(), 80);
728        metrics.false_positives.insert("Owned".to_string(), 20);
729        metrics.false_negatives.insert("Owned".to_string(), 20);
730
731        // Precision = 80/100 = 0.8, Recall = 80/100 = 0.8
732        // F1 = 2 * 0.8 * 0.8 / (0.8 + 0.8) = 0.8
733        assert!((metrics.f1_score("Owned") - 0.8).abs() < 0.001);
734    }
735
736    // ========================================================================
737    // ClassifierEvaluator tests
738    // ========================================================================
739
740    #[test]
741    fn evaluator_new() {
742        let samples = vec![TrainingSample::new(
743            OwnershipFeaturesBuilder::default().build(),
744            InferredOwnership::Owned,
745            "test.c",
746            1,
747        )];
748
749        let evaluator = ClassifierEvaluator::new(samples);
750        assert_eq!(evaluator.sample_count(), 1);
751    }
752
753    #[test]
754    fn evaluator_evaluate() {
755        let samples = vec![
756            TrainingSample::new(
757                OwnershipFeaturesBuilder::default()
758                    .allocation_site(AllocationKind::Malloc)
759                    .deallocation_count(1)
760                    .build(),
761                InferredOwnership::Owned,
762                "test.c",
763                1,
764            ),
765            TrainingSample::new(
766                OwnershipFeaturesBuilder::default()
767                    .const_qualified(true)
768                    .build(),
769                InferredOwnership::Borrowed,
770                "test.c",
771                2,
772            ),
773        ];
774
775        let evaluator = ClassifierEvaluator::new(samples);
776        let classifier = RuleBasedClassifier::new();
777        let metrics = evaluator.evaluate(&classifier);
778
779        assert_eq!(metrics.total_samples, 2);
780        assert_eq!(metrics.correct, 2);
781        assert!((metrics.accuracy() - 1.0).abs() < 0.001);
782    }
783
784    // ========================================================================
785    // TrainingConfig tests
786    // ========================================================================
787
788    #[test]
789    fn training_config_default() {
790        let config = TrainingConfig::default();
791
792        assert!((config.validation_split - 0.2).abs() < 0.001);
793        assert_eq!(config.random_seed, 42);
794        assert_eq!(config.max_iterations, 100);
795    }
796
797    // ========================================================================
798    // TrainingResult tests
799    // ========================================================================
800
801    #[test]
802    fn training_result_success() {
803        let result = TrainingResult::success(
804            EvaluationMetrics::default(),
805            EvaluationMetrics::default(),
806            10,
807            1.5,
808        );
809
810        assert!(result.success);
811        assert_eq!(result.iterations, 10);
812        assert!((result.duration_secs - 1.5).abs() < 0.001);
813    }
814
815    #[test]
816    fn training_result_failure() {
817        let result = TrainingResult::failure();
818
819        assert!(!result.success);
820        assert_eq!(result.iterations, 0);
821    }
822
823    // ========================================================================
824    // ClassifierTrainer tests
825    // ========================================================================
826
827    #[test]
828    fn trainer_with_defaults() {
829        let trainer = ClassifierTrainer::with_defaults();
830        assert!((trainer.config().validation_split - 0.2).abs() < 0.001);
831    }
832
833    #[test]
834    fn trainer_train_rule_based() {
835        let trainer = ClassifierTrainer::with_defaults();
836        let dataset = crate::training_data::TrainingDataset::new("test", "1.0.0");
837
838        let (classifier, result) = trainer.train_rule_based(&dataset);
839
840        assert!(result.success);
841        assert!(classifier.is_trained());
842    }
843
844    // ========================================================================
845    // EnsembleClassifier tests
846    // ========================================================================
847
848    #[test]
849    fn ensemble_new() {
850        let ensemble = EnsembleClassifier::new("test_ensemble");
851
852        assert_eq!(ensemble.name(), "test_ensemble");
853        assert_eq!(ensemble.classifier_count(), 0);
854    }
855
856    #[test]
857    fn ensemble_add_classifier() {
858        let mut ensemble = EnsembleClassifier::new("test");
859        ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
860
861        assert_eq!(ensemble.classifier_count(), 1);
862    }
863
864    #[test]
865    fn ensemble_classify_empty() {
866        let ensemble = EnsembleClassifier::new("empty");
867        let features = OwnershipFeaturesBuilder::default().build();
868
869        let pred = ensemble.classify(&features);
870
871        assert!(matches!(pred.prediction, InferredOwnership::RawPointer));
872        assert!((pred.confidence - 0.0).abs() < 0.001);
873    }
874
875    #[test]
876    fn ensemble_classify_single() {
877        let mut ensemble = EnsembleClassifier::new("single");
878        ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
879
880        let features = OwnershipFeaturesBuilder::default()
881            .allocation_site(AllocationKind::Malloc)
882            .deallocation_count(1)
883            .build();
884
885        let pred = ensemble.classify(&features);
886
887        assert!(matches!(pred.prediction, InferredOwnership::Owned));
888    }
889
890    #[test]
891    fn ensemble_is_trained() {
892        let mut ensemble = EnsembleClassifier::new("test");
893        // Empty ensemble is vacuously trained (all 0 classifiers are trained)
894        assert!(ensemble.is_trained());
895
896        ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
897        assert!(ensemble.is_trained()); // Rule-based always trained
898    }
899
900    // ========================================================================
901    // Integration tests
902    // ========================================================================
903
904    // ========================================================================
905    // Additional coverage: rule branches and edge cases
906    // ========================================================================
907
908    #[test]
909    fn rule_based_slice_mut() {
910        // Rule 4 inner: non-const + writes + array_decay + size_param → SliceMut
911        let classifier = RuleBasedClassifier::new();
912        let features = OwnershipFeaturesBuilder::default()
913            .const_qualified(false)
914            .write_count(3)
915            .array_decay(true)
916            .has_size_param(true)
917            .pointer_depth(1)
918            .build();
919
920        let pred = classifier.classify(&features);
921        assert!(matches!(pred.prediction, InferredOwnership::SliceMut));
922    }
923
924    #[test]
925    fn rule_based_rule5_non_const_slice_mut() {
926        // Rule 5: array_decay + size_param, non-const → SliceMut
927        // This hits Rule 5 when Rule 3/4 don't match
928        let classifier = RuleBasedClassifier::new();
929        let features = OwnershipFeaturesBuilder::default()
930            .const_qualified(false)
931            .write_count(0)
932            .array_decay(true)
933            .has_size_param(true)
934            .pointer_depth(1)
935            .build();
936
937        let pred = classifier.classify(&features);
938        assert!(matches!(pred.prediction, InferredOwnership::SliceMut));
939    }
940
941    #[test]
942    fn rule_based_stack_allocation_falls_through() {
943        // Stack allocation without other features → RawPointer
944        let classifier = RuleBasedClassifier::new();
945        let features = OwnershipFeaturesBuilder::default()
946            .allocation_site(AllocationKind::Stack)
947            .pointer_depth(1)
948            .build();
949
950        let pred = classifier.classify(&features);
951        assert!(matches!(pred.prediction, InferredOwnership::RawPointer));
952    }
953
954    #[test]
955    fn rule_based_with_custom_weights() {
956        let weights = RuleWeights {
957            malloc_free: 0.99,
958            array_alloc: 0.98,
959            const_qual: 0.97,
960            write_ops: 0.96,
961            size_param: 0.95,
962        };
963        let classifier = RuleBasedClassifier::with_weights(weights);
964        let features = OwnershipFeaturesBuilder::default()
965            .allocation_site(AllocationKind::Malloc)
966            .deallocation_count(1)
967            .pointer_depth(1)
968            .build();
969
970        let pred = classifier.classify(&features);
971        assert!(matches!(pred.prediction, InferredOwnership::Owned));
972        assert!((pred.confidence - 0.99).abs() < 0.001);
973    }
974
975    #[test]
976    fn metrics_macro_f1() {
977        let mut metrics = EvaluationMetrics::default();
978        metrics.true_positives.insert("Owned".to_string(), 80);
979        metrics.false_positives.insert("Owned".to_string(), 20);
980        metrics.false_negatives.insert("Owned".to_string(), 20);
981
982        metrics.true_positives.insert("Borrowed".to_string(), 90);
983        metrics.false_positives.insert("Borrowed".to_string(), 10);
984        metrics.false_negatives.insert("Borrowed".to_string(), 10);
985
986        let f1 = metrics.macro_f1();
987        assert!(f1 > 0.0);
988        assert!(f1 <= 1.0);
989    }
990
991    #[test]
992    fn metrics_macro_f1_empty() {
993        let metrics = EvaluationMetrics::default();
994        assert!((metrics.macro_f1() - 0.0).abs() < 0.001);
995    }
996
997    #[test]
998    fn metrics_f1_zero_precision_recall() {
999        let metrics = EvaluationMetrics::default();
1000        // No TP, no FP, no FN for "Unknown" → precision=0, recall=0, f1=0
1001        assert!((metrics.f1_score("Unknown") - 0.0).abs() < 0.001);
1002    }
1003
1004    #[test]
1005    fn metrics_precision_no_predictions() {
1006        let metrics = EvaluationMetrics::default();
1007        // No TP and no FP → 0.0
1008        assert!((metrics.precision("Unknown") - 0.0).abs() < 0.001);
1009    }
1010
1011    #[test]
1012    fn metrics_recall_no_positives() {
1013        let metrics = EvaluationMetrics::default();
1014        // No TP and no FN → 0.0
1015        assert!((metrics.recall("Unknown") - 0.0).abs() < 0.001);
1016    }
1017
1018    #[test]
1019    fn ensemble_classify_multiple() {
1020        // Two classifiers voting on same features
1021        let mut ensemble = EnsembleClassifier::new("dual");
1022        ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
1023        ensemble.add_classifier(RuleBasedClassifier::new(), 1.0);
1024
1025        let features = OwnershipFeaturesBuilder::default()
1026            .const_qualified(true)
1027            .pointer_depth(1)
1028            .build();
1029
1030        let pred = ensemble.classify(&features);
1031        assert!(matches!(pred.prediction, InferredOwnership::Borrowed));
1032        assert!(pred.confidence > 0.0);
1033    }
1034
1035    #[test]
1036    fn ensemble_zero_weight_classifiers() {
1037        let mut ensemble = EnsembleClassifier::new("zero_weight");
1038        ensemble.add_classifier(RuleBasedClassifier::new(), 0.0);
1039
1040        let features = OwnershipFeaturesBuilder::default()
1041            .allocation_site(AllocationKind::Malloc)
1042            .deallocation_count(1)
1043            .build();
1044
1045        let pred = ensemble.classify(&features);
1046        // Zero total weight → confidence 0.0
1047        assert!((pred.confidence - 0.0).abs() < 0.001);
1048    }
1049
1050    #[test]
1051    fn full_training_pipeline() {
1052        // Create synthetic dataset
1053        let config = crate::training_data::SyntheticConfig {
1054            samples_per_pattern: 20,
1055            ..Default::default()
1056        };
1057        let generator = crate::training_data::SyntheticDataGenerator::new(config);
1058        let dataset = generator.generate_full_dataset();
1059
1060        // Train classifier
1061        let trainer = ClassifierTrainer::with_defaults();
1062        let (classifier, result) = trainer.train_rule_based(&dataset);
1063
1064        assert!(result.success);
1065
1066        // Evaluate
1067        let evaluator = ClassifierEvaluator::from_dataset(&dataset);
1068        let metrics = evaluator.evaluate(&classifier);
1069
1070        // Should have high accuracy on synthetic data
1071        assert!(metrics.accuracy() > 0.8);
1072    }
1073
1074    // ========================================================================
1075    // Additional edge case coverage
1076    // ========================================================================
1077
1078    #[test]
1079    fn evaluator_with_mismatched_predictions() {
1080        // Force mismatches to exercise false positive/negative paths
1081        let owned_features = OwnershipFeatures {
1082            allocation_site: AllocationKind::Malloc,
1083            deallocation_count: 1,
1084            ..Default::default()
1085        };
1086        // Label says Borrowed, but features look like Owned → mismatch
1087        let samples = vec![TrainingSample::new(
1088            owned_features,
1089            InferredOwnership::Borrowed,
1090            "test.c",
1091            1,
1092        )];
1093
1094        let evaluator = ClassifierEvaluator::new(samples);
1095        let classifier = RuleBasedClassifier::new();
1096        let metrics = evaluator.evaluate(&classifier);
1097
1098        // Should have 0 correct (features match Owned, label says Borrowed)
1099        assert_eq!(metrics.total_samples, 1);
1100        // The prediction (Owned) != label (Borrowed), so false positive + false negative
1101        assert!(!metrics.false_positives.is_empty() || !metrics.false_negatives.is_empty());
1102    }
1103
1104    #[test]
1105    fn metrics_precision_both_tp_fp_zero() {
1106        let metrics = EvaluationMetrics::default();
1107        // No entries → both tp and fp are 0 → division guard returns 0.0
1108        assert!((metrics.precision("Owned") - 0.0).abs() < f64::EPSILON);
1109    }
1110
1111    #[test]
1112    fn metrics_recall_both_tp_fn_zero() {
1113        let metrics = EvaluationMetrics::default();
1114        assert!((metrics.recall("Owned") - 0.0).abs() < f64::EPSILON);
1115    }
1116
1117    #[test]
1118    fn metrics_f1_both_precision_recall_zero() {
1119        let metrics = EvaluationMetrics::default();
1120        assert!((metrics.f1_score("Owned") - 0.0).abs() < f64::EPSILON);
1121    }
1122
1123    #[test]
1124    fn metrics_macro_f1_empty_classes() {
1125        let metrics = EvaluationMetrics::default();
1126        // No classes → returns 0.0
1127        assert!((metrics.macro_f1() - 0.0).abs() < f64::EPSILON);
1128    }
1129
1130    #[test]
1131    fn metrics_macro_f1_with_classes() {
1132        let mut metrics = EvaluationMetrics::default();
1133        metrics.true_positives.insert("Owned".to_string(), 10);
1134        metrics.false_positives.insert("Owned".to_string(), 2);
1135        metrics.false_negatives.insert("Owned".to_string(), 1);
1136        metrics.true_positives.insert("Borrowed".to_string(), 5);
1137        metrics.false_positives.insert("Borrowed".to_string(), 1);
1138        metrics.false_negatives.insert("Borrowed".to_string(), 3);
1139
1140        let f1 = metrics.macro_f1();
1141        assert!(f1 > 0.0 && f1 <= 1.0);
1142    }
1143
1144    #[test]
1145    fn prediction_with_multiple_alternatives() {
1146        let pred = ClassifierPrediction::new(InferredOwnership::Owned, 0.9)
1147            .with_alternative(InferredOwnership::Borrowed, 0.5)
1148            .with_alternative(InferredOwnership::Vec, 0.3);
1149        assert_eq!(pred.alternatives.len(), 2);
1150        assert_eq!(pred.alternatives[0].0, InferredOwnership::Borrowed);
1151        assert_eq!(pred.alternatives[1].0, InferredOwnership::Vec);
1152    }
1153
1154    #[test]
1155    fn evaluator_multiple_mismatches() {
1156        // Multiple samples with various mismatches
1157        let owned_features = OwnershipFeatures {
1158            allocation_site: AllocationKind::Malloc,
1159            deallocation_count: 1,
1160            ..Default::default()
1161        };
1162        let borrow_features = OwnershipFeatures {
1163            is_const: true,
1164            ..Default::default()
1165        };
1166
1167        let samples = vec![
1168            // Correctly classified (Owned features, Owned label)
1169            TrainingSample::new(owned_features.clone(), InferredOwnership::Owned, "a.c", 1),
1170            // Misclassified (Borrowed features, but labeled Owned)
1171            TrainingSample::new(borrow_features, InferredOwnership::Owned, "b.c", 2),
1172        ];
1173
1174        let evaluator = ClassifierEvaluator::new(samples);
1175        let classifier = RuleBasedClassifier::new();
1176        let metrics = evaluator.evaluate(&classifier);
1177
1178        assert_eq!(metrics.total_samples, 2);
1179        // At least one should be correct
1180        assert!(metrics.correct >= 1);
1181        assert!(metrics.accuracy() >= 0.5);
1182    }
1183}