Skip to main content

sgr_agent/
benchmark.rs

1//! Benchmark suite: 5 fixed tasks to measure agent quality.
2//!
3//! Each task is deterministic — same input, measurable output.
4//! Score = average across all tasks (0.0–1.0).
5//!
6//! Run after every self-evolution patch to detect regressions.
7//!
8//! Inspired by Karpathy's autoresearch: fixed budget, single metric.
9
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12
13/// A single benchmark task.
14#[derive(Debug, Clone)]
15pub struct BenchmarkTask {
16    /// Short name
17    pub name: &'static str,
18    /// Prompt sent to agent
19    pub prompt: &'static str,
20    /// Max steps budget (Karpathy: fixed budget per experiment)
21    pub max_steps: usize,
22    /// How to verify success — function checks the output
23    pub verify: fn(&BenchmarkResult) -> f64,
24}
25
26/// Result of running one benchmark task.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct BenchmarkResult {
29    pub name: String,
30    pub steps: usize,
31    pub completed: bool,
32    pub tool_errors: usize,
33    pub loop_warnings: usize,
34    /// Agent's final output (finish summary)
35    pub output: String,
36    /// Score for this task (0.0–1.0)
37    pub score: f64,
38}
39
40/// Full benchmark report.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BenchmarkReport {
43    pub timestamp: u64,
44    pub commit: String,
45    pub results: Vec<BenchmarkResult>,
46    /// Average score across all tasks
47    pub avg_score: f64,
48    /// Standard deviation (Knuth: measure uncertainty)
49    pub std_dev: f64,
50}
51
52// ---------------------------------------------------------------------------
53// The 5 benchmark tasks
54// ---------------------------------------------------------------------------
55
56/// Task 1: Simple Q&A — can agent use finish tool correctly?
57const TASK_QA: BenchmarkTask = BenchmarkTask {
58    name: "qa_simple",
59    prompt: "What is the capital of France? Answer with the finish tool.",
60    max_steps: 3,
61    verify: verify_qa,
62};
63
64fn verify_qa(r: &BenchmarkResult) -> f64 {
65    if !r.completed {
66        return 0.0;
67    }
68    let has_paris = r.output.to_lowercase().contains("paris");
69    let efficiency = if r.steps <= 1 {
70        1.0
71    } else {
72        0.8 / r.steps as f64
73    };
74    if has_paris {
75        (0.7 + efficiency * 0.3).min(1.0)
76    } else {
77        0.1
78    }
79}
80
81/// Task 2: File read — can agent read a file and extract info?
82const TASK_READ: BenchmarkTask = BenchmarkTask {
83    name: "read_file",
84    prompt: "Read the file Cargo.toml in the current directory and tell me the package name. Use finish tool with the name.",
85    max_steps: 5,
86    verify: verify_read,
87};
88
89fn verify_read(r: &BenchmarkResult) -> f64 {
90    if !r.completed {
91        return 0.0;
92    }
93    // Should find "rust-code" or whatever the package name is
94    let output_lower = r.output.to_lowercase();
95    let found_name = output_lower.contains("rust-code")
96        || output_lower.contains("sgr-agent")
97        || output_lower.contains("package");
98    let efficiency = (3.0 / r.steps.max(1) as f64).min(1.0);
99    if found_name {
100        0.6 + efficiency * 0.4
101    } else if r.completed {
102        0.3
103    } else {
104        0.0
105    }
106}
107
108/// Task 3: Code search — can agent find something in the codebase?
109const TASK_SEARCH: BenchmarkTask = BenchmarkTask {
110    name: "code_search",
111    prompt: "Search for the function 'parse_spec' in the codebase and tell me which file it's in. Use finish tool.",
112    max_steps: 8,
113    verify: verify_search,
114};
115
116fn verify_search(r: &BenchmarkResult) -> f64 {
117    if !r.completed {
118        return 0.0;
119    }
120    let output_lower = r.output.to_lowercase();
121    let found = output_lower.contains("openapi")
122        || output_lower.contains("spec.rs")
123        || output_lower.contains("parse_spec");
124    let no_errors = r.tool_errors == 0;
125    let efficiency = (5.0 / r.steps.max(1) as f64).min(1.0);
126    let mut score = 0.0;
127    if found {
128        score += 0.5;
129    }
130    if no_errors {
131        score += 0.2;
132    }
133    score += efficiency * 0.3;
134    score.min(1.0)
135}
136
137/// Task 4: Multi-step — can agent do read + analyze + answer?
138const TASK_MULTI: BenchmarkTask = BenchmarkTask {
139    name: "multi_step",
140    prompt: "Read crates/sgr-agent/src/lib.rs, count how many pub mod declarations it has, and answer with the count using finish tool.",
141    max_steps: 10,
142    verify: verify_multi,
143};
144
145fn verify_multi(r: &BenchmarkResult) -> f64 {
146    if !r.completed {
147        return 0.0;
148    }
149    // Should have a number in the output
150    let has_number = r.output.chars().any(|c| c.is_ascii_digit());
151    let no_loops = r.loop_warnings == 0;
152    let efficiency = (4.0 / r.steps.max(1) as f64).min(1.0);
153    let mut score = 0.0;
154    if has_number {
155        score += 0.5;
156    }
157    if no_loops {
158        score += 0.2;
159    }
160    score += efficiency * 0.3;
161    score.min(1.0)
162}
163
164/// Task 5: Tool chaining — can agent use git_status + analysis?
165const TASK_GIT: BenchmarkTask = BenchmarkTask {
166    name: "git_status",
167    prompt: "Check git status of this repo. Tell me which branch we're on and if there are uncommitted changes. Use finish tool.",
168    max_steps: 5,
169    verify: verify_git,
170};
171
172fn verify_git(r: &BenchmarkResult) -> f64 {
173    if !r.completed {
174        return 0.0;
175    }
176    let output_lower = r.output.to_lowercase();
177    let has_branch = output_lower.contains("master")
178        || output_lower.contains("main")
179        || output_lower.contains("branch");
180    let has_status = output_lower.contains("clean")
181        || output_lower.contains("uncommitted")
182        || output_lower.contains("modified")
183        || output_lower.contains("changes");
184    let efficiency = (3.0 / r.steps.max(1) as f64).min(1.0);
185    let mut score = 0.0;
186    if has_branch {
187        score += 0.35;
188    }
189    if has_status {
190        score += 0.35;
191    }
192    score += efficiency * 0.3;
193    score.min(1.0)
194}
195
196// ---------------------------------------------------------------------------
197// Public API
198// ---------------------------------------------------------------------------
199
200/// All 5 benchmark tasks.
201pub fn all_tasks() -> Vec<BenchmarkTask> {
202    vec![TASK_QA, TASK_READ, TASK_SEARCH, TASK_MULTI, TASK_GIT]
203}
204
205/// Compute aggregate report from individual results.
206pub fn compute_report(results: Vec<BenchmarkResult>, commit: &str) -> BenchmarkReport {
207    let n = results.len() as f64;
208    let avg = if n > 0.0 {
209        results.iter().map(|r| r.score).sum::<f64>() / n
210    } else {
211        0.0
212    };
213    let variance = if n > 1.0 {
214        results.iter().map(|r| (r.score - avg).powi(2)).sum::<f64>() / (n - 1.0)
215    } else {
216        0.0
217    };
218    let std_dev = variance.sqrt();
219    let ts = std::time::SystemTime::now()
220        .duration_since(std::time::UNIX_EPOCH)
221        .unwrap_or_default()
222        .as_secs();
223
224    BenchmarkReport {
225        timestamp: ts,
226        commit: commit.to_string(),
227        results,
228        avg_score: avg,
229        std_dev,
230    }
231}
232
233/// Format report for display (Knuth: literate output).
234pub fn format_report(report: &BenchmarkReport) -> String {
235    let mut out = format!(
236        "## Benchmark Report\n\n\
237         Commit: {} | Score: {:.3} ± {:.3}\n\n\
238         | Task | Steps | Errors | Score | Status |\n\
239         |------|-------|--------|-------|--------|\n",
240        report.commit, report.avg_score, report.std_dev,
241    );
242    for r in &report.results {
243        let status = if r.score >= 0.8 {
244            "✓"
245        } else if r.score >= 0.5 {
246            "~"
247        } else {
248            "✗"
249        };
250        out.push_str(&format!(
251            "| {} | {} | {} | {:.2} | {} |\n",
252            r.name, r.steps, r.tool_errors, r.score, status,
253        ));
254    }
255    out.push_str(&format!(
256        "\n**Average: {:.3} ± {:.3}**\n",
257        report.avg_score, report.std_dev,
258    ));
259    out
260}
261
262/// Save benchmark report to JSONL log.
263pub fn log_benchmark(agent_home: &str, report: &BenchmarkReport) -> Result<(), String> {
264    let path = Path::new(agent_home).join("benchmark.jsonl");
265    let line = serde_json::to_string(report).map_err(|e| format!("serialize: {}", e))?;
266    use std::io::Write;
267    let mut f = std::fs::OpenOptions::new()
268        .create(true)
269        .append(true)
270        .open(&path)
271        .map_err(|e| format!("open: {}", e))?;
272    writeln!(f, "{}", line).map_err(|e| format!("write: {}", e))?;
273    Ok(())
274}
275
276/// Load benchmark history.
277pub fn load_benchmarks(agent_home: &str) -> Vec<BenchmarkReport> {
278    let path = Path::new(agent_home).join("benchmark.jsonl");
279    let content = match std::fs::read_to_string(&path) {
280        Ok(c) => c,
281        Err(_) => return Vec::new(),
282    };
283    content
284        .lines()
285        .filter(|l| !l.trim().is_empty())
286        .filter_map(|l| serde_json::from_str(l).ok())
287        .collect()
288}
289
290/// Compare two reports: did we improve? (Karpathy: keep/discard decision)
291pub fn compare(before: &BenchmarkReport, after: &BenchmarkReport) -> &'static str {
292    if after.avg_score > before.avg_score + before.std_dev * 0.5 {
293        "keep" // statistically significant improvement
294    } else if after.avg_score < before.avg_score - before.std_dev * 0.5 {
295        "discard" // regression
296    } else {
297        "neutral" // within noise
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn all_tasks_has_five() {
307        assert_eq!(all_tasks().len(), 5);
308    }
309
310    #[test]
311    fn verify_qa_correct() {
312        let r = BenchmarkResult {
313            name: "qa".into(),
314            steps: 1,
315            completed: true,
316            tool_errors: 0,
317            loop_warnings: 0,
318            output: "The capital of France is Paris.".into(),
319            score: 0.0,
320        };
321        let s = verify_qa(&r);
322        assert!(
323            s > 0.9,
324            "correct answer in 1 step should score >0.9, got {}",
325            s
326        );
327    }
328
329    #[test]
330    fn verify_qa_wrong() {
331        let r = BenchmarkResult {
332            name: "qa".into(),
333            steps: 1,
334            completed: true,
335            tool_errors: 0,
336            loop_warnings: 0,
337            output: "I don't know".into(),
338            score: 0.0,
339        };
340        assert!(verify_qa(&r) < 0.5);
341    }
342
343    #[test]
344    fn verify_qa_not_completed() {
345        let r = BenchmarkResult {
346            name: "qa".into(),
347            steps: 3,
348            completed: false,
349            tool_errors: 1,
350            loop_warnings: 0,
351            output: "".into(),
352            score: 0.0,
353        };
354        assert_eq!(verify_qa(&r), 0.0);
355    }
356
357    #[test]
358    fn compute_report_avg_and_stddev() {
359        let results = vec![
360            BenchmarkResult {
361                name: "a".into(),
362                steps: 1,
363                completed: true,
364                tool_errors: 0,
365                loop_warnings: 0,
366                output: "".into(),
367                score: 0.8,
368            },
369            BenchmarkResult {
370                name: "b".into(),
371                steps: 2,
372                completed: true,
373                tool_errors: 0,
374                loop_warnings: 0,
375                output: "".into(),
376                score: 0.6,
377            },
378        ];
379        let report = compute_report(results, "abc123");
380        assert!((report.avg_score - 0.7).abs() < 0.001);
381        assert!(report.std_dev > 0.0);
382    }
383
384    #[test]
385    fn compare_improvement() {
386        let before = BenchmarkReport {
387            timestamp: 0,
388            commit: "a".into(),
389            results: vec![],
390            avg_score: 0.5,
391            std_dev: 0.1,
392        };
393        let after = BenchmarkReport {
394            timestamp: 1,
395            commit: "b".into(),
396            results: vec![],
397            avg_score: 0.7,
398            std_dev: 0.1,
399        };
400        assert_eq!(compare(&before, &after), "keep");
401    }
402
403    #[test]
404    fn compare_regression() {
405        let before = BenchmarkReport {
406            timestamp: 0,
407            commit: "a".into(),
408            results: vec![],
409            avg_score: 0.8,
410            std_dev: 0.05,
411        };
412        let after = BenchmarkReport {
413            timestamp: 1,
414            commit: "b".into(),
415            results: vec![],
416            avg_score: 0.6,
417            std_dev: 0.05,
418        };
419        assert_eq!(compare(&before, &after), "discard");
420    }
421
422    #[test]
423    fn compare_neutral() {
424        let before = BenchmarkReport {
425            timestamp: 0,
426            commit: "a".into(),
427            results: vec![],
428            avg_score: 0.7,
429            std_dev: 0.15,
430        };
431        let after = BenchmarkReport {
432            timestamp: 1,
433            commit: "b".into(),
434            results: vec![],
435            avg_score: 0.72,
436            std_dev: 0.15,
437        };
438        assert_eq!(compare(&before, &after), "neutral");
439    }
440
441    #[test]
442    fn format_report_markdown() {
443        let report = BenchmarkReport {
444            timestamp: 0,
445            commit: "abc123".into(),
446            results: vec![BenchmarkResult {
447                name: "test".into(),
448                steps: 2,
449                completed: true,
450                tool_errors: 0,
451                loop_warnings: 0,
452                output: "done".into(),
453                score: 0.9,
454            }],
455            avg_score: 0.9,
456            std_dev: 0.0,
457        };
458        let md = format_report(&report);
459        assert!(md.contains("abc123"));
460        assert!(md.contains("0.900"));
461        assert!(md.contains("test"));
462    }
463
464    #[test]
465    fn log_and_load_benchmarks() {
466        let dir = tempfile::tempdir().unwrap();
467        let home = dir.path().to_str().unwrap();
468        let report = BenchmarkReport {
469            timestamp: 12345,
470            commit: "test".into(),
471            results: vec![],
472            avg_score: 0.75,
473            std_dev: 0.1,
474        };
475        log_benchmark(home, &report).unwrap();
476        let history = load_benchmarks(home);
477        assert_eq!(history.len(), 1);
478        assert_eq!(history[0].avg_score, 0.75);
479    }
480}