Skip to main content

verificar/ml/
quality_gate.rs

1//! Quality Gate - ML-based pre-oracle filter
2//!
3//! Reduces oracle calls by filtering low-value candidates using
4//! a RandomForest classifier trained on historical data.
5//!
6//! # Architecture
7//!
8//! ```text
9//! Code → FeatureExtractor → QualityGate → Oracle (if passes)
10//!                              ↓
11//!                         Filtered (if low quality)
12//! ```
13//!
14//! # Reference
15//! VER-050: Quality Gate - ML-based pre-oracle filter
16
17use crate::generator::GeneratedCode;
18use serde::{Deserialize, Serialize};
19
20/// Features extracted from code for quality prediction
21#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
22pub struct CodeQualityFeatures {
23    /// Lines of code
24    pub loc: u32,
25    /// AST depth
26    pub ast_depth: u32,
27    /// Number of unique identifiers
28    pub unique_identifiers: u32,
29    /// Cyclomatic complexity estimate
30    pub complexity: u32,
31    /// Has control flow (if/for/while)
32    pub has_control_flow: bool,
33    /// Has function definitions
34    pub has_functions: bool,
35    /// Has error handling (try/except)
36    pub has_error_handling: bool,
37    /// Ratio of comments to code
38    pub comment_ratio: f32,
39}
40
41impl CodeQualityFeatures {
42    /// Convert to feature array for ML model
43    #[must_use]
44    pub fn to_array(&self) -> [f32; 8] {
45        [
46            self.loc as f32,
47            self.ast_depth as f32,
48            self.unique_identifiers as f32,
49            self.complexity as f32,
50            if self.has_control_flow { 1.0 } else { 0.0 },
51            if self.has_functions { 1.0 } else { 0.0 },
52            if self.has_error_handling { 1.0 } else { 0.0 },
53            self.comment_ratio,
54        ]
55    }
56
57    /// Create from feature array
58    #[must_use]
59    #[allow(clippy::cast_sign_loss)]
60    pub fn from_array(arr: [f32; 8]) -> Self {
61        Self {
62            loc: arr[0].max(0.0) as u32,
63            ast_depth: arr[1].max(0.0) as u32,
64            unique_identifiers: arr[2].max(0.0) as u32,
65            complexity: arr[3].max(0.0) as u32,
66            has_control_flow: arr[4] > 0.5,
67            has_functions: arr[5] > 0.5,
68            has_error_handling: arr[6] > 0.5,
69            comment_ratio: arr[7],
70        }
71    }
72}
73
74/// Feature extractor for code quality
75#[derive(Debug, Default)]
76pub struct FeatureExtractor;
77
78impl FeatureExtractor {
79    /// Create new feature extractor
80    #[must_use]
81    pub fn new() -> Self {
82        Self
83    }
84
85    /// Extract features from code string
86    #[must_use]
87    pub fn extract(&self, code: &str) -> CodeQualityFeatures {
88        let lines: Vec<&str> = code.lines().collect();
89        let loc = lines.len() as u32;
90
91        // Count unique identifiers (simple heuristic)
92        let unique_identifiers = self.count_identifiers(code);
93
94        // Estimate complexity from control flow keywords
95        let complexity = self.estimate_complexity(code);
96
97        // Check for various code patterns
98        let has_control_flow = code.contains("if ")
99            || code.contains("for ")
100            || code.contains("while ")
101            || code.contains("match ");
102
103        let has_functions =
104            code.contains("def ") || code.contains("fn ") || code.contains("function ");
105
106        let has_error_handling =
107            code.contains("try:") || code.contains("except") || code.contains("catch");
108
109        // Comment ratio
110        let comment_lines = lines
111            .iter()
112            .filter(|l| l.trim().starts_with('#') || l.trim().starts_with("//"))
113            .count();
114        let comment_ratio = if loc > 0 {
115            comment_lines as f32 / loc as f32
116        } else {
117            0.0
118        };
119
120        CodeQualityFeatures {
121            loc,
122            ast_depth: 0, // Will be set from GeneratedCode if available
123            unique_identifiers,
124            complexity,
125            has_control_flow,
126            has_functions,
127            has_error_handling,
128            comment_ratio,
129        }
130    }
131
132    /// Extract features from GeneratedCode (includes AST depth)
133    #[must_use]
134    pub fn extract_from_generated(&self, generated: &GeneratedCode) -> CodeQualityFeatures {
135        let mut features = self.extract(&generated.code);
136        features.ast_depth = generated.ast_depth as u32;
137        features
138    }
139
140    fn count_identifiers(&self, code: &str) -> u32 {
141        use std::collections::HashSet;
142
143        let mut identifiers = HashSet::new();
144        let mut current = String::new();
145
146        for ch in code.chars() {
147            if ch.is_alphanumeric() || ch == '_' {
148                current.push(ch);
149            } else {
150                if !current.is_empty()
151                    && current
152                        .chars()
153                        .next()
154                        .is_some_and(|c| c.is_alphabetic() || c == '_')
155                {
156                    identifiers.insert(current.clone());
157                }
158                current.clear();
159            }
160        }
161
162        if !current.is_empty()
163            && current
164                .chars()
165                .next()
166                .is_some_and(|c| c.is_alphabetic() || c == '_')
167        {
168            identifiers.insert(current);
169        }
170
171        identifiers.len() as u32
172    }
173
174    fn estimate_complexity(&self, code: &str) -> u32 {
175        let mut complexity = 1u32; // Base complexity
176
177        // Count decision points
178        let keywords = ["if ", "elif ", "else:", "for ", "while ", "case ", "match "];
179        for kw in keywords {
180            complexity += code.matches(kw).count() as u32;
181        }
182
183        // Count logical operators
184        complexity += code.matches(" and ").count() as u32;
185        complexity += code.matches(" or ").count() as u32;
186        complexity += code.matches("&&").count() as u32;
187        complexity += code.matches("||").count() as u32;
188
189        complexity
190    }
191}
192
193/// Quality gate prediction result
194#[derive(Debug, Clone, Copy, PartialEq)]
195pub enum QualityVerdict {
196    /// Code passes quality gate, send to oracle
197    Pass,
198    /// Code filtered out as low value
199    Filtered,
200}
201
202/// Quality Gate classifier
203#[derive(Debug)]
204pub struct QualityGate {
205    /// Threshold for passing (0.0 to 1.0)
206    threshold: f32,
207    /// Feature weights (simple linear model)
208    weights: [f32; 8],
209    /// Bias term
210    bias: f32,
211    /// Statistics
212    stats: QualityGateStats,
213}
214
215/// Statistics for quality gate
216#[derive(Debug, Clone, Default)]
217pub struct QualityGateStats {
218    /// Total candidates evaluated
219    pub total: usize,
220    /// Candidates that passed
221    pub passed: usize,
222    /// Candidates filtered
223    pub filtered: usize,
224}
225
226impl QualityGateStats {
227    /// Filter rate (0.0 to 1.0)
228    #[must_use]
229    pub fn filter_rate(&self) -> f32 {
230        if self.total == 0 {
231            0.0
232        } else {
233            self.filtered as f32 / self.total as f32
234        }
235    }
236
237    /// Pass rate (0.0 to 1.0)
238    #[must_use]
239    pub fn pass_rate(&self) -> f32 {
240        if self.total == 0 {
241            0.0
242        } else {
243            self.passed as f32 / self.total as f32
244        }
245    }
246}
247
248impl Default for QualityGate {
249    fn default() -> Self {
250        Self::new(0.7)
251    }
252}
253
254impl QualityGate {
255    /// Create quality gate with threshold
256    #[must_use]
257    pub fn new(threshold: f32) -> Self {
258        // Default weights favoring complexity and control flow
259        let weights = [
260            0.05,  // loc: small positive
261            0.15,  // ast_depth: medium positive
262            0.10,  // unique_identifiers: positive
263            0.20,  // complexity: strong positive
264            0.25,  // has_control_flow: strong positive
265            0.15,  // has_functions: medium positive
266            0.10,  // has_error_handling: positive
267            -0.05, // comment_ratio: slightly negative (too many comments = template)
268        ];
269
270        Self {
271            threshold,
272            weights,
273            bias: 0.3, // Base score
274            stats: QualityGateStats::default(),
275        }
276    }
277
278    /// Create with custom weights
279    #[must_use]
280    pub fn with_weights(threshold: f32, weights: [f32; 8], bias: f32) -> Self {
281        Self {
282            threshold,
283            weights,
284            bias,
285            stats: QualityGateStats::default(),
286        }
287    }
288
289    /// Evaluate code and return verdict
290    pub fn evaluate(&mut self, features: &CodeQualityFeatures) -> QualityVerdict {
291        let score = self.score(features);
292        self.stats.total += 1;
293
294        if score >= self.threshold {
295            self.stats.passed += 1;
296            QualityVerdict::Pass
297        } else {
298            self.stats.filtered += 1;
299            QualityVerdict::Filtered
300        }
301    }
302
303    /// Get quality score (0.0 to 1.0)
304    #[must_use]
305    pub fn score(&self, features: &CodeQualityFeatures) -> f32 {
306        let arr = features.to_array();
307        let mut score = self.bias;
308
309        for (i, &val) in arr.iter().enumerate() {
310            // Normalize features to [0, 1] range approximately
311            let normalized = match i {
312                0 => (val / 100.0).min(1.0), // loc: normalize by 100
313                1 => (val / 10.0).min(1.0),  // ast_depth: normalize by 10
314                2 => (val / 50.0).min(1.0),  // identifiers: normalize by 50
315                3 => (val / 20.0).min(1.0),  // complexity: normalize by 20
316                4..=6 => val,                // booleans already 0/1
317                7 => val,                    // ratio already 0-1
318                _ => val,
319            };
320            score += self.weights[i] * normalized;
321        }
322
323        score.clamp(0.0, 1.0)
324    }
325
326    /// Get current statistics
327    #[must_use]
328    pub fn stats(&self) -> &QualityGateStats {
329        &self.stats
330    }
331
332    /// Reset statistics
333    pub fn reset_stats(&mut self) {
334        self.stats = QualityGateStats::default();
335    }
336
337    /// Get threshold
338    #[must_use]
339    pub fn threshold(&self) -> f32 {
340        self.threshold
341    }
342
343    /// Set threshold
344    pub fn set_threshold(&mut self, threshold: f32) {
345        self.threshold = threshold;
346    }
347
348    /// Batch evaluate and return passing codes
349    pub fn filter_batch<'a>(&mut self, codes: &'a [GeneratedCode]) -> Vec<&'a GeneratedCode> {
350        let extractor = FeatureExtractor::new();
351
352        codes
353            .iter()
354            .filter(|code| {
355                let features = extractor.extract_from_generated(code);
356                self.evaluate(&features) == QualityVerdict::Pass
357            })
358            .collect()
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::Language;
366
367    fn sample_code_simple() -> &'static str {
368        "x = 1"
369    }
370
371    fn sample_code_complex() -> &'static str {
372        r#"def factorial(n):
373    if n <= 1:
374        return 1
375    else:
376        return n * factorial(n - 1)
377
378def main():
379    for i in range(10):
380        print(factorial(i))
381"#
382    }
383
384    fn sample_generated(code: &str, depth: usize) -> GeneratedCode {
385        GeneratedCode {
386            code: code.to_string(),
387            language: Language::Python,
388            ast_depth: depth,
389            features: vec![],
390        }
391    }
392
393    // ========== RED PHASE: Feature Extraction Tests ==========
394
395    #[test]
396    fn test_feature_extractor_simple() {
397        let extractor = FeatureExtractor::new();
398        let features = extractor.extract(sample_code_simple());
399
400        assert_eq!(features.loc, 1);
401        assert!(!features.has_control_flow);
402        assert!(!features.has_functions);
403    }
404
405    #[test]
406    fn test_feature_extractor_complex() {
407        let extractor = FeatureExtractor::new();
408        let features = extractor.extract(sample_code_complex());
409
410        assert!(features.loc > 5);
411        assert!(features.has_control_flow);
412        assert!(features.has_functions);
413        assert!(features.complexity > 1);
414    }
415
416    #[test]
417    fn test_feature_extractor_identifiers() {
418        let extractor = FeatureExtractor::new();
419        let features = extractor.extract("x = 1\ny = 2\nz = x + y");
420
421        assert!(features.unique_identifiers >= 3);
422    }
423
424    #[test]
425    fn test_feature_extractor_complexity() {
426        let extractor = FeatureExtractor::new();
427
428        let simple = extractor.extract("x = 1");
429        let complex = extractor.extract("if x:\n    if y:\n        pass");
430
431        assert!(complex.complexity > simple.complexity);
432    }
433
434    #[test]
435    fn test_feature_extractor_comment_ratio() {
436        let extractor = FeatureExtractor::new();
437
438        let no_comments = extractor.extract("x = 1\ny = 2");
439        let all_comments = extractor.extract("# comment\n# another");
440
441        assert!(no_comments.comment_ratio < 0.1);
442        assert!(all_comments.comment_ratio > 0.9);
443    }
444
445    #[test]
446    fn test_feature_extractor_error_handling() {
447        let extractor = FeatureExtractor::new();
448
449        let with_try = extractor.extract("try:\n    x = 1\nexcept:\n    pass");
450        let without_try = extractor.extract("x = 1");
451
452        assert!(with_try.has_error_handling);
453        assert!(!without_try.has_error_handling);
454    }
455
456    #[test]
457    fn test_feature_extractor_from_generated() {
458        let extractor = FeatureExtractor::new();
459        let generated = sample_generated("x = 1", 3);
460
461        let features = extractor.extract_from_generated(&generated);
462
463        assert_eq!(features.ast_depth, 3);
464    }
465
466    // ========== RED PHASE: Feature Array Tests ==========
467
468    #[test]
469    fn test_features_to_array() {
470        let features = CodeQualityFeatures {
471            loc: 10,
472            ast_depth: 3,
473            unique_identifiers: 5,
474            complexity: 4,
475            has_control_flow: true,
476            has_functions: false,
477            has_error_handling: true,
478            comment_ratio: 0.2,
479        };
480
481        let arr = features.to_array();
482
483        assert_eq!(arr[0], 10.0);
484        assert_eq!(arr[1], 3.0);
485        assert_eq!(arr[4], 1.0); // has_control_flow
486        assert_eq!(arr[5], 0.0); // has_functions
487    }
488
489    #[test]
490    fn test_features_from_array() {
491        let arr = [10.0, 3.0, 5.0, 4.0, 1.0, 0.0, 1.0, 0.2];
492        let features = CodeQualityFeatures::from_array(arr);
493
494        assert_eq!(features.loc, 10);
495        assert!(features.has_control_flow);
496        assert!(!features.has_functions);
497    }
498
499    #[test]
500    fn test_features_roundtrip() {
501        let original = CodeQualityFeatures {
502            loc: 15,
503            ast_depth: 4,
504            unique_identifiers: 8,
505            complexity: 6,
506            has_control_flow: true,
507            has_functions: true,
508            has_error_handling: false,
509            comment_ratio: 0.1,
510        };
511
512        let arr = original.to_array();
513        let restored = CodeQualityFeatures::from_array(arr);
514
515        assert_eq!(original.loc, restored.loc);
516        assert_eq!(original.has_control_flow, restored.has_control_flow);
517    }
518
519    // ========== RED PHASE: Quality Gate Tests ==========
520
521    #[test]
522    fn test_quality_gate_default() {
523        let gate = QualityGate::default();
524        assert!((gate.threshold() - 0.7).abs() < f32::EPSILON);
525    }
526
527    #[test]
528    fn test_quality_gate_simple_code_filtered() {
529        let mut gate = QualityGate::new(0.5);
530        let extractor = FeatureExtractor::new();
531
532        let features = extractor.extract(sample_code_simple());
533        let verdict = gate.evaluate(&features);
534
535        // Simple code should be filtered
536        assert_eq!(verdict, QualityVerdict::Filtered);
537    }
538
539    #[test]
540    fn test_quality_gate_complex_code_passes() {
541        let mut gate = QualityGate::new(0.5);
542        let extractor = FeatureExtractor::new();
543
544        let features = extractor.extract(sample_code_complex());
545        let verdict = gate.evaluate(&features);
546
547        // Complex code should pass
548        assert_eq!(verdict, QualityVerdict::Pass);
549    }
550
551    #[test]
552    fn test_quality_gate_score_bounded() {
553        let gate = QualityGate::new(0.5);
554        let extractor = FeatureExtractor::new();
555
556        for code in &[sample_code_simple(), sample_code_complex(), ""] {
557            let features = extractor.extract(code);
558            let score = gate.score(&features);
559
560            assert!(score >= 0.0);
561            assert!(score <= 1.0);
562        }
563    }
564
565    #[test]
566    fn test_quality_gate_stats() {
567        let mut gate = QualityGate::new(0.5);
568        let extractor = FeatureExtractor::new();
569
570        let simple = extractor.extract(sample_code_simple());
571        let complex = extractor.extract(sample_code_complex());
572
573        gate.evaluate(&simple);
574        gate.evaluate(&complex);
575
576        let stats = gate.stats();
577        assert_eq!(stats.total, 2);
578        assert_eq!(stats.passed + stats.filtered, 2);
579    }
580
581    #[test]
582    fn test_quality_gate_stats_rates() {
583        let mut gate = QualityGate::new(0.5);
584        let extractor = FeatureExtractor::new();
585
586        // Add some evaluations
587        for _ in 0..10 {
588            let features = extractor.extract(sample_code_simple());
589            gate.evaluate(&features);
590        }
591
592        let stats = gate.stats();
593        let total_rate = stats.pass_rate() + stats.filter_rate();
594
595        assert!((total_rate - 1.0).abs() < 0.01);
596    }
597
598    #[test]
599    fn test_quality_gate_reset_stats() {
600        let mut gate = QualityGate::new(0.5);
601        let extractor = FeatureExtractor::new();
602
603        let features = extractor.extract(sample_code_simple());
604        gate.evaluate(&features);
605
606        assert!(gate.stats().total > 0);
607
608        gate.reset_stats();
609
610        assert_eq!(gate.stats().total, 0);
611    }
612
613    #[test]
614    fn test_quality_gate_threshold_adjustment() {
615        let mut gate = QualityGate::new(0.5);
616
617        gate.set_threshold(0.8);
618
619        assert!((gate.threshold() - 0.8).abs() < f32::EPSILON);
620    }
621
622    #[test]
623    fn test_quality_gate_custom_weights() {
624        let weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1];
625        let gate = QualityGate::with_weights(0.5, weights, 0.2);
626
627        assert!((gate.threshold() - 0.5).abs() < f32::EPSILON);
628    }
629
630    // ========== RED PHASE: Batch Filter Tests ==========
631
632    #[test]
633    fn test_filter_batch() {
634        let mut gate = QualityGate::new(0.4);
635
636        let codes = vec![
637            sample_generated(sample_code_simple(), 1),
638            sample_generated(sample_code_complex(), 4),
639        ];
640
641        let passing = gate.filter_batch(&codes);
642
643        // Complex code should pass
644        assert!(!passing.is_empty());
645        assert!(passing.iter().any(|c| c.code.contains("factorial")));
646    }
647
648    #[test]
649    fn test_filter_batch_empty() {
650        let mut gate = QualityGate::new(0.5);
651        let codes: Vec<GeneratedCode> = vec![];
652
653        let passing = gate.filter_batch(&codes);
654
655        assert!(passing.is_empty());
656    }
657
658    #[test]
659    fn test_filter_batch_all_pass() {
660        let mut gate = QualityGate::new(0.0); // Accept everything
661
662        let codes = vec![
663            sample_generated(sample_code_simple(), 1),
664            sample_generated(sample_code_complex(), 4),
665        ];
666
667        let passing = gate.filter_batch(&codes);
668
669        assert_eq!(passing.len(), 2);
670    }
671
672    #[test]
673    fn test_filter_batch_none_pass() {
674        let mut gate = QualityGate::new(1.0); // Reject everything
675
676        let codes = vec![
677            sample_generated(sample_code_simple(), 1),
678            sample_generated(sample_code_simple(), 2),
679        ];
680
681        let passing = gate.filter_batch(&codes);
682
683        assert!(passing.is_empty());
684    }
685
686    // ========== RED PHASE: Edge Cases ==========
687
688    #[test]
689    fn test_empty_code() {
690        let extractor = FeatureExtractor::new();
691        let features = extractor.extract("");
692
693        assert_eq!(features.loc, 0);
694        assert_eq!(features.complexity, 1); // Base complexity
695    }
696
697    #[test]
698    fn test_whitespace_only() {
699        let extractor = FeatureExtractor::new();
700        let features = extractor.extract("   \n\t\n   ");
701
702        assert_eq!(features.loc, 3);
703        assert!(!features.has_control_flow);
704    }
705
706    #[test]
707    fn test_quality_verdict_equality() {
708        assert_eq!(QualityVerdict::Pass, QualityVerdict::Pass);
709        assert_ne!(QualityVerdict::Pass, QualityVerdict::Filtered);
710    }
711
712    #[test]
713    fn test_quality_gate_stats_empty() {
714        let stats = QualityGateStats::default();
715
716        assert_eq!(stats.filter_rate(), 0.0);
717        assert_eq!(stats.pass_rate(), 0.0);
718    }
719
720    #[test]
721    fn test_features_default() {
722        let features = CodeQualityFeatures::default();
723
724        assert_eq!(features.loc, 0);
725        assert!(!features.has_control_flow);
726    }
727
728    #[test]
729    fn test_features_debug() {
730        let features = CodeQualityFeatures::default();
731        let debug = format!("{features:?}");
732        assert!(debug.contains("CodeQualityFeatures"));
733    }
734
735    #[test]
736    fn test_feature_extractor_debug() {
737        let extractor = FeatureExtractor::new();
738        let debug = format!("{extractor:?}");
739        assert!(debug.contains("FeatureExtractor"));
740    }
741
742    #[test]
743    fn test_quality_gate_debug() {
744        let gate = QualityGate::default();
745        let debug = format!("{gate:?}");
746        assert!(debug.contains("QualityGate"));
747    }
748}
749
750/// Property-based tests
751#[cfg(test)]
752mod proptests {
753    use super::*;
754    use proptest::prelude::*;
755
756    proptest! {
757        /// Score is always bounded [0, 1]
758        #[test]
759        fn prop_score_bounded(
760            loc in 0u32..1000,
761            depth in 0u32..20,
762            ids in 0u32..100,
763            complexity in 1u32..50,
764        ) {
765            let features = CodeQualityFeatures {
766                loc,
767                ast_depth: depth,
768                unique_identifiers: ids,
769                complexity,
770                ..Default::default()
771            };
772
773            let gate = QualityGate::default();
774            let score = gate.score(&features);
775
776            prop_assert!(score >= 0.0);
777            prop_assert!(score <= 1.0);
778        }
779
780        /// Higher complexity = higher score
781        #[test]
782        fn prop_complexity_increases_score(base_complexity in 1u32..10) {
783            let gate = QualityGate::default();
784
785            let low = CodeQualityFeatures {
786                complexity: base_complexity,
787                ..Default::default()
788            };
789
790            let high = CodeQualityFeatures {
791                complexity: base_complexity + 10,
792                ..Default::default()
793            };
794
795            let low_score = gate.score(&low);
796            let high_score = gate.score(&high);
797
798            prop_assert!(high_score >= low_score);
799        }
800
801        /// Control flow increases score
802        #[test]
803        fn prop_control_flow_increases_score(loc in 1u32..100) {
804            let gate = QualityGate::default();
805
806            let without = CodeQualityFeatures {
807                loc,
808                has_control_flow: false,
809                ..Default::default()
810            };
811
812            let with = CodeQualityFeatures {
813                loc,
814                has_control_flow: true,
815                ..Default::default()
816            };
817
818            let without_score = gate.score(&without);
819            let with_score = gate.score(&with);
820
821            prop_assert!(with_score >= without_score);
822        }
823
824        /// Pass rate + filter rate = 1.0
825        #[test]
826        fn prop_rates_sum_to_one(passed in 0usize..100, filtered in 0usize..100) {
827            let stats = QualityGateStats {
828                total: passed + filtered,
829                passed,
830                filtered,
831            };
832
833            if stats.total > 0 {
834                let sum = stats.pass_rate() + stats.filter_rate();
835                prop_assert!((sum - 1.0).abs() < 0.01);
836            }
837        }
838    }
839}