use crate::data::CodeFeatures;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RLTestPrioritizer {
success_counts: HashMap<FeatureSignature, f64>,
failure_counts: HashMap<FeatureSignature, f64>,
exploration_rate: f64,
total_tests: usize,
}
#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
struct FeatureSignature {
depth_bucket: u8,
operator_bucket: u8,
complexity_bucket: u8,
uses_edge_values: bool,
}
impl FeatureSignature {
fn from_features(features: &CodeFeatures) -> Self {
Self {
depth_bucket: match features.ast_depth {
0..=5 => 0,
6..=10 => 1,
_ => 2,
},
operator_bucket: match features.num_operators {
0..=10 => 0,
11..=30 => 1,
_ => 2,
},
complexity_bucket: if features.cyclomatic_complexity <= 5.0 {
0
} else if features.cyclomatic_complexity <= 15.0 {
1
} else {
2
},
uses_edge_values: features.uses_edge_values,
}
}
}
impl RLTestPrioritizer {
#[must_use]
pub fn new() -> Self {
Self {
success_counts: HashMap::new(),
failure_counts: HashMap::new(),
exploration_rate: 0.1,
total_tests: 0,
}
}
#[must_use]
pub fn with_exploration_rate(mut self, rate: f64) -> Self {
self.exploration_rate = rate.clamp(0.0, 1.0);
self
}
pub fn prioritize(&self, features: &[CodeFeatures]) -> Vec<usize> {
let mut rng = rand::rng();
let mut scored: Vec<(usize, f64)> = features
.iter()
.enumerate()
.map(|(i, f)| {
let sig = FeatureSignature::from_features(f);
let score = self.sample_failure_probability(&sig, &mut rng);
(i, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().map(|(i, _)| i).collect()
}
fn sample_failure_probability<R: rand::Rng>(&self, sig: &FeatureSignature, rng: &mut R) -> f64 {
use rand_distr::{Beta, Distribution};
let alpha = self.failure_counts.get(sig).copied().unwrap_or(0.0) + 1.0;
let beta = self.success_counts.get(sig).copied().unwrap_or(0.0) + 1.0;
#[allow(clippy::unwrap_used)]
let beta_dist = Beta::new(alpha, beta).unwrap_or_else(|_| Beta::new(1.0, 1.0).unwrap());
beta_dist.sample(rng)
}
pub fn update_feedback(&mut self, features: &CodeFeatures, revealed_bug: bool) {
let sig = FeatureSignature::from_features(features);
if revealed_bug {
*self.failure_counts.entry(sig).or_insert(0.0) += 1.0;
} else {
*self.success_counts.entry(sig).or_insert(0.0) += 1.0;
}
self.total_tests += 1;
}
#[must_use]
pub fn failure_rate(&self, features: &CodeFeatures) -> f64 {
let sig = FeatureSignature::from_features(features);
let failures = self.failure_counts.get(&sig).copied().unwrap_or(0.0);
let successes = self.success_counts.get(&sig).copied().unwrap_or(0.0);
let total = failures + successes;
if total == 0.0 {
0.5 } else {
failures / total
}
}
#[must_use]
pub const fn total_tests(&self) -> usize {
self.total_tests
}
#[must_use]
pub fn num_signatures(&self) -> usize {
let mut sigs = self.success_counts.keys().collect::<Vec<_>>();
sigs.extend(self.failure_counts.keys());
sigs.sort_unstable();
sigs.dedup();
sigs.len()
}
}
impl Default for RLTestPrioritizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rl_prioritizer_initial() {
let prioritizer = RLTestPrioritizer::new();
assert_eq!(prioritizer.total_tests(), 0);
assert_eq!(prioritizer.num_signatures(), 0);
}
#[test]
fn test_rl_prioritizer_feedback() {
let mut prioritizer = RLTestPrioritizer::new();
let features = CodeFeatures {
ast_depth: 5,
num_operators: 10,
num_control_flow: 2,
cyclomatic_complexity: 3.0,
uses_edge_values: false,
..Default::default()
};
prioritizer.update_feedback(&features, true);
assert_eq!(prioritizer.total_tests(), 1);
let rate = prioritizer.failure_rate(&features);
assert!(rate > 0.0);
}
#[test]
fn test_rl_prioritizer_learning() {
let mut prioritizer = RLTestPrioritizer::new();
let buggy_features = CodeFeatures {
ast_depth: 10,
num_operators: 50,
num_control_flow: 10,
cyclomatic_complexity: 15.0,
uses_edge_values: true,
..Default::default()
};
let clean_features = CodeFeatures {
ast_depth: 3,
num_operators: 5,
num_control_flow: 1,
cyclomatic_complexity: 2.0,
uses_edge_values: false,
..Default::default()
};
for _ in 0..10 {
prioritizer.update_feedback(&buggy_features, true);
prioritizer.update_feedback(&clean_features, false);
}
let buggy_rate = prioritizer.failure_rate(&buggy_features);
let clean_rate = prioritizer.failure_rate(&clean_features);
assert!(buggy_rate > clean_rate);
}
#[test]
fn test_rl_prioritizer_ordering() {
let mut prioritizer = RLTestPrioritizer::new();
let features = vec![
CodeFeatures {
ast_depth: 3,
num_operators: 5,
cyclomatic_complexity: 2.0,
uses_edge_values: false,
..Default::default()
},
CodeFeatures {
ast_depth: 10,
num_operators: 50,
cyclomatic_complexity: 15.0,
uses_edge_values: true,
..Default::default()
},
];
for _ in 0..5 {
prioritizer.update_feedback(&features[1], true);
prioritizer.update_feedback(&features[0], false);
}
let order = prioritizer.prioritize(&features);
assert_eq!(order.len(), 2);
assert!(order.contains(&0));
assert!(order.contains(&1));
}
#[test]
fn test_exploration_rate() {
let prioritizer = RLTestPrioritizer::new().with_exploration_rate(0.2);
assert!((prioritizer.exploration_rate - 0.2).abs() < f64::EPSILON);
}
#[test]
fn test_feature_signature_buckets() {
let features = CodeFeatures {
ast_depth: 7,
num_operators: 15,
cyclomatic_complexity: 8.0,
uses_edge_values: true,
..Default::default()
};
let sig = FeatureSignature::from_features(&features);
assert_eq!(sig.depth_bucket, 1); assert_eq!(sig.operator_bucket, 1); assert_eq!(sig.complexity_bucket, 1); assert!(sig.uses_edge_values);
}
}