1use crate::data::CodeFeatures;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
14pub struct RLTestPrioritizer {
15 success_counts: HashMap<FeatureSignature, f64>,
17 failure_counts: HashMap<FeatureSignature, f64>,
19 exploration_rate: f64,
21 total_tests: usize,
23}
24
25#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
27struct FeatureSignature {
28 depth_bucket: u8,
30 operator_bucket: u8,
32 complexity_bucket: u8,
34 uses_edge_values: bool,
36}
37
38impl FeatureSignature {
39 fn from_features(features: &CodeFeatures) -> Self {
40 Self {
41 depth_bucket: match features.ast_depth {
42 0..=5 => 0,
43 6..=10 => 1,
44 _ => 2,
45 },
46 operator_bucket: match features.num_operators {
47 0..=10 => 0,
48 11..=30 => 1,
49 _ => 2,
50 },
51 complexity_bucket: if features.cyclomatic_complexity <= 5.0 {
52 0
53 } else if features.cyclomatic_complexity <= 15.0 {
54 1
55 } else {
56 2
57 },
58 uses_edge_values: features.uses_edge_values,
59 }
60 }
61}
62
63impl RLTestPrioritizer {
64 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 success_counts: HashMap::new(),
69 failure_counts: HashMap::new(),
70 exploration_rate: 0.1,
71 total_tests: 0,
72 }
73 }
74
75 #[must_use]
77 pub fn with_exploration_rate(mut self, rate: f64) -> Self {
78 self.exploration_rate = rate.clamp(0.0, 1.0);
79 self
80 }
81
82 pub fn prioritize(&self, features: &[CodeFeatures]) -> Vec<usize> {
86 let mut rng = rand::rng();
87
88 let mut scored: Vec<(usize, f64)> = features
89 .iter()
90 .enumerate()
91 .map(|(i, f)| {
92 let sig = FeatureSignature::from_features(f);
93 let score = self.sample_failure_probability(&sig, &mut rng);
94 (i, score)
95 })
96 .collect();
97
98 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
100
101 scored.into_iter().map(|(i, _)| i).collect()
102 }
103
104 fn sample_failure_probability<R: rand::Rng>(&self, sig: &FeatureSignature, rng: &mut R) -> f64 {
106 use rand_distr::{Beta, Distribution};
107
108 let alpha = self.failure_counts.get(sig).copied().unwrap_or(0.0) + 1.0;
110 let beta = self.success_counts.get(sig).copied().unwrap_or(0.0) + 1.0;
111
112 #[allow(clippy::unwrap_used)]
115 let beta_dist = Beta::new(alpha, beta).unwrap_or_else(|_| Beta::new(1.0, 1.0).unwrap());
116 beta_dist.sample(rng)
117 }
118
119 pub fn update_feedback(&mut self, features: &CodeFeatures, revealed_bug: bool) {
126 let sig = FeatureSignature::from_features(features);
127
128 if revealed_bug {
129 *self.failure_counts.entry(sig).or_insert(0.0) += 1.0;
130 } else {
131 *self.success_counts.entry(sig).or_insert(0.0) += 1.0;
132 }
133
134 self.total_tests += 1;
135 }
136
137 #[must_use]
139 pub fn failure_rate(&self, features: &CodeFeatures) -> f64 {
140 let sig = FeatureSignature::from_features(features);
141 let failures = self.failure_counts.get(&sig).copied().unwrap_or(0.0);
142 let successes = self.success_counts.get(&sig).copied().unwrap_or(0.0);
143 let total = failures + successes;
144
145 if total == 0.0 {
146 0.5 } else {
148 failures / total
149 }
150 }
151
152 #[must_use]
154 pub const fn total_tests(&self) -> usize {
155 self.total_tests
156 }
157
158 #[must_use]
160 pub fn num_signatures(&self) -> usize {
161 let mut sigs = self.success_counts.keys().collect::<Vec<_>>();
162 sigs.extend(self.failure_counts.keys());
163 sigs.sort_unstable();
164 sigs.dedup();
165 sigs.len()
166 }
167}
168
169impl Default for RLTestPrioritizer {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_rl_prioritizer_initial() {
181 let prioritizer = RLTestPrioritizer::new();
182 assert_eq!(prioritizer.total_tests(), 0);
183 assert_eq!(prioritizer.num_signatures(), 0);
184 }
185
186 #[test]
187 fn test_rl_prioritizer_feedback() {
188 let mut prioritizer = RLTestPrioritizer::new();
189
190 let features = CodeFeatures {
191 ast_depth: 5,
192 num_operators: 10,
193 num_control_flow: 2,
194 cyclomatic_complexity: 3.0,
195 uses_edge_values: false,
196 ..Default::default()
197 };
198
199 prioritizer.update_feedback(&features, true);
201 assert_eq!(prioritizer.total_tests(), 1);
202
203 let rate = prioritizer.failure_rate(&features);
205 assert!(rate > 0.0);
206 }
207
208 #[test]
209 fn test_rl_prioritizer_learning() {
210 let mut prioritizer = RLTestPrioritizer::new();
211
212 let buggy_features = CodeFeatures {
213 ast_depth: 10,
214 num_operators: 50,
215 num_control_flow: 10,
216 cyclomatic_complexity: 15.0,
217 uses_edge_values: true,
218 ..Default::default()
219 };
220
221 let clean_features = CodeFeatures {
222 ast_depth: 3,
223 num_operators: 5,
224 num_control_flow: 1,
225 cyclomatic_complexity: 2.0,
226 uses_edge_values: false,
227 ..Default::default()
228 };
229
230 for _ in 0..10 {
232 prioritizer.update_feedback(&buggy_features, true);
233 prioritizer.update_feedback(&clean_features, false);
234 }
235
236 let buggy_rate = prioritizer.failure_rate(&buggy_features);
238 let clean_rate = prioritizer.failure_rate(&clean_features);
239 assert!(buggy_rate > clean_rate);
240 }
241
242 #[test]
243 fn test_rl_prioritizer_ordering() {
244 let mut prioritizer = RLTestPrioritizer::new();
245
246 let features = vec![
247 CodeFeatures {
248 ast_depth: 3,
249 num_operators: 5,
250 cyclomatic_complexity: 2.0,
251 uses_edge_values: false,
252 ..Default::default()
253 },
254 CodeFeatures {
255 ast_depth: 10,
256 num_operators: 50,
257 cyclomatic_complexity: 15.0,
258 uses_edge_values: true,
259 ..Default::default()
260 },
261 ];
262
263 for _ in 0..5 {
265 prioritizer.update_feedback(&features[1], true);
266 prioritizer.update_feedback(&features[0], false);
267 }
268
269 let order = prioritizer.prioritize(&features);
271 assert_eq!(order.len(), 2);
274 assert!(order.contains(&0));
275 assert!(order.contains(&1));
276 }
277
278 #[test]
279 fn test_exploration_rate() {
280 let prioritizer = RLTestPrioritizer::new().with_exploration_rate(0.2);
281 assert!((prioritizer.exploration_rate - 0.2).abs() < f64::EPSILON);
282 }
283
284 #[test]
285 fn test_feature_signature_buckets() {
286 let features = CodeFeatures {
287 ast_depth: 7,
288 num_operators: 15,
289 cyclomatic_complexity: 8.0,
290 uses_edge_values: true,
291 ..Default::default()
292 };
293
294 let sig = FeatureSignature::from_features(&features);
295 assert_eq!(sig.depth_bucket, 1); assert_eq!(sig.operator_bucket, 1); assert_eq!(sig.complexity_bucket, 1); assert!(sig.uses_edge_values);
299 }
300}