oxibonsai_eval/
dataset.rs1use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use crate::error::EvalError;
13
14#[inline]
25fn lcg_step(state: u64) -> u64 {
26 state
27 .wrapping_mul(6_364_136_223_846_793_005)
28 .wrapping_add(1_442_695_040_888_963_407)
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EvalExample {
38 pub id: String,
40 pub input: String,
42 pub expected_output: Option<String>,
44 pub metadata: HashMap<String, Value>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MultipleChoiceQuestion {
55 pub id: String,
57 pub question: String,
59 pub choices: Vec<String>,
61 pub correct_answer: usize,
63 pub subject: Option<String>,
65 pub difficulty: Option<String>,
67}
68
69pub struct EvalDataset {
75 pub name: String,
77 pub examples: Vec<EvalExample>,
79}
80
81impl EvalDataset {
82 pub fn new(name: &str) -> Self {
84 Self {
85 name: name.to_string(),
86 examples: Vec::new(),
87 }
88 }
89
90 pub fn add(&mut self, example: EvalExample) {
92 self.examples.push(example);
93 }
94
95 pub fn len(&self) -> usize {
97 self.examples.len()
98 }
99
100 pub fn is_empty(&self) -> bool {
102 self.examples.is_empty()
103 }
104
105 pub fn from_jsonl(name: &str, jsonl: &str) -> Result<Self, EvalError> {
110 let mut dataset = EvalDataset::new(name);
111 for (line_no, line) in jsonl.lines().enumerate() {
112 let trimmed = line.trim();
113 if trimmed.is_empty() {
114 continue;
115 }
116 let v: Value = serde_json::from_str(trimmed)
117 .map_err(|e| EvalError::ParseError(format!("line {}: {}", line_no + 1, e)))?;
118 let obj = v.as_object().ok_or_else(|| {
119 EvalError::InvalidFormat(format!("line {} is not a JSON object", line_no + 1))
120 })?;
121
122 let input = obj
123 .get("input")
124 .and_then(Value::as_str)
125 .ok_or_else(|| {
126 EvalError::InvalidFormat(format!(
127 "line {}: missing \"input\" field",
128 line_no + 1
129 ))
130 })?
131 .to_string();
132
133 let id = obj
134 .get("id")
135 .and_then(Value::as_str)
136 .map(str::to_string)
137 .unwrap_or_else(|| format!("{}", line_no));
138
139 let expected_output = obj
140 .get("expected_output")
141 .and_then(Value::as_str)
142 .map(str::to_string);
143
144 let metadata: HashMap<String, Value> = obj
145 .get("metadata")
146 .and_then(Value::as_object)
147 .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
148 .unwrap_or_default();
149
150 dataset.add(EvalExample {
151 id,
152 input,
153 expected_output,
154 metadata,
155 });
156 }
157 Ok(dataset)
158 }
159
160 pub fn to_jsonl(&self) -> String {
162 self.examples
163 .iter()
164 .filter_map(|ex| serde_json::to_string(ex).ok())
165 .collect::<Vec<_>>()
166 .join("\n")
167 }
168
169 pub fn sample(&self, n: usize, seed: u64) -> EvalDataset {
174 let count = n.min(self.len());
175 let mut indices: Vec<usize> = (0..self.len()).collect();
176
177 let mut state = seed;
179 for i in (1..indices.len()).rev() {
180 state = lcg_step(state);
181 let j = (state >> 33) as usize % (i + 1);
182 indices.swap(i, j);
183 }
184
185 let mut sampled = EvalDataset::new(&self.name);
186 for &idx in indices.iter().take(count) {
187 sampled.add(self.examples[idx].clone());
188 }
189 sampled
190 }
191
192 pub fn sample_with_seed(&self, n: usize, seed: u64) -> EvalDataset {
198 self.sample(n, seed)
199 }
200
201 pub fn split(&self, train_ratio: f32) -> (EvalDataset, EvalDataset) {
206 let split_at = ((self.len() as f32) * train_ratio.clamp(0.0, 1.0)) as usize;
207 let mut train = EvalDataset::new(&format!("{}-train", self.name));
208 let mut test = EvalDataset::new(&format!("{}-test", self.name));
209 for (i, ex) in self.examples.iter().enumerate() {
210 if i < split_at {
211 train.add(ex.clone());
212 } else {
213 test.add(ex.clone());
214 }
215 }
216 (train, test)
217 }
218}
219
220pub struct McDataset {
226 pub name: String,
228 pub questions: Vec<MultipleChoiceQuestion>,
230}
231
232impl McDataset {
233 pub fn new(name: &str) -> Self {
235 Self {
236 name: name.to_string(),
237 questions: Vec::new(),
238 }
239 }
240
241 pub fn add(&mut self, q: MultipleChoiceQuestion) {
243 self.questions.push(q);
244 }
245
246 pub fn len(&self) -> usize {
248 self.questions.len()
249 }
250
251 pub fn is_empty(&self) -> bool {
253 self.questions.is_empty()
254 }
255
256 pub fn from_jsonl(name: &str, jsonl: &str) -> Result<Self, EvalError> {
261 let mut dataset = McDataset::new(name);
262 for (line_no, line) in jsonl.lines().enumerate() {
263 let trimmed = line.trim();
264 if trimmed.is_empty() {
265 continue;
266 }
267 let v: Value = serde_json::from_str(trimmed)
268 .map_err(|e| EvalError::ParseError(format!("line {}: {}", line_no + 1, e)))?;
269 let obj = v.as_object().ok_or_else(|| {
270 EvalError::InvalidFormat(format!("line {} is not a JSON object", line_no + 1))
271 })?;
272
273 let id = obj
274 .get("id")
275 .and_then(Value::as_str)
276 .map(str::to_string)
277 .unwrap_or_else(|| format!("{}", line_no));
278
279 let question = obj
280 .get("question")
281 .and_then(Value::as_str)
282 .ok_or_else(|| {
283 EvalError::InvalidFormat(format!(
284 "line {}: missing \"question\" field",
285 line_no + 1
286 ))
287 })?
288 .to_string();
289
290 let choices: Vec<String> = obj
291 .get("choices")
292 .and_then(Value::as_array)
293 .ok_or_else(|| {
294 EvalError::InvalidFormat(format!(
295 "line {}: missing or invalid \"choices\" field",
296 line_no + 1
297 ))
298 })?
299 .iter()
300 .enumerate()
301 .map(|(i, c)| {
302 c.as_str().map(str::to_string).ok_or_else(|| {
303 EvalError::InvalidFormat(format!(
304 "line {}: choice {} is not a string",
305 line_no + 1,
306 i
307 ))
308 })
309 })
310 .collect::<Result<Vec<_>, _>>()?;
311
312 let correct_answer = obj
313 .get("correct_answer")
314 .and_then(Value::as_u64)
315 .ok_or_else(|| {
316 EvalError::InvalidFormat(format!(
317 "line {}: missing or invalid \"correct_answer\" field",
318 line_no + 1
319 ))
320 })? as usize;
321
322 let subject = obj
323 .get("subject")
324 .and_then(Value::as_str)
325 .map(str::to_string);
326 let difficulty = obj
327 .get("difficulty")
328 .and_then(Value::as_str)
329 .map(str::to_string);
330
331 dataset.add(MultipleChoiceQuestion {
332 id,
333 question,
334 choices,
335 correct_answer,
336 subject,
337 difficulty,
338 });
339 }
340 Ok(dataset)
341 }
342
343 pub fn filter_by_subject(&self, subject: &str) -> McDataset {
345 let mut out = McDataset::new(&format!("{}-{}", self.name, subject));
346 for q in &self.questions {
347 if q.subject.as_deref() == Some(subject) {
348 out.add(q.clone());
349 }
350 }
351 out
352 }
353
354 pub fn sample_with_seed(&self, n: usize, seed: u64) -> McDataset {
358 let count = n.min(self.len());
359 let mut indices: Vec<usize> = (0..self.len()).collect();
360
361 let mut state = seed;
362 for i in (1..indices.len()).rev() {
363 state = lcg_step(state);
364 let j = (state >> 33) as usize % (i + 1);
365 indices.swap(i, j);
366 }
367
368 let mut sampled = McDataset::new(&self.name);
369 for &idx in indices.iter().take(count) {
370 sampled.add(self.questions[idx].clone());
371 }
372 sampled
373 }
374
375 pub fn subjects(&self) -> Vec<String> {
377 let mut seen: Vec<String> = self
378 .questions
379 .iter()
380 .filter_map(|q| q.subject.clone())
381 .collect();
382 seen.sort();
383 seen.dedup();
384 seen
385 }
386}