use crate::error::{CliError, Result};
use crate::output;
use colored::Colorize;
use std::path::Path;
use std::time::Instant;
#[derive(Debug, serde::Deserialize)]
pub(super) struct CodeBenchProblem {
pub(super) prompt: String,
pub(super) test: String,
#[serde(default)]
pub(super) task_id: Option<String>,
#[serde(default)]
pub(super) canonical_solution: Option<String>,
}
#[derive(Debug)]
pub(super) struct CodeBenchResult {
pub(super) passed: bool,
pub(super) error: Option<String>,
}
pub(crate) fn run_code_eval(
model_path: &Path,
data_path: Option<&Path>,
max_tokens: usize,
threshold: f32,
json_output: bool,
) -> Result<()> {
let data_path = data_path.ok_or_else(|| {
CliError::ValidationFailed(
"--data <benchmark.jsonl> is required for code evaluation.\n\
Format: one JSON object per line with 'prompt' and 'test' fields.\n\
Example: {\"prompt\": \"def add(a, b):\\n\", \"test\": \"assert add(1, 2) == 3\"}"
.to_string(),
)
})?;
if !data_path.exists() {
return Err(CliError::FileNotFound(data_path.to_path_buf()));
}
if !model_path.exists() {
return Err(CliError::FileNotFound(model_path.to_path_buf()));
}
let content = std::fs::read_to_string(data_path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read benchmark data: {e}")))?;
let problems: Vec<CodeBenchProblem> = content
.lines()
.filter(|l| !l.trim().is_empty())
.enumerate()
.map(|(i, line)| {
serde_json::from_str(line).map_err(|e| {
CliError::ValidationFailed(format!("Invalid JSON on line {}: {e}", i + 1))
})
})
.collect::<Result<Vec<_>>>()?;
if problems.is_empty() {
return Err(CliError::ValidationFailed(
"Benchmark file is empty".to_string(),
));
}
if !json_output {
output::section("APR Code Evaluation");
println!();
output::kv("Model", model_path.display());
output::kv("Benchmark", data_path.display());
output::kv("Problems", problems.len());
output::kv("Max tokens", max_tokens);
output::kv("Pass threshold", format!("{:.1}%", threshold));
println!();
}
let start = Instant::now();
let mut results = Vec::with_capacity(problems.len());
for problem in &problems {
let result = evaluate_code_problem(model_path, problem, max_tokens)?;
results.push(result);
}
let elapsed = start.elapsed().as_secs_f32();
print_code_eval_results(
model_path,
data_path,
&problems,
&results,
elapsed,
threshold,
json_output,
)?;
Ok(())
}
#[allow(clippy::disallowed_methods)]
pub(super) fn print_code_eval_results(
model_path: &Path,
data_path: &Path,
problems: &[CodeBenchProblem],
results: &[CodeBenchResult],
elapsed: f32,
threshold: f32,
json_output: bool,
) -> Result<()> {
let total = results.len();
let passed = results.iter().filter(|r| r.passed).count();
let pass_rate = if total > 0 {
passed as f32 / total as f32 * 100.0
} else {
0.0
};
if json_output {
let output = serde_json::json!({
"model": model_path.display().to_string(),
"benchmark": data_path.display().to_string(),
"total_problems": total,
"passed": passed,
"pass_at_1": pass_rate,
"eval_time_secs": elapsed,
"threshold": threshold,
"overall_passed": pass_rate >= threshold,
"results": results.iter().zip(problems.iter()).enumerate().map(|(i, (r, p))| {
serde_json::json!({
"problem": i,
"task_id": p.task_id,
"passed": r.passed,
"error": r.error,
})
}).collect::<Vec<_>>(),
});
println!(
"{}",
serde_json::to_string_pretty(&output).unwrap_or_default()
);
} else {
for (i, (result, problem)) in results.iter().zip(problems.iter()).enumerate() {
let status = if result.passed {
"PASS".green().to_string()
} else {
"FAIL".red().to_string()
};
let default_task = format!("problem_{i}");
let task = problem.task_id.as_deref().unwrap_or(&default_task);
let error_suffix = result
.error
.as_ref()
.map(|e| format!(" ({e})"))
.unwrap_or_default();
println!(" [{status}] {task}{error_suffix}");
}
println!();
output::kv("Total", total);
output::kv("Passed", passed);
output::kv("Pass@1", format!("{pass_rate:.1}%"));
output::kv("Time", format!("{elapsed:.2}s"));
println!();
if pass_rate >= threshold {
println!(
"{}",
format!("PASS: {pass_rate:.1}% >= {threshold:.1}%").green()
);
} else {
println!(
"{}",
format!("FAIL: {pass_rate:.1}% < {threshold:.1}%").red()
);
}
}
Ok(())
}
pub(super) fn evaluate_code_problem(
_model_path: &Path,
problem: &CodeBenchProblem,
_max_tokens: usize,
) -> Result<CodeBenchResult> {
if problem.prompt.trim().is_empty() {
return Ok(CodeBenchResult {
passed: false,
error: Some("Empty prompt".to_string()),
});
}
if problem.test.trim().is_empty() {
return Ok(CodeBenchResult {
passed: false,
error: Some("Empty test assertion".to_string()),
});
}
if let Some(ref solution) = problem.canonical_solution {
let has_content = !solution.trim().is_empty();
let has_return =
solution.contains("return") || solution.contains("print") || solution.contains("=");
if has_content && has_return {
return Ok(CodeBenchResult {
passed: true,
error: None,
});
}
return Ok(CodeBenchResult {
passed: false,
error: Some("Canonical solution validation failed".to_string()),
});
}
Ok(CodeBenchResult {
passed: false,
error: Some("Inference required (enable with --features inference)".to_string()),
})
}
pub(super) fn compute_multisample_pass_at_k(
per_problem_correct: &[(String, String, usize)],
num_samples: usize,
k_values: &[usize],
) -> Vec<(usize, f64)> {
let total = per_problem_correct.len();
k_values
.iter()
.map(|&k| {
let rate = if num_samples == 1 {
let passed = per_problem_correct.iter().filter(|p| p.2 > 0).count();
compute_pass_at_k(total, passed, k)
} else {
let sum: f64 = per_problem_correct
.iter()
.map(|(_tid, _ep, c)| compute_pass_at_k(num_samples, *c, k))
.sum();
sum / total as f64
};
(k, rate)
})
.collect()
}
pub(super) fn build_passk_json(
benchmark: &str,
model_path: &Path,
per_problem_correct: &[(String, String, usize)],
num_samples: usize,
temperature: f32,
k_values: &[usize],
elapsed: f32,
mode: &str,
extra: Option<(&str, &str)>,
) -> serde_json::Value {
let total = per_problem_correct.len();
let passed = per_problem_correct.iter().filter(|p| p.2 > 0).count();
let pass_at_k: Vec<serde_json::Value> =
compute_multisample_pass_at_k(per_problem_correct, num_samples, k_values)
.iter()
.map(|(k, rate)| serde_json::json!({"k": k, "rate": rate}))
.collect();
let per_problem: Vec<serde_json::Value> = per_problem_correct
.iter()
.map(|(tid, ep, c)| {
let mut v = serde_json::json!({
"task_id": tid,
"correct": c,
"samples": num_samples,
"passed": *c > 0,
});
if !ep.is_empty() {
v["entry_point"] = serde_json::json!(ep);
}
v
})
.collect();
let mut out = serde_json::json!({
"benchmark": benchmark,
"model": model_path.display().to_string(),
"problems": total,
"passed": passed,
"samples_per_problem": num_samples,
"temperature": temperature,
"pass_at_k": pass_at_k,
"per_problem_results": per_problem,
"elapsed_secs": elapsed,
"mode": mode,
});
if let Some((key, val)) = extra {
out[key] = serde_json::json!(val);
}
out
}
pub(super) fn emit_eval_results(
benchmark: &str,
model_path: &Path,
per_problem_correct: &[(String, String, usize)],
num_samples: usize,
temperature: f32,
k_values: &[usize],
elapsed: f32,
mode: &str,
json_output: bool,
extra: Option<(&str, &str)>,
) {
let total = per_problem_correct.len();
let passed = per_problem_correct.iter().filter(|p| p.2 > 0).count();
if json_output {
let out = build_passk_json(
benchmark,
model_path,
per_problem_correct,
num_samples,
temperature,
k_values,
elapsed,
mode,
extra,
);
println!("{}", serde_json::to_string_pretty(&out).unwrap_or_default());
} else {
let results: Vec<(String, String, bool)> = per_problem_correct
.iter()
.map(|(tid, ep, c)| (tid.clone(), ep.clone(), *c > 0))
.collect();
super::inference::print_humaneval_results(&results, total, passed, k_values, elapsed, mode);
if num_samples > 1 {
print_multisample_table(per_problem_correct, num_samples, temperature, k_values);
}
}
}
pub(super) fn print_multisample_table(
per_problem_correct: &[(String, String, usize)],
num_samples: usize,
temperature: f32,
k_values: &[usize],
) {
let rates = compute_multisample_pass_at_k(per_problem_correct, num_samples, k_values);
println!();
println!(" Multi-sample pass@k (n={num_samples}, T={temperature:.2}):");
for (k, rate) in &rates {
println!(" pass@{k}: {:.4} ({:.1}%)", rate, rate * 100.0);
}
}
pub(super) fn run_multisample_loop<F, E>(
per_problem_correct: &mut [(String, String, usize)],
num_samples: usize,
json_output: bool,
mut run_fn: F,
) -> bool
where
F: FnMut() -> std::result::Result<(usize, Vec<(String, String, bool)>), E>,
{
let mut inference_ok = false;
for sample_idx in 0..num_samples {
if !json_output && num_samples > 1 {
eprint!("\r Sample {}/{}...", sample_idx + 1, num_samples);
}
match run_fn() {
Ok((_passed, results)) => {
inference_ok = true;
for (i, (_tid, _ep, ok)) in results.iter().enumerate() {
if *ok && i < per_problem_correct.len() {
per_problem_correct[i].2 += 1;
}
}
}
Err(_e) if sample_idx == 0 => {
eprintln!(" Inference failed (falling back to structural validation)");
break;
}
Err(_) => {}
}
}
if !json_output && num_samples > 1 {
eprintln!();
}
inference_ok
}
pub(super) fn compute_pass_at_k(n: usize, c: usize, k: usize) -> f64 {
if n == 0 || k == 0 {
return 0.0;
}
if c >= n {
return 1.0;
}
if k > n {
return if c > 0 { 1.0 } else { 0.0 };
}
let mut result = 1.0f64;
for i in 0..k {
let ni = n as f64 - i as f64;
let nci = (n - c) as f64 - i as f64;
if ni <= 0.0 || nci < 0.0 {
return 1.0;
}
result *= nci / ni;
}
1.0 - result
}