Skip to main content

atomr_agents_eval/
scorer.rs

1use async_trait::async_trait;
2use atomr_agents_core::Value;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct ScorerOutcome {
7    pub passed: bool,
8    pub score: f32,
9    pub note: String,
10}
11
12/// Sync scorer — pure-CPU comparators (substring match, JSON shape,
13/// regex, etc.). Most scorers should implement this.
14pub trait Scorer: Send + Sync + 'static {
15    fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome;
16}
17
18/// Async-friendly scorer for impls that genuinely await — LLM judges,
19/// retrieval-grounded checks, anything network-bound. The blanket impl
20/// below promotes every sync `Scorer` into an `AsyncScorer`, so callers
21/// who hold `Arc<dyn AsyncScorer>` can accept both transparently.
22#[async_trait]
23pub trait AsyncScorer: Send + Sync + 'static {
24    async fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome;
25}
26
27#[async_trait]
28impl<S: Scorer> AsyncScorer for S {
29    async fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome {
30        Scorer::score(self, expected, actual)
31    }
32}
33
34/// Trivial substring-presence scorer.
35pub struct ContainsScorer;
36
37impl Scorer for ContainsScorer {
38    fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome {
39        let needle = expected
40            .get("must_contain")
41            .and_then(|v| v.as_str())
42            .unwrap_or("");
43        let hay = match actual {
44            Value::String(s) => s.clone(),
45            other => serde_json::to_string(other).unwrap_or_default(),
46        };
47        let passed = hay.contains(needle);
48        ScorerOutcome {
49            passed,
50            score: if passed { 1.0 } else { 0.0 },
51            note: if passed {
52                format!("found {needle:?}")
53            } else {
54                format!("missing {needle:?} in {hay:?}")
55            },
56        }
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[tokio::test]
65    async fn blanket_async_promotes_sync_scorer() {
66        // The blanket `impl<S: Scorer> AsyncScorer for S` means
67        // ContainsScorer can be awaited directly — no wrapper, no
68        // explicit cast on the call.
69        let s = ContainsScorer;
70        let out = AsyncScorer::score(
71            &s,
72            &serde_json::json!({"must_contain": "hi"}),
73            &Value::String("oh hi there".into()),
74        )
75        .await;
76        assert!(out.passed);
77        assert!((out.score - 1.0).abs() < 1e-6);
78
79        let out2 = AsyncScorer::score(
80            &s,
81            &serde_json::json!({"must_contain": "missing"}),
82            &Value::String("oh hi there".into()),
83        )
84        .await;
85        assert!(!out2.passed);
86    }
87
88    #[tokio::test]
89    async fn blanket_works_through_trait_object() {
90        // Confirm that an `Arc<dyn AsyncScorer>` constructed from a
91        // sync Scorer dispatches correctly.
92        use std::sync::Arc;
93        let s: Arc<dyn AsyncScorer> = Arc::new(ContainsScorer);
94        let out = s
95            .score(
96                &serde_json::json!({"must_contain": "yes"}),
97                &Value::String("yes please".into()),
98            )
99            .await;
100        assert!(out.passed);
101    }
102}