use std::collections::HashMap;
use serde::Serialize;
use crate::dataset::{EvalDataset, McDataset, MultipleChoiceQuestion};
#[derive(Debug, Serialize)]
pub struct AccuracyResult {
pub correct: usize,
pub total: usize,
pub accuracy: f32,
pub by_subject: HashMap<String, f32>,
}
impl AccuracyResult {
pub fn accuracy_pct(&self) -> f32 {
self.accuracy * 100.0
}
}
pub struct McEvaluator {
pub prompt_template: String,
}
impl Default for McEvaluator {
fn default() -> Self {
Self::new()
}
}
impl McEvaluator {
pub fn new() -> Self {
Self {
prompt_template: "{question}\nA) {a}\nB) {b}\nC) {c}\nD) {d}\nAnswer:".to_string(),
}
}
pub fn with_template(template: &str) -> Self {
Self {
prompt_template: template.to_string(),
}
}
pub fn format_question(&self, q: &MultipleChoiceQuestion) -> String {
let get = |i: usize| -> &str { q.choices.get(i).map(String::as_str).unwrap_or("") };
let strip_label = |s: &str| -> String {
if s.len() >= 3 && s.chars().nth(1) == Some(':') {
s[2..].trim().to_string()
} else {
s.to_string()
}
};
self.prompt_template
.replace("{question}", &q.question)
.replace("{a}", &strip_label(get(0)))
.replace("{b}", &strip_label(get(1)))
.replace("{c}", &strip_label(get(2)))
.replace("{d}", &strip_label(get(3)))
}
pub fn score_completion(&self, completion: &str, correct_answer: usize) -> bool {
match self.extract_answer(completion) {
Some(idx) => idx == correct_answer,
None => false,
}
}
pub fn extract_answer(&self, completion: &str) -> Option<usize> {
let first = completion.trim().chars().next()?;
match first.to_ascii_uppercase() {
'A' => Some(0),
'B' => Some(1),
'C' => Some(2),
'D' => Some(3),
_ => None,
}
}
pub fn evaluate_dataset(&self, dataset: &McDataset, completions: &[String]) -> AccuracyResult {
let mut correct = 0usize;
let mut total = 0usize;
let mut by_subject_counts: HashMap<String, (usize, usize)> = HashMap::new();
for (q, completion) in dataset.questions.iter().zip(completions.iter()) {
total += 1;
let is_correct = self.score_completion(completion, q.correct_answer);
if is_correct {
correct += 1;
}
if let Some(ref subj) = q.subject {
let entry = by_subject_counts.entry(subj.clone()).or_insert((0, 0));
entry.1 += 1;
if is_correct {
entry.0 += 1;
}
}
}
let accuracy = if total == 0 {
0.0
} else {
correct as f32 / total as f32
};
let by_subject = by_subject_counts
.into_iter()
.map(|(subj, (c, t))| (subj, if t == 0 { 0.0 } else { c as f32 / t as f32 }))
.collect();
AccuracyResult {
correct,
total,
accuracy,
by_subject,
}
}
}
#[derive(Debug, Clone)]
pub struct LogitMcResult {
pub picked: usize,
pub correct: bool,
pub per_choice: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct McLogitEvaluator {
pub prompt_template: String,
}
impl Default for McLogitEvaluator {
fn default() -> Self {
Self::new()
}
}
impl McLogitEvaluator {
pub fn new() -> Self {
Self {
prompt_template: "{question}\nA) {a}\nB) {b}\nC) {c}\nD) {d}\nAnswer:".to_string(),
}
}
pub fn with_template(template: &str) -> Self {
Self {
prompt_template: template.to_string(),
}
}
pub fn score(&self, per_choice: &[f32], correct_answer: usize) -> LogitMcResult {
if per_choice.is_empty() {
return LogitMcResult {
picked: 0,
correct: false,
per_choice: Vec::new(),
};
}
let mut best_idx = 0usize;
let mut best_val = f32::NEG_INFINITY;
for (i, &v) in per_choice.iter().enumerate() {
if v > best_val {
best_val = v;
best_idx = i;
}
}
LogitMcResult {
picked: best_idx,
correct: best_idx == correct_answer,
per_choice: per_choice.to_vec(),
}
}
pub fn evaluate_dataset(
&self,
dataset: &McDataset,
per_question: &[Vec<f32>],
) -> AccuracyResult {
let mut correct = 0usize;
let mut total = 0usize;
let mut by_subject_counts: HashMap<String, (usize, usize)> = HashMap::new();
for (q, slate) in dataset.questions.iter().zip(per_question.iter()) {
total += 1;
let out = self.score(slate, q.correct_answer);
if out.correct {
correct += 1;
}
if let Some(ref subj) = q.subject {
let entry = by_subject_counts.entry(subj.clone()).or_insert((0, 0));
entry.1 += 1;
if out.correct {
entry.0 += 1;
}
}
}
let accuracy = if total == 0 {
0.0
} else {
correct as f32 / total as f32
};
let by_subject = by_subject_counts
.into_iter()
.map(|(s, (c, t))| (s, if t == 0 { 0.0 } else { c as f32 / t as f32 }))
.collect();
AccuracyResult {
correct,
total,
accuracy,
by_subject,
}
}
}
pub struct ExactMatchEvaluator {
pub normalize: bool,
pub partial_match: bool,
}
impl Default for ExactMatchEvaluator {
fn default() -> Self {
Self::new()
}
}
impl ExactMatchEvaluator {
pub fn new() -> Self {
Self {
normalize: false,
partial_match: false,
}
}
pub fn score(&self, completion: &str, expected: &str) -> bool {
let (c, e) = if self.normalize {
(
completion.trim().to_lowercase(),
expected.trim().to_lowercase(),
)
} else {
(completion.to_string(), expected.to_string())
};
if self.partial_match {
c.contains(e.as_str())
} else {
c == e
}
}
pub fn evaluate_dataset(
&self,
dataset: &EvalDataset,
completions: &[String],
) -> AccuracyResult {
let mut correct = 0usize;
let mut total = 0usize;
for (ex, completion) in dataset.examples.iter().zip(completions.iter()) {
if let Some(ref expected) = ex.expected_output {
total += 1;
if self.score(completion, expected) {
correct += 1;
}
}
}
let accuracy = if total == 0 {
0.0
} else {
correct as f32 / total as f32
};
AccuracyResult {
correct,
total,
accuracy,
by_subject: HashMap::new(),
}
}
}