1use crate::generator::GeneratedCode;
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
22pub struct CodeQualityFeatures {
23 pub loc: u32,
25 pub ast_depth: u32,
27 pub unique_identifiers: u32,
29 pub complexity: u32,
31 pub has_control_flow: bool,
33 pub has_functions: bool,
35 pub has_error_handling: bool,
37 pub comment_ratio: f32,
39}
40
41impl CodeQualityFeatures {
42 #[must_use]
44 pub fn to_array(&self) -> [f32; 8] {
45 [
46 self.loc as f32,
47 self.ast_depth as f32,
48 self.unique_identifiers as f32,
49 self.complexity as f32,
50 if self.has_control_flow { 1.0 } else { 0.0 },
51 if self.has_functions { 1.0 } else { 0.0 },
52 if self.has_error_handling { 1.0 } else { 0.0 },
53 self.comment_ratio,
54 ]
55 }
56
57 #[must_use]
59 #[allow(clippy::cast_sign_loss)]
60 pub fn from_array(arr: [f32; 8]) -> Self {
61 Self {
62 loc: arr[0].max(0.0) as u32,
63 ast_depth: arr[1].max(0.0) as u32,
64 unique_identifiers: arr[2].max(0.0) as u32,
65 complexity: arr[3].max(0.0) as u32,
66 has_control_flow: arr[4] > 0.5,
67 has_functions: arr[5] > 0.5,
68 has_error_handling: arr[6] > 0.5,
69 comment_ratio: arr[7],
70 }
71 }
72}
73
74#[derive(Debug, Default)]
76pub struct FeatureExtractor;
77
78impl FeatureExtractor {
79 #[must_use]
81 pub fn new() -> Self {
82 Self
83 }
84
85 #[must_use]
87 pub fn extract(&self, code: &str) -> CodeQualityFeatures {
88 let lines: Vec<&str> = code.lines().collect();
89 let loc = lines.len() as u32;
90
91 let unique_identifiers = self.count_identifiers(code);
93
94 let complexity = self.estimate_complexity(code);
96
97 let has_control_flow = code.contains("if ")
99 || code.contains("for ")
100 || code.contains("while ")
101 || code.contains("match ");
102
103 let has_functions =
104 code.contains("def ") || code.contains("fn ") || code.contains("function ");
105
106 let has_error_handling =
107 code.contains("try:") || code.contains("except") || code.contains("catch");
108
109 let comment_lines = lines
111 .iter()
112 .filter(|l| l.trim().starts_with('#') || l.trim().starts_with("//"))
113 .count();
114 let comment_ratio = if loc > 0 {
115 comment_lines as f32 / loc as f32
116 } else {
117 0.0
118 };
119
120 CodeQualityFeatures {
121 loc,
122 ast_depth: 0, unique_identifiers,
124 complexity,
125 has_control_flow,
126 has_functions,
127 has_error_handling,
128 comment_ratio,
129 }
130 }
131
132 #[must_use]
134 pub fn extract_from_generated(&self, generated: &GeneratedCode) -> CodeQualityFeatures {
135 let mut features = self.extract(&generated.code);
136 features.ast_depth = generated.ast_depth as u32;
137 features
138 }
139
140 fn count_identifiers(&self, code: &str) -> u32 {
141 use std::collections::HashSet;
142
143 let mut identifiers = HashSet::new();
144 let mut current = String::new();
145
146 for ch in code.chars() {
147 if ch.is_alphanumeric() || ch == '_' {
148 current.push(ch);
149 } else {
150 if !current.is_empty()
151 && current
152 .chars()
153 .next()
154 .is_some_and(|c| c.is_alphabetic() || c == '_')
155 {
156 identifiers.insert(current.clone());
157 }
158 current.clear();
159 }
160 }
161
162 if !current.is_empty()
163 && current
164 .chars()
165 .next()
166 .is_some_and(|c| c.is_alphabetic() || c == '_')
167 {
168 identifiers.insert(current);
169 }
170
171 identifiers.len() as u32
172 }
173
174 fn estimate_complexity(&self, code: &str) -> u32 {
175 let mut complexity = 1u32; let keywords = ["if ", "elif ", "else:", "for ", "while ", "case ", "match "];
179 for kw in keywords {
180 complexity += code.matches(kw).count() as u32;
181 }
182
183 complexity += code.matches(" and ").count() as u32;
185 complexity += code.matches(" or ").count() as u32;
186 complexity += code.matches("&&").count() as u32;
187 complexity += code.matches("||").count() as u32;
188
189 complexity
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq)]
195pub enum QualityVerdict {
196 Pass,
198 Filtered,
200}
201
202#[derive(Debug)]
204pub struct QualityGate {
205 threshold: f32,
207 weights: [f32; 8],
209 bias: f32,
211 stats: QualityGateStats,
213}
214
215#[derive(Debug, Clone, Default)]
217pub struct QualityGateStats {
218 pub total: usize,
220 pub passed: usize,
222 pub filtered: usize,
224}
225
226impl QualityGateStats {
227 #[must_use]
229 pub fn filter_rate(&self) -> f32 {
230 if self.total == 0 {
231 0.0
232 } else {
233 self.filtered as f32 / self.total as f32
234 }
235 }
236
237 #[must_use]
239 pub fn pass_rate(&self) -> f32 {
240 if self.total == 0 {
241 0.0
242 } else {
243 self.passed as f32 / self.total as f32
244 }
245 }
246}
247
248impl Default for QualityGate {
249 fn default() -> Self {
250 Self::new(0.7)
251 }
252}
253
254impl QualityGate {
255 #[must_use]
257 pub fn new(threshold: f32) -> Self {
258 let weights = [
260 0.05, 0.15, 0.10, 0.20, 0.25, 0.15, 0.10, -0.05, ];
269
270 Self {
271 threshold,
272 weights,
273 bias: 0.3, stats: QualityGateStats::default(),
275 }
276 }
277
278 #[must_use]
280 pub fn with_weights(threshold: f32, weights: [f32; 8], bias: f32) -> Self {
281 Self {
282 threshold,
283 weights,
284 bias,
285 stats: QualityGateStats::default(),
286 }
287 }
288
289 pub fn evaluate(&mut self, features: &CodeQualityFeatures) -> QualityVerdict {
291 let score = self.score(features);
292 self.stats.total += 1;
293
294 if score >= self.threshold {
295 self.stats.passed += 1;
296 QualityVerdict::Pass
297 } else {
298 self.stats.filtered += 1;
299 QualityVerdict::Filtered
300 }
301 }
302
303 #[must_use]
305 pub fn score(&self, features: &CodeQualityFeatures) -> f32 {
306 let arr = features.to_array();
307 let mut score = self.bias;
308
309 for (i, &val) in arr.iter().enumerate() {
310 let normalized = match i {
312 0 => (val / 100.0).min(1.0), 1 => (val / 10.0).min(1.0), 2 => (val / 50.0).min(1.0), 3 => (val / 20.0).min(1.0), 4..=6 => val, 7 => val, _ => val,
319 };
320 score += self.weights[i] * normalized;
321 }
322
323 score.clamp(0.0, 1.0)
324 }
325
326 #[must_use]
328 pub fn stats(&self) -> &QualityGateStats {
329 &self.stats
330 }
331
332 pub fn reset_stats(&mut self) {
334 self.stats = QualityGateStats::default();
335 }
336
337 #[must_use]
339 pub fn threshold(&self) -> f32 {
340 self.threshold
341 }
342
343 pub fn set_threshold(&mut self, threshold: f32) {
345 self.threshold = threshold;
346 }
347
348 pub fn filter_batch<'a>(&mut self, codes: &'a [GeneratedCode]) -> Vec<&'a GeneratedCode> {
350 let extractor = FeatureExtractor::new();
351
352 codes
353 .iter()
354 .filter(|code| {
355 let features = extractor.extract_from_generated(code);
356 self.evaluate(&features) == QualityVerdict::Pass
357 })
358 .collect()
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::Language;
366
367 fn sample_code_simple() -> &'static str {
368 "x = 1"
369 }
370
371 fn sample_code_complex() -> &'static str {
372 r#"def factorial(n):
373 if n <= 1:
374 return 1
375 else:
376 return n * factorial(n - 1)
377
378def main():
379 for i in range(10):
380 print(factorial(i))
381"#
382 }
383
384 fn sample_generated(code: &str, depth: usize) -> GeneratedCode {
385 GeneratedCode {
386 code: code.to_string(),
387 language: Language::Python,
388 ast_depth: depth,
389 features: vec![],
390 }
391 }
392
393 #[test]
396 fn test_feature_extractor_simple() {
397 let extractor = FeatureExtractor::new();
398 let features = extractor.extract(sample_code_simple());
399
400 assert_eq!(features.loc, 1);
401 assert!(!features.has_control_flow);
402 assert!(!features.has_functions);
403 }
404
405 #[test]
406 fn test_feature_extractor_complex() {
407 let extractor = FeatureExtractor::new();
408 let features = extractor.extract(sample_code_complex());
409
410 assert!(features.loc > 5);
411 assert!(features.has_control_flow);
412 assert!(features.has_functions);
413 assert!(features.complexity > 1);
414 }
415
416 #[test]
417 fn test_feature_extractor_identifiers() {
418 let extractor = FeatureExtractor::new();
419 let features = extractor.extract("x = 1\ny = 2\nz = x + y");
420
421 assert!(features.unique_identifiers >= 3);
422 }
423
424 #[test]
425 fn test_feature_extractor_complexity() {
426 let extractor = FeatureExtractor::new();
427
428 let simple = extractor.extract("x = 1");
429 let complex = extractor.extract("if x:\n if y:\n pass");
430
431 assert!(complex.complexity > simple.complexity);
432 }
433
434 #[test]
435 fn test_feature_extractor_comment_ratio() {
436 let extractor = FeatureExtractor::new();
437
438 let no_comments = extractor.extract("x = 1\ny = 2");
439 let all_comments = extractor.extract("# comment\n# another");
440
441 assert!(no_comments.comment_ratio < 0.1);
442 assert!(all_comments.comment_ratio > 0.9);
443 }
444
445 #[test]
446 fn test_feature_extractor_error_handling() {
447 let extractor = FeatureExtractor::new();
448
449 let with_try = extractor.extract("try:\n x = 1\nexcept:\n pass");
450 let without_try = extractor.extract("x = 1");
451
452 assert!(with_try.has_error_handling);
453 assert!(!without_try.has_error_handling);
454 }
455
456 #[test]
457 fn test_feature_extractor_from_generated() {
458 let extractor = FeatureExtractor::new();
459 let generated = sample_generated("x = 1", 3);
460
461 let features = extractor.extract_from_generated(&generated);
462
463 assert_eq!(features.ast_depth, 3);
464 }
465
466 #[test]
469 fn test_features_to_array() {
470 let features = CodeQualityFeatures {
471 loc: 10,
472 ast_depth: 3,
473 unique_identifiers: 5,
474 complexity: 4,
475 has_control_flow: true,
476 has_functions: false,
477 has_error_handling: true,
478 comment_ratio: 0.2,
479 };
480
481 let arr = features.to_array();
482
483 assert_eq!(arr[0], 10.0);
484 assert_eq!(arr[1], 3.0);
485 assert_eq!(arr[4], 1.0); assert_eq!(arr[5], 0.0); }
488
489 #[test]
490 fn test_features_from_array() {
491 let arr = [10.0, 3.0, 5.0, 4.0, 1.0, 0.0, 1.0, 0.2];
492 let features = CodeQualityFeatures::from_array(arr);
493
494 assert_eq!(features.loc, 10);
495 assert!(features.has_control_flow);
496 assert!(!features.has_functions);
497 }
498
499 #[test]
500 fn test_features_roundtrip() {
501 let original = CodeQualityFeatures {
502 loc: 15,
503 ast_depth: 4,
504 unique_identifiers: 8,
505 complexity: 6,
506 has_control_flow: true,
507 has_functions: true,
508 has_error_handling: false,
509 comment_ratio: 0.1,
510 };
511
512 let arr = original.to_array();
513 let restored = CodeQualityFeatures::from_array(arr);
514
515 assert_eq!(original.loc, restored.loc);
516 assert_eq!(original.has_control_flow, restored.has_control_flow);
517 }
518
519 #[test]
522 fn test_quality_gate_default() {
523 let gate = QualityGate::default();
524 assert!((gate.threshold() - 0.7).abs() < f32::EPSILON);
525 }
526
527 #[test]
528 fn test_quality_gate_simple_code_filtered() {
529 let mut gate = QualityGate::new(0.5);
530 let extractor = FeatureExtractor::new();
531
532 let features = extractor.extract(sample_code_simple());
533 let verdict = gate.evaluate(&features);
534
535 assert_eq!(verdict, QualityVerdict::Filtered);
537 }
538
539 #[test]
540 fn test_quality_gate_complex_code_passes() {
541 let mut gate = QualityGate::new(0.5);
542 let extractor = FeatureExtractor::new();
543
544 let features = extractor.extract(sample_code_complex());
545 let verdict = gate.evaluate(&features);
546
547 assert_eq!(verdict, QualityVerdict::Pass);
549 }
550
551 #[test]
552 fn test_quality_gate_score_bounded() {
553 let gate = QualityGate::new(0.5);
554 let extractor = FeatureExtractor::new();
555
556 for code in &[sample_code_simple(), sample_code_complex(), ""] {
557 let features = extractor.extract(code);
558 let score = gate.score(&features);
559
560 assert!(score >= 0.0);
561 assert!(score <= 1.0);
562 }
563 }
564
565 #[test]
566 fn test_quality_gate_stats() {
567 let mut gate = QualityGate::new(0.5);
568 let extractor = FeatureExtractor::new();
569
570 let simple = extractor.extract(sample_code_simple());
571 let complex = extractor.extract(sample_code_complex());
572
573 gate.evaluate(&simple);
574 gate.evaluate(&complex);
575
576 let stats = gate.stats();
577 assert_eq!(stats.total, 2);
578 assert_eq!(stats.passed + stats.filtered, 2);
579 }
580
581 #[test]
582 fn test_quality_gate_stats_rates() {
583 let mut gate = QualityGate::new(0.5);
584 let extractor = FeatureExtractor::new();
585
586 for _ in 0..10 {
588 let features = extractor.extract(sample_code_simple());
589 gate.evaluate(&features);
590 }
591
592 let stats = gate.stats();
593 let total_rate = stats.pass_rate() + stats.filter_rate();
594
595 assert!((total_rate - 1.0).abs() < 0.01);
596 }
597
598 #[test]
599 fn test_quality_gate_reset_stats() {
600 let mut gate = QualityGate::new(0.5);
601 let extractor = FeatureExtractor::new();
602
603 let features = extractor.extract(sample_code_simple());
604 gate.evaluate(&features);
605
606 assert!(gate.stats().total > 0);
607
608 gate.reset_stats();
609
610 assert_eq!(gate.stats().total, 0);
611 }
612
613 #[test]
614 fn test_quality_gate_threshold_adjustment() {
615 let mut gate = QualityGate::new(0.5);
616
617 gate.set_threshold(0.8);
618
619 assert!((gate.threshold() - 0.8).abs() < f32::EPSILON);
620 }
621
622 #[test]
623 fn test_quality_gate_custom_weights() {
624 let weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1];
625 let gate = QualityGate::with_weights(0.5, weights, 0.2);
626
627 assert!((gate.threshold() - 0.5).abs() < f32::EPSILON);
628 }
629
630 #[test]
633 fn test_filter_batch() {
634 let mut gate = QualityGate::new(0.4);
635
636 let codes = vec![
637 sample_generated(sample_code_simple(), 1),
638 sample_generated(sample_code_complex(), 4),
639 ];
640
641 let passing = gate.filter_batch(&codes);
642
643 assert!(!passing.is_empty());
645 assert!(passing.iter().any(|c| c.code.contains("factorial")));
646 }
647
648 #[test]
649 fn test_filter_batch_empty() {
650 let mut gate = QualityGate::new(0.5);
651 let codes: Vec<GeneratedCode> = vec![];
652
653 let passing = gate.filter_batch(&codes);
654
655 assert!(passing.is_empty());
656 }
657
658 #[test]
659 fn test_filter_batch_all_pass() {
660 let mut gate = QualityGate::new(0.0); let codes = vec![
663 sample_generated(sample_code_simple(), 1),
664 sample_generated(sample_code_complex(), 4),
665 ];
666
667 let passing = gate.filter_batch(&codes);
668
669 assert_eq!(passing.len(), 2);
670 }
671
672 #[test]
673 fn test_filter_batch_none_pass() {
674 let mut gate = QualityGate::new(1.0); let codes = vec![
677 sample_generated(sample_code_simple(), 1),
678 sample_generated(sample_code_simple(), 2),
679 ];
680
681 let passing = gate.filter_batch(&codes);
682
683 assert!(passing.is_empty());
684 }
685
686 #[test]
689 fn test_empty_code() {
690 let extractor = FeatureExtractor::new();
691 let features = extractor.extract("");
692
693 assert_eq!(features.loc, 0);
694 assert_eq!(features.complexity, 1); }
696
697 #[test]
698 fn test_whitespace_only() {
699 let extractor = FeatureExtractor::new();
700 let features = extractor.extract(" \n\t\n ");
701
702 assert_eq!(features.loc, 3);
703 assert!(!features.has_control_flow);
704 }
705
706 #[test]
707 fn test_quality_verdict_equality() {
708 assert_eq!(QualityVerdict::Pass, QualityVerdict::Pass);
709 assert_ne!(QualityVerdict::Pass, QualityVerdict::Filtered);
710 }
711
712 #[test]
713 fn test_quality_gate_stats_empty() {
714 let stats = QualityGateStats::default();
715
716 assert_eq!(stats.filter_rate(), 0.0);
717 assert_eq!(stats.pass_rate(), 0.0);
718 }
719
720 #[test]
721 fn test_features_default() {
722 let features = CodeQualityFeatures::default();
723
724 assert_eq!(features.loc, 0);
725 assert!(!features.has_control_flow);
726 }
727
728 #[test]
729 fn test_features_debug() {
730 let features = CodeQualityFeatures::default();
731 let debug = format!("{features:?}");
732 assert!(debug.contains("CodeQualityFeatures"));
733 }
734
735 #[test]
736 fn test_feature_extractor_debug() {
737 let extractor = FeatureExtractor::new();
738 let debug = format!("{extractor:?}");
739 assert!(debug.contains("FeatureExtractor"));
740 }
741
742 #[test]
743 fn test_quality_gate_debug() {
744 let gate = QualityGate::default();
745 let debug = format!("{gate:?}");
746 assert!(debug.contains("QualityGate"));
747 }
748}
749
750#[cfg(test)]
752mod proptests {
753 use super::*;
754 use proptest::prelude::*;
755
756 proptest! {
757 #[test]
759 fn prop_score_bounded(
760 loc in 0u32..1000,
761 depth in 0u32..20,
762 ids in 0u32..100,
763 complexity in 1u32..50,
764 ) {
765 let features = CodeQualityFeatures {
766 loc,
767 ast_depth: depth,
768 unique_identifiers: ids,
769 complexity,
770 ..Default::default()
771 };
772
773 let gate = QualityGate::default();
774 let score = gate.score(&features);
775
776 prop_assert!(score >= 0.0);
777 prop_assert!(score <= 1.0);
778 }
779
780 #[test]
782 fn prop_complexity_increases_score(base_complexity in 1u32..10) {
783 let gate = QualityGate::default();
784
785 let low = CodeQualityFeatures {
786 complexity: base_complexity,
787 ..Default::default()
788 };
789
790 let high = CodeQualityFeatures {
791 complexity: base_complexity + 10,
792 ..Default::default()
793 };
794
795 let low_score = gate.score(&low);
796 let high_score = gate.score(&high);
797
798 prop_assert!(high_score >= low_score);
799 }
800
801 #[test]
803 fn prop_control_flow_increases_score(loc in 1u32..100) {
804 let gate = QualityGate::default();
805
806 let without = CodeQualityFeatures {
807 loc,
808 has_control_flow: false,
809 ..Default::default()
810 };
811
812 let with = CodeQualityFeatures {
813 loc,
814 has_control_flow: true,
815 ..Default::default()
816 };
817
818 let without_score = gate.score(&without);
819 let with_score = gate.score(&with);
820
821 prop_assert!(with_score >= without_score);
822 }
823
824 #[test]
826 fn prop_rates_sum_to_one(passed in 0usize..100, filtered in 0usize..100) {
827 let stats = QualityGateStats {
828 total: passed + filtered,
829 passed,
830 filtered,
831 };
832
833 if stats.total > 0 {
834 let sum = stats.pass_rate() + stats.filter_rate();
835 prop_assert!((sum - 1.0).abs() < 0.01);
836 }
837 }
838 }
839}