1use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::path::Path;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BenchmarkProblem {
23 pub id: String,
25 pub question: String,
27 pub answer: Answer,
29 pub solution: Option<String>,
31 pub category: Option<String>,
33 pub difficulty: Option<u8>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(untagged)]
40pub enum Answer {
41 Numeric(f64),
43 Text(String),
45 MultipleChoice { correct: char, options: Vec<String> },
47 MultiAnswer(Vec<String>),
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct EvaluationResult {
54 pub problem_id: String,
55 pub correct: bool,
56 pub predicted: String,
57 pub expected: String,
58 pub confidence: f32,
59 pub reasoning_steps: usize,
60 pub latency_ms: u64,
61 pub tokens_used: usize,
62 #[serde(default)]
64 pub category: Option<String>,
65 #[serde(default)]
67 pub difficulty: Option<u8>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct BenchmarkResults {
73 pub benchmark_name: String,
74 pub total_problems: usize,
75 pub correct: usize,
76 pub accuracy: f32,
77 pub avg_confidence: f32,
78 pub avg_latency_ms: f64,
79 pub total_tokens: usize,
80 pub category_accuracy: HashMap<String, f32>,
82 pub difficulty_accuracy: HashMap<u8, f32>,
84 pub results: Vec<EvaluationResult>,
86 pub calibration: CalibrationMetrics,
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct CalibrationMetrics {
93 pub brier_score: f32,
95 pub ece: f32,
97 pub overconfidence_ratio: f32,
99 pub confidence_bins: Vec<ConfidenceBin>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct ConfidenceBin {
105 pub range_start: f32,
106 pub range_end: f32,
107 pub count: usize,
108 pub accuracy: f32,
109}
110
111impl CalibrationMetrics {
112 pub fn compute(results: &[EvaluationResult]) -> Self {
113 if results.is_empty() {
114 return Self::default();
115 }
116
117 let brier_score: f32 = results
119 .iter()
120 .map(|r| {
121 let outcome = if r.correct { 1.0 } else { 0.0 };
122 (r.confidence - outcome).powi(2)
123 })
124 .sum::<f32>()
125 / results.len() as f32;
126
127 let num_bins = 10;
129 let mut bins: Vec<Vec<&EvaluationResult>> = vec![Vec::new(); num_bins];
130
131 for result in results {
132 let bin_idx = ((result.confidence * num_bins as f32) as usize).min(num_bins - 1);
133 bins[bin_idx].push(result);
134 }
135
136 let mut ece = 0.0f32;
137 let mut confidence_bins = Vec::with_capacity(num_bins);
138
139 for (i, bin) in bins.iter().enumerate() {
140 let range_start = i as f32 / num_bins as f32;
141 let range_end = (i + 1) as f32 / num_bins as f32;
142
143 if bin.is_empty() {
144 confidence_bins.push(ConfidenceBin {
145 range_start,
146 range_end,
147 count: 0,
148 accuracy: 0.0,
149 });
150 continue;
151 }
152
153 let bin_accuracy = bin.iter().filter(|r| r.correct).count() as f32 / bin.len() as f32;
154 let bin_confidence: f32 =
155 bin.iter().map(|r| r.confidence).sum::<f32>() / bin.len() as f32;
156
157 ece +=
158 (bin.len() as f32 / results.len() as f32) * (bin_accuracy - bin_confidence).abs();
159
160 confidence_bins.push(ConfidenceBin {
161 range_start,
162 range_end,
163 count: bin.len(),
164 accuracy: bin_accuracy,
165 });
166 }
167
168 let overconfidence_ratio = results
170 .iter()
171 .filter(|r| r.confidence > 0.8 && !r.correct)
172 .count() as f32
173 / results.iter().filter(|r| r.confidence > 0.8).count().max(1) as f32;
174
175 Self {
176 brier_score,
177 ece,
178 overconfidence_ratio,
179 confidence_bins,
180 }
181 }
182}
183
184impl BenchmarkResults {
185 pub fn compute(benchmark_name: &str, results: Vec<EvaluationResult>) -> Self {
186 let total_problems = results.len();
187 let correct = results.iter().filter(|r| r.correct).count();
188 let accuracy = if total_problems > 0 {
189 correct as f32 / total_problems as f32
190 } else {
191 0.0
192 };
193
194 let avg_confidence = if total_problems > 0 {
195 results.iter().map(|r| r.confidence).sum::<f32>() / total_problems as f32
196 } else {
197 0.0
198 };
199
200 let avg_latency_ms = if total_problems > 0 {
201 results.iter().map(|r| r.latency_ms).sum::<u64>() as f64 / total_problems as f64
202 } else {
203 0.0
204 };
205
206 let total_tokens = results.iter().map(|r| r.tokens_used).sum();
207
208 let calibration = CalibrationMetrics::compute(&results);
209
210 let mut category_counts: HashMap<String, (usize, usize)> = HashMap::new();
212 for result in &results {
213 if let Some(ref cat) = result.category {
214 let entry = category_counts.entry(cat.clone()).or_insert((0, 0));
215 entry.0 += 1; if result.correct {
217 entry.1 += 1; }
219 }
220 }
221 let category_accuracy: HashMap<String, f32> = category_counts
222 .into_iter()
223 .map(|(cat, (total, correct))| {
224 (
225 cat,
226 if total > 0 {
227 correct as f32 / total as f32
228 } else {
229 0.0
230 },
231 )
232 })
233 .collect();
234
235 let mut difficulty_counts: HashMap<u8, (usize, usize)> = HashMap::new();
237 for result in &results {
238 if let Some(diff) = result.difficulty {
239 let entry = difficulty_counts.entry(diff).or_insert((0, 0));
240 entry.0 += 1; if result.correct {
242 entry.1 += 1; }
244 }
245 }
246 let difficulty_accuracy: HashMap<u8, f32> = difficulty_counts
247 .into_iter()
248 .map(|(diff, (total, correct))| {
249 (
250 diff,
251 if total > 0 {
252 correct as f32 / total as f32
253 } else {
254 0.0
255 },
256 )
257 })
258 .collect();
259
260 Self {
261 benchmark_name: benchmark_name.to_string(),
262 total_problems,
263 correct,
264 accuracy,
265 avg_confidence,
266 avg_latency_ms,
267 total_tokens,
268 category_accuracy,
269 difficulty_accuracy,
270 results,
271 calibration,
272 }
273 }
274
275 pub fn compare(&self, baseline: &BenchmarkResults) -> ComparisonReport {
277 ComparisonReport {
278 benchmark: self.benchmark_name.clone(),
279 baseline_accuracy: baseline.accuracy,
280 current_accuracy: self.accuracy,
281 delta_accuracy: self.accuracy - baseline.accuracy,
282 baseline_brier: baseline.calibration.brier_score,
283 current_brier: self.calibration.brier_score,
284 delta_brier: self.calibration.brier_score - baseline.calibration.brier_score,
285 latency_ratio: self.avg_latency_ms / baseline.avg_latency_ms.max(1.0),
286 significant_improvement: (self.accuracy - baseline.accuracy) > 0.02,
287 }
288 }
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct ComparisonReport {
293 pub benchmark: String,
294 pub baseline_accuracy: f32,
295 pub current_accuracy: f32,
296 pub delta_accuracy: f32,
297 pub baseline_brier: f32,
298 pub current_brier: f32,
299 pub delta_brier: f32,
300 pub latency_ratio: f64,
301 pub significant_improvement: bool,
302}
303
304impl std::fmt::Display for ComparisonReport {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 let delta_sign = if self.delta_accuracy >= 0.0 { "+" } else { "" };
307 let brier_sign = if self.delta_brier <= 0.0 { "+" } else { "-" };
308
309 write!(
310 f,
311 r#"
312┌─────────────────────────────────────────────────────────────────────┐
313│ BENCHMARK COMPARISON: {}
314├─────────────────────────────────────────────────────────────────────┤
315│ Accuracy: {:.1}% → {:.1}% ({}{:.1}%) {}
316│ Brier Score: {:.3} → {:.3} ({}{:.3})
317│ Latency: {:.1}x baseline
318│ Significant: {}
319└─────────────────────────────────────────────────────────────────────┘"#,
320 self.benchmark,
321 self.baseline_accuracy * 100.0,
322 self.current_accuracy * 100.0,
323 delta_sign,
324 self.delta_accuracy * 100.0,
325 if self.significant_improvement {
326 "✓"
327 } else {
328 "○"
329 },
330 self.baseline_brier,
331 self.current_brier,
332 brier_sign,
333 self.delta_brier.abs(),
334 self.latency_ratio,
335 if self.significant_improvement {
336 "YES - Improvement detected"
337 } else {
338 "NO - Within noise margin"
339 }
340 )
341 }
342}
343
344pub mod gsm8k {
346 use super::*;
347 use std::fs::File;
348 use std::io::{BufRead, BufReader};
349
350 pub fn load_problems(path: impl AsRef<Path>) -> anyhow::Result<Vec<BenchmarkProblem>> {
352 let file = File::open(path)?;
353 let reader = BufReader::new(file);
354 let mut problems = Vec::new();
355
356 for (idx, line) in reader.lines().enumerate() {
357 let line = line?;
358 if line.trim().is_empty() {
359 continue;
360 }
361
362 let raw: serde_json::Value = serde_json::from_str(&line)?;
363
364 let question = raw["question"].as_str().unwrap_or_default().to_string();
365
366 let answer_str = raw["answer"].as_str().unwrap_or_default();
367 let answer = extract_gsm8k_answer(answer_str);
369
370 problems.push(BenchmarkProblem {
371 id: format!("gsm8k_{}", idx),
372 question,
373 answer: Answer::Numeric(answer),
374 solution: Some(answer_str.to_string()),
375 category: None,
376 difficulty: None,
377 });
378 }
379
380 Ok(problems)
381 }
382
383 fn extract_gsm8k_answer(answer_str: &str) -> f64 {
384 if let Some(pos) = answer_str.rfind("####") {
386 let num_str = answer_str[pos + 4..].trim();
387 let cleaned = num_str.replace(',', "");
389 cleaned.parse().unwrap_or(0.0)
390 } else {
391 0.0
392 }
393 }
394
395 pub fn check_answer(predicted: &str, expected: f64) -> bool {
397 let predicted_num = extract_number_from_response(predicted);
399
400 (predicted_num - expected).abs() < 0.01
402 }
403
404 fn extract_number_from_response(response: &str) -> f64 {
405 if let Some(pos) = response.rfind("####") {
407 let after = &response[pos + 4..];
408 if let Some(num) = extract_first_number(after) {
409 return num;
410 }
411 }
412
413 let patterns = ["answer is", "= ", "equals", "result:"];
415 for pattern in patterns {
416 if let Some(pos) = response.to_lowercase().rfind(pattern) {
417 let after = &response[pos + pattern.len()..];
418 if let Some(num) = extract_first_number(after) {
419 return num;
420 }
421 }
422 }
423
424 extract_last_number(response).unwrap_or(0.0)
426 }
427
428 fn extract_first_number(s: &str) -> Option<f64> {
429 let mut num_str = String::new();
430 let mut in_number = false;
431
432 for c in s.chars() {
433 if c.is_ascii_digit() || c == '.' || c == '-' {
434 in_number = true;
435 num_str.push(c);
436 } else if c == ',' && in_number {
437 continue;
439 } else if in_number {
440 break;
441 }
442 }
443
444 num_str.parse().ok()
445 }
446
447 fn extract_last_number(s: &str) -> Option<f64> {
448 let mut last_num = None;
449 let mut current = String::new();
450
451 for c in s.chars() {
452 if c.is_ascii_digit() || c == '.' || c == '-' {
453 current.push(c);
454 } else if c == ',' && !current.is_empty() {
455 continue;
456 } else if !current.is_empty() {
457 if let Ok(n) = current.parse() {
458 last_num = Some(n);
459 }
460 current.clear();
461 }
462 }
463
464 if !current.is_empty() {
465 if let Ok(n) = current.parse() {
466 last_num = Some(n);
467 }
468 }
469
470 last_num
471 }
472
473 #[cfg(test)]
474 mod tests {
475 use super::*;
476
477 #[test]
478 fn test_gsm8k_answer_extraction() {
479 assert_eq!(extract_gsm8k_answer("The answer is #### 42"), 42.0);
480 assert_eq!(
481 extract_gsm8k_answer("Step 1... Step 2... #### 1234"),
482 1234.0
483 );
484 assert_eq!(extract_gsm8k_answer("#### 1,234"), 1234.0);
485 }
486
487 #[test]
488 fn test_check_answer() {
489 assert!(check_answer("The answer is 42", 42.0));
490 assert!(check_answer("#### 42", 42.0));
491 assert!(!check_answer("The answer is 43", 42.0));
492 }
493 }
494}
495
496pub struct BenchmarkRunner {
498 pub problems: Vec<BenchmarkProblem>,
499 pub benchmark_name: String,
500}
501
502impl BenchmarkRunner {
503 pub fn new(benchmark_name: impl Into<String>, problems: Vec<BenchmarkProblem>) -> Self {
504 Self {
505 problems,
506 benchmark_name: benchmark_name.into(),
507 }
508 }
509
510 pub fn gsm8k(path: impl AsRef<Path>) -> anyhow::Result<Self> {
512 let problems = gsm8k::load_problems(path)?;
513 Ok(Self::new("GSM8K", problems))
514 }
515
516 pub async fn run<F, Fut>(&self, evaluator: F, limit: Option<usize>) -> BenchmarkResults
518 where
519 F: Fn(BenchmarkProblem) -> Fut,
520 Fut: std::future::Future<Output = EvaluationResult>,
521 {
522 let problems = match limit {
523 Some(n) => self.problems.iter().take(n).cloned().collect::<Vec<_>>(),
524 None => self.problems.clone(),
525 };
526
527 let mut results = Vec::with_capacity(problems.len());
528
529 for problem in problems {
530 let result = evaluator(problem).await;
531 results.push(result);
532 }
533
534 BenchmarkResults::compute(&self.benchmark_name, results)
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_calibration_metrics() {
544 let results = vec![
545 EvaluationResult {
546 problem_id: "1".into(),
547 correct: true,
548 predicted: "42".into(),
549 expected: "42".into(),
550 confidence: 0.9,
551 reasoning_steps: 3,
552 latency_ms: 100,
553 tokens_used: 500,
554 category: Some("arithmetic".into()),
555 difficulty: Some(1),
556 },
557 EvaluationResult {
558 problem_id: "2".into(),
559 correct: false,
560 predicted: "41".into(),
561 expected: "42".into(),
562 confidence: 0.8,
563 reasoning_steps: 3,
564 latency_ms: 120,
565 tokens_used: 520,
566 category: Some("arithmetic".into()),
567 difficulty: Some(2),
568 },
569 ];
570
571 let metrics = CalibrationMetrics::compute(&results);
572 assert!(metrics.brier_score > 0.0);
573 assert!(metrics.brier_score < 1.0);
574 }
575
576 #[test]
577 fn test_comparison_report() {
578 let baseline = BenchmarkResults {
579 benchmark_name: "GSM8K".into(),
580 total_problems: 100,
581 correct: 78,
582 accuracy: 0.78,
583 avg_confidence: 0.75,
584 avg_latency_ms: 500.0,
585 total_tokens: 50000,
586 category_accuracy: HashMap::new(),
587 difficulty_accuracy: HashMap::new(),
588 results: vec![],
589 calibration: CalibrationMetrics::default(),
590 };
591
592 let improved = BenchmarkResults {
593 benchmark_name: "GSM8K".into(),
594 total_problems: 100,
595 correct: 86,
596 accuracy: 0.86,
597 avg_confidence: 0.82,
598 avg_latency_ms: 800.0,
599 total_tokens: 75000,
600 category_accuracy: HashMap::new(),
601 difficulty_accuracy: HashMap::new(),
602 results: vec![],
603 calibration: CalibrationMetrics::default(),
604 };
605
606 let report = improved.compare(&baseline);
607 assert!(report.significant_improvement);
608 assert!((report.delta_accuracy - 0.08).abs() < 0.001);
609 }
610}