use std::path::Path;
use crate::error::BenchError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone)]
pub struct Turn {
pub role: Role,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct Scenario {
pub id: String,
pub turns: Vec<Turn>,
pub expected: String,
pub metadata: serde_json::Value,
}
impl Scenario {
#[must_use]
pub fn single(
id: impl Into<String>,
prompt: impl Into<String>,
expected: impl Into<String>,
metadata: serde_json::Value,
) -> Self {
Self {
id: id.into(),
turns: vec![Turn {
role: Role::User,
content: prompt.into(),
}],
expected: expected.into(),
metadata,
}
}
pub fn primary_prompt(&self) -> Result<&str, BenchError> {
self.turns
.iter()
.find(|t| matches!(t.role, Role::User))
.map(|t| t.content.as_str())
.ok_or_else(|| {
BenchError::InvalidFormat(format!("scenario '{}' has no user turn", self.id))
})
}
}
#[derive(Debug, Clone)]
pub struct EvalResult {
pub scenario_id: String,
pub score: f64,
pub passed: bool,
pub details: String,
}
pub trait DatasetLoader {
fn name(&self) -> &'static str;
fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError>;
}
pub trait Evaluator {
fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult;
}
#[must_use]
pub fn token_f1(prediction: &str, reference: &str) -> f64 {
let pred_tokens: std::collections::HashSet<&str> = prediction.split_whitespace().collect();
let ref_tokens: std::collections::HashSet<&str> = reference.split_whitespace().collect();
if pred_tokens.is_empty() || ref_tokens.is_empty() {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let common = pred_tokens.intersection(&ref_tokens).count() as f64;
#[allow(clippy::cast_precision_loss)]
let precision = common / pred_tokens.len() as f64;
#[allow(clippy::cast_precision_loss)]
let recall = common / ref_tokens.len() as f64;
if precision + recall == 0.0 {
return 0.0;
}
2.0 * precision * recall / (precision + recall)
}
#[must_use]
pub fn exact_match(prediction: &str, reference: &str) -> bool {
normalize_basic(prediction) == normalize_basic(reference)
}
#[must_use]
pub fn gaia_normalized_exact_match(prediction: &str, reference: &str) -> bool {
normalize_gaia(prediction) == normalize_gaia(reference)
}
fn normalize_basic(s: &str) -> String {
s.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.collect::<String>()
.to_lowercase()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
fn normalize_gaia(s: &str) -> String {
const ARTICLES: &[&str] = &["a", "an", "the"];
let ascii_mapped: String = s.chars().map(ascii_fold_digit).collect();
let stripped = ascii_mapped
.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.collect::<String>()
.to_lowercase();
stripped
.split_whitespace()
.filter(|tok| !ARTICLES.contains(tok))
.collect::<Vec<_>>()
.join(" ")
}
fn ascii_fold_digit(c: char) -> char {
match c {
'\u{2080}' | '\u{2070}' => '0',
'\u{2081}' | '\u{00B9}' => '1',
'\u{2082}' | '\u{00B2}' => '2',
'\u{2083}' | '\u{00B3}' => '3',
'\u{2084}' | '\u{2074}' => '4',
'\u{2085}' | '\u{2075}' => '5',
'\u{2086}' | '\u{2076}' => '6',
'\u{2087}' | '\u{2077}' => '7',
'\u{2088}' | '\u{2078}' => '8',
'\u{2089}' | '\u{2079}' => '9',
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_f1_identical() {
assert!((token_f1("hello world", "hello world") - 1.0).abs() < f64::EPSILON);
}
#[test]
fn token_f1_no_overlap() {
assert!(token_f1("foo bar", "baz qux") < f64::EPSILON);
}
#[test]
fn token_f1_partial_overlap() {
let f1 = token_f1("hello world foo", "hello world bar");
assert!(f1 > 0.0 && f1 < 1.0);
}
#[test]
fn token_f1_empty_prediction() {
assert!(token_f1("", "hello") < f64::EPSILON);
}
#[test]
fn token_f1_empty_reference() {
assert!(token_f1("hello", "") < f64::EPSILON);
}
#[test]
fn exact_match_identical() {
assert!(exact_match("Hello, World!", "hello world"));
}
#[test]
fn exact_match_differs() {
assert!(!exact_match("foo", "bar"));
}
#[test]
fn exact_match_strips_punctuation() {
assert!(exact_match("answer: yes.", "answer yes"));
}
#[test]
fn gaia_normalized_strips_articles() {
assert!(gaia_normalized_exact_match(
"The quick brown fox",
"quick brown fox"
));
}
#[test]
fn gaia_normalized_strips_a_an() {
assert!(gaia_normalized_exact_match(
"a cat sat on an apple",
"cat sat on apple"
));
}
#[test]
fn gaia_normalized_differs() {
assert!(!gaia_normalized_exact_match("cat", "dog"));
}
#[test]
fn gaia_normalized_subscript_digits_match_ascii() {
assert!(gaia_normalized_exact_match("H\u{2082}O", "H2O"));
}
#[test]
fn single_constructs_one_user_turn() {
let s = Scenario::single("id1", "hello", "world", serde_json::Value::Null);
assert_eq!(s.turns.len(), 1);
assert!(matches!(s.turns[0].role, Role::User));
assert_eq!(s.turns[0].content, "hello");
assert_eq!(s.expected, "world");
}
#[test]
fn primary_prompt_returns_first_user_turn_content() {
let s = Scenario::single("id1", "What year?", "2026", serde_json::Value::Null);
assert_eq!(s.primary_prompt().unwrap(), "What year?");
}
#[test]
fn primary_prompt_skips_leading_assistant_turns() {
let s = Scenario {
id: "id2".into(),
turns: vec![
Turn {
role: Role::Assistant,
content: "I am ready.".into(),
},
Turn {
role: Role::User,
content: "What is Rust?".into(),
},
],
expected: "A systems language".into(),
metadata: serde_json::Value::Null,
};
assert_eq!(s.primary_prompt().unwrap(), "What is Rust?");
}
#[test]
fn primary_prompt_errors_when_no_user_turn() {
let s = Scenario {
id: "id3".into(),
turns: vec![Turn {
role: Role::Assistant,
content: "assistant only".into(),
}],
expected: String::new(),
metadata: serde_json::Value::Null,
};
assert!(s.primary_prompt().is_err());
let empty = Scenario {
id: "id4".into(),
turns: vec![],
expected: String::new(),
metadata: serde_json::Value::Null,
};
assert!(empty.primary_prompt().is_err());
}
}