traitclaw_eval/
metrics.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use crate::runner::AsyncMetric;
12
13#[async_trait]
21pub trait JudgeProvider: Send + Sync + 'static {
22 async fn complete(&self, prompt: &str) -> traitclaw_core::Result<String>;
24}
25
26pub struct LlmJudgeMetric<P: JudgeProvider> {
53 provider: Arc<P>,
54 criteria: Vec<(String, String)>,
55}
56
57impl<P: JudgeProvider> LlmJudgeMetric<P> {
58 #[must_use]
60 pub fn new(provider: P) -> Self {
61 Self {
62 provider: Arc::new(provider),
63 criteria: Vec::new(),
64 }
65 }
66
67 #[must_use]
71 pub fn with_criteria(mut self, name: impl Into<String>, prompt: impl Into<String>) -> Self {
72 self.criteria.push((name.into(), prompt.into()));
73 self
74 }
75}
76
77#[async_trait]
78impl<P: JudgeProvider> AsyncMetric for LlmJudgeMetric<P> {
79 fn name(&self) -> &'static str {
80 "llm_judge"
81 }
82
83 async fn score(&self, input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
84 let criteria_text = if self.criteria.is_empty() {
85 "Is this a high-quality response?".to_string()
86 } else {
87 self.criteria
88 .iter()
89 .map(|(name, prompt)| format!("- {name}: {prompt}"))
90 .collect::<Vec<_>>()
91 .join("\n")
92 };
93
94 let prompt = format!(
95 "Evaluate the following agent response:\n\nInput: {input}\n\nResponse: {actual_output}\n\nCriteria:\n{criteria_text}\n\nProvide a score from 0.0 to 1.0. Respond with only: Score: <number>"
96 );
97
98 match self.provider.complete(&prompt).await {
99 Ok(response) => parse_score(&response),
100 Err(_) => 0.0,
101 }
102 }
103}
104
105pub(crate) fn parse_score(response: &str) -> f64 {
107 for line in response.lines() {
109 let line = line.trim();
110 if let Some(rest) = line.strip_prefix("Score:") {
111 if let Ok(score) = rest.trim().parse::<f64>() {
112 return score.clamp(0.0, 1.0);
113 }
114 }
115 if let Ok(score) = line.parse::<f64>() {
117 return score.clamp(0.0, 1.0);
118 }
119 }
120 0.0
121}
122
123pub struct SchemaValidationMetric {
145 schema: serde_json::Value,
146}
147
148impl SchemaValidationMetric {
149 #[must_use]
151 pub fn new(schema: serde_json::Value) -> Self {
152 Self { schema }
153 }
154}
155
156#[async_trait]
157impl AsyncMetric for SchemaValidationMetric {
158 fn name(&self) -> &'static str {
159 "schema_validation"
160 }
161
162 async fn score(&self, _input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
163 let Ok(output_val) = serde_json::from_str::<serde_json::Value>(actual_output) else {
165 return 0.0; };
167
168 let schema_obj = match &self.schema {
170 serde_json::Value::Object(m) => m,
171 _ => return if output_val == self.schema { 1.0 } else { 0.0 },
172 };
173
174 let output_obj = match &output_val {
175 serde_json::Value::Object(m) => m,
176 _ => return 0.0,
177 };
178
179 if schema_obj.is_empty() {
180 return 1.0;
181 }
182
183 let present = schema_obj
184 .keys()
185 .filter(|k| output_obj.contains_key(*k))
186 .count();
187 present as f64 / schema_obj.len() as f64
188 }
189}
190
191pub struct ToolUsageMetric {
209 expected_tools: Vec<String>,
210}
211
212impl ToolUsageMetric {
213 #[must_use]
215 pub fn new(expected_tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
216 Self {
217 expected_tools: expected_tools.into_iter().map(Into::into).collect(),
218 }
219 }
220}
221
222#[async_trait]
223impl AsyncMetric for ToolUsageMetric {
224 fn name(&self) -> &'static str {
225 "tool_usage"
226 }
227
228 async fn score(&self, _input: &str, actual_output: &str, _kw: &[&str]) -> f64 {
229 if self.expected_tools.is_empty() {
230 return 1.0;
231 }
232
233 let output_lower = actual_output.to_lowercase();
234 let found = self
235 .expected_tools
236 .iter()
237 .filter(|tool| output_lower.contains(tool.to_lowercase().as_str()))
238 .count();
239
240 found as f64 / self.expected_tools.len() as f64
241 }
242}
243
244#[cfg(test)]
249mod tests {
250 use super::*;
251
252 struct MockJudge(String);
255
256 #[async_trait]
257 impl JudgeProvider for MockJudge {
258 async fn complete(&self, _prompt: &str) -> traitclaw_core::Result<String> {
259 Ok(self.0.clone())
260 }
261 }
262
263 #[tokio::test]
264 async fn test_llm_judge_parses_score() {
265 let metric = LlmJudgeMetric::new(MockJudge("Score: 0.85".to_string()))
267 .with_criteria("accuracy", "Is it accurate?");
268
269 let score = metric.score("input", "output", &[]).await;
270 assert!((score - 0.85).abs() < 1e-6, "expected 0.85, got {score}");
271 }
272
273 #[tokio::test]
274 async fn test_llm_judge_clamps_above_one() {
275 let metric = LlmJudgeMetric::new(MockJudge("Score: 1.5".to_string()));
276 let score = metric.score("in", "out", &[]).await;
277 assert!((score - 1.0).abs() < 1e-6);
278 }
279
280 #[tokio::test]
281 async fn test_llm_judge_invalid_response_returns_zero() {
282 let metric = LlmJudgeMetric::new(MockJudge("I cannot provide a score.".to_string()));
283 let score = metric.score("in", "out", &[]).await;
284 assert!((score - 0.0).abs() < 1e-6);
285 }
286
287 #[test]
288 fn test_parse_score_variants() {
289 assert!((parse_score("Score: 0.75") - 0.75).abs() < 1e-6);
290 assert!((parse_score("0.90") - 0.90).abs() < 1e-6);
291 assert!((parse_score("no score here") - 0.0).abs() < 1e-6);
292 assert!((parse_score("Score: 1.5") - 1.0).abs() < 1e-6); }
294
295 #[tokio::test]
298 async fn test_schema_validation_valid_json() {
299 let metric = SchemaValidationMetric::new(serde_json::json!({
301 "name": "string",
302 "score": "number"
303 }));
304 let output = r#"{"name": "test", "score": 42}"#;
305 let score = metric.score("in", output, &[]).await;
306 assert!((score - 1.0).abs() < 1e-6, "expected 1.0, got {score}");
307 }
308
309 #[tokio::test]
310 async fn test_schema_validation_partial_keys() {
311 let metric = SchemaValidationMetric::new(serde_json::json!({
312 "name": "string",
313 "score": "number",
314 "extra": "string"
315 }));
316 let output = r#"{"name": "test"}"#; let score = metric.score("in", output, &[]).await;
318 assert!(score < 0.5, "expected < 0.5, got {score}");
320 }
321
322 #[tokio::test]
323 async fn test_schema_validation_invalid_json() {
324 let metric = SchemaValidationMetric::new(serde_json::json!({"name": "string"}));
326 let score = metric.score("in", "not json at all", &[]).await;
327 assert!((score - 0.0).abs() < 1e-6);
328 }
329
330 #[tokio::test]
333 async fn test_tool_usage_all_found() {
334 let metric = ToolUsageMetric::new(vec!["search", "calculator"]);
335 let score = metric
336 .score("in", "I used search and calculator tools", &[])
337 .await;
338 assert!((score - 1.0).abs() < 1e-6);
339 }
340
341 #[tokio::test]
342 async fn test_tool_usage_partial() {
343 let metric = ToolUsageMetric::new(vec!["search", "calculator"]);
344 let score = metric.score("in", "I only used search", &[]).await;
345 assert!((score - 0.5).abs() < 1e-6);
346 }
347
348 #[tokio::test]
349 async fn test_tool_usage_none_found() {
350 let metric = ToolUsageMetric::new(vec!["search"]);
351 let score = metric.score("in", "I didn't call any tools", &[]).await;
352 assert!((score - 0.0).abs() < 1e-6);
353 }
354
355 #[tokio::test]
356 async fn test_tool_usage_empty_expected() {
357 let metric = ToolUsageMetric::new(Vec::<String>::new());
358 let score = metric.score("in", "anything", &[]).await;
359 assert!((score - 1.0).abs() < 1e-6);
360 }
361}