Skip to main content

zeph_bench/loaders/
gaia.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::{
5    io::{BufRead as _, BufReader},
6    path::Path,
7};
8
9use serde::Deserialize;
10
11use crate::{
12    error::BenchError,
13    scenario::{DatasetLoader, EvalResult, Evaluator, Scenario, gaia_normalized_exact_match},
14};
15
16#[derive(Debug, Deserialize)]
17struct GaiaRecord {
18    task_id: String,
19    #[serde(rename = "Question")]
20    question: String,
21    #[serde(rename = "Level")]
22    level: u8,
23    #[serde(rename = "Final answer")]
24    final_answer: String,
25    #[serde(rename = "Annotator Metadata")]
26    annotator_metadata: Option<serde_json::Value>,
27}
28
29/// Loads GAIA benchmark scenarios from a JSONL file with an optional level filter.
30///
31/// **Source**: [`gaia-benchmark/GAIA`](https://huggingface.co/datasets/gaia-benchmark/GAIA)
32/// on `HuggingFace`.
33///
34/// **Schema**: one JSON object per line:
35/// ```json
36/// {
37///   "task_id": "...",
38///   "Question": "...",
39///   "Level": 1,
40///   "Final answer": "...",
41///   "Annotator Metadata": { ... }
42/// }
43/// ```
44///
45/// Scenarios are mapped as:
46/// - `id` — `task_id`.
47/// - `prompt` — `Question`.
48/// - `expected` — `Final answer`.
49/// - `metadata` — `{"level": N, "annotator_metadata": {...}}`.
50///
51/// When [`level`][GaiaLoader::level] is `Some(n)`, only lines whose `Level` field
52/// equals `n` are returned.
53///
54/// # Examples
55///
56/// ```no_run
57/// use std::path::Path;
58/// use zeph_bench::loaders::GaiaLoader;
59/// use zeph_bench::scenario::DatasetLoader;
60///
61/// // Load all levels.
62/// let all = GaiaLoader::all_levels().load(Path::new("/data/gaia.jsonl")).unwrap();
63///
64/// // Load only level-1 tasks.
65/// let easy = GaiaLoader::with_level(1).load(Path::new("/data/gaia.jsonl")).unwrap();
66/// assert!(easy.len() <= all.len());
67/// ```
68#[derive(Debug)]
69pub struct GaiaLoader {
70    /// Optional level filter. When `Some(n)`, only scenarios where `Level == n` are loaded.
71    pub level: Option<u8>,
72}
73
74impl GaiaLoader {
75    /// Create a loader that loads scenarios from all difficulty levels.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use zeph_bench::loaders::GaiaLoader;
81    ///
82    /// let loader = GaiaLoader::all_levels();
83    /// assert!(loader.level.is_none());
84    /// ```
85    #[must_use]
86    pub fn all_levels() -> Self {
87        Self { level: None }
88    }
89
90    /// Create a loader that only loads scenarios whose `Level` field equals `level`.
91    ///
92    /// GAIA levels run from 1 (easy) to 3 (hard).
93    ///
94    /// # Examples
95    ///
96    /// ```
97    /// use zeph_bench::loaders::GaiaLoader;
98    ///
99    /// let loader = GaiaLoader::with_level(2);
100    /// assert_eq!(loader.level, Some(2));
101    /// ```
102    #[must_use]
103    pub fn with_level(level: u8) -> Self {
104        Self { level: Some(level) }
105    }
106}
107
108impl DatasetLoader for GaiaLoader {
109    fn name(&self) -> &'static str {
110        "gaia"
111    }
112
113    /// # Errors
114    ///
115    /// Returns [`BenchError::Io`] when the file cannot be read and
116    /// [`BenchError::InvalidFormat`] when a JSONL line cannot be parsed.
117    fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
118        let file = std::fs::File::open(path)?;
119        let reader = BufReader::new(file);
120
121        let mut scenarios = Vec::new();
122        for (line_number, line) in reader.lines().enumerate() {
123            let line = line?;
124            let trimmed = line.trim();
125            if trimmed.is_empty() {
126                continue;
127            }
128            let record: GaiaRecord = serde_json::from_str(trimmed)
129                .map_err(|e| BenchError::InvalidFormat(format!("line {line_number}: {e}")))?;
130
131            if let Some(filter_level) = self.level
132                && record.level != filter_level
133            {
134                continue;
135            }
136
137            let metadata = serde_json::json!({
138                "level": record.level,
139                "annotator_metadata": record.annotator_metadata,
140            });
141
142            scenarios.push(Scenario::single(
143                record.task_id,
144                record.question,
145                record.final_answer,
146                metadata,
147            ));
148        }
149        Ok(scenarios)
150    }
151}
152
153/// Evaluates GAIA responses using GAIA-normalized exact match.
154///
155/// Normalization (applied to both prediction and reference):
156/// 1. Keep only alphanumeric characters and whitespace.
157/// 2. Convert to lowercase.
158/// 3. Remove the articles `a`, `an`, and `the`.
159/// 4. Collapse whitespace.
160///
161/// This matches the official GAIA leaderboard evaluation script.
162/// Score is `1.0` on match, `0.0` otherwise.
163///
164/// # Examples
165///
166/// ```
167/// use zeph_bench::{Scenario, loaders::GaiaEvaluator};
168/// use zeph_bench::scenario::Evaluator;
169///
170/// let scenario = Scenario::single("t1", "Capital of Japan?", "Tokyo", serde_json::json!({"level": 1}));
171///
172/// // Article "The" is stripped before comparison.
173/// assert!(GaiaEvaluator.evaluate(&scenario, "The Tokyo").passed);
174/// assert!(!GaiaEvaluator.evaluate(&scenario, "Osaka").passed);
175/// ```
176#[derive(Debug)]
177pub struct GaiaEvaluator;
178
179impl Evaluator for GaiaEvaluator {
180    fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult {
181        let passed = gaia_normalized_exact_match(agent_response, &scenario.expected);
182        EvalResult {
183            scenario_id: scenario.id.clone(),
184            score: if passed { 1.0 } else { 0.0 },
185            passed,
186            details: format!(
187                "gaia_normalized_exact_match={}",
188                if passed { "true" } else { "false" }
189            ),
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    const FIXTURE: &str = r#"{"task_id": "t1", "Question": "What year did WWII end?", "Level": 1, "Final answer": "1945", "Annotator Metadata": {"difficulty": "easy"}}
199{"task_id": "t2", "Question": "Who wrote Hamlet?", "Level": 2, "Final answer": "Shakespeare", "Annotator Metadata": null}
200{"task_id": "t3", "Question": "Capital of Japan?", "Level": 1, "Final answer": "Tokyo", "Annotator Metadata": null}
201"#;
202
203    fn load_from_str(jsonl: &str, level: Option<u8>) -> Vec<Scenario> {
204        let dir = tempfile::tempdir().unwrap();
205        let path = dir.path().join("gaia.jsonl");
206        std::fs::write(&path, jsonl).unwrap();
207        GaiaLoader { level }.load(&path).unwrap()
208    }
209
210    #[test]
211    fn load_all_levels_parses_scenario_count() {
212        let scenarios = load_from_str(FIXTURE, None);
213        assert_eq!(scenarios.len(), 3);
214    }
215
216    #[test]
217    fn load_filters_by_level() {
218        let scenarios = load_from_str(FIXTURE, Some(1));
219        assert_eq!(scenarios.len(), 2);
220        for s in &scenarios {
221            assert_eq!(s.metadata["level"], 1);
222        }
223    }
224
225    #[test]
226    fn load_maps_task_id_to_scenario_id() {
227        let scenarios = load_from_str(FIXTURE, None);
228        assert_eq!(scenarios[0].id, "t1");
229        assert_eq!(scenarios[1].id, "t2");
230    }
231
232    #[test]
233    fn load_maps_prompt_and_expected() {
234        let scenarios = load_from_str(FIXTURE, None);
235        assert_eq!(
236            scenarios[0].primary_prompt().unwrap(),
237            "What year did WWII end?"
238        );
239        assert_eq!(scenarios[0].expected, "1945");
240    }
241
242    #[test]
243    fn load_stores_level_in_metadata() {
244        let scenarios = load_from_str(FIXTURE, None);
245        assert_eq!(scenarios[1].metadata["level"], 2);
246    }
247
248    #[test]
249    fn evaluator_normalized_match_passes() {
250        let scenarios = load_from_str(FIXTURE, None);
251        // "The 1945" should match "1945" after stripping article and comparing
252        let result = GaiaEvaluator.evaluate(&scenarios[0], "1945");
253        assert!(result.passed);
254    }
255
256    #[test]
257    fn evaluator_wrong_answer_fails() {
258        let scenarios = load_from_str(FIXTURE, None);
259        let result = GaiaEvaluator.evaluate(&scenarios[0], "1944");
260        assert!(!result.passed);
261        assert!(result.score < f64::EPSILON);
262    }
263
264    #[test]
265    fn evaluator_strips_article_the() {
266        let scenarios = load_from_str(FIXTURE, None);
267        // scenario[2]: expected = "Tokyo"
268        let result = GaiaEvaluator.evaluate(&scenarios[2], "The Tokyo");
269        assert!(result.passed);
270    }
271
272    #[test]
273    fn load_invalid_jsonl_returns_error() {
274        let dir = tempfile::tempdir().unwrap();
275        let path = dir.path().join("bad.jsonl");
276        std::fs::write(&path, "not json\n").unwrap();
277        assert!(GaiaLoader::all_levels().load(&path).is_err());
278    }
279
280    #[test]
281    fn all_levels_constructor() {
282        let loader = GaiaLoader::all_levels();
283        assert!(loader.level.is_none());
284    }
285
286    #[test]
287    fn with_level_constructor() {
288        let loader = GaiaLoader::with_level(2);
289        assert_eq!(loader.level, Some(2));
290    }
291}