use async_trait::async_trait;
use crate::error::Result;
use crate::eval::set::{EvalScore, EvalStatus, Invocation};
#[async_trait]
pub trait Evaluator: Send + Sync + 'static {
fn name(&self) -> &str;
async fn evaluate(&self, expected: &Invocation, actual: &Invocation) -> Result<EvalScore>;
}
#[derive(Debug)]
pub struct TrajectoryMatch {
threshold: f64,
}
impl Default for TrajectoryMatch {
fn default() -> Self {
Self { threshold: 1.0 }
}
}
impl TrajectoryMatch {
#[must_use]
pub fn new(threshold: f64) -> Self {
Self { threshold }
}
}
#[async_trait]
impl Evaluator for TrajectoryMatch {
fn name(&self) -> &str {
"tool_trajectory_avg_score"
}
async fn evaluate(&self, expected: &Invocation, actual: &Invocation) -> Result<EvalScore> {
let e = &expected.intermediate_data.tool_uses;
let a = &actual.intermediate_data.tool_uses;
let denom = e.len().max(a.len()).max(1);
let mut matched = 0;
for (i, ex) in e.iter().enumerate() {
if let Some(ac) = a.get(i) {
if ex.name == ac.name && ex.args == ac.args {
matched += 1;
}
}
}
let score = (matched as f64) / (denom as f64);
let status = if score + 1e-9 >= self.threshold {
EvalStatus::Passed
} else {
EvalStatus::Failed
};
Ok(EvalScore {
score,
status,
details: serde_json::json!({"matched": matched, "expected": e.len(), "actual": a.len()}),
})
}
}
#[derive(Debug)]
pub struct ResponseMatch {
threshold: f64,
}
impl Default for ResponseMatch {
fn default() -> Self {
Self { threshold: 0.8 }
}
}
impl ResponseMatch {
#[must_use]
pub fn new(threshold: f64) -> Self {
Self { threshold }
}
}
fn response_text(c: &Option<crate::genai_types::Content>) -> String {
c.as_ref().map(|c| c.text_concat()).unwrap_or_default()
}
#[async_trait]
impl Evaluator for ResponseMatch {
fn name(&self) -> &str {
"final_response_match_v1"
}
async fn evaluate(&self, expected: &Invocation, actual: &Invocation) -> Result<EvalScore> {
let e = response_text(&expected.final_response).to_lowercase();
let a = response_text(&actual.final_response).to_lowercase();
let e_tokens: Vec<&str> = e.split_whitespace().collect();
if e_tokens.is_empty() {
return Ok(EvalScore {
score: 1.0,
status: EvalStatus::Passed,
details: serde_json::json!({"reason": "empty expected"}),
});
}
let a_tokens: std::collections::HashSet<&str> = a.split_whitespace().collect();
let mut hit = 0;
for t in &e_tokens {
if a_tokens.contains(t) {
hit += 1;
}
}
let score = (hit as f64) / (e_tokens.len() as f64);
let status = if score + 1e-9 >= self.threshold {
EvalStatus::Passed
} else {
EvalStatus::Failed
};
Ok(EvalScore {
score,
status,
details: serde_json::json!({
"expected_tokens": e_tokens.len(),
"hit": hit,
}),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::set::{IntermediateData, ToolUse};
use crate::genai_types::Content;
fn inv(final_text: &str, tool_uses: Vec<ToolUse>) -> Invocation {
Invocation {
user_content: Content::user_text(""),
final_response: Some(Content::model_text(final_text)),
intermediate_data: IntermediateData {
tool_uses,
..Default::default()
},
invocation_id: String::new(),
creation_timestamp: 0.0,
}
}
#[tokio::test]
async fn trajectory_exact_match() {
let m = TrajectoryMatch::new(1.0);
let e = inv(
"",
vec![ToolUse {
name: "f".into(),
args: serde_json::json!({"x": 1}),
}],
);
let r = m.evaluate(&e, &e).await.unwrap();
assert!((r.score - 1.0).abs() < 1e-9);
assert_eq!(r.status, EvalStatus::Passed);
}
#[tokio::test]
async fn response_match_token_score() {
let m = ResponseMatch::new(0.5);
let e = inv("hello world", vec![]);
let a = inv("Why, hello there", vec![]);
let r = m.evaluate(&e, &a).await.unwrap();
assert!((r.score - 0.5).abs() < 1e-9);
assert_eq!(r.status, EvalStatus::Passed);
}
#[tokio::test]
async fn response_match_rejects_substring_hits() {
let m = ResponseMatch::new(0.5);
let e = inv("cat", vec![]);
let a = inv("concatenate strings", vec![]);
let r = m.evaluate(&e, &a).await.unwrap();
assert!((r.score - 0.0).abs() < 1e-9);
assert_eq!(r.status, EvalStatus::Failed);
}
#[tokio::test]
async fn default_thresholds_are_strict() {
let response_match = ResponseMatch::default();
let expected = inv("alpha beta gamma delta epsilon", vec![]);
let actual = inv("alpha", vec![]);
let r = response_match.evaluate(&expected, &actual).await.unwrap();
assert_eq!(r.status, EvalStatus::Failed);
let trajectory = TrajectoryMatch::default();
let expected = inv(
"",
vec![ToolUse {
name: "f".into(),
args: serde_json::json!({}),
}],
);
let actual = inv("", vec![]);
let r = trajectory.evaluate(&expected, &actual).await.unwrap();
assert_eq!(r.status, EvalStatus::Failed);
}
}