use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use serde::Deserialize;
use serde_json::Value;
use crate::accuracy::{AccuracyResult, McEvaluator, McLogitEvaluator};
use crate::dataset::{McDataset, MultipleChoiceQuestion};
use crate::error::EvalError;
#[derive(Debug, Deserialize)]
struct HellaSwagRecord {
ind: Value,
activity_label: String,
ctx: String,
endings: Vec<String>,
label: Value,
}
#[derive(Debug, Clone, PartialEq)]
pub struct HellaSwagItem {
pub id: String,
pub activity_label: String,
pub ctx: String,
pub endings: Vec<String>,
pub label: usize,
}
pub struct HellaSwagDataset {
pub items: Vec<HellaSwagItem>,
}
impl HellaSwagDataset {
pub fn from_items(items: Vec<HellaSwagItem>) -> Self {
Self { items }
}
pub fn from_jsonl(path: &Path) -> Result<Self, EvalError> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut items = Vec::new();
for (line_no, line_result) in reader.lines().enumerate() {
let line = line_result?;
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let record: HellaSwagRecord = serde_json::from_str(trimmed).map_err(|e| {
EvalError::ParseError(format!("hellaswag: line {}: {}", line_no + 1, e))
})?;
let id = match &record.ind {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
other => {
return Err(EvalError::ParseError(format!(
"hellaswag: line {}: unexpected type for \"ind\": {}",
line_no + 1,
other
)))
}
};
let label: usize = match &record.label {
Value::Number(n) => n.as_u64().ok_or_else(|| {
EvalError::ParseError(format!(
"hellaswag: line {}: \"label\" is not a non-negative integer",
line_no + 1
))
})? as usize,
Value::String(s) => s.trim().parse::<usize>().map_err(|e| {
EvalError::ParseError(format!(
"hellaswag: line {}: cannot parse string \"label\": {}",
line_no + 1,
e
))
})?,
other => {
return Err(EvalError::ParseError(format!(
"hellaswag: line {}: unexpected type for \"label\": {}",
line_no + 1,
other
)))
}
};
items.push(HellaSwagItem {
id,
activity_label: record.activity_label,
ctx: record.ctx,
endings: record.endings,
label,
});
}
Ok(Self { items })
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn as_mc_dataset(&self) -> McDataset {
let mut mc = McDataset::new("hellaswag");
for item in &self.items {
mc.add(MultipleChoiceQuestion {
id: item.id.clone(),
question: item.ctx.clone(),
choices: item.endings.clone(),
correct_answer: item.label,
subject: Some(item.activity_label.clone()),
difficulty: None,
});
}
mc
}
}
#[derive(Debug, Clone)]
pub struct HellaSwagResult {
pub accuracy: f32,
pub accuracy_pct: f32,
pub correct: usize,
pub total: usize,
}
impl HellaSwagResult {
fn from_accuracy(acc: AccuracyResult) -> Self {
let accuracy = if acc.total == 0 {
0.0_f32
} else {
acc.correct as f32 / acc.total as f32
};
Self {
accuracy,
accuracy_pct: accuracy * 100.0,
correct: acc.correct,
total: acc.total,
}
}
}
pub struct HellaSwagEvaluator {
mc: McEvaluator,
mc_logit: McLogitEvaluator,
}
impl HellaSwagEvaluator {
pub fn new() -> Self {
let template = "{question}\nA) {a}\nB) {b}\nC) {c}\nD) {d}\nAnswer:".to_string();
Self {
mc: McEvaluator {
prompt_template: template.clone(),
},
mc_logit: McLogitEvaluator {
prompt_template: template,
},
}
}
pub fn evaluate_completions(
&self,
dataset: &HellaSwagDataset,
completions: &[String],
) -> HellaSwagResult {
let mc_dataset = dataset.as_mc_dataset();
let acc = self.mc.evaluate_dataset(&mc_dataset, completions);
HellaSwagResult::from_accuracy(acc)
}
pub fn evaluate_logits(
&self,
dataset: &HellaSwagDataset,
per_choice_logits: &[Vec<f32>],
) -> HellaSwagResult {
let mc_dataset = dataset.as_mc_dataset();
let acc = self
.mc_logit
.evaluate_dataset(&mc_dataset, per_choice_logits);
HellaSwagResult::from_accuracy(acc)
}
}
impl Default for HellaSwagEvaluator {
fn default() -> Self {
Self::new()
}
}