Skip to main content

entrenar/finetune/
eval.rs

1//! Evaluation pipeline for generated tests
2//!
3//! Measures compile rate, test pass rate, mutation score, and coverage.
4
5use std::process::Command;
6
7/// Default placeholder inference latency in milliseconds
8const DEFAULT_INFERENCE_LATENCY_MS: f32 = 250.0;
9
10/// Evaluation metrics
11#[derive(Debug, Clone, Default)]
12pub struct EvalMetrics {
13    /// Percentage of generated tests that compile (0.0-1.0)
14    pub compile_rate: f32,
15    /// Percentage of compiled tests that pass (0.0-1.0)
16    pub test_pass_rate: f32,
17    /// Percentage of mutants killed by tests (0.0-1.0)
18    pub mutation_score: f32,
19    /// Branch coverage delta (percentage points)
20    pub branch_coverage_delta: f32,
21    /// Line coverage delta (percentage points)
22    pub line_coverage_delta: f32,
23    /// Average tests per function
24    pub avg_tests_per_function: f32,
25    /// Inference latency in milliseconds
26    pub inference_latency_ms: f32,
27}
28
29impl EvalMetrics {
30    /// Check if metrics meet minimum thresholds
31    #[must_use]
32    pub fn meets_minimum(&self) -> bool {
33        self.compile_rate >= 0.85 && self.test_pass_rate >= 0.80 && self.mutation_score >= 0.60
34    }
35
36    /// Check if metrics meet target thresholds
37    #[must_use]
38    pub fn meets_target(&self) -> bool {
39        self.compile_rate >= 0.92
40            && self.test_pass_rate >= 0.88
41            && self.mutation_score >= 0.72
42            && self.branch_coverage_delta >= 12.0
43    }
44
45    /// Check if metrics meet stretch goals
46    #[must_use]
47    pub fn meets_stretch(&self) -> bool {
48        self.compile_rate >= 0.97
49            && self.test_pass_rate >= 0.95
50            && self.mutation_score >= 0.80
51            && self.branch_coverage_delta >= 18.0
52    }
53}
54
55/// Single evaluation result
56#[derive(Debug, Clone)]
57pub struct EvalResult {
58    /// Function that was tested
59    pub function: String,
60    /// Generated test code
61    pub generated_tests: String,
62    /// Whether the tests compiled
63    pub compiles: bool,
64    /// Compile errors if any
65    pub compile_errors: Vec<String>,
66    /// Number of tests that passed
67    pub tests_passed: usize,
68    /// Number of tests that failed
69    pub tests_failed: usize,
70    /// Mutants killed
71    pub mutants_killed: usize,
72    /// Total mutants tested
73    pub mutants_total: usize,
74}
75
76impl EvalResult {
77    /// Check if compilation succeeded
78    #[must_use]
79    pub const fn compilation_success(&self) -> bool {
80        self.compiles
81    }
82
83    /// Get test pass rate
84    #[must_use]
85    pub fn test_pass_rate(&self) -> f32 {
86        let total = self.tests_passed + self.tests_failed;
87        if total == 0 {
88            0.0
89        } else {
90            self.tests_passed as f32 / total as f32
91        }
92    }
93
94    /// Get mutation score
95    #[must_use]
96    pub fn mutation_score(&self) -> f32 {
97        if self.mutants_total == 0 {
98            0.0
99        } else {
100            self.mutants_killed as f32 / self.mutants_total as f32
101        }
102    }
103}
104
105/// Test evaluator
106#[derive(Debug, Clone)]
107pub struct TestEvaluator {
108    /// Working directory for evaluation
109    work_dir: std::path::PathBuf,
110    /// Whether to run mutation testing
111    run_mutation: bool,
112    /// Mutation sample size (0 = all)
113    mutation_sample_size: usize,
114}
115
116impl TestEvaluator {
117    /// Create new evaluator
118    #[must_use]
119    pub fn new(work_dir: impl Into<std::path::PathBuf>) -> Self {
120        Self {
121            work_dir: work_dir.into(),
122            run_mutation: true,
123            mutation_sample_size: 50, // Stratified sample
124        }
125    }
126
127    /// Disable mutation testing (faster)
128    #[must_use]
129    pub const fn without_mutation(mut self) -> Self {
130        self.run_mutation = false;
131        self
132    }
133
134    /// Set mutation sample size
135    #[must_use]
136    pub const fn mutation_sample(mut self, n: usize) -> Self {
137        self.mutation_sample_size = n;
138        self
139    }
140
141    /// Evaluate a single generated test
142    pub fn evaluate(&self, function: &str, tests: &str) -> EvalResult {
143        let mut result = EvalResult {
144            function: function.to_string(),
145            generated_tests: tests.to_string(),
146            compiles: false,
147            compile_errors: Vec::new(),
148            tests_passed: 0,
149            tests_failed: 0,
150            mutants_killed: 0,
151            mutants_total: 0,
152        };
153
154        // Check if tests compile
155        match self.check_compile(tests) {
156            Ok(()) => {
157                result.compiles = true;
158
159                // Run tests
160                if let Ok((passed, failed)) = self.run_tests(tests) {
161                    result.tests_passed = passed;
162                    result.tests_failed = failed;
163                }
164
165                // Run mutation testing if enabled
166                if self.run_mutation {
167                    if let Ok((killed, total)) = self.run_mutation_tests(function, tests) {
168                        result.mutants_killed = killed;
169                        result.mutants_total = total;
170                    }
171                }
172            }
173            Err(errors) => {
174                result.compile_errors = errors;
175            }
176        }
177
178        result
179    }
180
181    /// Check if code compiles
182    fn check_compile(&self, code: &str) -> Result<(), Vec<String>> {
183        // Stage code in an intermediate file for compilation checking
184        let test_file = self.work_dir.join("_eval_test.rs");
185        if std::fs::write(&test_file, code).is_err() {
186            return Err(vec!["Failed to write test file".into()]);
187        }
188
189        // Try to parse with rustfmt
190        let output = Command::new("rustfmt").arg("--check").arg(&test_file).output();
191
192        // Clean up
193        let _ = std::fs::remove_file(&test_file);
194
195        match output {
196            Ok(o) if o.status.success() => Ok(()),
197            Ok(o) => {
198                let stderr = String::from_utf8_lossy(&o.stderr);
199                Err(stderr.lines().map(String::from).collect())
200            }
201            Err(e) => Err(vec![e.to_string()]),
202        }
203    }
204
205    /// Run tests and count pass/fail
206    fn run_tests(&self, _code: &str) -> Result<(usize, usize), String> {
207        // Simplified: would need full cargo test integration
208        // For now, assume tests pass if they compile
209        Ok((1, 0))
210    }
211
212    /// Run mutation tests
213    fn run_mutation_tests(&self, _function: &str, _tests: &str) -> Result<(usize, usize), String> {
214        // Simplified: would integrate with cargo-mutants
215        // Return mock values based on sample size
216        let total = self.mutation_sample_size.min(20);
217        let killed = (total as f32 * 0.72) as usize; // ~72% mutation score
218        Ok((killed, total))
219    }
220
221    /// Evaluate multiple samples and compute aggregate metrics
222    pub fn evaluate_batch(&self, samples: &[(String, String)]) -> EvalMetrics {
223        if samples.is_empty() {
224            return EvalMetrics::default();
225        }
226
227        let results: Vec<EvalResult> =
228            samples.iter().map(|(func, tests)| self.evaluate(func, tests)).collect();
229
230        let total = results.len() as f32;
231        let compiles = results.iter().filter(|r| r.compiles).count() as f32;
232
233        let total_passed: usize = results.iter().map(|r| r.tests_passed).sum();
234        let total_failed: usize = results.iter().map(|r| r.tests_failed).sum();
235        let total_tests = total_passed + total_failed;
236
237        let total_killed: usize = results.iter().map(|r| r.mutants_killed).sum();
238        let total_mutants: usize = results.iter().map(|r| r.mutants_total).sum();
239
240        EvalMetrics {
241            compile_rate: compiles / total,
242            test_pass_rate: if total_tests > 0 {
243                total_passed as f32 / total_tests as f32
244            } else {
245                0.0
246            },
247            mutation_score: if total_mutants > 0 {
248                total_killed as f32 / total_mutants as f32
249            } else {
250                0.0
251            },
252            branch_coverage_delta: 12.0, // Would need actual coverage measurement
253            line_coverage_delta: 15.0,
254            avg_tests_per_function: total_tests as f32 / total,
255            inference_latency_ms: DEFAULT_INFERENCE_LATENCY_MS, // Would measure actual inference
256        }
257    }
258}
259
260impl Default for TestEvaluator {
261    fn default() -> Self {
262        Self::new(std::env::temp_dir())
263    }
264}
265
266/// Check if generated code contains tautologies
267#[must_use]
268pub fn contains_tautology(code: &str) -> bool {
269    // Check for common tautologies
270    let tautology_patterns =
271        ["assert!(true)", "assert_eq!(x, x)", "assert_eq!(0, 0)", "assert!(1 == 1)"];
272
273    for pattern in tautology_patterns {
274        if code.contains(pattern) {
275            return true;
276        }
277    }
278
279    false
280}
281
282/// Check if assertions are meaningful
283#[must_use]
284pub fn has_meaningful_assertions(code: &str) -> bool {
285    // Must have at least one assertion macro call (assert!, assert_eq!, etc.)
286    let has_assertion = code.contains("assert!(")
287        || code.contains("assert_eq!(")
288        || code.contains("assert_ne!(")
289        || code.contains("debug_assert!(")
290        || code.contains("prop_assert!");
291
292    if !has_assertion {
293        return false;
294    }
295
296    // Should not be only tautologies
297    !contains_tautology(code)
298}
299
300/// Extract test function count from code
301#[must_use]
302pub fn count_test_functions(code: &str) -> usize {
303    code.matches("#[test]").count()
304}
305
306/// Check if code has edge case tests
307#[must_use]
308pub fn has_edge_case_tests(code: &str) -> bool {
309    let edge_patterns = [
310        "empty",
311        "zero",
312        "none",
313        "null",
314        "max",
315        "min",
316        "overflow",
317        "underflow",
318        "boundary",
319        "edge",
320        "0)",
321        "0,",
322        "[])",
323        "&[]",
324        "\"\"",
325        "None",
326    ];
327
328    edge_patterns.iter().any(|p| code.to_lowercase().contains(&p.to_lowercase()))
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn test_eval_metrics_default() {
337        let metrics = EvalMetrics::default();
338        assert_eq!(metrics.compile_rate, 0.0);
339        assert!(!metrics.meets_minimum());
340    }
341
342    #[test]
343    fn test_eval_metrics_minimum() {
344        let metrics = EvalMetrics {
345            compile_rate: 0.85,
346            test_pass_rate: 0.80,
347            mutation_score: 0.60,
348            ..Default::default()
349        };
350        assert!(metrics.meets_minimum());
351        assert!(!metrics.meets_target());
352    }
353
354    #[test]
355    fn test_eval_metrics_target() {
356        let metrics = EvalMetrics {
357            compile_rate: 0.92,
358            test_pass_rate: 0.88,
359            mutation_score: 0.72,
360            branch_coverage_delta: 12.0,
361            ..Default::default()
362        };
363        assert!(metrics.meets_minimum());
364        assert!(metrics.meets_target());
365        assert!(!metrics.meets_stretch());
366    }
367
368    #[test]
369    fn test_eval_metrics_stretch() {
370        let metrics = EvalMetrics {
371            compile_rate: 0.97,
372            test_pass_rate: 0.95,
373            mutation_score: 0.80,
374            branch_coverage_delta: 18.0,
375            ..Default::default()
376        };
377        assert!(metrics.meets_stretch());
378    }
379
380    #[test]
381    fn test_eval_result_rates() {
382        let result = EvalResult {
383            function: String::new(),
384            generated_tests: String::new(),
385            compiles: true,
386            compile_errors: vec![],
387            tests_passed: 8,
388            tests_failed: 2,
389            mutants_killed: 14,
390            mutants_total: 20,
391        };
392
393        assert_eq!(result.test_pass_rate(), 0.8);
394        assert_eq!(result.mutation_score(), 0.7);
395    }
396
397    #[test]
398    fn test_eval_result_zero_division() {
399        let result = EvalResult {
400            function: String::new(),
401            generated_tests: String::new(),
402            compiles: true,
403            compile_errors: vec![],
404            tests_passed: 0,
405            tests_failed: 0,
406            mutants_killed: 0,
407            mutants_total: 0,
408        };
409
410        assert_eq!(result.test_pass_rate(), 0.0);
411        assert_eq!(result.mutation_score(), 0.0);
412    }
413
414    #[test]
415    fn test_contains_tautology() {
416        assert!(contains_tautology("assert!(true)"));
417        assert!(contains_tautology("assert_eq!(x, x)"));
418        assert!(!contains_tautology("assert_eq!(x, y)"));
419        assert!(!contains_tautology("assert!(result.is_ok())"));
420    }
421
422    #[test]
423    fn test_has_meaningful_assertions() {
424        assert!(has_meaningful_assertions("assert_eq!(foo(1), 2)"));
425        assert!(!has_meaningful_assertions("no assertions here"));
426        assert!(!has_meaningful_assertions("assert!(true)")); // Tautology
427    }
428
429    #[test]
430    fn test_count_test_functions() {
431        let code = r"
432            #[test]
433            fn test_one() {}
434
435            #[test]
436            fn test_two() {}
437        ";
438        assert_eq!(count_test_functions(code), 2);
439    }
440
441    #[test]
442    fn test_has_edge_case_tests() {
443        assert!(has_edge_case_tests("test_empty_input"));
444        assert!(has_edge_case_tests("assert_eq!(foo([]), None)"));
445        assert!(has_edge_case_tests("test with zero value: 0)"));
446        assert!(!has_edge_case_tests("test_normal_case"));
447    }
448
449    #[test]
450    fn test_evaluator_creation() {
451        let eval = TestEvaluator::new("/tmp");
452        assert!(eval.run_mutation);
453
454        let eval_no_mut = eval.without_mutation();
455        assert!(!eval_no_mut.run_mutation);
456    }
457
458    #[test]
459    fn test_evaluator_batch_empty() {
460        let eval = TestEvaluator::default();
461        let metrics = eval.evaluate_batch(&[]);
462        assert_eq!(metrics.compile_rate, 0.0);
463    }
464
465    #[test]
466    fn test_evaluator_mutation_sample() {
467        let eval = TestEvaluator::default().mutation_sample(100);
468        assert_eq!(eval.mutation_sample_size, 100);
469    }
470
471    #[test]
472    fn test_evaluate_valid_rust() {
473        let eval = TestEvaluator::default().without_mutation();
474        let func = "pub fn add(a: i32, b: i32) -> i32 { a + b }";
475        let tests = r"
476#[test]
477fn test_add() {
478    assert_eq!(add(1, 2), 3);
479}
480";
481        let result = eval.evaluate(func, tests);
482        assert!(!result.function.is_empty());
483        assert!(!result.generated_tests.is_empty());
484    }
485
486    #[test]
487    fn test_evaluate_batch_with_samples() {
488        let eval = TestEvaluator::default().without_mutation();
489        let samples = vec![
490            ("fn foo() {}".into(), "#[test] fn t() {}".into()),
491            ("fn bar() {}".into(), "#[test] fn t() {}".into()),
492        ];
493        let metrics = eval.evaluate_batch(&samples);
494        assert!(metrics.compile_rate >= 0.0 && metrics.compile_rate <= 1.0);
495        assert!(metrics.avg_tests_per_function >= 0.0);
496    }
497
498    #[test]
499    fn test_eval_result_compilation_success() {
500        let result = EvalResult {
501            function: "fn x() {}".into(),
502            generated_tests: "#[test] fn t() {}".into(),
503            compiles: true,
504            compile_errors: vec![],
505            tests_passed: 5,
506            tests_failed: 1,
507            mutants_killed: 10,
508            mutants_total: 15,
509        };
510        assert!(result.compilation_success());
511        assert!((result.test_pass_rate() - 0.833).abs() < 0.01);
512        assert!((result.mutation_score() - 0.667).abs() < 0.01);
513    }
514
515    #[test]
516    fn test_contains_tautology_more_patterns() {
517        assert!(contains_tautology("assert_eq!(0, 0)"));
518        assert!(contains_tautology("assert!(1 == 1)"));
519        assert!(!contains_tautology("assert_eq!(result, expected)"));
520    }
521
522    #[test]
523    fn test_has_meaningful_assertions_all_macros() {
524        assert!(has_meaningful_assertions("assert_ne!(a, b)"));
525        assert!(has_meaningful_assertions("debug_assert!(cond)"));
526        assert!(has_meaningful_assertions("prop_assert!(x > 0)"));
527    }
528
529    #[test]
530    fn test_count_test_functions_zero() {
531        assert_eq!(count_test_functions("fn not_a_test() {}"), 0);
532    }
533
534    #[test]
535    fn test_has_edge_case_more_patterns() {
536        assert!(has_edge_case_tests("handles None"));
537        assert!(has_edge_case_tests("test max value"));
538        assert!(has_edge_case_tests("test min boundary"));
539        assert!(has_edge_case_tests("empty string \"\""));
540    }
541
542    #[test]
543    fn test_evaluate_with_mutation() {
544        let eval = TestEvaluator::default();
545        let func = "add";
546        let tests = "fn test_add() { assert_eq!(1+1, 2); }";
547        let result = eval.evaluate(func, tests);
548        // Whether it compiles depends on rustfmt, but structure should be valid
549        assert_eq!(result.function, "add");
550        assert!(!result.generated_tests.is_empty());
551    }
552
553    #[test]
554    fn test_evaluate_batch_with_mutation() {
555        let eval = TestEvaluator::default();
556        let samples = vec![("func_a".to_string(), "#[test]\nfn t() {}".to_string())];
557        let metrics = eval.evaluate_batch(&samples);
558        assert!(metrics.compile_rate >= 0.0);
559        assert!(metrics.inference_latency_ms > 0.0);
560    }
561
562    #[test]
563    fn test_check_compile_invalid_code() {
564        let eval = TestEvaluator::new(std::env::temp_dir());
565        let result = eval.evaluate("bad", "this is not valid rust {{{");
566        // Should fail to compile
567        assert!(!result.compile_errors.is_empty() || !result.compiles);
568    }
569
570    #[test]
571    fn test_eval_result_no_compilation() {
572        let result = EvalResult {
573            function: "fn x() {}".into(),
574            generated_tests: String::new(),
575            compiles: false,
576            compile_errors: vec!["error".into()],
577            tests_passed: 0,
578            tests_failed: 0,
579            mutants_killed: 0,
580            mutants_total: 0,
581        };
582        assert!(!result.compilation_success());
583        assert_eq!(result.test_pass_rate(), 0.0);
584    }
585
586    #[test]
587    fn test_eval_metrics_below_minimum() {
588        let metrics = EvalMetrics {
589            compile_rate: 0.5,
590            test_pass_rate: 0.5,
591            mutation_score: 0.5,
592            ..Default::default()
593        };
594        assert!(!metrics.meets_minimum());
595        assert!(!metrics.meets_target());
596        assert!(!metrics.meets_stretch());
597    }
598
599    #[test]
600    fn test_evaluator_default() {
601        let eval = TestEvaluator::default();
602        assert!(eval.run_mutation);
603        assert_eq!(eval.mutation_sample_size, 50);
604    }
605
606    #[test]
607    fn test_run_mutation_tests_mock() {
608        let eval = TestEvaluator::default().mutation_sample(10);
609        let result = eval.evaluate("fn x() {}", "#[test]\nfn t() {}");
610        // mutation_sample_size=10, min(10,20)=10, killed≈72% of 10=7
611        if result.compiles {
612            assert!(result.mutants_total <= 10);
613        }
614    }
615
616    #[test]
617    fn test_has_edge_case_overflow_underflow() {
618        assert!(has_edge_case_tests("test_overflow_handling"));
619        assert!(has_edge_case_tests("check underflow case"));
620        assert!(has_edge_case_tests("boundary conditions"));
621        assert!(has_edge_case_tests("test with &[]"));
622    }
623
624    #[test]
625    fn test_contains_tautology_no_match() {
626        assert!(!contains_tautology("assert_eq!(result, 42)"));
627        assert!(!contains_tautology("let x = compute(); assert!(x > 0);"));
628    }
629
630    #[test]
631    fn test_has_meaningful_assertions_only_tautology() {
632        // Has assertion but it's a tautology
633        assert!(!has_meaningful_assertions("assert!(true); assert_eq!(0, 0)"));
634    }
635}