#[derive(Debug, Clone, PartialEq)]
pub struct BoolQItem {
pub passage: String,
pub question: String,
pub answer: bool,
}
pub struct BoolQDataset {
pub items: Vec<BoolQItem>,
}
impl BoolQDataset {
pub fn from_items(items: Vec<BoolQItem>) -> Self {
Self { items }
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct BoolQResult {
pub accuracy: f32,
pub accuracy_pct: f32,
pub correct: usize,
pub total: usize,
pub yes_predicted: usize,
pub no_predicted: usize,
}
pub struct BoolQEvaluator;
impl BoolQEvaluator {
pub fn new() -> Self {
Self
}
pub fn build_prompt(&self, item: &BoolQItem) -> String {
format!(
"Passage: {}\nQuestion: {}\nAnswer:",
item.passage, item.question
)
}
pub fn extract_answer(completion: &str) -> Option<bool> {
let trimmed = completion.trim_start();
if trimmed.len() < 2 {
return None;
}
let prefix: String = trimmed.chars().take(3).collect::<String>().to_lowercase();
if prefix.starts_with("yes") {
Some(true)
} else if prefix.starts_with("no") {
Some(false)
} else {
None
}
}
pub fn score(&self, completion: &str, gold: bool) -> bool {
Self::extract_answer(completion) == Some(gold)
}
pub fn evaluate_completions(
&self,
dataset: &BoolQDataset,
completions: &[String],
) -> BoolQResult {
let n = dataset.items.len().min(completions.len());
let mut correct = 0usize;
let mut yes_predicted = 0usize;
let mut no_predicted = 0usize;
for (i, completion) in completions.iter().enumerate().take(n) {
let prediction = Self::extract_answer(completion);
match prediction {
Some(true) => yes_predicted += 1,
Some(false) => no_predicted += 1,
None => {}
}
if prediction == Some(dataset.items[i].answer) {
correct += 1;
}
}
let total = n;
let accuracy = if total == 0 {
0.0_f32
} else {
correct as f32 / total as f32
};
BoolQResult {
accuracy,
accuracy_pct: accuracy * 100.0,
correct,
total,
yes_predicted,
no_predicted,
}
}
pub fn evaluate_logits(&self, dataset: &BoolQDataset, logit_pairs: &[[f32; 2]]) -> BoolQResult {
let n = dataset.items.len().min(logit_pairs.len());
let mut correct = 0usize;
let mut yes_predicted = 0usize;
let mut no_predicted = 0usize;
for (i, pair) in logit_pairs.iter().enumerate().take(n) {
let pred_yes = pair[0] > pair[1];
if pred_yes {
yes_predicted += 1;
} else {
no_predicted += 1;
}
if pred_yes == dataset.items[i].answer {
correct += 1;
}
}
let total = n;
let accuracy = if total == 0 {
0.0_f32
} else {
correct as f32 / total as f32
};
BoolQResult {
accuracy,
accuracy_pct: accuracy * 100.0,
correct,
total,
yes_predicted,
no_predicted,
}
}
}
impl Default for BoolQEvaluator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_answer_yes() {
assert_eq!(
BoolQEvaluator::extract_answer("Yes, that is correct."),
Some(true)
);
}
#[test]
fn extract_answer_no() {
assert_eq!(
BoolQEvaluator::extract_answer("No, it is not."),
Some(false)
);
}
#[test]
fn extract_answer_case_insensitive() {
assert_eq!(BoolQEvaluator::extract_answer("YES"), Some(true));
assert_eq!(BoolQEvaluator::extract_answer("NO"), Some(false));
assert_eq!(BoolQEvaluator::extract_answer("yes."), Some(true));
}
#[test]
fn extract_answer_leading_whitespace() {
assert_eq!(BoolQEvaluator::extract_answer(" yes"), Some(true));
assert_eq!(BoolQEvaluator::extract_answer("\t\nno"), Some(false));
}
#[test]
fn extract_answer_none_for_garbage() {
assert_eq!(BoolQEvaluator::extract_answer("maybe"), None);
assert_eq!(BoolQEvaluator::extract_answer(""), None);
assert_eq!(BoolQEvaluator::extract_answer("I don't know"), None);
}
#[test]
fn extract_answer_short_string_no_panic() {
assert_eq!(BoolQEvaluator::extract_answer("y"), None);
assert_eq!(BoolQEvaluator::extract_answer("n"), None);
}
}