atomr_agents_eval/
scorer.rs1use async_trait::async_trait;
2use atomr_agents_core::Value;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ScorerOutcome {
7 pub passed: bool,
8 pub score: f32,
9 pub note: String,
10}
11
12pub trait Scorer: Send + Sync + 'static {
15 fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome;
16}
17
18#[async_trait]
23pub trait AsyncScorer: Send + Sync + 'static {
24 async fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome;
25}
26
27#[async_trait]
28impl<S: Scorer> AsyncScorer for S {
29 async fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome {
30 Scorer::score(self, expected, actual)
31 }
32}
33
34pub struct ContainsScorer;
36
37impl Scorer for ContainsScorer {
38 fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome {
39 let needle = expected
40 .get("must_contain")
41 .and_then(|v| v.as_str())
42 .unwrap_or("");
43 let hay = match actual {
44 Value::String(s) => s.clone(),
45 other => serde_json::to_string(other).unwrap_or_default(),
46 };
47 let passed = hay.contains(needle);
48 ScorerOutcome {
49 passed,
50 score: if passed { 1.0 } else { 0.0 },
51 note: if passed {
52 format!("found {needle:?}")
53 } else {
54 format!("missing {needle:?} in {hay:?}")
55 },
56 }
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[tokio::test]
65 async fn blanket_async_promotes_sync_scorer() {
66 let s = ContainsScorer;
70 let out = AsyncScorer::score(
71 &s,
72 &serde_json::json!({"must_contain": "hi"}),
73 &Value::String("oh hi there".into()),
74 )
75 .await;
76 assert!(out.passed);
77 assert!((out.score - 1.0).abs() < 1e-6);
78
79 let out2 = AsyncScorer::score(
80 &s,
81 &serde_json::json!({"must_contain": "missing"}),
82 &Value::String("oh hi there".into()),
83 )
84 .await;
85 assert!(!out2.passed);
86 }
87
88 #[tokio::test]
89 async fn blanket_works_through_trait_object() {
90 use std::sync::Arc;
93 let s: Arc<dyn AsyncScorer> = Arc::new(ContainsScorer);
94 let out = s
95 .score(
96 &serde_json::json!({"must_contain": "yes"}),
97 &Value::String("yes please".into()),
98 )
99 .await;
100 assert!(out.passed);
101 }
102}