Skip to main content

entrenar/eval/generative/
code.rs

1//! Code generation evaluation metrics
2//!
3//! Provides pass@k — the unbiased estimator for functional correctness
4//! of code generation models (Chen et al., 2021 "Evaluating Large Language
5//! Models Trained on Code").
6
7/// Compute pass@k: unbiased estimator of functional correctness.
8///
9/// Formula: `1 - C(n-c, k) / C(n, k)`
10///
11/// where n = total samples, c = correct samples, k = top-k threshold.
12///
13/// Returns a value in [0, 1] where 1.0 means all k samples pass.
14///
15/// # Arguments
16/// * `n` - Total number of generated code samples
17/// * `c` - Number of correct (passing) samples
18/// * `k` - Number of samples to consider (typically 1, 10, or 100)
19///
20/// # Edge Cases
21/// * If `k > n`, returns `if c > 0 { 1.0 } else { 0.0 }`
22/// * If `c >= n`, returns 1.0
23/// * If `c == 0`, returns 0.0
24pub fn pass_at_k(n: usize, c: usize, k: usize) -> f64 {
25    if c == 0 {
26        return 0.0;
27    }
28    if c >= n || k > n {
29        return 1.0;
30    }
31
32    // 1 - C(n-c, k) / C(n, k)
33    // Compute in log space to avoid overflow for large n
34    // C(n-c, k) / C(n, k) = product_{i=0..k} (n-c-i) / (n-i)
35    let mut log_ratio = 0.0f64;
36    for i in 0..k {
37        let numerator = (n - c - i) as f64;
38        let denominator = (n - i) as f64;
39        if numerator <= 0.0 {
40            return 1.0;
41        }
42        log_ratio +=
43            numerator.max(f64::MIN_POSITIVE).ln() - denominator.max(f64::MIN_POSITIVE).ln();
44    }
45
46    1.0 - log_ratio.exp()
47}