use crate::intern::Sym;
use crate::metric::Metric;
use crate::optimizer::ExampleSet;
use crate::recursive::llm::Llm;
pub struct Evaluate<'a, L: Llm, M: Metric> {
llm: &'a L,
metric: M,
instruction: Option<&'a str>,
output_field: Option<Sym>,
threshold: f64,
demo_indices: &'a [u32],
}
impl<'a, L: Llm, M: Metric> Evaluate<'a, L, M> {
pub fn new(llm: &'a L, metric: M) -> Self {
Self {
llm,
metric,
instruction: None,
output_field: None,
threshold: 0.5,
demo_indices: &[],
}
}
pub fn instruction(mut self, instruction: &'a str) -> Self {
self.instruction = Some(instruction);
self
}
pub fn output_field(mut self, sym: Sym) -> Self {
self.output_field = Some(sym);
self
}
pub fn threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold;
self
}
pub fn demos(mut self, indices: &'a [u32]) -> Self {
self.demo_indices = indices;
self
}
pub fn run(self, dataset: &ExampleSet<'_>) -> EvalResult {
crate::recursive::shared::block_on(self.run_async(dataset))
}
pub async fn run_async(self, dataset: &ExampleSet<'_>) -> EvalResult {
let mut per_example = Vec::with_capacity(dataset.len());
let mut total_tokens = 0u64;
let context = self.build_demo_context(dataset);
for (idx, view) in dataset.iter().enumerate() {
if self.demo_indices.contains(&(idx as u32)) {
continue;
}
let input_text: String = view
.inputs()
.map(|(sym, val)| format!("{}: {}", sym.as_str(), val))
.collect::<Vec<_>>()
.join("\n");
let expected = if let Some(out_sym) = self.output_field {
view.get_output(out_sym).unwrap_or("").to_string()
} else {
view.outputs()
.next()
.map(|(_, val)| val.to_string())
.unwrap_or_default()
};
let prompt = self.build_prompt(&input_text);
let prediction = match self.llm.generate(&prompt, &context, None).await {
Ok(output) => {
total_tokens += output.total_tokens() as u64;
output.text.to_string()
}
Err(_) => String::new(),
};
let score = self.metric.evaluate(&prediction, &expected);
let passed = score >= self.threshold;
per_example.push(ExampleResult {
index: idx,
prediction,
expected,
score,
passed,
});
}
EvalResult::from_examples(per_example, total_tokens)
}
fn build_prompt(&self, input_text: &str) -> String {
match self.instruction {
Some(inst) => format!("{}\n\n{}", inst, input_text),
None => input_text.to_string(),
}
}
fn build_demo_context(&self, dataset: &ExampleSet<'_>) -> String {
if self.demo_indices.is_empty() {
return String::new();
}
let mut context = String::new();
for &idx in self.demo_indices {
let idx = idx as usize;
if let Some(view) = dataset.iter().nth(idx) {
for (sym, val) in view.inputs() {
context.push_str(&format!("{}: {}\n", sym.as_str(), val));
}
for (sym, val) in view.outputs() {
context.push_str(&format!("{}: {}\n", sym.as_str(), val));
}
context.push('\n');
}
}
context
}
}
#[derive(Debug, Clone)]
pub struct EvalResult {
pub mean: f64,
pub median: f64,
pub std_dev: f64,
pub pass_count: usize,
pub total: usize,
pub per_example: Vec<ExampleResult>,
pub total_tokens: u64,
}
impl EvalResult {
fn from_examples(per_example: Vec<ExampleResult>, total_tokens: u64) -> Self {
let total = per_example.len();
if total == 0 {
return Self {
mean: 0.0,
median: 0.0,
std_dev: 0.0,
pass_count: 0,
total: 0,
per_example,
total_tokens,
};
}
let scores: Vec<f64> = per_example.iter().map(|e| e.score).collect();
let mean = scores.iter().sum::<f64>() / total as f64;
let median = Self::compute_median(&scores);
let std_dev = Self::compute_std_dev(&scores, mean);
let pass_count = per_example.iter().filter(|e| e.passed).count();
Self {
mean,
median,
std_dev,
pass_count,
total,
per_example,
total_tokens,
}
}
pub fn rescore<M2: Metric>(&self, metric: &M2) -> EvalResult {
self.rescore_with_threshold(metric, 0.5)
}
pub fn rescore_with_threshold<M2: Metric>(&self, metric: &M2, threshold: f64) -> EvalResult {
let per_example: Vec<ExampleResult> = self
.per_example
.iter()
.map(|ex| {
let score = metric.evaluate(&ex.prediction, &ex.expected);
let passed = score >= threshold;
ExampleResult {
index: ex.index,
prediction: ex.prediction.clone(),
expected: ex.expected.clone(),
score,
passed,
}
})
.collect();
EvalResult::from_examples(per_example, self.total_tokens)
}
pub fn pass_rate(&self) -> f64 {
if self.total == 0 {
0.0
} else {
self.pass_count as f64 / self.total as f64
}
}
fn compute_median(scores: &[f64]) -> f64 {
if scores.is_empty() {
return 0.0;
}
let mut sorted: Vec<f64> = scores.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = sorted.len() / 2;
if sorted.len() % 2 == 0 {
(sorted[mid - 1] + sorted[mid]) / 2.0
} else {
sorted[mid]
}
}
fn compute_std_dev(scores: &[f64], mean: f64) -> f64 {
if scores.len() <= 1 {
return 0.0;
}
let variance =
scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / (scores.len() - 1) as f64;
variance.sqrt()
}
}
#[derive(Debug, Clone)]
pub struct ExampleResult {
pub index: usize,
pub prediction: String,
pub expected: String,
pub score: f64,
pub passed: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::buffer::Buffer;
use crate::intern::sym;
use crate::metric::{Contains, ExactMatch, F1Token, FnMetric};
use crate::optimizer::ExampleMeta;
use crate::predict::FieldRange;
use crate::recursive::llm::MockLlm;
fn build_dataset(pairs: &[(&str, &str)]) -> (Buffer, Vec<ExampleMeta>, Sym, Sym) {
let input_sym = sym("question");
let output_sym = sym("answer");
let mut buf = Vec::new();
let mut metas = Vec::new();
for (input, expected) in pairs {
let input_start = buf.len() as u32;
buf.extend_from_slice(input.as_bytes());
let input_end = buf.len() as u32;
let output_start = buf.len() as u32;
buf.extend_from_slice(expected.as_bytes());
let output_end = buf.len() as u32;
let meta = ExampleMeta {
input_ranges: [
(input_sym, FieldRange::new(input_start, input_end)),
(Sym::EMPTY, FieldRange::new(0, 0)),
(Sym::EMPTY, FieldRange::new(0, 0)),
(Sym::EMPTY, FieldRange::new(0, 0)),
],
input_count: 1,
output_ranges: [
(output_sym, FieldRange::new(output_start, output_end)),
(Sym::EMPTY, FieldRange::new(0, 0)),
],
output_count: 1,
};
metas.push(meta);
}
let buffer = Buffer::from_bytes(buf);
(buffer, metas, input_sym, output_sym)
}
#[test]
fn test_evaluate_exact_match_all_correct() {
let llm = MockLlm::new(|prompt, _| {
if prompt.contains("2+2") {
"4".to_string()
} else if prompt.contains("3+3") {
"6".to_string()
} else {
"unknown".to_string()
}
});
let (buffer, metas, _input_sym, output_sym) =
build_dataset(&[("What is 2+2?", "4"), ("What is 3+3?", "6")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, ExactMatch)
.output_field(output_sym)
.run(&dataset);
assert_eq!(result.total, 2);
assert_eq!(result.pass_count, 2);
assert_eq!(result.mean, 1.0);
assert_eq!(result.median, 1.0);
assert_eq!(result.std_dev, 0.0);
}
#[test]
fn test_evaluate_partial_match() {
let llm = MockLlm::new(|prompt, _| {
if prompt.contains("2+2") {
"4".to_string()
} else {
"wrong".to_string()
}
});
let (buffer, metas, _input_sym, output_sym) =
build_dataset(&[("What is 2+2?", "4"), ("What is 3+3?", "6")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, ExactMatch)
.output_field(output_sym)
.run(&dataset);
assert_eq!(result.total, 2);
assert_eq!(result.pass_count, 1);
assert_eq!(result.mean, 0.5);
}
#[test]
fn test_evaluate_with_instruction() {
let llm = MockLlm::new(|prompt, _| {
if prompt.starts_with("Be concise.") {
"4".to_string()
} else {
"The answer is four".to_string()
}
});
let (buffer, metas, _input_sym, output_sym) = build_dataset(&[("What is 2+2?", "4")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, ExactMatch)
.instruction("Be concise.")
.output_field(output_sym)
.run(&dataset);
assert_eq!(result.total, 1);
assert_eq!(result.mean, 1.0);
}
#[test]
fn test_evaluate_with_threshold() {
let llm = MockLlm::new(|_, _| "the quick brown fox".to_string());
let (buffer, metas, _input_sym, output_sym) =
build_dataset(&[("Repeat", "the quick brown fox jumps")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, F1Token)
.output_field(output_sym)
.threshold(0.5)
.run(&dataset);
assert_eq!(result.total, 1);
assert!(result.mean > 0.5);
assert_eq!(result.pass_count, 1);
}
#[test]
fn test_evaluate_with_high_threshold() {
let llm = MockLlm::new(|_, _| "the quick brown fox".to_string());
let (buffer, metas, _input_sym, output_sym) =
build_dataset(&[("Repeat", "the quick brown fox jumps over")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, ExactMatch)
.output_field(output_sym)
.threshold(1.0)
.run(&dataset);
assert_eq!(result.total, 1);
assert_eq!(result.pass_count, 0);
}
#[test]
fn test_evaluate_empty_dataset() {
let llm = MockLlm::new(|_, _| String::new());
static BUFFER: Buffer = Buffer::Static(b"");
let dataset = ExampleSet::new(&BUFFER, &[]);
let result = Evaluate::new(&llm, ExactMatch).run(&dataset);
assert_eq!(result.total, 0);
assert_eq!(result.mean, 0.0);
assert_eq!(result.median, 0.0);
assert_eq!(result.std_dev, 0.0);
assert_eq!(result.pass_count, 0);
}
#[test]
fn test_evaluate_contains_metric() {
let llm = MockLlm::new(|_, _| "The answer is 42, obviously.".to_string());
let (buffer, metas, _input_sym, output_sym) = build_dataset(&[("What is it?", "42")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, Contains)
.output_field(output_sym)
.run(&dataset);
assert_eq!(result.total, 1);
assert_eq!(result.mean, 1.0);
assert_eq!(result.pass_count, 1);
}
#[test]
fn test_evaluate_rescore() {
let llm = MockLlm::new(|_, _| "The answer is 42".to_string());
let (buffer, metas, _input_sym, output_sym) = build_dataset(&[("Question", "42")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, ExactMatch)
.output_field(output_sym)
.run(&dataset);
assert_eq!(result.mean, 0.0);
assert_eq!(result.pass_count, 0);
let rescored = result.rescore(&Contains);
assert_eq!(rescored.mean, 1.0);
assert_eq!(rescored.pass_count, 1);
assert_eq!(rescored.total, 1);
assert_eq!(rescored.per_example[0].prediction, "The answer is 42");
assert_eq!(rescored.per_example[0].expected, "42");
}
#[test]
fn test_evaluate_rescore_with_threshold() {
let llm = MockLlm::new(|_, _| "partial match here".to_string());
let (buffer, metas, _input_sym, output_sym) =
build_dataset(&[("Q", "partial match here and more")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, F1Token)
.output_field(output_sym)
.run(&dataset);
let strict = result.rescore_with_threshold(&F1Token, 0.99);
assert_eq!(strict.pass_count, 0);
let lenient = result.rescore_with_threshold(&F1Token, 0.1);
assert_eq!(lenient.pass_count, 1);
}
#[test]
fn test_eval_result_pass_rate() {
let per_example = vec![
ExampleResult {
index: 0,
prediction: "a".to_string(),
expected: "a".to_string(),
score: 1.0,
passed: true,
},
ExampleResult {
index: 1,
prediction: "b".to_string(),
expected: "c".to_string(),
score: 0.0,
passed: false,
},
ExampleResult {
index: 2,
prediction: "d".to_string(),
expected: "d".to_string(),
score: 1.0,
passed: true,
},
];
let result = EvalResult::from_examples(per_example, 0);
assert!((result.pass_rate() - 2.0 / 3.0).abs() < 1e-9);
}
#[test]
fn test_eval_result_pass_rate_empty() {
let result = EvalResult::from_examples(Vec::new(), 0);
assert_eq!(result.pass_rate(), 0.0);
}
#[test]
fn test_eval_result_statistics() {
let per_example = vec![
ExampleResult {
index: 0,
prediction: String::new(),
expected: String::new(),
score: 0.0,
passed: false,
},
ExampleResult {
index: 1,
prediction: String::new(),
expected: String::new(),
score: 0.5,
passed: true,
},
ExampleResult {
index: 2,
prediction: String::new(),
expected: String::new(),
score: 1.0,
passed: true,
},
];
let result = EvalResult::from_examples(per_example, 100);
assert!((result.mean - 0.5).abs() < 1e-9);
assert!((result.median - 0.5).abs() < 1e-9);
assert!((result.std_dev - 0.5).abs() < 1e-9);
assert_eq!(result.pass_count, 2);
assert_eq!(result.total, 3);
assert_eq!(result.total_tokens, 100);
}
#[test]
fn test_eval_result_median_even_count() {
let per_example: Vec<ExampleResult> = [0.2, 0.4, 0.6, 0.8]
.iter()
.enumerate()
.map(|(i, &score)| ExampleResult {
index: i,
prediction: String::new(),
expected: String::new(),
score,
passed: score >= 0.5,
})
.collect();
let result = EvalResult::from_examples(per_example, 0);
assert!((result.median - 0.5).abs() < 1e-9);
}
#[test]
fn test_evaluate_with_fn_metric() {
let llm = MockLlm::new(|_, _| "abc".to_string());
let (buffer, metas, _input_sym, output_sym) = build_dataset(&[("Q", "abcd")]);
let dataset = ExampleSet::new(&buffer, &metas);
let length_metric = FnMetric::new("length_ratio", |pred, expected| {
pred.len() as f64 / expected.len().max(1) as f64
});
let result = Evaluate::new(&llm, length_metric)
.output_field(output_sym)
.run(&dataset);
assert_eq!(result.total, 1);
assert!((result.mean - 0.75).abs() < 1e-9);
}
#[test]
fn test_evaluate_with_demos() {
let llm = MockLlm::new(|_prompt, _feedback| {
"4".to_string()
});
let (buffer, metas, _input_sym, output_sym) = build_dataset(&[
("What is 1+1?", "2"), ("What is 2+2?", "4"), ]);
let dataset = ExampleSet::new(&buffer, &metas);
let demo_indices = [0u32];
let result = Evaluate::new(&llm, ExactMatch)
.output_field(output_sym)
.demos(&demo_indices)
.run(&dataset);
assert_eq!(result.total, 1);
assert_eq!(result.mean, 1.0);
assert_eq!(result.per_example[0].index, 1);
}
#[test]
fn test_evaluate_preserves_tokens() {
let llm = MockLlm::new(|_, _| "answer".to_string());
let (buffer, metas, _input_sym, output_sym) = build_dataset(&[("Q1", "A1"), ("Q2", "A2")]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, ExactMatch)
.output_field(output_sym)
.run(&dataset);
assert_eq!(result.total_tokens, 0);
}
#[test]
fn test_evaluate_multiple_examples_statistics() {
let llm = MockLlm::new(|_, _| "hello".to_string());
let (buffer, metas, _input_sym, output_sym) = build_dataset(&[
("Q1", "hello"), ("Q2", "world"), ("Q3", "hello"), ("Q4", "goodbye"), ("Q5", "hello"), ]);
let dataset = ExampleSet::new(&buffer, &metas);
let result = Evaluate::new(&llm, ExactMatch)
.output_field(output_sym)
.threshold(0.5)
.run(&dataset);
assert_eq!(result.total, 5);
assert_eq!(result.pass_count, 3);
assert!((result.mean - 0.6).abs() < 1e-9);
assert!((result.median - 1.0).abs() < 1e-9);
assert!((result.pass_rate() - 0.6).abs() < 1e-9);
}
#[test]
fn test_example_result_fields() {
let ex = ExampleResult {
index: 42,
prediction: "predicted".to_string(),
expected: "expected".to_string(),
score: 0.75,
passed: true,
};
assert_eq!(ex.index, 42);
assert_eq!(ex.prediction, "predicted");
assert_eq!(ex.expected, "expected");
assert!((ex.score - 0.75).abs() < f64::EPSILON);
assert!(ex.passed);
}
#[test]
fn test_eval_result_single_example() {
let per_example = vec![ExampleResult {
index: 0,
prediction: "a".to_string(),
expected: "a".to_string(),
score: 1.0,
passed: true,
}];
let result = EvalResult::from_examples(per_example, 50);
assert_eq!(result.mean, 1.0);
assert_eq!(result.median, 1.0);
assert_eq!(result.std_dev, 0.0); assert_eq!(result.pass_count, 1);
assert_eq!(result.total, 1);
assert_eq!(result.total_tokens, 50);
}
}