zeph_bench/loaders/
longmemeval.rs1use std::{
5 io::{BufRead as _, BufReader},
6 path::Path,
7};
8
9use serde::Deserialize;
10
11use crate::{
12 error::BenchError,
13 scenario::{DatasetLoader, EvalResult, Evaluator, Scenario, exact_match, token_f1},
14};
15
16#[derive(Debug, Deserialize)]
17struct LongMemEvalItem {
18 question_id: String,
19 question: String,
20 answer: String,
21 session_id: String,
22 sessions: serde_json::Value,
23}
24
25#[derive(Debug)]
54pub struct LongMemEvalLoader;
55
56impl DatasetLoader for LongMemEvalLoader {
57 fn name(&self) -> &'static str {
58 "longmemeval"
59 }
60
61 fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
67 let file = std::fs::File::open(path)?;
68 let reader = BufReader::new(file);
69
70 let mut scenarios = Vec::new();
71 for (idx, line) in reader.lines().enumerate() {
72 let line = line?;
73 let trimmed = line.trim();
74 if trimmed.is_empty() {
75 continue;
76 }
77 let item: LongMemEvalItem = serde_json::from_str(trimmed)
78 .map_err(|e| BenchError::InvalidFormat(format!("line {}: {e}", idx + 1)))?;
79
80 scenarios.push(Scenario::single(
81 item.question_id,
82 item.question,
83 item.answer,
84 serde_json::json!({
85 "session_id": item.session_id,
86 "sessions": item.sessions,
87 }),
88 ));
89 }
90 Ok(scenarios)
91 }
92}
93
94#[derive(Debug)]
113pub struct LongMemEvalEvaluator;
114
115impl Evaluator for LongMemEvalEvaluator {
116 fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult {
117 let matched = exact_match(agent_response, &scenario.expected);
118 let f1 = token_f1(agent_response, &scenario.expected);
119 let score = if matched { 1.0 } else { f1 };
120 EvalResult {
121 scenario_id: scenario.id.clone(),
122 score,
123 passed: matched,
124 details: format!("exact_match={matched} token_f1={f1:.4}"),
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 const FIXTURE: &str = r#"{"question_id":"q1","question":"What is Rust?","answer":"A systems language","session_id":"s1","sessions":[]}
134{"question_id":"q2","question":"Is it fast?","answer":"Yes","session_id":"s1","sessions":[]}
135{"question_id":"q3","question":"Creator?","answer":"Graydon Hoare","session_id":"s2","sessions":[]}"#;
136
137 fn load_from_str(jsonl: &str) -> Vec<Scenario> {
138 let dir = tempfile::tempdir().unwrap();
139 let path = dir.path().join("longmemeval.jsonl");
140 std::fs::write(&path, jsonl).unwrap();
141 LongMemEvalLoader.load(&path).unwrap()
142 }
143
144 #[test]
145 fn load_parses_scenario_count() {
146 assert_eq!(load_from_str(FIXTURE).len(), 3);
147 }
148
149 #[test]
150 fn load_builds_correct_ids() {
151 let scenarios = load_from_str(FIXTURE);
152 assert_eq!(scenarios[0].id, "q1");
153 assert_eq!(scenarios[1].id, "q2");
154 assert_eq!(scenarios[2].id, "q3");
155 }
156
157 #[test]
158 fn load_maps_prompt_and_expected() {
159 let scenarios = load_from_str(FIXTURE);
160 assert_eq!(scenarios[0].primary_prompt().unwrap(), "What is Rust?");
161 assert_eq!(scenarios[0].expected, "A systems language");
162 }
163
164 #[test]
165 fn load_stores_session_id_in_metadata() {
166 let scenarios = load_from_str(FIXTURE);
167 assert_eq!(scenarios[0].metadata["session_id"], "s1");
168 }
169
170 #[test]
171 fn load_stores_sessions_in_metadata() {
172 let scenarios = load_from_str(FIXTURE);
173 assert!(scenarios[0].metadata["sessions"].is_array());
174 }
175
176 #[test]
177 fn evaluator_exact_match_passes() {
178 let scenarios = load_from_str(FIXTURE);
179 let result = LongMemEvalEvaluator.evaluate(&scenarios[0], "A systems language");
180 assert!(result.passed);
181 assert!((result.score - 1.0).abs() < f64::EPSILON);
182 }
183
184 #[test]
185 fn evaluator_wrong_answer_fails() {
186 let scenarios = load_from_str(FIXTURE);
187 let result = LongMemEvalEvaluator.evaluate(&scenarios[0], "A web framework");
188 assert!(!result.passed);
189 }
190
191 #[test]
192 fn evaluator_partial_overlap_gives_token_f1_score() {
193 let scenarios = load_from_str(FIXTURE);
194 let result = LongMemEvalEvaluator.evaluate(&scenarios[0], "A systems framework");
196 assert!(!result.passed);
197 let expected_f1 = token_f1("A systems framework", "A systems language");
198 assert!((result.score - expected_f1).abs() < f64::EPSILON);
199 }
200
201 #[test]
202 fn evaluator_details_contain_token_f1() {
203 let scenarios = load_from_str(FIXTURE);
204 let result = LongMemEvalEvaluator.evaluate(&scenarios[0], "some answer");
205 assert!(result.details.contains("token_f1="));
206 }
207
208 #[test]
209 fn load_invalid_jsonl_returns_error() {
210 let dir = tempfile::tempdir().unwrap();
211 let path = dir.path().join("bad.jsonl");
212 std::fs::write(&path, "not json\n").unwrap();
213 assert!(LongMemEvalLoader.load(&path).is_err());
214 }
215}