use crate::dataset::{EvalDataset, EvalExample};
pub struct Gsm8kEvaluator;
impl Gsm8kEvaluator {
pub fn new() -> Self {
Self
}
pub fn extract_final_answer(text: &str) -> Option<f64> {
let marker = "####";
let marker_pos = text.rfind(marker)?;
let after = &text[marker_pos + marker.len()..];
let trimmed = after.trim();
let no_commas: String = trimmed.chars().filter(|&c| c != ',').collect();
no_commas.parse::<f64>().ok()
}
pub fn score(&self, completion: &str, gold: &str) -> bool {
let Some(pred_val) = Self::extract_final_answer(completion) else {
return false;
};
let Some(gold_val) = Self::extract_final_answer(gold) else {
return false;
};
let tol = 1e-6_f64 * pred_val.abs().max(1.0);
(pred_val - gold_val).abs() < tol
}
pub fn evaluate_dataset(&self, dataset: &EvalDataset, completions: &[String]) -> Gsm8kResult {
let mut correct: usize = 0;
let mut total: usize = 0;
let mut no_answer_extracted: usize = 0;
for (example, completion) in dataset.examples.iter().zip(completions.iter()) {
let Some(ref gold) = example.expected_output else {
continue;
};
total += 1;
if Self::extract_final_answer(completion).is_none() {
no_answer_extracted += 1;
continue;
}
if self.score(completion, gold) {
correct += 1;
}
}
let accuracy = if total == 0 {
0.0_f32
} else {
correct as f32 / total as f32
};
Gsm8kResult {
correct,
total,
accuracy,
no_answer_extracted,
}
}
}
impl Default for Gsm8kEvaluator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Gsm8kResult {
pub correct: usize,
pub total: usize,
pub accuracy: f32,
pub no_answer_extracted: usize,
}
impl Gsm8kResult {
pub fn accuracy_pct(&self) -> f32 {
self.accuracy * 100.0
}
pub fn no_answer_rate(&self) -> f32 {
if self.total == 0 {
0.0
} else {
self.no_answer_extracted as f32 / self.total as f32
}
}
}
pub fn gsm8k_example(id: &str, input: &str, gold_answer: &str) -> EvalExample {
EvalExample {
id: id.to_string(),
input: input.to_string(),
expected_output: Some(gold_answer.to_string()),
metadata: std::collections::HashMap::new(),
}
}