Skip to main content

atomr_agents_eval/
pairwise.rs

1//! Pairwise eval — judge picks A vs B and emits a preference.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_agents_core::{Result, Value};
7use serde::{Deserialize, Serialize};
8
9use crate::judge::JudgeModel;
10use crate::scorer::{AsyncScorer, ScorerOutcome};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum PairwiseChoice {
15    A,
16    B,
17    Tie,
18}
19
20pub struct PairwiseScorer {
21    pub model: Arc<dyn JudgeModel>,
22    pub criteria_label: String,
23}
24
25impl PairwiseScorer {
26    pub fn new(model: Arc<dyn JudgeModel>, criteria_label: impl Into<String>) -> Self {
27        Self {
28            model,
29            criteria_label: criteria_label.into(),
30        }
31    }
32
33    pub async fn compare(&self, prompt: &str, a: &Value, b: &Value) -> Result<(PairwiseChoice, String)> {
34        let p = format!(
35            "Pairwise preference task. Criterion: {}\n\nPrompt:\n{prompt}\n\nResponse A:\n{a}\n\nResponse B:\n{b}\n\nReply on the first line with one of: A, B, or TIE. Then on the next line a short justification.",
36            self.criteria_label
37        );
38        let reply = self.model.judge(&p).await?;
39        let choice = reply
40            .lines()
41            .next()
42            .map(|s| s.trim().to_uppercase())
43            .unwrap_or_default();
44        let pc = match choice.as_str() {
45            "A" => PairwiseChoice::A,
46            "B" => PairwiseChoice::B,
47            _ => PairwiseChoice::Tie,
48        };
49        let note = reply.lines().nth(1).unwrap_or("").trim().to_string();
50        Ok((pc, note))
51    }
52}
53
54#[async_trait]
55impl AsyncScorer for PairwiseScorer {
56    /// Treat `expected` as Response A and `actual` as Response B and
57    /// run a pairwise judgment. The criterion label doubles as the
58    /// task prompt context — sufficient for cases where the comparison
59    /// criterion is fully described by `criteria_label`. Callers
60    /// needing a richer prompt should use `compare()` directly.
61    ///
62    /// Score mapping:
63    /// - A wins → score 1.0, passed=true
64    /// - tie    → score 0.5, passed=true
65    /// - B wins → score 0.0, passed=false
66    async fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome {
67        let prompt = self.criteria_label.clone();
68        match self.compare(&prompt, expected, actual).await {
69            Ok((choice, note)) => {
70                let (passed, score) = match choice {
71                    PairwiseChoice::A => (true, 1.0),
72                    PairwiseChoice::Tie => (true, 0.5),
73                    PairwiseChoice::B => (false, 0.0),
74                };
75                ScorerOutcome { passed, score, note }
76            }
77            Err(e) => ScorerOutcome {
78                passed: false,
79                score: 0.0,
80                note: format!("pairwise error: {e}"),
81            },
82        }
83    }
84}
85
86/// Aggregate a series of pairwise comparisons into a preference rate
87/// for option A.
88pub fn preference_rate(votes: &[PairwiseChoice]) -> f32 {
89    if votes.is_empty() {
90        return 0.0;
91    }
92    let a_score: f32 = votes
93        .iter()
94        .map(|c| match c {
95            PairwiseChoice::A => 1.0,
96            PairwiseChoice::Tie => 0.5,
97            PairwiseChoice::B => 0.0,
98        })
99        .sum();
100    a_score / votes.len() as f32
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use parking_lot::Mutex;
107
108    struct ScriptedJudge {
109        replies: Mutex<Vec<String>>,
110    }
111    #[async_trait]
112    impl JudgeModel for ScriptedJudge {
113        async fn judge(&self, _prompt: &str) -> Result<String> {
114            Ok(self.replies.lock().remove(0))
115        }
116    }
117
118    #[tokio::test]
119    async fn pairwise_picks_a_or_b() {
120        let m = Arc::new(ScriptedJudge {
121            replies: Mutex::new(vec!["A\nclearer answer".into()]),
122        });
123        let s = PairwiseScorer::new(m, "helpfulness");
124        let (c, note) = s
125            .compare("hi", &Value::String("a".into()), &Value::String("b".into()))
126            .await
127            .unwrap();
128        assert_eq!(c, PairwiseChoice::A);
129        assert!(note.contains("clearer"));
130    }
131
132    #[tokio::test]
133    async fn async_scorer_picks_a_as_pass_with_score_one() {
134        let m = Arc::new(ScriptedJudge {
135            replies: Mutex::new(vec!["A\nclearer".into()]),
136        });
137        let s = PairwiseScorer::new(m, "helpfulness");
138        let out = AsyncScorer::score(
139            &s,
140            &Value::String("expected".into()),
141            &Value::String("actual".into()),
142        )
143        .await;
144        assert!(out.passed);
145        assert!((out.score - 1.0).abs() < 1e-6);
146        assert!(out.note.contains("clearer"));
147    }
148
149    #[tokio::test]
150    async fn async_scorer_b_choice_fails_with_score_zero() {
151        let m = Arc::new(ScriptedJudge {
152            replies: Mutex::new(vec!["B\nbetter".into()]),
153        });
154        let s = PairwiseScorer::new(m, "quality");
155        let out = AsyncScorer::score(
156            &s,
157            &Value::String("expected".into()),
158            &Value::String("actual".into()),
159        )
160        .await;
161        assert!(!out.passed);
162        assert!((out.score - 0.0).abs() < 1e-6);
163    }
164
165    #[tokio::test]
166    async fn async_scorer_tie_passes_with_half_score() {
167        let m = Arc::new(ScriptedJudge {
168            replies: Mutex::new(vec!["TIE\nequal".into()]),
169        });
170        let s = PairwiseScorer::new(m, "quality");
171        let out = AsyncScorer::score(
172            &s,
173            &Value::String("expected".into()),
174            &Value::String("actual".into()),
175        )
176        .await;
177        assert!(out.passed);
178        assert!((out.score - 0.5).abs() < 1e-6);
179    }
180
181    #[test]
182    fn preference_rate_averages_choices() {
183        let votes = vec![
184            PairwiseChoice::A,
185            PairwiseChoice::A,
186            PairwiseChoice::B,
187            PairwiseChoice::Tie,
188        ];
189        // 1 + 1 + 0 + 0.5 = 2.5; /4 = 0.625
190        assert!((preference_rate(&votes) - 0.625).abs() < 1e-5);
191    }
192}