atomr_agents_eval/
pairwise.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_agents_core::{Result, Value};
7use serde::{Deserialize, Serialize};
8
9use crate::judge::JudgeModel;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum PairwiseChoice {
14 A,
15 B,
16 Tie,
17}
18
19pub struct PairwiseScorer {
20 pub model: Arc<dyn JudgeModel>,
21 pub criteria_label: String,
22}
23
24impl PairwiseScorer {
25 pub fn new(model: Arc<dyn JudgeModel>, criteria_label: impl Into<String>) -> Self {
26 Self {
27 model,
28 criteria_label: criteria_label.into(),
29 }
30 }
31
32 pub async fn compare(&self, prompt: &str, a: &Value, b: &Value) -> Result<(PairwiseChoice, String)> {
33 let p = format!(
34 "Pairwise preference task. Criterion: {}\n\nPrompt:\n{prompt}\n\nResponse A:\n{a}\n\nResponse B:\n{b}\n\nReply on the first line with one of: A, B, or TIE. Then on the next line a short justification.",
35 self.criteria_label
36 );
37 let reply = self.model.judge(&p).await?;
38 let choice = reply
39 .lines()
40 .next()
41 .map(|s| s.trim().to_uppercase())
42 .unwrap_or_default();
43 let pc = match choice.as_str() {
44 "A" => PairwiseChoice::A,
45 "B" => PairwiseChoice::B,
46 _ => PairwiseChoice::Tie,
47 };
48 let note = reply.lines().nth(1).unwrap_or("").trim().to_string();
49 Ok((pc, note))
50 }
51}
52
53pub fn preference_rate(votes: &[PairwiseChoice]) -> f32 {
56 if votes.is_empty() {
57 return 0.0;
58 }
59 let a_score: f32 = votes
60 .iter()
61 .map(|c| match c {
62 PairwiseChoice::A => 1.0,
63 PairwiseChoice::Tie => 0.5,
64 PairwiseChoice::B => 0.0,
65 })
66 .sum();
67 a_score / votes.len() as f32
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use parking_lot::Mutex;
74
75 struct ScriptedJudge {
76 replies: Mutex<Vec<String>>,
77 }
78 #[async_trait]
79 impl JudgeModel for ScriptedJudge {
80 async fn judge(&self, _prompt: &str) -> Result<String> {
81 Ok(self.replies.lock().remove(0))
82 }
83 }
84
85 #[tokio::test]
86 async fn pairwise_picks_a_or_b() {
87 let m = Arc::new(ScriptedJudge {
88 replies: Mutex::new(vec!["A\nclearer answer".into()]),
89 });
90 let s = PairwiseScorer::new(m, "helpfulness");
91 let (c, note) = s
92 .compare("hi", &Value::String("a".into()), &Value::String("b".into()))
93 .await
94 .unwrap();
95 assert_eq!(c, PairwiseChoice::A);
96 assert!(note.contains("clearer"));
97 }
98
99 #[test]
100 fn preference_rate_averages_choices() {
101 let votes = vec![
102 PairwiseChoice::A,
103 PairwiseChoice::A,
104 PairwiseChoice::B,
105 PairwiseChoice::Tie,
106 ];
107 assert!((preference_rate(&votes) - 0.625).abs() < 1e-5);
109 }
110}