Skip to main content

atomr_agents_eval/
pairwise.rs

1//! Pairwise eval — judge picks A vs B and emits a preference.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_agents_core::{Result, Value};
7use serde::{Deserialize, Serialize};
8
9use crate::judge::JudgeModel;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum PairwiseChoice {
14    A,
15    B,
16    Tie,
17}
18
19pub struct PairwiseScorer {
20    pub model: Arc<dyn JudgeModel>,
21    pub criteria_label: String,
22}
23
24impl PairwiseScorer {
25    pub fn new(model: Arc<dyn JudgeModel>, criteria_label: impl Into<String>) -> Self {
26        Self {
27            model,
28            criteria_label: criteria_label.into(),
29        }
30    }
31
32    pub async fn compare(&self, prompt: &str, a: &Value, b: &Value) -> Result<(PairwiseChoice, String)> {
33        let p = format!(
34            "Pairwise preference task. Criterion: {}\n\nPrompt:\n{prompt}\n\nResponse A:\n{a}\n\nResponse B:\n{b}\n\nReply on the first line with one of: A, B, or TIE. Then on the next line a short justification.",
35            self.criteria_label
36        );
37        let reply = self.model.judge(&p).await?;
38        let choice = reply
39            .lines()
40            .next()
41            .map(|s| s.trim().to_uppercase())
42            .unwrap_or_default();
43        let pc = match choice.as_str() {
44            "A" => PairwiseChoice::A,
45            "B" => PairwiseChoice::B,
46            _ => PairwiseChoice::Tie,
47        };
48        let note = reply.lines().nth(1).unwrap_or("").trim().to_string();
49        Ok((pc, note))
50    }
51}
52
53/// Aggregate a series of pairwise comparisons into a preference rate
54/// for option A.
55pub fn preference_rate(votes: &[PairwiseChoice]) -> f32 {
56    if votes.is_empty() {
57        return 0.0;
58    }
59    let a_score: f32 = votes
60        .iter()
61        .map(|c| match c {
62            PairwiseChoice::A => 1.0,
63            PairwiseChoice::Tie => 0.5,
64            PairwiseChoice::B => 0.0,
65        })
66        .sum();
67    a_score / votes.len() as f32
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use parking_lot::Mutex;
74
75    struct ScriptedJudge {
76        replies: Mutex<Vec<String>>,
77    }
78    #[async_trait]
79    impl JudgeModel for ScriptedJudge {
80        async fn judge(&self, _prompt: &str) -> Result<String> {
81            Ok(self.replies.lock().remove(0))
82        }
83    }
84
85    #[tokio::test]
86    async fn pairwise_picks_a_or_b() {
87        let m = Arc::new(ScriptedJudge {
88            replies: Mutex::new(vec!["A\nclearer answer".into()]),
89        });
90        let s = PairwiseScorer::new(m, "helpfulness");
91        let (c, note) = s
92            .compare("hi", &Value::String("a".into()), &Value::String("b".into()))
93            .await
94            .unwrap();
95        assert_eq!(c, PairwiseChoice::A);
96        assert!(note.contains("clearer"));
97    }
98
99    #[test]
100    fn preference_rate_averages_choices() {
101        let votes = vec![
102            PairwiseChoice::A,
103            PairwiseChoice::A,
104            PairwiseChoice::B,
105            PairwiseChoice::Tie,
106        ];
107        // 1 + 1 + 0 + 0.5 = 2.5; /4 = 0.625
108        assert!((preference_rate(&votes) - 0.625).abs() < 1e-5);
109    }
110}