use std::path::Path;
use serde::Deserialize;
use crate::{
error::BenchError,
scenario::{DatasetLoader, EvalResult, Evaluator, Scenario, token_f1},
};
const PASS_THRESHOLD: f64 = 0.5;
#[derive(Debug, Deserialize)]
struct LocomoSession {
session_id: String,
qa: Vec<LocomoQa>,
}
#[derive(Debug, Deserialize)]
struct LocomoQa {
question: String,
answer: String,
}
#[derive(Debug)]
pub struct LocomoLoader;
impl DatasetLoader for LocomoLoader {
fn name(&self) -> &'static str {
"locomo"
}
fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
let content = std::fs::read_to_string(path)?;
let sessions: Vec<LocomoSession> =
serde_json::from_str(&content).map_err(|e| BenchError::InvalidFormat(e.to_string()))?;
let mut scenarios = Vec::new();
for session in sessions {
for (idx, qa) in session.qa.iter().enumerate() {
scenarios.push(Scenario::single(
format!("{}_{}", session.session_id, idx),
qa.question.clone(),
qa.answer.clone(),
serde_json::Value::Null,
));
}
}
Ok(scenarios)
}
}
#[derive(Debug)]
pub struct LocomoEvaluator;
impl Evaluator for LocomoEvaluator {
fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult {
let normalized_response = normalize_for_f1(agent_response);
let normalized_expected = normalize_for_f1(&scenario.expected);
let score = token_f1(&normalized_response, &normalized_expected);
EvalResult {
scenario_id: scenario.id.clone(),
score,
passed: score >= PASS_THRESHOLD,
details: format!("token_f1={score:.4}"),
}
}
}
fn normalize_for_f1(s: &str) -> String {
s.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.collect::<String>()
.to_lowercase()
}
#[cfg(test)]
mod tests {
use super::*;
const FIXTURE: &str = r#"[
{
"session_id": "s1",
"qa": [
{"question": "What is Rust?", "answer": "A systems programming language"},
{"question": "Is it fast?", "answer": "Yes"}
]
}
]"#;
fn load_from_str(json: &str) -> Vec<Scenario> {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("locomo.json");
std::fs::write(&path, json).unwrap();
LocomoLoader.load(&path).unwrap()
}
#[test]
fn load_parses_scenario_count() {
let scenarios = load_from_str(FIXTURE);
assert_eq!(scenarios.len(), 2);
}
#[test]
fn load_builds_correct_ids() {
let scenarios = load_from_str(FIXTURE);
assert_eq!(scenarios[0].id, "s1_0");
assert_eq!(scenarios[1].id, "s1_1");
}
#[test]
fn load_maps_prompt_and_expected() {
let scenarios = load_from_str(FIXTURE);
assert_eq!(scenarios[0].primary_prompt().unwrap(), "What is Rust?");
assert_eq!(scenarios[0].expected, "A systems programming language");
}
#[test]
fn evaluator_perfect_match_passes() {
let scenarios = load_from_str(FIXTURE);
let result = LocomoEvaluator.evaluate(&scenarios[0], "A systems programming language");
assert!((result.score - 1.0).abs() < f64::EPSILON);
assert!(result.passed);
}
#[test]
fn evaluator_no_match_fails() {
let scenarios = load_from_str(FIXTURE);
let result = LocomoEvaluator.evaluate(&scenarios[0], "completely different response xyz");
assert!(!result.passed);
}
#[test]
fn load_invalid_json_returns_error() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("bad.json");
std::fs::write(&path, "not json").unwrap();
assert!(LocomoLoader.load(&path).is_err());
}
}