zeph_bench/loaders/
locomo.rs1use std::path::Path;
5
6use serde::Deserialize;
7
8use crate::{
9 error::BenchError,
10 scenario::{DatasetLoader, EvalResult, Evaluator, Scenario, token_f1},
11};
12
13const PASS_THRESHOLD: f64 = 0.5;
14
15#[derive(Debug, Deserialize)]
16struct LocomoSession {
17 session_id: String,
18 qa: Vec<LocomoQa>,
19}
20
21#[derive(Debug, Deserialize)]
22struct LocomoQa {
23 question: String,
24 answer: String,
25}
26
27#[derive(Debug)]
58pub struct LocomoLoader;
59
60impl DatasetLoader for LocomoLoader {
61 fn name(&self) -> &'static str {
62 "locomo"
63 }
64
65 fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
70 let content = std::fs::read_to_string(path)?;
71 let sessions: Vec<LocomoSession> =
72 serde_json::from_str(&content).map_err(|e| BenchError::InvalidFormat(e.to_string()))?;
73
74 let mut scenarios = Vec::new();
75 for session in sessions {
76 for (idx, qa) in session.qa.iter().enumerate() {
77 scenarios.push(Scenario::single(
78 format!("{}_{}", session.session_id, idx),
79 qa.question.clone(),
80 qa.answer.clone(),
81 serde_json::Value::Null,
82 ));
83 }
84 }
85 Ok(scenarios)
86 }
87}
88
89#[derive(Debug)]
116pub struct LocomoEvaluator;
117
118impl Evaluator for LocomoEvaluator {
119 fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult {
120 let normalized_response = normalize_for_f1(agent_response);
123 let normalized_expected = normalize_for_f1(&scenario.expected);
124 let score = token_f1(&normalized_response, &normalized_expected);
125 EvalResult {
126 scenario_id: scenario.id.clone(),
127 score,
128 passed: score >= PASS_THRESHOLD,
129 details: format!("token_f1={score:.4}"),
130 }
131 }
132}
133
134fn normalize_for_f1(s: &str) -> String {
140 s.chars()
141 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
142 .collect::<String>()
143 .to_lowercase()
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 const FIXTURE: &str = r#"[
151 {
152 "session_id": "s1",
153 "qa": [
154 {"question": "What is Rust?", "answer": "A systems programming language"},
155 {"question": "Is it fast?", "answer": "Yes"}
156 ]
157 }
158 ]"#;
159
160 fn load_from_str(json: &str) -> Vec<Scenario> {
161 let dir = tempfile::tempdir().unwrap();
162 let path = dir.path().join("locomo.json");
163 std::fs::write(&path, json).unwrap();
164 LocomoLoader.load(&path).unwrap()
165 }
166
167 #[test]
168 fn load_parses_scenario_count() {
169 let scenarios = load_from_str(FIXTURE);
170 assert_eq!(scenarios.len(), 2);
171 }
172
173 #[test]
174 fn load_builds_correct_ids() {
175 let scenarios = load_from_str(FIXTURE);
176 assert_eq!(scenarios[0].id, "s1_0");
177 assert_eq!(scenarios[1].id, "s1_1");
178 }
179
180 #[test]
181 fn load_maps_prompt_and_expected() {
182 let scenarios = load_from_str(FIXTURE);
183 assert_eq!(scenarios[0].primary_prompt().unwrap(), "What is Rust?");
184 assert_eq!(scenarios[0].expected, "A systems programming language");
185 }
186
187 #[test]
188 fn evaluator_perfect_match_passes() {
189 let scenarios = load_from_str(FIXTURE);
190 let result = LocomoEvaluator.evaluate(&scenarios[0], "A systems programming language");
191 assert!((result.score - 1.0).abs() < f64::EPSILON);
192 assert!(result.passed);
193 }
194
195 #[test]
196 fn evaluator_no_match_fails() {
197 let scenarios = load_from_str(FIXTURE);
198 let result = LocomoEvaluator.evaluate(&scenarios[0], "completely different response xyz");
199 assert!(!result.passed);
200 }
201
202 #[test]
203 fn load_invalid_json_returns_error() {
204 let dir = tempfile::tempdir().unwrap();
205 let path = dir.path().join("bad.json");
206 std::fs::write(&path, "not json").unwrap();
207 assert!(LocomoLoader.load(&path).is_err());
208 }
209}