use crate::accuracy::{AccuracyResult, McEvaluator, McLogitEvaluator};
use crate::dataset::{McDataset, MultipleChoiceQuestion};
#[derive(Debug, Clone, PartialEq)]
pub struct WinoGrandeItem {
pub sentence: String,
pub option1: String,
pub option2: String,
pub answer: u8,
}
pub struct WinoGrandeDataset {
pub items: Vec<WinoGrandeItem>,
}
impl WinoGrandeDataset {
pub fn from_items(items: Vec<WinoGrandeItem>) -> Self {
Self { items }
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn as_mc_dataset(&self) -> McDataset {
let mut mc = McDataset::new("winogrande");
for (i, item) in self.items.iter().enumerate() {
let correct_answer: usize = if item.answer == 1 { 0 } else { 1 };
mc.add(MultipleChoiceQuestion {
id: format!("winogrande-{}", i),
question: item.sentence.clone(),
choices: vec![item.option1.clone(), item.option2.clone()],
correct_answer,
subject: None,
difficulty: None,
});
}
mc
}
}
#[derive(Debug, Clone)]
pub struct WinoGrandeResult {
pub accuracy: f32,
pub accuracy_pct: f32,
pub correct: usize,
pub total: usize,
}
impl WinoGrandeResult {
fn from_accuracy(acc: AccuracyResult) -> Self {
let accuracy = if acc.total == 0 {
0.0
} else {
acc.correct as f32 / acc.total as f32
};
Self {
accuracy,
accuracy_pct: accuracy * 100.0,
correct: acc.correct,
total: acc.total,
}
}
pub fn accuracy_pct(&self) -> f32 {
self.accuracy_pct
}
}
pub struct WinoGrandeEvaluator {
mc: McEvaluator,
mc_logit: McLogitEvaluator,
}
impl WinoGrandeEvaluator {
pub fn new() -> Self {
let template = "{question}\nA) {a}\nB) {b}\nAnswer:".to_string();
Self {
mc: McEvaluator {
prompt_template: template.clone(),
},
mc_logit: McLogitEvaluator {
prompt_template: template,
},
}
}
pub fn evaluate_completions(
&self,
dataset: &WinoGrandeDataset,
completions: &[String],
) -> WinoGrandeResult {
let mc_dataset = dataset.as_mc_dataset();
let acc = self.mc.evaluate_dataset(&mc_dataset, completions);
WinoGrandeResult::from_accuracy(acc)
}
pub fn evaluate_logits(
&self,
dataset: &WinoGrandeDataset,
per_choice_logits: &[Vec<f32>],
) -> WinoGrandeResult {
let mc_dataset = dataset.as_mc_dataset();
let acc = self
.mc_logit
.evaluate_dataset(&mc_dataset, per_choice_logits);
WinoGrandeResult::from_accuracy(acc)
}
pub fn build_prompt(&self, item: &WinoGrandeItem) -> String {
self.mc.format_question(&MultipleChoiceQuestion {
id: String::new(),
question: item.sentence.clone(),
choices: vec![item.option1.clone(), item.option2.clone()],
correct_answer: 0,
subject: None,
difficulty: None,
})
}
}
impl Default for WinoGrandeEvaluator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_dataset() -> WinoGrandeDataset {
WinoGrandeDataset::from_items(vec![
WinoGrandeItem {
sentence: "The trophy doesn't fit in the suitcase because the ___ is too large."
.to_string(),
option1: "trophy".to_string(),
option2: "suitcase".to_string(),
answer: 1,
},
WinoGrandeItem {
sentence: "The cat sat on the mat because ___ was comfortable.".to_string(),
option1: "mat".to_string(),
option2: "cat".to_string(),
answer: 1,
},
])
}
#[test]
fn winogrande_dataset_len() {
let ds = make_dataset();
assert_eq!(ds.len(), 2);
assert!(!ds.is_empty());
}
#[test]
fn winogrande_empty_dataset() {
let ds = WinoGrandeDataset::from_items(vec![]);
assert!(ds.is_empty());
}
#[test]
fn winogrande_as_mc_dataset_correct_answer_index_zero_for_answer1() {
let item = WinoGrandeItem {
sentence: "S".into(),
option1: "X".into(),
option2: "Y".into(),
answer: 1,
};
let ds = WinoGrandeDataset::from_items(vec![item]);
let mc = ds.as_mc_dataset();
assert_eq!(mc.questions[0].correct_answer, 0);
assert_eq!(mc.questions[0].choices.len(), 2);
}
#[test]
fn winogrande_as_mc_dataset_correct_answer_index_one_for_answer2() {
let item = WinoGrandeItem {
sentence: "S".into(),
option1: "X".into(),
option2: "Y".into(),
answer: 2,
};
let ds = WinoGrandeDataset::from_items(vec![item]);
let mc = ds.as_mc_dataset();
assert_eq!(mc.questions[0].correct_answer, 1);
}
}