Skip to main content

swink_agent_eval/
response.rs

1//! Response matching evaluator.
2//!
3//! Scores the agent's final response text against expected criteria:
4//! exact match, substring containment, regex pattern, or custom function.
5
6use std::panic::{AssertUnwindSafe, catch_unwind};
7
8use regex::Regex;
9use swink_agent::prefix_chars;
10
11use crate::evaluator::Evaluator;
12use crate::score::Score;
13use crate::types::{EvalCase, EvalMetricResult, Invocation, ResponseCriteria};
14
15/// Evaluator that scores the final response text against expected criteria.
16///
17/// Returns `None` when the case has no `expected_response` defined.
18pub struct ResponseMatcher;
19
20impl Evaluator for ResponseMatcher {
21    fn name(&self) -> &'static str {
22        "response"
23    }
24
25    fn evaluate(&self, case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
26        let criteria = case.expected_response.as_ref()?;
27        let actual = invocation.final_response.as_deref().unwrap_or("");
28
29        let (score, details) = match criteria {
30            ResponseCriteria::Exact { expected } => {
31                if actual == expected {
32                    (Score::pass(), "exact match".to_string())
33                } else {
34                    (
35                        Score::fail(),
36                        format!("expected exact match, got: {}", truncate(actual, 100)),
37                    )
38                }
39            }
40            ResponseCriteria::Contains { substring } => {
41                if actual.contains(substring.as_str()) {
42                    (Score::pass(), format!("contains \"{substring}\""))
43                } else {
44                    (
45                        Score::fail(),
46                        format!(
47                            "expected to contain \"{substring}\", got: {}",
48                            truncate(actual, 100)
49                        ),
50                    )
51                }
52            }
53            ResponseCriteria::Regex { pattern } => match Regex::new(pattern) {
54                Ok(re) => {
55                    if re.is_match(actual) {
56                        (Score::pass(), format!("matches pattern /{pattern}/"))
57                    } else {
58                        (
59                            Score::fail(),
60                            format!("does not match /{pattern}/, got: {}", truncate(actual, 100)),
61                        )
62                    }
63                }
64                Err(e) => (Score::fail(), format!("invalid regex: {e}")),
65            },
66            ResponseCriteria::Custom(f) => match catch_unwind(AssertUnwindSafe(|| f(actual))) {
67                Ok(score) => {
68                    let details = format!("custom score: {:.2}", score.value);
69                    (score, details)
70                }
71                Err(payload) => {
72                    let msg = payload
73                        .downcast_ref::<&str>()
74                        .copied()
75                        .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
76                        .unwrap_or("unknown panic");
77                    (Score::fail(), format!("custom matcher panicked: {msg}"))
78                }
79            },
80        };
81
82        Some(EvalMetricResult {
83            evaluator_name: "response".to_string(),
84            score,
85            details: Some(details),
86        })
87    }
88}
89
90/// Truncate a string to at most `max_len` characters, appending "..." if truncated.
91fn truncate(s: &str, max_len: usize) -> String {
92    if s.chars().count() <= max_len {
93        s.to_string()
94    } else {
95        format!("{}...", prefix_chars(s, max_len))
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    use std::sync::Arc;
104    use std::time::Duration;
105
106    use swink_agent::{AssistantMessage, ContentBlock, Cost, ModelSpec, StopReason, Usage};
107
108    use crate::types::{EvalCase, Invocation, TurnRecord};
109
110    fn minimal_case_with_response(criteria: ResponseCriteria) -> EvalCase {
111        EvalCase {
112            id: "test".to_string(),
113            name: "Test".to_string(),
114            description: None,
115            system_prompt: "test".to_string(),
116            user_messages: vec!["test".to_string()],
117            expected_trajectory: None,
118            expected_response: Some(criteria),
119            expected_assertion: None,
120            expected_interactions: None,
121            few_shot_examples: vec![],
122            budget: None,
123            evaluators: vec![],
124            metadata: serde_json::Value::Null,
125            attachments: vec![],
126            session_id: None,
127            expected_environment_state: None,
128            expected_tool_intent: None,
129            semantic_tool_selection: false,
130            state_capture: None,
131        }
132    }
133
134    fn invocation_with_response(text: &str) -> Invocation {
135        Invocation {
136            turns: vec![TurnRecord {
137                turn_index: 0,
138                assistant_message: AssistantMessage {
139                    content: vec![ContentBlock::Text {
140                        text: text.to_string(),
141                    }],
142                    provider: "test".to_string(),
143                    model_id: "test-model".to_string(),
144                    usage: Usage::default(),
145                    cost: Cost::default(),
146                    stop_reason: StopReason::Stop,
147                    error_message: None,
148                    error_kind: None,
149                    timestamp: 0,
150                    cache_hint: None,
151                },
152                tool_calls: vec![],
153                tool_results: vec![],
154                duration: Duration::from_millis(10),
155            }],
156            total_usage: Usage::default(),
157            total_cost: Cost::default(),
158            total_duration: Duration::from_millis(10),
159            final_response: Some(text.to_string()),
160            stop_reason: StopReason::Stop,
161            model: ModelSpec::new("test", "test-model"),
162        }
163    }
164
165    #[test]
166    fn truncate_short_string() {
167        assert_eq!(truncate("hello", 10), "hello");
168    }
169
170    #[test]
171    fn truncate_long_string() {
172        let long = "a".repeat(200);
173        let result = truncate(&long, 100);
174        assert_eq!(result.len(), 103); // 100 + "..."
175        assert!(result.ends_with("..."));
176    }
177
178    #[test]
179    fn truncate_multibyte_string_is_utf8_safe() {
180        let text = format!("{}🙂tail", "a".repeat(99));
181        let result = truncate(&text, 100);
182        assert_eq!(result, format!("{}🙂...", "a".repeat(99)));
183    }
184
185    #[test]
186    fn custom_fn_panic_caught_as_failure() {
187        let criteria = ResponseCriteria::Custom(Arc::new(|_: &str| -> Score {
188            panic!("deliberate test panic");
189        }));
190        let case = minimal_case_with_response(criteria);
191        let invocation = invocation_with_response("anything");
192
193        let result = ResponseMatcher.evaluate(&case, &invocation).unwrap();
194        assert!((result.score.value - 0.0).abs() < f64::EPSILON);
195        let details = result.details.unwrap();
196        assert!(
197            details.contains("panicked"),
198            "expected panic mention, got: {details}"
199        );
200        assert!(
201            details.contains("deliberate test panic"),
202            "expected panic message, got: {details}"
203        );
204    }
205}