use std::sync::Arc;
use async_trait::async_trait;
use crate::{EvalReport, EvalSuite, TestResult};
#[async_trait]
pub trait AsyncMetric: Send + Sync + 'static {
fn name(&self) -> &'static str;
async fn score(&self, input: &str, actual_output: &str, expected_keywords: &[&str]) -> f64;
}
#[async_trait]
pub trait EvalAgent: Send + Sync {
async fn respond(&self, input: &str) -> traitclaw_core::Result<String>;
}
pub struct EvalRunner {
metrics: Vec<Arc<dyn AsyncMetric>>,
threshold: f64,
}
impl EvalRunner {
#[must_use]
pub fn new() -> Self {
Self {
metrics: Vec::new(),
threshold: 0.7,
}
}
#[must_use]
pub fn metric(mut self, metric: Box<dyn AsyncMetric>) -> Self {
self.metrics.push(Arc::from(metric));
self
}
#[must_use]
pub fn threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold;
self
}
pub async fn run(
&self,
agent: &dyn EvalAgent,
suite: &EvalSuite,
) -> traitclaw_core::Result<EvalReport> {
let mut results = Vec::new();
let mut total_score = 0.0;
let mut passed_count = 0;
let mut score_count = 0;
for case in suite.cases() {
let actual_output = agent.respond(&case.input).await?;
let keywords: Vec<&str> = case.expected_keywords.iter().map(String::as_str).collect();
let mut scores = std::collections::HashMap::new();
for metric in &self.metrics {
let s = metric.score(&case.input, &actual_output, &keywords).await;
scores.insert(metric.name().to_string(), s);
total_score += s;
score_count += 1;
}
if self.metrics.is_empty() {
let kw_score = score_keywords(&actual_output, &keywords);
scores.insert("keyword_match".to_string(), kw_score);
total_score += kw_score;
score_count += 1;
}
let all_pass = scores.values().all(|&s| s >= self.threshold);
if all_pass {
passed_count += 1;
}
results.push(TestResult {
case_id: case.id.clone(),
actual_output,
scores,
passed: all_pass,
});
}
let average_score = if score_count > 0 {
total_score / score_count as f64
} else {
0.0
};
Ok(EvalReport {
suite_name: suite.name().to_string(),
results,
average_score,
passed: passed_count,
total: suite.cases().len(),
})
}
}
impl Default for EvalRunner {
fn default() -> Self {
Self::new()
}
}
fn score_keywords(output: &str, keywords: &[&str]) -> f64 {
if keywords.is_empty() {
return 1.0;
}
let lower = output.to_lowercase();
let matched = keywords.iter().filter(|&&kw| lower.contains(kw)).count();
matched as f64 / keywords.len() as f64
}
pub struct SyncMetricAdapter<M: crate::Metric>(pub M);
#[async_trait]
impl<M: crate::Metric> AsyncMetric for SyncMetricAdapter<M> {
fn name(&self) -> &'static str {
self.0.name()
}
async fn score(&self, input: &str, actual_output: &str, expected_keywords: &[&str]) -> f64 {
self.0.score(input, actual_output, expected_keywords)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{EvalSuite, TestCase};
struct EchoAgent;
#[async_trait]
impl EvalAgent for EchoAgent {
async fn respond(&self, input: &str) -> traitclaw_core::Result<String> {
Ok(format!("echo: {input}"))
}
}
struct FixedMetric(f64, &'static str);
#[async_trait]
impl AsyncMetric for FixedMetric {
fn name(&self) -> &'static str {
self.1
}
async fn score(&self, _: &str, _: &str, _: &[&str]) -> f64 {
self.0
}
}
struct KeywordAsyncMetric;
#[async_trait]
impl AsyncMetric for KeywordAsyncMetric {
fn name(&self) -> &'static str {
"keyword"
}
async fn score(&self, _: &str, output: &str, kw: &[&str]) -> f64 {
if kw.is_empty() {
return 1.0;
}
let low = output.to_lowercase();
let m = kw.iter().filter(|&&k| low.contains(k)).count();
m as f64 / kw.len() as f64
}
}
#[tokio::test]
async fn test_eval_runner_three_cases() {
let suite = EvalSuite::new("suite")
.add_case(TestCase::new("c1", "hello").expect_contains("echo"))
.add_case(TestCase::new("c2", "world").expect_contains("echo"))
.add_case(TestCase::new("c3", "foo").expect_contains("echo"));
let runner = EvalRunner::new()
.metric(Box::new(KeywordAsyncMetric))
.threshold(0.8);
let report = runner.run(&EchoAgent, &suite).await.unwrap();
assert_eq!(report.results.len(), 3);
assert_eq!(report.total, 3);
assert_eq!(report.passed, 3);
}
#[tokio::test]
async fn test_eval_runner_threshold_fail() {
let suite =
EvalSuite::new("s").add_case(TestCase::new("c1", "hello").expect_contains("xyzabc"));
let runner = EvalRunner::new()
.metric(Box::new(KeywordAsyncMetric))
.threshold(0.8);
let report = runner.run(&EchoAgent, &suite).await.unwrap();
assert_eq!(report.passed, 0, "case with 0.0 keyword score should fail");
}
#[tokio::test]
async fn test_eval_runner_average_score() {
let suite = EvalSuite::new("s")
.add_case(TestCase::new("c1", "hello"))
.add_case(TestCase::new("c2", "world"));
let runner = EvalRunner::new()
.metric(Box::new(FixedMetric(0.8, "m")))
.threshold(0.7);
let report = runner.run(&EchoAgent, &suite).await.unwrap();
assert!((report.average_score - 0.8).abs() < 1e-6);
assert_eq!(report.passed, 2);
}
#[tokio::test]
async fn test_sync_metric_adapter() {
let adapter = SyncMetricAdapter(crate::KeywordMetric);
let score = adapter.score("in", "hello world", &["hello"]).await;
assert!((score - 1.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_empty_suite_gives_zero_results() {
let suite = EvalSuite::new("empty");
let runner = EvalRunner::new().metric(Box::new(KeywordAsyncMetric));
let report = runner.run(&EchoAgent, &suite).await.unwrap();
assert_eq!(report.results.len(), 0);
assert_eq!(report.total, 0);
}
}