atomr_agents_eval/
pairwise.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_agents_core::{Result, Value};
7use serde::{Deserialize, Serialize};
8
9use crate::judge::JudgeModel;
10use crate::scorer::{AsyncScorer, ScorerOutcome};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum PairwiseChoice {
15 A,
16 B,
17 Tie,
18}
19
20pub struct PairwiseScorer {
21 pub model: Arc<dyn JudgeModel>,
22 pub criteria_label: String,
23}
24
25impl PairwiseScorer {
26 pub fn new(model: Arc<dyn JudgeModel>, criteria_label: impl Into<String>) -> Self {
27 Self {
28 model,
29 criteria_label: criteria_label.into(),
30 }
31 }
32
33 pub async fn compare(&self, prompt: &str, a: &Value, b: &Value) -> Result<(PairwiseChoice, String)> {
34 let p = format!(
35 "Pairwise preference task. Criterion: {}\n\nPrompt:\n{prompt}\n\nResponse A:\n{a}\n\nResponse B:\n{b}\n\nReply on the first line with one of: A, B, or TIE. Then on the next line a short justification.",
36 self.criteria_label
37 );
38 let reply = self.model.judge(&p).await?;
39 let choice = reply
40 .lines()
41 .next()
42 .map(|s| s.trim().to_uppercase())
43 .unwrap_or_default();
44 let pc = match choice.as_str() {
45 "A" => PairwiseChoice::A,
46 "B" => PairwiseChoice::B,
47 _ => PairwiseChoice::Tie,
48 };
49 let note = reply.lines().nth(1).unwrap_or("").trim().to_string();
50 Ok((pc, note))
51 }
52}
53
54#[async_trait]
55impl AsyncScorer for PairwiseScorer {
56 async fn score(&self, expected: &Value, actual: &Value) -> ScorerOutcome {
67 let prompt = self.criteria_label.clone();
68 match self.compare(&prompt, expected, actual).await {
69 Ok((choice, note)) => {
70 let (passed, score) = match choice {
71 PairwiseChoice::A => (true, 1.0),
72 PairwiseChoice::Tie => (true, 0.5),
73 PairwiseChoice::B => (false, 0.0),
74 };
75 ScorerOutcome { passed, score, note }
76 }
77 Err(e) => ScorerOutcome {
78 passed: false,
79 score: 0.0,
80 note: format!("pairwise error: {e}"),
81 },
82 }
83 }
84}
85
86pub fn preference_rate(votes: &[PairwiseChoice]) -> f32 {
89 if votes.is_empty() {
90 return 0.0;
91 }
92 let a_score: f32 = votes
93 .iter()
94 .map(|c| match c {
95 PairwiseChoice::A => 1.0,
96 PairwiseChoice::Tie => 0.5,
97 PairwiseChoice::B => 0.0,
98 })
99 .sum();
100 a_score / votes.len() as f32
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use parking_lot::Mutex;
107
108 struct ScriptedJudge {
109 replies: Mutex<Vec<String>>,
110 }
111 #[async_trait]
112 impl JudgeModel for ScriptedJudge {
113 async fn judge(&self, _prompt: &str) -> Result<String> {
114 Ok(self.replies.lock().remove(0))
115 }
116 }
117
118 #[tokio::test]
119 async fn pairwise_picks_a_or_b() {
120 let m = Arc::new(ScriptedJudge {
121 replies: Mutex::new(vec!["A\nclearer answer".into()]),
122 });
123 let s = PairwiseScorer::new(m, "helpfulness");
124 let (c, note) = s
125 .compare("hi", &Value::String("a".into()), &Value::String("b".into()))
126 .await
127 .unwrap();
128 assert_eq!(c, PairwiseChoice::A);
129 assert!(note.contains("clearer"));
130 }
131
132 #[tokio::test]
133 async fn async_scorer_picks_a_as_pass_with_score_one() {
134 let m = Arc::new(ScriptedJudge {
135 replies: Mutex::new(vec!["A\nclearer".into()]),
136 });
137 let s = PairwiseScorer::new(m, "helpfulness");
138 let out = AsyncScorer::score(
139 &s,
140 &Value::String("expected".into()),
141 &Value::String("actual".into()),
142 )
143 .await;
144 assert!(out.passed);
145 assert!((out.score - 1.0).abs() < 1e-6);
146 assert!(out.note.contains("clearer"));
147 }
148
149 #[tokio::test]
150 async fn async_scorer_b_choice_fails_with_score_zero() {
151 let m = Arc::new(ScriptedJudge {
152 replies: Mutex::new(vec!["B\nbetter".into()]),
153 });
154 let s = PairwiseScorer::new(m, "quality");
155 let out = AsyncScorer::score(
156 &s,
157 &Value::String("expected".into()),
158 &Value::String("actual".into()),
159 )
160 .await;
161 assert!(!out.passed);
162 assert!((out.score - 0.0).abs() < 1e-6);
163 }
164
165 #[tokio::test]
166 async fn async_scorer_tie_passes_with_half_score() {
167 let m = Arc::new(ScriptedJudge {
168 replies: Mutex::new(vec!["TIE\nequal".into()]),
169 });
170 let s = PairwiseScorer::new(m, "quality");
171 let out = AsyncScorer::score(
172 &s,
173 &Value::String("expected".into()),
174 &Value::String("actual".into()),
175 )
176 .await;
177 assert!(out.passed);
178 assert!((out.score - 0.5).abs() < 1e-6);
179 }
180
181 #[test]
182 fn preference_rate_averages_choices() {
183 let votes = vec![
184 PairwiseChoice::A,
185 PairwiseChoice::A,
186 PairwiseChoice::B,
187 PairwiseChoice::Tie,
188 ];
189 assert!((preference_rate(&votes) - 0.625).abs() < 1e-5);
191 }
192}