Skip to main content

verificar/ml/
rl_prioritizer.rs

1//! Reinforcement learning-based test prioritizer
2//!
3//! Implements test case prioritization using Thompson Sampling (a contextual bandit approach).
4//! Based on Spieker et al. (2017): "Reinforcement Learning for Automatic Test Case Prioritization"
5
6use crate::data::CodeFeatures;
7use std::collections::HashMap;
8
9/// RL-based test prioritizer using Thompson Sampling
10///
11/// Learns optimal test prioritization policy by tracking success/failure rates
12/// for different feature combinations.
13#[derive(Debug, Clone)]
14pub struct RLTestPrioritizer {
15    /// Success counts (alpha) for each feature signature
16    success_counts: HashMap<FeatureSignature, f64>,
17    /// Failure counts (beta) for each feature signature
18    failure_counts: HashMap<FeatureSignature, f64>,
19    /// Exploration parameter (higher = more exploration)
20    exploration_rate: f64,
21    /// Total tests executed
22    total_tests: usize,
23}
24
25/// Compact feature signature for hash table lookups
26#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
27struct FeatureSignature {
28    /// Bucketed AST depth (0-5, 6-10, 11+)
29    depth_bucket: u8,
30    /// Bucketed operator count (0-10, 11-30, 31+)
31    operator_bucket: u8,
32    /// Bucketed cyclomatic complexity (0-5, 6-15, 16+)
33    complexity_bucket: u8,
34    /// Whether code uses edge values
35    uses_edge_values: bool,
36}
37
38impl FeatureSignature {
39    fn from_features(features: &CodeFeatures) -> Self {
40        Self {
41            depth_bucket: match features.ast_depth {
42                0..=5 => 0,
43                6..=10 => 1,
44                _ => 2,
45            },
46            operator_bucket: match features.num_operators {
47                0..=10 => 0,
48                11..=30 => 1,
49                _ => 2,
50            },
51            complexity_bucket: if features.cyclomatic_complexity <= 5.0 {
52                0
53            } else if features.cyclomatic_complexity <= 15.0 {
54                1
55            } else {
56                2
57            },
58            uses_edge_values: features.uses_edge_values,
59        }
60    }
61}
62
63impl RLTestPrioritizer {
64    /// Create a new RL test prioritizer
65    #[must_use]
66    pub fn new() -> Self {
67        Self {
68            success_counts: HashMap::new(),
69            failure_counts: HashMap::new(),
70            exploration_rate: 0.1,
71            total_tests: 0,
72        }
73    }
74
75    /// Create prioritizer with custom exploration rate
76    #[must_use]
77    pub fn with_exploration_rate(mut self, rate: f64) -> Self {
78        self.exploration_rate = rate.clamp(0.0, 1.0);
79        self
80    }
81
82    /// Prioritize test cases using Thompson Sampling
83    ///
84    /// Returns indices sorted by priority (highest failure probability first)
85    pub fn prioritize(&self, features: &[CodeFeatures]) -> Vec<usize> {
86        let mut rng = rand::rng();
87
88        let mut scored: Vec<(usize, f64)> = features
89            .iter()
90            .enumerate()
91            .map(|(i, f)| {
92                let sig = FeatureSignature::from_features(f);
93                let score = self.sample_failure_probability(&sig, &mut rng);
94                (i, score)
95            })
96            .collect();
97
98        // Sort by score descending (highest failure probability first)
99        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
100
101        scored.into_iter().map(|(i, _)| i).collect()
102    }
103
104    /// Sample failure probability from Beta distribution (Thompson Sampling)
105    fn sample_failure_probability<R: rand::Rng>(&self, sig: &FeatureSignature, rng: &mut R) -> f64 {
106        use rand_distr::{Beta, Distribution};
107
108        // Get counts with Laplace smoothing (prior: Beta(1,1))
109        let alpha = self.failure_counts.get(sig).copied().unwrap_or(0.0) + 1.0;
110        let beta = self.success_counts.get(sig).copied().unwrap_or(0.0) + 1.0;
111
112        // Sample from Beta(alpha, beta)
113        // Beta distribution creation is mathematically guaranteed with positive alpha, beta >= 1.0
114        #[allow(clippy::unwrap_used)]
115        let beta_dist = Beta::new(alpha, beta).unwrap_or_else(|_| Beta::new(1.0, 1.0).unwrap());
116        beta_dist.sample(rng)
117    }
118
119    /// Update with feedback from test execution
120    ///
121    /// # Arguments
122    ///
123    /// * `features` - Features of the executed test case
124    /// * `revealed_bug` - True if test revealed a bug, false otherwise
125    pub fn update_feedback(&mut self, features: &CodeFeatures, revealed_bug: bool) {
126        let sig = FeatureSignature::from_features(features);
127
128        if revealed_bug {
129            *self.failure_counts.entry(sig).or_insert(0.0) += 1.0;
130        } else {
131            *self.success_counts.entry(sig).or_insert(0.0) += 1.0;
132        }
133
134        self.total_tests += 1;
135    }
136
137    /// Get current failure rate estimate for a feature signature
138    #[must_use]
139    pub fn failure_rate(&self, features: &CodeFeatures) -> f64 {
140        let sig = FeatureSignature::from_features(features);
141        let failures = self.failure_counts.get(&sig).copied().unwrap_or(0.0);
142        let successes = self.success_counts.get(&sig).copied().unwrap_or(0.0);
143        let total = failures + successes;
144
145        if total == 0.0 {
146            0.5 // Prior: unknown tests have 50% failure rate
147        } else {
148            failures / total
149        }
150    }
151
152    /// Get total number of tests executed
153    #[must_use]
154    pub const fn total_tests(&self) -> usize {
155        self.total_tests
156    }
157
158    /// Get number of tracked feature signatures
159    #[must_use]
160    pub fn num_signatures(&self) -> usize {
161        let mut sigs = self.success_counts.keys().collect::<Vec<_>>();
162        sigs.extend(self.failure_counts.keys());
163        sigs.sort_unstable();
164        sigs.dedup();
165        sigs.len()
166    }
167}
168
169impl Default for RLTestPrioritizer {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_rl_prioritizer_initial() {
181        let prioritizer = RLTestPrioritizer::new();
182        assert_eq!(prioritizer.total_tests(), 0);
183        assert_eq!(prioritizer.num_signatures(), 0);
184    }
185
186    #[test]
187    fn test_rl_prioritizer_feedback() {
188        let mut prioritizer = RLTestPrioritizer::new();
189
190        let features = CodeFeatures {
191            ast_depth: 5,
192            num_operators: 10,
193            num_control_flow: 2,
194            cyclomatic_complexity: 3.0,
195            uses_edge_values: false,
196            ..Default::default()
197        };
198
199        // Simulate test revealing bug
200        prioritizer.update_feedback(&features, true);
201        assert_eq!(prioritizer.total_tests(), 1);
202
203        // Check failure rate increased
204        let rate = prioritizer.failure_rate(&features);
205        assert!(rate > 0.0);
206    }
207
208    #[test]
209    fn test_rl_prioritizer_learning() {
210        let mut prioritizer = RLTestPrioritizer::new();
211
212        let buggy_features = CodeFeatures {
213            ast_depth: 10,
214            num_operators: 50,
215            num_control_flow: 10,
216            cyclomatic_complexity: 15.0,
217            uses_edge_values: true,
218            ..Default::default()
219        };
220
221        let clean_features = CodeFeatures {
222            ast_depth: 3,
223            num_operators: 5,
224            num_control_flow: 1,
225            cyclomatic_complexity: 2.0,
226            uses_edge_values: false,
227            ..Default::default()
228        };
229
230        // Simulate multiple test executions
231        for _ in 0..10 {
232            prioritizer.update_feedback(&buggy_features, true);
233            prioritizer.update_feedback(&clean_features, false);
234        }
235
236        // Buggy features should have higher failure rate
237        let buggy_rate = prioritizer.failure_rate(&buggy_features);
238        let clean_rate = prioritizer.failure_rate(&clean_features);
239        assert!(buggy_rate > clean_rate);
240    }
241
242    #[test]
243    fn test_rl_prioritizer_ordering() {
244        let mut prioritizer = RLTestPrioritizer::new();
245
246        let features = vec![
247            CodeFeatures {
248                ast_depth: 3,
249                num_operators: 5,
250                cyclomatic_complexity: 2.0,
251                uses_edge_values: false,
252                ..Default::default()
253            },
254            CodeFeatures {
255                ast_depth: 10,
256                num_operators: 50,
257                cyclomatic_complexity: 15.0,
258                uses_edge_values: true,
259                ..Default::default()
260            },
261        ];
262
263        // Train: second feature reveals bugs more often
264        for _ in 0..5 {
265            prioritizer.update_feedback(&features[1], true);
266            prioritizer.update_feedback(&features[0], false);
267        }
268
269        // Prioritize should put buggy test first
270        let order = prioritizer.prioritize(&features);
271        // Due to Thompson Sampling randomness, we can't guarantee exact order
272        // Just check both indices are present
273        assert_eq!(order.len(), 2);
274        assert!(order.contains(&0));
275        assert!(order.contains(&1));
276    }
277
278    #[test]
279    fn test_exploration_rate() {
280        let prioritizer = RLTestPrioritizer::new().with_exploration_rate(0.2);
281        assert!((prioritizer.exploration_rate - 0.2).abs() < f64::EPSILON);
282    }
283
284    #[test]
285    fn test_feature_signature_buckets() {
286        let features = CodeFeatures {
287            ast_depth: 7,
288            num_operators: 15,
289            cyclomatic_complexity: 8.0,
290            uses_edge_values: true,
291            ..Default::default()
292        };
293
294        let sig = FeatureSignature::from_features(&features);
295        assert_eq!(sig.depth_bucket, 1); // 6-10
296        assert_eq!(sig.operator_bucket, 1); // 11-30
297        assert_eq!(sig.complexity_bucket, 1); // 6-15
298        assert!(sig.uses_edge_values);
299    }
300}