Skip to main content

oxibonsai_eval/
dataset.rs

1//! Dataset types and loaders for the evaluation harness.
2//!
3//! Provides [`EvalDataset`] for free-form text evaluation and [`McDataset`]
4//! for MMLU-style multiple-choice evaluation. Both support JSONL loading
5//! and deterministic sampling via a simple LCG (no external rand crate).
6
7use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use crate::error::EvalError;
13
14// ──────────────────────────────────────────────────────────────────────────────
15// LCG helper (Knuth's multiplicative congruential generator)
16// ──────────────────────────────────────────────────────────────────────────────
17
18/// Advance one LCG step and return the new state.
19///
20/// Parameters from Numerical Recipes:
21/// - multiplier = 1664525
22/// - increment  = 1013904223
23/// - modulus    = 2^32 (implicit via u32 overflow)
24#[inline]
25fn lcg_step(state: u64) -> u64 {
26    state
27        .wrapping_mul(6_364_136_223_846_793_005)
28        .wrapping_add(1_442_695_040_888_963_407)
29}
30
31// ──────────────────────────────────────────────────────────────────────────────
32// EvalExample
33// ──────────────────────────────────────────────────────────────────────────────
34
35/// A single evaluation example for free-form text tasks.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EvalExample {
38    /// Unique identifier for this example.
39    pub id: String,
40    /// The input prompt / context fed to the model.
41    pub input: String,
42    /// Expected output, if known (used for scoring).
43    pub expected_output: Option<String>,
44    /// Arbitrary key-value metadata.
45    pub metadata: HashMap<String, Value>,
46}
47
48// ──────────────────────────────────────────────────────────────────────────────
49// MultipleChoiceQuestion
50// ──────────────────────────────────────────────────────────────────────────────
51
52/// A multiple-choice question in MMLU format.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MultipleChoiceQuestion {
55    /// Unique identifier for this question.
56    pub id: String,
57    /// The question stem.
58    pub question: String,
59    /// Answer choices, e.g. `["A: option1", "B: option2", ...]`.
60    pub choices: Vec<String>,
61    /// Index of the correct choice (0-based).
62    pub correct_answer: usize,
63    /// Subject area (e.g. "high_school_biology").
64    pub subject: Option<String>,
65    /// Difficulty label (e.g. "easy", "medium", "hard").
66    pub difficulty: Option<String>,
67}
68
69// ──────────────────────────────────────────────────────────────────────────────
70// EvalDataset
71// ──────────────────────────────────────────────────────────────────────────────
72
73/// A named collection of [`EvalExample`] instances.
74pub struct EvalDataset {
75    /// Human-readable dataset name.
76    pub name: String,
77    /// All examples in insertion order.
78    pub examples: Vec<EvalExample>,
79}
80
81impl EvalDataset {
82    /// Create an empty dataset with the given name.
83    pub fn new(name: &str) -> Self {
84        Self {
85            name: name.to_string(),
86            examples: Vec::new(),
87        }
88    }
89
90    /// Append an example to the dataset.
91    pub fn add(&mut self, example: EvalExample) {
92        self.examples.push(example);
93    }
94
95    /// Return the number of examples.
96    pub fn len(&self) -> usize {
97        self.examples.len()
98    }
99
100    /// Return `true` if the dataset contains no examples.
101    pub fn is_empty(&self) -> bool {
102        self.examples.is_empty()
103    }
104
105    /// Parse a JSONL string into a dataset.
106    ///
107    /// Each line must be a JSON object with at least an `"input"` field.
108    /// `"id"`, `"expected_output"`, and `"metadata"` are optional.
109    pub fn from_jsonl(name: &str, jsonl: &str) -> Result<Self, EvalError> {
110        let mut dataset = EvalDataset::new(name);
111        for (line_no, line) in jsonl.lines().enumerate() {
112            let trimmed = line.trim();
113            if trimmed.is_empty() {
114                continue;
115            }
116            let v: Value = serde_json::from_str(trimmed)
117                .map_err(|e| EvalError::ParseError(format!("line {}: {}", line_no + 1, e)))?;
118            let obj = v.as_object().ok_or_else(|| {
119                EvalError::InvalidFormat(format!("line {} is not a JSON object", line_no + 1))
120            })?;
121
122            let input = obj
123                .get("input")
124                .and_then(Value::as_str)
125                .ok_or_else(|| {
126                    EvalError::InvalidFormat(format!(
127                        "line {}: missing \"input\" field",
128                        line_no + 1
129                    ))
130                })?
131                .to_string();
132
133            let id = obj
134                .get("id")
135                .and_then(Value::as_str)
136                .map(str::to_string)
137                .unwrap_or_else(|| format!("{}", line_no));
138
139            let expected_output = obj
140                .get("expected_output")
141                .and_then(Value::as_str)
142                .map(str::to_string);
143
144            let metadata: HashMap<String, Value> = obj
145                .get("metadata")
146                .and_then(Value::as_object)
147                .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
148                .unwrap_or_default();
149
150            dataset.add(EvalExample {
151                id,
152                input,
153                expected_output,
154                metadata,
155            });
156        }
157        Ok(dataset)
158    }
159
160    /// Serialise the dataset back to JSONL format (one JSON object per line).
161    pub fn to_jsonl(&self) -> String {
162        self.examples
163            .iter()
164            .filter_map(|ex| serde_json::to_string(ex).ok())
165            .collect::<Vec<_>>()
166            .join("\n")
167    }
168
169    /// Return a deterministic random sample of `n` examples.
170    ///
171    /// Uses a 64-bit LCG seeded with `seed`. If `n >= self.len()`, returns a
172    /// clone of the entire dataset in a shuffled order.
173    pub fn sample(&self, n: usize, seed: u64) -> EvalDataset {
174        let count = n.min(self.len());
175        let mut indices: Vec<usize> = (0..self.len()).collect();
176
177        // Fisher-Yates shuffle driven by LCG
178        let mut state = seed;
179        for i in (1..indices.len()).rev() {
180            state = lcg_step(state);
181            let j = (state >> 33) as usize % (i + 1);
182            indices.swap(i, j);
183        }
184
185        let mut sampled = EvalDataset::new(&self.name);
186        for &idx in indices.iter().take(count) {
187            sampled.add(self.examples[idx].clone());
188        }
189        sampled
190    }
191
192    /// Explicit alias for [`EvalDataset::sample`], surfacing the seeded
193    /// nature of the sampler in the name.
194    ///
195    /// Given identical `(n, seed)` inputs, the returned dataset is bit-identical
196    /// across runs and across platforms (LCG constants are fixed).
197    pub fn sample_with_seed(&self, n: usize, seed: u64) -> EvalDataset {
198        self.sample(n, seed)
199    }
200
201    /// Split the dataset into train and test subsets.
202    ///
203    /// The first `floor(len * train_ratio)` examples become the training set;
204    /// the remainder form the test set. Order is preserved.
205    pub fn split(&self, train_ratio: f32) -> (EvalDataset, EvalDataset) {
206        let split_at = ((self.len() as f32) * train_ratio.clamp(0.0, 1.0)) as usize;
207        let mut train = EvalDataset::new(&format!("{}-train", self.name));
208        let mut test = EvalDataset::new(&format!("{}-test", self.name));
209        for (i, ex) in self.examples.iter().enumerate() {
210            if i < split_at {
211                train.add(ex.clone());
212            } else {
213                test.add(ex.clone());
214            }
215        }
216        (train, test)
217    }
218}
219
220// ──────────────────────────────────────────────────────────────────────────────
221// McDataset
222// ──────────────────────────────────────────────────────────────────────────────
223
224/// A named collection of [`MultipleChoiceQuestion`] instances.
225pub struct McDataset {
226    /// Human-readable dataset name.
227    pub name: String,
228    /// All questions in insertion order.
229    pub questions: Vec<MultipleChoiceQuestion>,
230}
231
232impl McDataset {
233    /// Create an empty multiple-choice dataset.
234    pub fn new(name: &str) -> Self {
235        Self {
236            name: name.to_string(),
237            questions: Vec::new(),
238        }
239    }
240
241    /// Append a question.
242    pub fn add(&mut self, q: MultipleChoiceQuestion) {
243        self.questions.push(q);
244    }
245
246    /// Return the number of questions.
247    pub fn len(&self) -> usize {
248        self.questions.len()
249    }
250
251    /// Return `true` if the dataset contains no questions.
252    pub fn is_empty(&self) -> bool {
253        self.questions.is_empty()
254    }
255
256    /// Parse a JSONL string into a multiple-choice dataset.
257    ///
258    /// Each line must have `"id"`, `"question"`, `"choices"` (array), and
259    /// `"correct_answer"` (integer). `"subject"` and `"difficulty"` are optional.
260    pub fn from_jsonl(name: &str, jsonl: &str) -> Result<Self, EvalError> {
261        let mut dataset = McDataset::new(name);
262        for (line_no, line) in jsonl.lines().enumerate() {
263            let trimmed = line.trim();
264            if trimmed.is_empty() {
265                continue;
266            }
267            let v: Value = serde_json::from_str(trimmed)
268                .map_err(|e| EvalError::ParseError(format!("line {}: {}", line_no + 1, e)))?;
269            let obj = v.as_object().ok_or_else(|| {
270                EvalError::InvalidFormat(format!("line {} is not a JSON object", line_no + 1))
271            })?;
272
273            let id = obj
274                .get("id")
275                .and_then(Value::as_str)
276                .map(str::to_string)
277                .unwrap_or_else(|| format!("{}", line_no));
278
279            let question = obj
280                .get("question")
281                .and_then(Value::as_str)
282                .ok_or_else(|| {
283                    EvalError::InvalidFormat(format!(
284                        "line {}: missing \"question\" field",
285                        line_no + 1
286                    ))
287                })?
288                .to_string();
289
290            let choices: Vec<String> = obj
291                .get("choices")
292                .and_then(Value::as_array)
293                .ok_or_else(|| {
294                    EvalError::InvalidFormat(format!(
295                        "line {}: missing or invalid \"choices\" field",
296                        line_no + 1
297                    ))
298                })?
299                .iter()
300                .enumerate()
301                .map(|(i, c)| {
302                    c.as_str().map(str::to_string).ok_or_else(|| {
303                        EvalError::InvalidFormat(format!(
304                            "line {}: choice {} is not a string",
305                            line_no + 1,
306                            i
307                        ))
308                    })
309                })
310                .collect::<Result<Vec<_>, _>>()?;
311
312            let correct_answer = obj
313                .get("correct_answer")
314                .and_then(Value::as_u64)
315                .ok_or_else(|| {
316                    EvalError::InvalidFormat(format!(
317                        "line {}: missing or invalid \"correct_answer\" field",
318                        line_no + 1
319                    ))
320                })? as usize;
321
322            let subject = obj
323                .get("subject")
324                .and_then(Value::as_str)
325                .map(str::to_string);
326            let difficulty = obj
327                .get("difficulty")
328                .and_then(Value::as_str)
329                .map(str::to_string);
330
331            dataset.add(MultipleChoiceQuestion {
332                id,
333                question,
334                choices,
335                correct_answer,
336                subject,
337                difficulty,
338            });
339        }
340        Ok(dataset)
341    }
342
343    /// Return a new dataset containing only questions with the given subject.
344    pub fn filter_by_subject(&self, subject: &str) -> McDataset {
345        let mut out = McDataset::new(&format!("{}-{}", self.name, subject));
346        for q in &self.questions {
347            if q.subject.as_deref() == Some(subject) {
348                out.add(q.clone());
349            }
350        }
351        out
352    }
353
354    /// Return a deterministic random sample of `n` questions.
355    ///
356    /// Uses the same 64-bit LCG scheme as [`EvalDataset::sample_with_seed`].
357    pub fn sample_with_seed(&self, n: usize, seed: u64) -> McDataset {
358        let count = n.min(self.len());
359        let mut indices: Vec<usize> = (0..self.len()).collect();
360
361        let mut state = seed;
362        for i in (1..indices.len()).rev() {
363            state = lcg_step(state);
364            let j = (state >> 33) as usize % (i + 1);
365            indices.swap(i, j);
366        }
367
368        let mut sampled = McDataset::new(&self.name);
369        for &idx in indices.iter().take(count) {
370            sampled.add(self.questions[idx].clone());
371        }
372        sampled
373    }
374
375    /// Return a sorted, deduplicated list of all subject names in this dataset.
376    pub fn subjects(&self) -> Vec<String> {
377        let mut seen: Vec<String> = self
378            .questions
379            .iter()
380            .filter_map(|q| q.subject.clone())
381            .collect();
382        seen.sort();
383        seen.dedup();
384        seen
385    }
386}