use crate::generator::GeneratedCode;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct CodeQualityFeatures {
pub loc: u32,
pub ast_depth: u32,
pub unique_identifiers: u32,
pub complexity: u32,
pub has_control_flow: bool,
pub has_functions: bool,
pub has_error_handling: bool,
pub comment_ratio: f32,
}
impl CodeQualityFeatures {
#[must_use]
pub fn to_array(&self) -> [f32; 8] {
[
self.loc as f32,
self.ast_depth as f32,
self.unique_identifiers as f32,
self.complexity as f32,
if self.has_control_flow { 1.0 } else { 0.0 },
if self.has_functions { 1.0 } else { 0.0 },
if self.has_error_handling { 1.0 } else { 0.0 },
self.comment_ratio,
]
}
#[must_use]
#[allow(clippy::cast_sign_loss)]
pub fn from_array(arr: [f32; 8]) -> Self {
Self {
loc: arr[0].max(0.0) as u32,
ast_depth: arr[1].max(0.0) as u32,
unique_identifiers: arr[2].max(0.0) as u32,
complexity: arr[3].max(0.0) as u32,
has_control_flow: arr[4] > 0.5,
has_functions: arr[5] > 0.5,
has_error_handling: arr[6] > 0.5,
comment_ratio: arr[7],
}
}
}
#[derive(Debug, Default)]
pub struct FeatureExtractor;
impl FeatureExtractor {
#[must_use]
pub fn new() -> Self {
Self
}
#[must_use]
pub fn extract(&self, code: &str) -> CodeQualityFeatures {
let lines: Vec<&str> = code.lines().collect();
let loc = lines.len() as u32;
let unique_identifiers = self.count_identifiers(code);
let complexity = self.estimate_complexity(code);
let has_control_flow = code.contains("if ")
|| code.contains("for ")
|| code.contains("while ")
|| code.contains("match ");
let has_functions =
code.contains("def ") || code.contains("fn ") || code.contains("function ");
let has_error_handling =
code.contains("try:") || code.contains("except") || code.contains("catch");
let comment_lines = lines
.iter()
.filter(|l| l.trim().starts_with('#') || l.trim().starts_with("//"))
.count();
let comment_ratio = if loc > 0 {
comment_lines as f32 / loc as f32
} else {
0.0
};
CodeQualityFeatures {
loc,
ast_depth: 0, unique_identifiers,
complexity,
has_control_flow,
has_functions,
has_error_handling,
comment_ratio,
}
}
#[must_use]
pub fn extract_from_generated(&self, generated: &GeneratedCode) -> CodeQualityFeatures {
let mut features = self.extract(&generated.code);
features.ast_depth = generated.ast_depth as u32;
features
}
fn count_identifiers(&self, code: &str) -> u32 {
use std::collections::HashSet;
let mut identifiers = HashSet::new();
let mut current = String::new();
for ch in code.chars() {
if ch.is_alphanumeric() || ch == '_' {
current.push(ch);
} else {
if !current.is_empty()
&& current
.chars()
.next()
.is_some_and(|c| c.is_alphabetic() || c == '_')
{
identifiers.insert(current.clone());
}
current.clear();
}
}
if !current.is_empty()
&& current
.chars()
.next()
.is_some_and(|c| c.is_alphabetic() || c == '_')
{
identifiers.insert(current);
}
identifiers.len() as u32
}
fn estimate_complexity(&self, code: &str) -> u32 {
let mut complexity = 1u32;
let keywords = ["if ", "elif ", "else:", "for ", "while ", "case ", "match "];
for kw in keywords {
complexity += code.matches(kw).count() as u32;
}
complexity += code.matches(" and ").count() as u32;
complexity += code.matches(" or ").count() as u32;
complexity += code.matches("&&").count() as u32;
complexity += code.matches("||").count() as u32;
complexity
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QualityVerdict {
Pass,
Filtered,
}
#[derive(Debug)]
pub struct QualityGate {
threshold: f32,
weights: [f32; 8],
bias: f32,
stats: QualityGateStats,
}
#[derive(Debug, Clone, Default)]
pub struct QualityGateStats {
pub total: usize,
pub passed: usize,
pub filtered: usize,
}
impl QualityGateStats {
#[must_use]
pub fn filter_rate(&self) -> f32 {
if self.total == 0 {
0.0
} else {
self.filtered as f32 / self.total as f32
}
}
#[must_use]
pub fn pass_rate(&self) -> f32 {
if self.total == 0 {
0.0
} else {
self.passed as f32 / self.total as f32
}
}
}
impl Default for QualityGate {
fn default() -> Self {
Self::new(0.7)
}
}
impl QualityGate {
#[must_use]
pub fn new(threshold: f32) -> Self {
let weights = [
0.05, 0.15, 0.10, 0.20, 0.25, 0.15, 0.10, -0.05, ];
Self {
threshold,
weights,
bias: 0.3, stats: QualityGateStats::default(),
}
}
#[must_use]
pub fn with_weights(threshold: f32, weights: [f32; 8], bias: f32) -> Self {
Self {
threshold,
weights,
bias,
stats: QualityGateStats::default(),
}
}
pub fn evaluate(&mut self, features: &CodeQualityFeatures) -> QualityVerdict {
let score = self.score(features);
self.stats.total += 1;
if score >= self.threshold {
self.stats.passed += 1;
QualityVerdict::Pass
} else {
self.stats.filtered += 1;
QualityVerdict::Filtered
}
}
#[must_use]
pub fn score(&self, features: &CodeQualityFeatures) -> f32 {
let arr = features.to_array();
let mut score = self.bias;
for (i, &val) in arr.iter().enumerate() {
let normalized = match i {
0 => (val / 100.0).min(1.0), 1 => (val / 10.0).min(1.0), 2 => (val / 50.0).min(1.0), 3 => (val / 20.0).min(1.0), 4..=6 => val, 7 => val, _ => val,
};
score += self.weights[i] * normalized;
}
score.clamp(0.0, 1.0)
}
#[must_use]
pub fn stats(&self) -> &QualityGateStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = QualityGateStats::default();
}
#[must_use]
pub fn threshold(&self) -> f32 {
self.threshold
}
pub fn set_threshold(&mut self, threshold: f32) {
self.threshold = threshold;
}
pub fn filter_batch<'a>(&mut self, codes: &'a [GeneratedCode]) -> Vec<&'a GeneratedCode> {
let extractor = FeatureExtractor::new();
codes
.iter()
.filter(|code| {
let features = extractor.extract_from_generated(code);
self.evaluate(&features) == QualityVerdict::Pass
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Language;
fn sample_code_simple() -> &'static str {
"x = 1"
}
fn sample_code_complex() -> &'static str {
r#"def factorial(n):
if n <= 1:
return 1
else:
return n * factorial(n - 1)
def main():
for i in range(10):
print(factorial(i))
"#
}
fn sample_generated(code: &str, depth: usize) -> GeneratedCode {
GeneratedCode {
code: code.to_string(),
language: Language::Python,
ast_depth: depth,
features: vec![],
}
}
#[test]
fn test_feature_extractor_simple() {
let extractor = FeatureExtractor::new();
let features = extractor.extract(sample_code_simple());
assert_eq!(features.loc, 1);
assert!(!features.has_control_flow);
assert!(!features.has_functions);
}
#[test]
fn test_feature_extractor_complex() {
let extractor = FeatureExtractor::new();
let features = extractor.extract(sample_code_complex());
assert!(features.loc > 5);
assert!(features.has_control_flow);
assert!(features.has_functions);
assert!(features.complexity > 1);
}
#[test]
fn test_feature_extractor_identifiers() {
let extractor = FeatureExtractor::new();
let features = extractor.extract("x = 1\ny = 2\nz = x + y");
assert!(features.unique_identifiers >= 3);
}
#[test]
fn test_feature_extractor_complexity() {
let extractor = FeatureExtractor::new();
let simple = extractor.extract("x = 1");
let complex = extractor.extract("if x:\n if y:\n pass");
assert!(complex.complexity > simple.complexity);
}
#[test]
fn test_feature_extractor_comment_ratio() {
let extractor = FeatureExtractor::new();
let no_comments = extractor.extract("x = 1\ny = 2");
let all_comments = extractor.extract("# comment\n# another");
assert!(no_comments.comment_ratio < 0.1);
assert!(all_comments.comment_ratio > 0.9);
}
#[test]
fn test_feature_extractor_error_handling() {
let extractor = FeatureExtractor::new();
let with_try = extractor.extract("try:\n x = 1\nexcept:\n pass");
let without_try = extractor.extract("x = 1");
assert!(with_try.has_error_handling);
assert!(!without_try.has_error_handling);
}
#[test]
fn test_feature_extractor_from_generated() {
let extractor = FeatureExtractor::new();
let generated = sample_generated("x = 1", 3);
let features = extractor.extract_from_generated(&generated);
assert_eq!(features.ast_depth, 3);
}
#[test]
fn test_features_to_array() {
let features = CodeQualityFeatures {
loc: 10,
ast_depth: 3,
unique_identifiers: 5,
complexity: 4,
has_control_flow: true,
has_functions: false,
has_error_handling: true,
comment_ratio: 0.2,
};
let arr = features.to_array();
assert_eq!(arr[0], 10.0);
assert_eq!(arr[1], 3.0);
assert_eq!(arr[4], 1.0); assert_eq!(arr[5], 0.0); }
#[test]
fn test_features_from_array() {
let arr = [10.0, 3.0, 5.0, 4.0, 1.0, 0.0, 1.0, 0.2];
let features = CodeQualityFeatures::from_array(arr);
assert_eq!(features.loc, 10);
assert!(features.has_control_flow);
assert!(!features.has_functions);
}
#[test]
fn test_features_roundtrip() {
let original = CodeQualityFeatures {
loc: 15,
ast_depth: 4,
unique_identifiers: 8,
complexity: 6,
has_control_flow: true,
has_functions: true,
has_error_handling: false,
comment_ratio: 0.1,
};
let arr = original.to_array();
let restored = CodeQualityFeatures::from_array(arr);
assert_eq!(original.loc, restored.loc);
assert_eq!(original.has_control_flow, restored.has_control_flow);
}
#[test]
fn test_quality_gate_default() {
let gate = QualityGate::default();
assert!((gate.threshold() - 0.7).abs() < f32::EPSILON);
}
#[test]
fn test_quality_gate_simple_code_filtered() {
let mut gate = QualityGate::new(0.5);
let extractor = FeatureExtractor::new();
let features = extractor.extract(sample_code_simple());
let verdict = gate.evaluate(&features);
assert_eq!(verdict, QualityVerdict::Filtered);
}
#[test]
fn test_quality_gate_complex_code_passes() {
let mut gate = QualityGate::new(0.5);
let extractor = FeatureExtractor::new();
let features = extractor.extract(sample_code_complex());
let verdict = gate.evaluate(&features);
assert_eq!(verdict, QualityVerdict::Pass);
}
#[test]
fn test_quality_gate_score_bounded() {
let gate = QualityGate::new(0.5);
let extractor = FeatureExtractor::new();
for code in &[sample_code_simple(), sample_code_complex(), ""] {
let features = extractor.extract(code);
let score = gate.score(&features);
assert!(score >= 0.0);
assert!(score <= 1.0);
}
}
#[test]
fn test_quality_gate_stats() {
let mut gate = QualityGate::new(0.5);
let extractor = FeatureExtractor::new();
let simple = extractor.extract(sample_code_simple());
let complex = extractor.extract(sample_code_complex());
gate.evaluate(&simple);
gate.evaluate(&complex);
let stats = gate.stats();
assert_eq!(stats.total, 2);
assert_eq!(stats.passed + stats.filtered, 2);
}
#[test]
fn test_quality_gate_stats_rates() {
let mut gate = QualityGate::new(0.5);
let extractor = FeatureExtractor::new();
for _ in 0..10 {
let features = extractor.extract(sample_code_simple());
gate.evaluate(&features);
}
let stats = gate.stats();
let total_rate = stats.pass_rate() + stats.filter_rate();
assert!((total_rate - 1.0).abs() < 0.01);
}
#[test]
fn test_quality_gate_reset_stats() {
let mut gate = QualityGate::new(0.5);
let extractor = FeatureExtractor::new();
let features = extractor.extract(sample_code_simple());
gate.evaluate(&features);
assert!(gate.stats().total > 0);
gate.reset_stats();
assert_eq!(gate.stats().total, 0);
}
#[test]
fn test_quality_gate_threshold_adjustment() {
let mut gate = QualityGate::new(0.5);
gate.set_threshold(0.8);
assert!((gate.threshold() - 0.8).abs() < f32::EPSILON);
}
#[test]
fn test_quality_gate_custom_weights() {
let weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1];
let gate = QualityGate::with_weights(0.5, weights, 0.2);
assert!((gate.threshold() - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_filter_batch() {
let mut gate = QualityGate::new(0.4);
let codes = vec![
sample_generated(sample_code_simple(), 1),
sample_generated(sample_code_complex(), 4),
];
let passing = gate.filter_batch(&codes);
assert!(!passing.is_empty());
assert!(passing.iter().any(|c| c.code.contains("factorial")));
}
#[test]
fn test_filter_batch_empty() {
let mut gate = QualityGate::new(0.5);
let codes: Vec<GeneratedCode> = vec![];
let passing = gate.filter_batch(&codes);
assert!(passing.is_empty());
}
#[test]
fn test_filter_batch_all_pass() {
let mut gate = QualityGate::new(0.0);
let codes = vec![
sample_generated(sample_code_simple(), 1),
sample_generated(sample_code_complex(), 4),
];
let passing = gate.filter_batch(&codes);
assert_eq!(passing.len(), 2);
}
#[test]
fn test_filter_batch_none_pass() {
let mut gate = QualityGate::new(1.0);
let codes = vec![
sample_generated(sample_code_simple(), 1),
sample_generated(sample_code_simple(), 2),
];
let passing = gate.filter_batch(&codes);
assert!(passing.is_empty());
}
#[test]
fn test_empty_code() {
let extractor = FeatureExtractor::new();
let features = extractor.extract("");
assert_eq!(features.loc, 0);
assert_eq!(features.complexity, 1); }
#[test]
fn test_whitespace_only() {
let extractor = FeatureExtractor::new();
let features = extractor.extract(" \n\t\n ");
assert_eq!(features.loc, 3);
assert!(!features.has_control_flow);
}
#[test]
fn test_quality_verdict_equality() {
assert_eq!(QualityVerdict::Pass, QualityVerdict::Pass);
assert_ne!(QualityVerdict::Pass, QualityVerdict::Filtered);
}
#[test]
fn test_quality_gate_stats_empty() {
let stats = QualityGateStats::default();
assert_eq!(stats.filter_rate(), 0.0);
assert_eq!(stats.pass_rate(), 0.0);
}
#[test]
fn test_features_default() {
let features = CodeQualityFeatures::default();
assert_eq!(features.loc, 0);
assert!(!features.has_control_flow);
}
#[test]
fn test_features_debug() {
let features = CodeQualityFeatures::default();
let debug = format!("{features:?}");
assert!(debug.contains("CodeQualityFeatures"));
}
#[test]
fn test_feature_extractor_debug() {
let extractor = FeatureExtractor::new();
let debug = format!("{extractor:?}");
assert!(debug.contains("FeatureExtractor"));
}
#[test]
fn test_quality_gate_debug() {
let gate = QualityGate::default();
let debug = format!("{gate:?}");
assert!(debug.contains("QualityGate"));
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_score_bounded(
loc in 0u32..1000,
depth in 0u32..20,
ids in 0u32..100,
complexity in 1u32..50,
) {
let features = CodeQualityFeatures {
loc,
ast_depth: depth,
unique_identifiers: ids,
complexity,
..Default::default()
};
let gate = QualityGate::default();
let score = gate.score(&features);
prop_assert!(score >= 0.0);
prop_assert!(score <= 1.0);
}
#[test]
fn prop_complexity_increases_score(base_complexity in 1u32..10) {
let gate = QualityGate::default();
let low = CodeQualityFeatures {
complexity: base_complexity,
..Default::default()
};
let high = CodeQualityFeatures {
complexity: base_complexity + 10,
..Default::default()
};
let low_score = gate.score(&low);
let high_score = gate.score(&high);
prop_assert!(high_score >= low_score);
}
#[test]
fn prop_control_flow_increases_score(loc in 1u32..100) {
let gate = QualityGate::default();
let without = CodeQualityFeatures {
loc,
has_control_flow: false,
..Default::default()
};
let with = CodeQualityFeatures {
loc,
has_control_flow: true,
..Default::default()
};
let without_score = gate.score(&without);
let with_score = gate.score(&with);
prop_assert!(with_score >= without_score);
}
#[test]
fn prop_rates_sum_to_one(passed in 0usize..100, filtered in 0usize..100) {
let stats = QualityGateStats {
total: passed + filtered,
passed,
filtered,
};
if stats.total > 0 {
let sum = stats.pass_rate() + stats.filter_rate();
prop_assert!((sum - 1.0).abs() < 0.01);
}
}
}
}