reasonkit/thinktool/
benchmark.rs

1//! # Benchmark Harness for ThinkTools Evaluation
2//!
3//! Provides infrastructure to measure reasoning quality improvements
4//! against established benchmarks (GSM8K, MATH, TruthfulQA, etc.)
5//!
6//! ## Supported Benchmarks
7//!
8//! | Benchmark | Type | Metric | Target |
9//! |-----------|------|--------|--------|
10//! | GSM8K | Math reasoning | Accuracy | 85.9% |
11//! | MATH | Advanced math | Accuracy | 36.5% |
12//! | TruthfulQA | Factuality | MC1/MC2 | 72% |
13//! | Game of 24 | Creative | Success rate | 60%+ |
14//! | ARC-C | Science | Accuracy | 90% |
15
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::path::Path;
19
20/// Benchmark problem from evaluation set
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BenchmarkProblem {
23    /// Unique identifier
24    pub id: String,
25    /// Problem statement
26    pub question: String,
27    /// Expected answer(s)
28    pub answer: Answer,
29    /// Optional solution steps
30    pub solution: Option<String>,
31    /// Problem category/topic
32    pub category: Option<String>,
33    /// Difficulty level (1-5)
34    pub difficulty: Option<u8>,
35}
36
37/// Answer type - handles different benchmark formats
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(untagged)]
40pub enum Answer {
41    /// Numeric answer (GSM8K, MATH)
42    Numeric(f64),
43    /// Text answer
44    Text(String),
45    /// Multiple choice (ARC, TruthfulQA)
46    MultipleChoice { correct: char, options: Vec<String> },
47    /// List of acceptable answers
48    MultiAnswer(Vec<String>),
49}
50
51/// Result of evaluating a single problem
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct EvaluationResult {
54    pub problem_id: String,
55    pub correct: bool,
56    pub predicted: String,
57    pub expected: String,
58    pub confidence: f32,
59    pub reasoning_steps: usize,
60    pub latency_ms: u64,
61    pub tokens_used: usize,
62    /// Problem category for category-level accuracy
63    #[serde(default)]
64    pub category: Option<String>,
65    /// Problem difficulty for difficulty-level accuracy
66    #[serde(default)]
67    pub difficulty: Option<u8>,
68}
69
70/// Aggregate benchmark results
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct BenchmarkResults {
73    pub benchmark_name: String,
74    pub total_problems: usize,
75    pub correct: usize,
76    pub accuracy: f32,
77    pub avg_confidence: f32,
78    pub avg_latency_ms: f64,
79    pub total_tokens: usize,
80    /// Accuracy by category
81    pub category_accuracy: HashMap<String, f32>,
82    /// Accuracy by difficulty
83    pub difficulty_accuracy: HashMap<u8, f32>,
84    /// Individual results
85    pub results: Vec<EvaluationResult>,
86    /// Calibration metrics
87    pub calibration: CalibrationMetrics,
88}
89
90/// Calibration metrics for confidence assessment
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct CalibrationMetrics {
93    /// Brier score (lower is better, 0 = perfect)
94    pub brier_score: f32,
95    /// Expected calibration error
96    pub ece: f32,
97    /// Overconfidence ratio (predictions with high conf but wrong)
98    pub overconfidence_ratio: f32,
99    /// Confidence histogram bins
100    pub confidence_bins: Vec<ConfidenceBin>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ConfidenceBin {
105    pub range_start: f32,
106    pub range_end: f32,
107    pub count: usize,
108    pub accuracy: f32,
109}
110
111impl CalibrationMetrics {
112    pub fn compute(results: &[EvaluationResult]) -> Self {
113        if results.is_empty() {
114            return Self::default();
115        }
116
117        // Brier score
118        let brier_score: f32 = results
119            .iter()
120            .map(|r| {
121                let outcome = if r.correct { 1.0 } else { 0.0 };
122                (r.confidence - outcome).powi(2)
123            })
124            .sum::<f32>()
125            / results.len() as f32;
126
127        // ECE with 10 bins
128        let num_bins = 10;
129        let mut bins: Vec<Vec<&EvaluationResult>> = vec![Vec::new(); num_bins];
130
131        for result in results {
132            let bin_idx = ((result.confidence * num_bins as f32) as usize).min(num_bins - 1);
133            bins[bin_idx].push(result);
134        }
135
136        let mut ece = 0.0f32;
137        let mut confidence_bins = Vec::with_capacity(num_bins);
138
139        for (i, bin) in bins.iter().enumerate() {
140            let range_start = i as f32 / num_bins as f32;
141            let range_end = (i + 1) as f32 / num_bins as f32;
142
143            if bin.is_empty() {
144                confidence_bins.push(ConfidenceBin {
145                    range_start,
146                    range_end,
147                    count: 0,
148                    accuracy: 0.0,
149                });
150                continue;
151            }
152
153            let bin_accuracy = bin.iter().filter(|r| r.correct).count() as f32 / bin.len() as f32;
154            let bin_confidence: f32 =
155                bin.iter().map(|r| r.confidence).sum::<f32>() / bin.len() as f32;
156
157            ece +=
158                (bin.len() as f32 / results.len() as f32) * (bin_accuracy - bin_confidence).abs();
159
160            confidence_bins.push(ConfidenceBin {
161                range_start,
162                range_end,
163                count: bin.len(),
164                accuracy: bin_accuracy,
165            });
166        }
167
168        // Overconfidence ratio: high confidence (>0.8) but wrong
169        let overconfidence_ratio = results
170            .iter()
171            .filter(|r| r.confidence > 0.8 && !r.correct)
172            .count() as f32
173            / results.iter().filter(|r| r.confidence > 0.8).count().max(1) as f32;
174
175        Self {
176            brier_score,
177            ece,
178            overconfidence_ratio,
179            confidence_bins,
180        }
181    }
182}
183
184impl BenchmarkResults {
185    pub fn compute(benchmark_name: &str, results: Vec<EvaluationResult>) -> Self {
186        let total_problems = results.len();
187        let correct = results.iter().filter(|r| r.correct).count();
188        let accuracy = if total_problems > 0 {
189            correct as f32 / total_problems as f32
190        } else {
191            0.0
192        };
193
194        let avg_confidence = if total_problems > 0 {
195            results.iter().map(|r| r.confidence).sum::<f32>() / total_problems as f32
196        } else {
197            0.0
198        };
199
200        let avg_latency_ms = if total_problems > 0 {
201            results.iter().map(|r| r.latency_ms).sum::<u64>() as f64 / total_problems as f64
202        } else {
203            0.0
204        };
205
206        let total_tokens = results.iter().map(|r| r.tokens_used).sum();
207
208        let calibration = CalibrationMetrics::compute(&results);
209
210        // Compute category-level accuracy
211        let mut category_counts: HashMap<String, (usize, usize)> = HashMap::new();
212        for result in &results {
213            if let Some(ref cat) = result.category {
214                let entry = category_counts.entry(cat.clone()).or_insert((0, 0));
215                entry.0 += 1; // total
216                if result.correct {
217                    entry.1 += 1; // correct
218                }
219            }
220        }
221        let category_accuracy: HashMap<String, f32> = category_counts
222            .into_iter()
223            .map(|(cat, (total, correct))| {
224                (
225                    cat,
226                    if total > 0 {
227                        correct as f32 / total as f32
228                    } else {
229                        0.0
230                    },
231                )
232            })
233            .collect();
234
235        // Compute difficulty-level accuracy
236        let mut difficulty_counts: HashMap<u8, (usize, usize)> = HashMap::new();
237        for result in &results {
238            if let Some(diff) = result.difficulty {
239                let entry = difficulty_counts.entry(diff).or_insert((0, 0));
240                entry.0 += 1; // total
241                if result.correct {
242                    entry.1 += 1; // correct
243                }
244            }
245        }
246        let difficulty_accuracy: HashMap<u8, f32> = difficulty_counts
247            .into_iter()
248            .map(|(diff, (total, correct))| {
249                (
250                    diff,
251                    if total > 0 {
252                        correct as f32 / total as f32
253                    } else {
254                        0.0
255                    },
256                )
257            })
258            .collect();
259
260        Self {
261            benchmark_name: benchmark_name.to_string(),
262            total_problems,
263            correct,
264            accuracy,
265            avg_confidence,
266            avg_latency_ms,
267            total_tokens,
268            category_accuracy,
269            difficulty_accuracy,
270            results,
271            calibration,
272        }
273    }
274
275    /// Generate a comparison report against baseline
276    pub fn compare(&self, baseline: &BenchmarkResults) -> ComparisonReport {
277        ComparisonReport {
278            benchmark: self.benchmark_name.clone(),
279            baseline_accuracy: baseline.accuracy,
280            current_accuracy: self.accuracy,
281            delta_accuracy: self.accuracy - baseline.accuracy,
282            baseline_brier: baseline.calibration.brier_score,
283            current_brier: self.calibration.brier_score,
284            delta_brier: self.calibration.brier_score - baseline.calibration.brier_score,
285            latency_ratio: self.avg_latency_ms / baseline.avg_latency_ms.max(1.0),
286            significant_improvement: (self.accuracy - baseline.accuracy) > 0.02,
287        }
288    }
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct ComparisonReport {
293    pub benchmark: String,
294    pub baseline_accuracy: f32,
295    pub current_accuracy: f32,
296    pub delta_accuracy: f32,
297    pub baseline_brier: f32,
298    pub current_brier: f32,
299    pub delta_brier: f32,
300    pub latency_ratio: f64,
301    pub significant_improvement: bool,
302}
303
304impl std::fmt::Display for ComparisonReport {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        let delta_sign = if self.delta_accuracy >= 0.0 { "+" } else { "" };
307        let brier_sign = if self.delta_brier <= 0.0 { "+" } else { "-" };
308
309        write!(
310            f,
311            r#"
312┌─────────────────────────────────────────────────────────────────────┐
313│ BENCHMARK COMPARISON: {}
314├─────────────────────────────────────────────────────────────────────┤
315│ Accuracy:    {:.1}% → {:.1}% ({}{:.1}%)  {}
316│ Brier Score: {:.3} → {:.3} ({}{:.3})
317│ Latency:     {:.1}x baseline
318│ Significant: {}
319└─────────────────────────────────────────────────────────────────────┘"#,
320            self.benchmark,
321            self.baseline_accuracy * 100.0,
322            self.current_accuracy * 100.0,
323            delta_sign,
324            self.delta_accuracy * 100.0,
325            if self.significant_improvement {
326                "✓"
327            } else {
328                "○"
329            },
330            self.baseline_brier,
331            self.current_brier,
332            brier_sign,
333            self.delta_brier.abs(),
334            self.latency_ratio,
335            if self.significant_improvement {
336                "YES - Improvement detected"
337            } else {
338                "NO - Within noise margin"
339            }
340        )
341    }
342}
343
344/// GSM8K-specific loader
345pub mod gsm8k {
346    use super::*;
347    use std::fs::File;
348    use std::io::{BufRead, BufReader};
349
350    /// Load GSM8K problems from JSONL file
351    pub fn load_problems(path: impl AsRef<Path>) -> anyhow::Result<Vec<BenchmarkProblem>> {
352        let file = File::open(path)?;
353        let reader = BufReader::new(file);
354        let mut problems = Vec::new();
355
356        for (idx, line) in reader.lines().enumerate() {
357            let line = line?;
358            if line.trim().is_empty() {
359                continue;
360            }
361
362            let raw: serde_json::Value = serde_json::from_str(&line)?;
363
364            let question = raw["question"].as_str().unwrap_or_default().to_string();
365
366            let answer_str = raw["answer"].as_str().unwrap_or_default();
367            // GSM8K answers end with #### <number>
368            let answer = extract_gsm8k_answer(answer_str);
369
370            problems.push(BenchmarkProblem {
371                id: format!("gsm8k_{}", idx),
372                question,
373                answer: Answer::Numeric(answer),
374                solution: Some(answer_str.to_string()),
375                category: None,
376                difficulty: None,
377            });
378        }
379
380        Ok(problems)
381    }
382
383    fn extract_gsm8k_answer(answer_str: &str) -> f64 {
384        // GSM8K format: "... #### 42"
385        if let Some(pos) = answer_str.rfind("####") {
386            let num_str = answer_str[pos + 4..].trim();
387            // Remove commas from numbers like "1,234"
388            let cleaned = num_str.replace(',', "");
389            cleaned.parse().unwrap_or(0.0)
390        } else {
391            0.0
392        }
393    }
394
395    /// Check if model answer matches expected
396    pub fn check_answer(predicted: &str, expected: f64) -> bool {
397        // Extract number from predicted answer
398        let predicted_num = extract_number_from_response(predicted);
399
400        // Allow small floating point tolerance
401        (predicted_num - expected).abs() < 0.01
402    }
403
404    fn extract_number_from_response(response: &str) -> f64 {
405        // Try to find #### marker first
406        if let Some(pos) = response.rfind("####") {
407            let after = &response[pos + 4..];
408            if let Some(num) = extract_first_number(after) {
409                return num;
410            }
411        }
412
413        // Try "answer is" pattern
414        let patterns = ["answer is", "= ", "equals", "result:"];
415        for pattern in patterns {
416            if let Some(pos) = response.to_lowercase().rfind(pattern) {
417                let after = &response[pos + pattern.len()..];
418                if let Some(num) = extract_first_number(after) {
419                    return num;
420                }
421            }
422        }
423
424        // Last resort: find last number in response
425        extract_last_number(response).unwrap_or(0.0)
426    }
427
428    fn extract_first_number(s: &str) -> Option<f64> {
429        let mut num_str = String::new();
430        let mut in_number = false;
431
432        for c in s.chars() {
433            if c.is_ascii_digit() || c == '.' || c == '-' {
434                in_number = true;
435                num_str.push(c);
436            } else if c == ',' && in_number {
437                // Skip commas in numbers
438                continue;
439            } else if in_number {
440                break;
441            }
442        }
443
444        num_str.parse().ok()
445    }
446
447    fn extract_last_number(s: &str) -> Option<f64> {
448        let mut last_num = None;
449        let mut current = String::new();
450
451        for c in s.chars() {
452            if c.is_ascii_digit() || c == '.' || c == '-' {
453                current.push(c);
454            } else if c == ',' && !current.is_empty() {
455                continue;
456            } else if !current.is_empty() {
457                if let Ok(n) = current.parse() {
458                    last_num = Some(n);
459                }
460                current.clear();
461            }
462        }
463
464        if !current.is_empty() {
465            if let Ok(n) = current.parse() {
466                last_num = Some(n);
467            }
468        }
469
470        last_num
471    }
472
473    #[cfg(test)]
474    mod tests {
475        use super::*;
476
477        #[test]
478        fn test_gsm8k_answer_extraction() {
479            assert_eq!(extract_gsm8k_answer("The answer is #### 42"), 42.0);
480            assert_eq!(
481                extract_gsm8k_answer("Step 1... Step 2... #### 1234"),
482                1234.0
483            );
484            assert_eq!(extract_gsm8k_answer("#### 1,234"), 1234.0);
485        }
486
487        #[test]
488        fn test_check_answer() {
489            assert!(check_answer("The answer is 42", 42.0));
490            assert!(check_answer("#### 42", 42.0));
491            assert!(!check_answer("The answer is 43", 42.0));
492        }
493    }
494}
495
496/// Benchmark runner
497pub struct BenchmarkRunner {
498    pub problems: Vec<BenchmarkProblem>,
499    pub benchmark_name: String,
500}
501
502impl BenchmarkRunner {
503    pub fn new(benchmark_name: impl Into<String>, problems: Vec<BenchmarkProblem>) -> Self {
504        Self {
505            problems,
506            benchmark_name: benchmark_name.into(),
507        }
508    }
509
510    /// Load GSM8K benchmark
511    pub fn gsm8k(path: impl AsRef<Path>) -> anyhow::Result<Self> {
512        let problems = gsm8k::load_problems(path)?;
513        Ok(Self::new("GSM8K", problems))
514    }
515
516    /// Run evaluation with a given evaluator function
517    pub async fn run<F, Fut>(&self, evaluator: F, limit: Option<usize>) -> BenchmarkResults
518    where
519        F: Fn(BenchmarkProblem) -> Fut,
520        Fut: std::future::Future<Output = EvaluationResult>,
521    {
522        let problems = match limit {
523            Some(n) => self.problems.iter().take(n).cloned().collect::<Vec<_>>(),
524            None => self.problems.clone(),
525        };
526
527        let mut results = Vec::with_capacity(problems.len());
528
529        for problem in problems {
530            let result = evaluator(problem).await;
531            results.push(result);
532        }
533
534        BenchmarkResults::compute(&self.benchmark_name, results)
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_calibration_metrics() {
544        let results = vec![
545            EvaluationResult {
546                problem_id: "1".into(),
547                correct: true,
548                predicted: "42".into(),
549                expected: "42".into(),
550                confidence: 0.9,
551                reasoning_steps: 3,
552                latency_ms: 100,
553                tokens_used: 500,
554                category: Some("arithmetic".into()),
555                difficulty: Some(1),
556            },
557            EvaluationResult {
558                problem_id: "2".into(),
559                correct: false,
560                predicted: "41".into(),
561                expected: "42".into(),
562                confidence: 0.8,
563                reasoning_steps: 3,
564                latency_ms: 120,
565                tokens_used: 520,
566                category: Some("arithmetic".into()),
567                difficulty: Some(2),
568            },
569        ];
570
571        let metrics = CalibrationMetrics::compute(&results);
572        assert!(metrics.brier_score > 0.0);
573        assert!(metrics.brier_score < 1.0);
574    }
575
576    #[test]
577    fn test_comparison_report() {
578        let baseline = BenchmarkResults {
579            benchmark_name: "GSM8K".into(),
580            total_problems: 100,
581            correct: 78,
582            accuracy: 0.78,
583            avg_confidence: 0.75,
584            avg_latency_ms: 500.0,
585            total_tokens: 50000,
586            category_accuracy: HashMap::new(),
587            difficulty_accuracy: HashMap::new(),
588            results: vec![],
589            calibration: CalibrationMetrics::default(),
590        };
591
592        let improved = BenchmarkResults {
593            benchmark_name: "GSM8K".into(),
594            total_problems: 100,
595            correct: 86,
596            accuracy: 0.86,
597            avg_confidence: 0.82,
598            avg_latency_ms: 800.0,
599            total_tokens: 75000,
600            category_accuracy: HashMap::new(),
601            difficulty_accuracy: HashMap::new(),
602            results: vec![],
603            calibration: CalibrationMetrics::default(),
604        };
605
606        let report = improved.compare(&baseline);
607        assert!(report.significant_improvement);
608        assert!((report.delta_accuracy - 0.08).abs() < 0.001);
609    }
610}