use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::EvalError;
#[inline]
fn lcg_step(state: u64) -> u64 {
state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalExample {
pub id: String,
pub input: String,
pub expected_output: Option<String>,
pub metadata: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultipleChoiceQuestion {
pub id: String,
pub question: String,
pub choices: Vec<String>,
pub correct_answer: usize,
pub subject: Option<String>,
pub difficulty: Option<String>,
}
pub struct EvalDataset {
pub name: String,
pub examples: Vec<EvalExample>,
}
impl EvalDataset {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
examples: Vec::new(),
}
}
pub fn add(&mut self, example: EvalExample) {
self.examples.push(example);
}
pub fn len(&self) -> usize {
self.examples.len()
}
pub fn is_empty(&self) -> bool {
self.examples.is_empty()
}
pub fn from_jsonl(name: &str, jsonl: &str) -> Result<Self, EvalError> {
let mut dataset = EvalDataset::new(name);
for (line_no, line) in jsonl.lines().enumerate() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let v: Value = serde_json::from_str(trimmed)
.map_err(|e| EvalError::ParseError(format!("line {}: {}", line_no + 1, e)))?;
let obj = v.as_object().ok_or_else(|| {
EvalError::InvalidFormat(format!("line {} is not a JSON object", line_no + 1))
})?;
let input = obj
.get("input")
.and_then(Value::as_str)
.ok_or_else(|| {
EvalError::InvalidFormat(format!(
"line {}: missing \"input\" field",
line_no + 1
))
})?
.to_string();
let id = obj
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.unwrap_or_else(|| format!("{}", line_no));
let expected_output = obj
.get("expected_output")
.and_then(Value::as_str)
.map(str::to_string);
let metadata: HashMap<String, Value> = obj
.get("metadata")
.and_then(Value::as_object)
.map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
.unwrap_or_default();
dataset.add(EvalExample {
id,
input,
expected_output,
metadata,
});
}
Ok(dataset)
}
pub fn to_jsonl(&self) -> String {
self.examples
.iter()
.filter_map(|ex| serde_json::to_string(ex).ok())
.collect::<Vec<_>>()
.join("\n")
}
pub fn sample(&self, n: usize, seed: u64) -> EvalDataset {
let count = n.min(self.len());
let mut indices: Vec<usize> = (0..self.len()).collect();
let mut state = seed;
for i in (1..indices.len()).rev() {
state = lcg_step(state);
let j = (state >> 33) as usize % (i + 1);
indices.swap(i, j);
}
let mut sampled = EvalDataset::new(&self.name);
for &idx in indices.iter().take(count) {
sampled.add(self.examples[idx].clone());
}
sampled
}
pub fn split(&self, train_ratio: f32) -> (EvalDataset, EvalDataset) {
let split_at = ((self.len() as f32) * train_ratio.clamp(0.0, 1.0)) as usize;
let mut train = EvalDataset::new(&format!("{}-train", self.name));
let mut test = EvalDataset::new(&format!("{}-test", self.name));
for (i, ex) in self.examples.iter().enumerate() {
if i < split_at {
train.add(ex.clone());
} else {
test.add(ex.clone());
}
}
(train, test)
}
}
pub struct McDataset {
pub name: String,
pub questions: Vec<MultipleChoiceQuestion>,
}
impl McDataset {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
questions: Vec::new(),
}
}
pub fn add(&mut self, q: MultipleChoiceQuestion) {
self.questions.push(q);
}
pub fn len(&self) -> usize {
self.questions.len()
}
pub fn is_empty(&self) -> bool {
self.questions.is_empty()
}
pub fn from_jsonl(name: &str, jsonl: &str) -> Result<Self, EvalError> {
let mut dataset = McDataset::new(name);
for (line_no, line) in jsonl.lines().enumerate() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let v: Value = serde_json::from_str(trimmed)
.map_err(|e| EvalError::ParseError(format!("line {}: {}", line_no + 1, e)))?;
let obj = v.as_object().ok_or_else(|| {
EvalError::InvalidFormat(format!("line {} is not a JSON object", line_no + 1))
})?;
let id = obj
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.unwrap_or_else(|| format!("{}", line_no));
let question = obj
.get("question")
.and_then(Value::as_str)
.ok_or_else(|| {
EvalError::InvalidFormat(format!(
"line {}: missing \"question\" field",
line_no + 1
))
})?
.to_string();
let choices: Vec<String> = obj
.get("choices")
.and_then(Value::as_array)
.ok_or_else(|| {
EvalError::InvalidFormat(format!(
"line {}: missing or invalid \"choices\" field",
line_no + 1
))
})?
.iter()
.enumerate()
.map(|(i, c)| {
c.as_str().map(str::to_string).ok_or_else(|| {
EvalError::InvalidFormat(format!(
"line {}: choice {} is not a string",
line_no + 1,
i
))
})
})
.collect::<Result<Vec<_>, _>>()?;
let correct_answer = obj
.get("correct_answer")
.and_then(Value::as_u64)
.ok_or_else(|| {
EvalError::InvalidFormat(format!(
"line {}: missing or invalid \"correct_answer\" field",
line_no + 1
))
})? as usize;
let subject = obj
.get("subject")
.and_then(Value::as_str)
.map(str::to_string);
let difficulty = obj
.get("difficulty")
.and_then(Value::as_str)
.map(str::to_string);
dataset.add(MultipleChoiceQuestion {
id,
question,
choices,
correct_answer,
subject,
difficulty,
});
}
Ok(dataset)
}
pub fn filter_by_subject(&self, subject: &str) -> McDataset {
let mut out = McDataset::new(&format!("{}-{}", self.name, subject));
for q in &self.questions {
if q.subject.as_deref() == Some(subject) {
out.add(q.clone());
}
}
out
}
pub fn subjects(&self) -> Vec<String> {
let mut seen: Vec<String> = self
.questions
.iter()
.filter_map(|q| q.subject.clone())
.collect();
seen.sort();
seen.dedup();
seen
}
}