use std::collections::HashMap;
use crate::accuracy::{AccuracyResult, McEvaluator, McLogitEvaluator};
use crate::dataset::{McDataset, MultipleChoiceQuestion};
fn extract_subject(question: &MultipleChoiceQuestion) -> String {
if let Some(ref subj) = question.subject {
return subj.clone();
}
match question.id.find('/') {
Some(pos) => question.id[..pos].to_string(),
None => question.id.clone(),
}
}
#[derive(Debug)]
pub struct MmluResult {
pub accuracy: f32,
pub accuracy_pct: f32,
pub correct: usize,
pub total: usize,
pub by_subject: HashMap<String, AccuracyResult>,
}
pub struct MmluEvaluator {
mc: McEvaluator,
mc_logit: McLogitEvaluator,
}
impl MmluEvaluator {
pub fn new() -> Self {
let template = "{question}\nA) {a}\nB) {b}\nC) {c}\nD) {d}\nAnswer:".to_string();
Self {
mc: McEvaluator {
prompt_template: template.clone(),
},
mc_logit: McLogitEvaluator {
prompt_template: template,
},
}
}
pub fn evaluate_completions(&self, dataset: &McDataset, completions: &[String]) -> MmluResult {
let overall = self.mc.evaluate_dataset(dataset, completions);
let by_subject = self.evaluate_by_subject_completions(dataset, completions);
MmluResult {
accuracy: overall.accuracy,
accuracy_pct: overall.accuracy * 100.0,
correct: overall.correct,
total: overall.total,
by_subject,
}
}
pub fn evaluate_logits(
&self,
dataset: &McDataset,
per_choice_logits: &[Vec<f32>],
) -> MmluResult {
let overall = self.mc_logit.evaluate_dataset(dataset, per_choice_logits);
let by_subject = self.evaluate_by_subject_logits(dataset, per_choice_logits);
MmluResult {
accuracy: overall.accuracy,
accuracy_pct: overall.accuracy * 100.0,
correct: overall.correct,
total: overall.total,
by_subject,
}
}
pub fn evaluate_by_subject_completions(
&self,
dataset: &McDataset,
completions: &[String],
) -> HashMap<String, AccuracyResult> {
self.compute_by_subject(dataset, completions.len(), |i| {
let q = &dataset.questions[i];
let completion = &completions[i];
self.mc.score_completion(completion, q.correct_answer)
})
}
pub fn evaluate_by_subject_logits(
&self,
dataset: &McDataset,
per_choice_logits: &[Vec<f32>],
) -> HashMap<String, AccuracyResult> {
self.compute_by_subject(dataset, per_choice_logits.len(), |i| {
let q = &dataset.questions[i];
let slate = &per_choice_logits[i];
let result = self.mc_logit.score(slate, q.correct_answer);
result.correct
})
}
fn compute_by_subject(
&self,
dataset: &McDataset,
annotation_count: usize,
mut is_correct: impl FnMut(usize) -> bool,
) -> HashMap<String, AccuracyResult> {
let mut counts: HashMap<String, (usize, usize)> = HashMap::new();
let n = dataset.questions.len().min(annotation_count);
for i in 0..n {
let q = &dataset.questions[i];
let subject = extract_subject(q);
let entry = counts.entry(subject).or_insert((0, 0));
entry.1 += 1;
if is_correct(i) {
entry.0 += 1;
}
}
counts
.into_iter()
.map(|(subject, (correct, total))| {
let accuracy = if total == 0 {
0.0_f32
} else {
correct as f32 / total as f32
};
(
subject,
AccuracyResult {
correct,
total,
accuracy,
by_subject: std::collections::HashMap::new(),
},
)
})
.collect()
}
}
impl Default for MmluEvaluator {
fn default() -> Self {
Self::new()
}
}