use std::sync::Arc;
use async_trait::async_trait;
use crate::runner::AsyncMetric;
#[async_trait]
pub trait JudgeProvider: Send + Sync + 'static {
async fn complete(&self, prompt: &str) -> traitclaw_core::Result<String>;
}
pub struct LlmJudgeMetric<P: JudgeProvider> {
provider: Arc<P>,
criteria: Vec<(String, String)>,
}
impl<P: JudgeProvider> LlmJudgeMetric<P> {
#[must_use]
pub fn new(provider: P) -> Self {
Self {
provider: Arc::new(provider),
criteria: Vec::new(),
}
}
#[must_use]
pub fn with_criteria(mut self, name: impl Into<String>, prompt: impl Into<String>) -> Self {
self.criteria.push((name.into(), prompt.into()));
self
}
}
#[async_trait]
impl<P: JudgeProvider> AsyncMetric for LlmJudgeMetric<P> {
fn name(&self) -> &'static str {
"llm_judge"
}
async fn score(&self, input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
let criteria_text = if self.criteria.is_empty() {
"Is this a high-quality response?".to_string()
} else {
self.criteria
.iter()
.map(|(name, prompt)| format!("- {name}: {prompt}"))
.collect::<Vec<_>>()
.join("\n")
};
let prompt = format!(
"Evaluate the following agent response:\n\nInput: {input}\n\nResponse: {actual_output}\n\nCriteria:\n{criteria_text}\n\nProvide a score from 0.0 to 1.0. Respond with only: Score: <number>"
);
match self.provider.complete(&prompt).await {
Ok(response) => parse_score(&response),
Err(_) => 0.0,
}
}
}
pub(crate) fn parse_score(response: &str) -> f64 {
for line in response.lines() {
let line = line.trim();
if let Some(rest) = line.strip_prefix("Score:") {
if let Ok(score) = rest.trim().parse::<f64>() {
return score.clamp(0.0, 1.0);
}
}
if let Ok(score) = line.parse::<f64>() {
return score.clamp(0.0, 1.0);
}
}
0.0
}
pub struct SchemaValidationMetric {
schema: serde_json::Value,
}
impl SchemaValidationMetric {
#[must_use]
pub fn new(schema: serde_json::Value) -> Self {
Self { schema }
}
}
#[async_trait]
impl AsyncMetric for SchemaValidationMetric {
fn name(&self) -> &'static str {
"schema_validation"
}
async fn score(&self, _input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
let Ok(output_val) = serde_json::from_str::<serde_json::Value>(actual_output) else {
return 0.0; };
let schema_obj = match &self.schema {
serde_json::Value::Object(m) => m,
_ => return if output_val == self.schema { 1.0 } else { 0.0 },
};
let output_obj = match &output_val {
serde_json::Value::Object(m) => m,
_ => return 0.0,
};
if schema_obj.is_empty() {
return 1.0;
}
let present = schema_obj
.keys()
.filter(|k| output_obj.contains_key(*k))
.count();
present as f64 / schema_obj.len() as f64
}
}
pub struct ToolUsageMetric {
expected_tools: Vec<String>,
}
impl ToolUsageMetric {
#[must_use]
pub fn new(expected_tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
expected_tools: expected_tools.into_iter().map(Into::into).collect(),
}
}
}
#[async_trait]
impl AsyncMetric for ToolUsageMetric {
fn name(&self) -> &'static str {
"tool_usage"
}
async fn score(&self, _input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
if self.expected_tools.is_empty() {
return 1.0;
}
let output_lower = actual_output.to_lowercase();
let found = self
.expected_tools
.iter()
.filter(|tool| output_lower.contains(tool.to_lowercase().as_str()))
.count();
found as f64 / self.expected_tools.len() as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockJudge(String);
#[async_trait]
impl JudgeProvider for MockJudge {
async fn complete(&self, _prompt: &str) -> traitclaw_core::Result<String> {
Ok(self.0.clone())
}
}
#[tokio::test]
async fn test_llm_judge_parses_score() {
let metric = LlmJudgeMetric::new(MockJudge("Score: 0.85".to_string()))
.with_criteria("accuracy", "Is it accurate?");
let score = metric.score("input", "output", &[]).await;
assert!((score - 0.85).abs() < 1e-6, "expected 0.85, got {score}");
}
#[tokio::test]
async fn test_llm_judge_clamps_above_one() {
let metric = LlmJudgeMetric::new(MockJudge("Score: 1.5".to_string()));
let score = metric.score("in", "out", &[]).await;
assert!((score - 1.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_llm_judge_invalid_response_returns_zero() {
let metric = LlmJudgeMetric::new(MockJudge("I cannot provide a score.".to_string()));
let score = metric.score("in", "out", &[]).await;
assert!((score - 0.0).abs() < 1e-6);
}
#[test]
fn test_parse_score_variants() {
assert!((parse_score("Score: 0.75") - 0.75).abs() < 1e-6);
assert!((parse_score("0.90") - 0.90).abs() < 1e-6);
assert!((parse_score("no score here") - 0.0).abs() < 1e-6);
assert!((parse_score("Score: 1.5") - 1.0).abs() < 1e-6); }
#[tokio::test]
async fn test_schema_validation_valid_json() {
let metric = SchemaValidationMetric::new(serde_json::json!({
"name": "string",
"score": "number"
}));
let output = r#"{"name": "test", "score": 42}"#;
let score = metric.score("in", output, &[]).await;
assert!((score - 1.0).abs() < 1e-6, "expected 1.0, got {score}");
}
#[tokio::test]
async fn test_schema_validation_partial_keys() {
let metric = SchemaValidationMetric::new(serde_json::json!({
"name": "string",
"score": "number",
"extra": "string"
}));
let output = r#"{"name": "test"}"#; let score = metric.score("in", output, &[]).await;
assert!(score < 0.5, "expected < 0.5, got {score}");
}
#[tokio::test]
async fn test_schema_validation_invalid_json() {
let metric = SchemaValidationMetric::new(serde_json::json!({"name": "string"}));
let score = metric.score("in", "not json at all", &[]).await;
assert!((score - 0.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_tool_usage_all_found() {
let metric = ToolUsageMetric::new(vec!["search", "calculator"]);
let score = metric
.score("in", "I used search and calculator tools", &[])
.await;
assert!((score - 1.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_tool_usage_partial() {
let metric = ToolUsageMetric::new(vec!["search", "calculator"]);
let score = metric.score("in", "I only used search", &[]).await;
assert!((score - 0.5).abs() < 1e-6);
}
#[tokio::test]
async fn test_tool_usage_none_found() {
let metric = ToolUsageMetric::new(vec!["search"]);
let score = metric.score("in", "I didn't call any tools", &[]).await;
assert!((score - 0.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_tool_usage_empty_expected() {
let metric = ToolUsageMetric::new(Vec::<String>::new());
let score = metric.score("in", "anything", &[]).await;
assert!((score - 1.0).abs() < 1e-6);
}
}