use crate::accuracy::{AccuracyResult, McEvaluator, McLogitEvaluator};
use crate::dataset::McDataset;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArcSplit {
Easy,
Challenge,
}
impl ArcSplit {
pub fn name(&self) -> &'static str {
match self {
ArcSplit::Easy => "ARC-Easy",
ArcSplit::Challenge => "ARC-Challenge",
}
}
}
pub struct ArcEvaluator {
split: ArcSplit,
mc: McEvaluator,
mc_logit: McLogitEvaluator,
}
impl ArcEvaluator {
pub fn easy() -> Self {
Self::new(ArcSplit::Easy)
}
pub fn challenge() -> Self {
Self::new(ArcSplit::Challenge)
}
fn new(split: ArcSplit) -> Self {
let template = "{question}\nA) {a}\nB) {b}\nC) {c}\nD) {d}\nAnswer:".to_string();
Self {
split,
mc: McEvaluator {
prompt_template: template.clone(),
},
mc_logit: McLogitEvaluator {
prompt_template: template,
},
}
}
pub fn split(&self) -> ArcSplit {
self.split
}
pub fn evaluate_completions(
&self,
dataset: &McDataset,
completions: &[String],
) -> AccuracyResult {
self.mc.evaluate_dataset(dataset, completions)
}
pub fn evaluate_logits(
&self,
dataset: &McDataset,
per_choice_logits: &[Vec<f32>],
) -> AccuracyResult {
self.mc_logit.evaluate_dataset(dataset, per_choice_logits)
}
}
#[derive(Debug, Clone)]
pub struct ArcResult {
pub split: ArcSplit,
pub accuracy: f32,
pub correct: usize,
pub total: usize,
}
impl ArcResult {
pub fn from_accuracy_result(split: ArcSplit, result: AccuracyResult) -> Self {
Self {
split,
accuracy: result.accuracy,
correct: result.correct,
total: result.total,
}
}
pub fn accuracy_pct(&self) -> f32 {
self.accuracy * 100.0
}
pub fn split_name(&self) -> &'static str {
self.split.name()
}
}