1use std::path::{Path, PathBuf};
33use std::sync::Arc;
34
35use anyhow::{Context, Result};
36use async_trait::async_trait;
37use regex::Regex;
38use serde::{Deserialize, Serialize};
39
40use super::case::EvaluationCase;
41use super::trial::TrialResult;
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Fixture {
46 pub name: String,
48 #[serde(default = "default_category")]
50 pub category: String,
51 #[serde(default)]
53 pub model: Option<String>,
54 pub messages: Vec<FixtureMessage>,
56 pub expected: ExpectedBehavior,
58}
59
60fn default_category() -> String {
61 "fixture".to_string()
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FixtureMessage {
67 pub role: String,
69 pub content: String,
71}
72
73#[derive(Debug, Clone, Default, Serialize, Deserialize)]
75pub struct ExpectedBehavior {
76 #[serde(default)]
78 pub tool_sequence: Vec<String>,
79 #[serde(default)]
81 pub assertions: Vec<Assertion>,
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
89pub struct Assertion {
90 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub contains: Option<String>,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub regex: Option<String>,
96 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub tool_called: Option<String>,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub finish_reason: Option<String>,
102}
103
104#[derive(Debug, Clone, Default)]
106pub struct RunOutcome {
107 pub output_text: String,
109 pub tool_sequence: Vec<String>,
111 pub finish_reason: Option<String>,
113 pub duration_ms: u64,
115}
116
117#[async_trait]
123pub trait FixtureRunner: Send + Sync {
124 async fn run(&self, fixture: &Fixture, trial_id: usize) -> Result<RunOutcome>;
129}
130
131pub struct FixtureCase {
133 fixture: Arc<Fixture>,
134 runner: Arc<dyn FixtureRunner>,
135}
136
137impl FixtureCase {
138 pub fn new(fixture: Arc<Fixture>, runner: Arc<dyn FixtureRunner>) -> Self {
140 Self { fixture, runner }
141 }
142
143 pub fn fixture(&self) -> &Fixture {
145 &self.fixture
146 }
147}
148
149#[async_trait]
150impl EvaluationCase for FixtureCase {
151 fn name(&self) -> &str {
152 &self.fixture.name
153 }
154 fn category(&self) -> &str {
155 &self.fixture.category
156 }
157 async fn run(&self, trial_id: usize) -> Result<TrialResult> {
158 let started = std::time::Instant::now();
159 let outcome = match self.runner.run(&self.fixture, trial_id).await {
160 Ok(o) => o,
161 Err(e) => {
162 return Ok(TrialResult::failure(
163 trial_id,
164 started.elapsed().as_millis() as u64,
165 format!("runner error: {e:#}"),
166 ));
167 }
168 };
169 match evaluate(&self.fixture.expected, &outcome) {
170 Ok(()) => Ok(TrialResult::success(trial_id, outcome.duration_ms)),
171 Err(reason) => Ok(TrialResult::failure(trial_id, outcome.duration_ms, reason)),
172 }
173 }
174}
175
176pub fn evaluate(expected: &ExpectedBehavior, outcome: &RunOutcome) -> Result<(), String> {
179 if !expected.tool_sequence.is_empty() && expected.tool_sequence != outcome.tool_sequence {
180 return Err(format!(
181 "tool_sequence mismatch: expected {:?}, got {:?}",
182 expected.tool_sequence, outcome.tool_sequence
183 ));
184 }
185 for a in &expected.assertions {
186 if let Some(needle) = &a.contains
187 && !outcome.output_text.contains(needle.as_str())
188 {
189 return Err(format!(
190 "output_text missing expected substring: {needle:?}"
191 ));
192 }
193 if let Some(pat) = &a.regex {
194 let re =
195 Regex::new(pat).map_err(|e| format!("invalid regex in fixture: {pat:?} ({e})"))?;
196 if !re.is_match(&outcome.output_text) {
197 return Err(format!("output_text did not match regex: {pat:?}"));
198 }
199 }
200 if let Some(name) = &a.tool_called
201 && !outcome.tool_sequence.iter().any(|t| t == name)
202 {
203 return Err(format!(
204 "expected tool `{name}` to be called; got {:?}",
205 outcome.tool_sequence
206 ));
207 }
208 if let Some(expected_reason) = &a.finish_reason {
209 let got = outcome.finish_reason.as_deref().unwrap_or("");
210 if got != expected_reason {
211 return Err(format!(
212 "finish_reason mismatch: expected {expected_reason:?}, got {got:?}"
213 ));
214 }
215 }
216 }
217 Ok(())
218}
219
220pub fn load_fixture_file(path: impl AsRef<Path>) -> Result<Fixture> {
222 let path = path.as_ref();
223 let raw = std::fs::read_to_string(path)
224 .with_context(|| format!("reading fixture {}", path.display()))?;
225 let fixture: Fixture =
226 serde_yml::from_str(&raw).with_context(|| format!("parsing fixture {}", path.display()))?;
227 Ok(fixture)
228}
229
230pub fn load_fixtures_from_dir(dir: impl AsRef<Path>) -> Result<Vec<Fixture>> {
233 let dir = dir.as_ref();
234 let mut out = Vec::new();
235 let mut paths: Vec<PathBuf> = Vec::new();
236 let entries =
237 std::fs::read_dir(dir).with_context(|| format!("reading fixture dir {}", dir.display()))?;
238 for entry in entries {
239 let entry = entry?;
240 let path = entry.path();
241 if !path.is_file() {
242 continue;
243 }
244 match path.extension().and_then(|s| s.to_str()) {
245 Some("yaml") | Some("yml") => paths.push(path),
246 _ => {}
247 }
248 }
249 paths.sort();
251 for p in paths {
252 out.push(load_fixture_file(&p)?);
253 }
254 Ok(out)
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 fn happy_outcome(seq: Vec<&str>, text: &str) -> RunOutcome {
262 RunOutcome {
263 output_text: text.to_string(),
264 tool_sequence: seq.into_iter().map(String::from).collect(),
265 finish_reason: Some("end_turn".into()),
266 duration_ms: 5,
267 }
268 }
269
270 fn contains(s: &str) -> Assertion {
271 Assertion {
272 contains: Some(s.into()),
273 ..Default::default()
274 }
275 }
276 fn tool_called(s: &str) -> Assertion {
277 Assertion {
278 tool_called: Some(s.into()),
279 ..Default::default()
280 }
281 }
282 fn finish_reason(s: &str) -> Assertion {
283 Assertion {
284 finish_reason: Some(s.into()),
285 ..Default::default()
286 }
287 }
288 fn regex_match(s: &str) -> Assertion {
289 Assertion {
290 regex: Some(s.into()),
291 ..Default::default()
292 }
293 }
294
295 #[test]
296 fn evaluate_passes_when_all_assertions_hold() {
297 let expected = ExpectedBehavior {
298 tool_sequence: vec!["read_file".into(), "edit_file".into()],
299 assertions: vec![
300 contains("fn bar"),
301 tool_called("edit_file"),
302 finish_reason("end_turn"),
303 ],
304 };
305 let outcome = happy_outcome(vec!["read_file", "edit_file"], "updated: fn bar() {}");
306 evaluate(&expected, &outcome).expect("should pass");
307 }
308
309 #[test]
310 fn evaluate_fails_on_tool_sequence_mismatch() {
311 let expected = ExpectedBehavior {
312 tool_sequence: vec!["read_file".into(), "edit_file".into()],
313 ..Default::default()
314 };
315 let outcome = happy_outcome(vec!["edit_file"], "");
316 let err = evaluate(&expected, &outcome).unwrap_err();
317 assert!(err.contains("tool_sequence mismatch"));
318 }
319
320 #[test]
321 fn evaluate_fails_on_missing_substring() {
322 let expected = ExpectedBehavior {
323 assertions: vec![contains("bar")],
324 ..Default::default()
325 };
326 let outcome = happy_outcome(vec![], "only foo here");
327 assert!(evaluate(&expected, &outcome).is_err());
328 }
329
330 #[test]
331 fn evaluate_regex_assertion() {
332 let expected = ExpectedBehavior {
333 assertions: vec![regex_match(r"^updated:")],
334 ..Default::default()
335 };
336 let outcome = happy_outcome(vec![], "updated: ok");
337 evaluate(&expected, &outcome).expect("matches");
338 }
339
340 #[test]
341 fn load_fixtures_from_tmpdir_in_sorted_order() {
342 let dir = tempfile::tempdir().unwrap();
343 let a = r#"
344name: aa
345category: test
346messages:
347 - { role: user, content: "hi" }
348expected:
349 assertions:
350 - contains: "hi"
351"#;
352 let b = r#"
353name: bb
354category: test
355messages:
356 - { role: user, content: "go" }
357expected:
358 assertions:
359 - finish_reason: end_turn
360"#;
361 std::fs::write(dir.path().join("a_first.yaml"), a).unwrap();
362 std::fs::write(dir.path().join("b_second.yml"), b).unwrap();
363 std::fs::write(dir.path().join("ignore_me.txt"), "").unwrap();
364
365 let fixtures = load_fixtures_from_dir(dir.path()).unwrap();
366 assert_eq!(fixtures.len(), 2);
367 assert_eq!(fixtures[0].name, "aa");
368 assert_eq!(fixtures[1].name, "bb");
369 }
370
371 struct StubRunner {
372 outcome: RunOutcome,
373 }
374 #[async_trait]
375 impl FixtureRunner for StubRunner {
376 async fn run(&self, _: &Fixture, _: usize) -> Result<RunOutcome> {
377 Ok(self.outcome.clone())
378 }
379 }
380
381 #[tokio::test]
382 async fn fixture_case_bridges_to_trial_result() {
383 let fixture = Arc::new(Fixture {
384 name: "f1".into(),
385 category: "smoke".into(),
386 model: None,
387 messages: vec![FixtureMessage {
388 role: "user".into(),
389 content: "hi".into(),
390 }],
391 expected: ExpectedBehavior {
392 tool_sequence: vec![],
393 assertions: vec![contains("hi")],
394 },
395 });
396 let runner = Arc::new(StubRunner {
397 outcome: happy_outcome(vec![], "hi there"),
398 });
399 let case = FixtureCase::new(fixture.clone(), runner);
400 let r = case.run(0).await.unwrap();
401 assert!(r.success);
402
403 let fixture_bad = Arc::new(Fixture {
406 expected: ExpectedBehavior {
407 assertions: vec![contains("BYE")],
408 ..fixture.expected.clone()
409 },
410 ..(*fixture).clone()
411 });
412 let runner = Arc::new(StubRunner {
413 outcome: happy_outcome(vec![], "hi there"),
414 });
415 let case = FixtureCase::new(fixture_bad, runner);
416 let r = case.run(0).await.unwrap();
417 assert!(!r.success);
418 assert!(
419 r.error
420 .as_deref()
421 .unwrap()
422 .contains("missing expected substring")
423 );
424 }
425}