use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::{Context, Result};
use async_trait::async_trait;
use regex::Regex;
use serde::{Deserialize, Serialize};
use super::case::EvaluationCase;
use super::trial::TrialResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Fixture {
pub name: String,
#[serde(default = "default_category")]
pub category: String,
#[serde(default)]
pub model: Option<String>,
pub messages: Vec<FixtureMessage>,
pub expected: ExpectedBehavior,
}
fn default_category() -> String {
"fixture".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FixtureMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ExpectedBehavior {
#[serde(default)]
pub tool_sequence: Vec<String>,
#[serde(default)]
pub assertions: Vec<Assertion>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Assertion {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub contains: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_called: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct RunOutcome {
pub output_text: String,
pub tool_sequence: Vec<String>,
pub finish_reason: Option<String>,
pub duration_ms: u64,
}
#[async_trait]
pub trait FixtureRunner: Send + Sync {
async fn run(&self, fixture: &Fixture, trial_id: usize) -> Result<RunOutcome>;
}
pub struct FixtureCase {
fixture: Arc<Fixture>,
runner: Arc<dyn FixtureRunner>,
}
impl FixtureCase {
pub fn new(fixture: Arc<Fixture>, runner: Arc<dyn FixtureRunner>) -> Self {
Self { fixture, runner }
}
pub fn fixture(&self) -> &Fixture {
&self.fixture
}
}
#[async_trait]
impl EvaluationCase for FixtureCase {
fn name(&self) -> &str {
&self.fixture.name
}
fn category(&self) -> &str {
&self.fixture.category
}
async fn run(&self, trial_id: usize) -> Result<TrialResult> {
let started = std::time::Instant::now();
let outcome = match self.runner.run(&self.fixture, trial_id).await {
Ok(o) => o,
Err(e) => {
return Ok(TrialResult::failure(
trial_id,
started.elapsed().as_millis() as u64,
format!("runner error: {e:#}"),
));
}
};
match evaluate(&self.fixture.expected, &outcome) {
Ok(()) => Ok(TrialResult::success(trial_id, outcome.duration_ms)),
Err(reason) => Ok(TrialResult::failure(trial_id, outcome.duration_ms, reason)),
}
}
}
pub fn evaluate(expected: &ExpectedBehavior, outcome: &RunOutcome) -> Result<(), String> {
if !expected.tool_sequence.is_empty() && expected.tool_sequence != outcome.tool_sequence {
return Err(format!(
"tool_sequence mismatch: expected {:?}, got {:?}",
expected.tool_sequence, outcome.tool_sequence
));
}
for a in &expected.assertions {
if let Some(needle) = &a.contains
&& !outcome.output_text.contains(needle.as_str())
{
return Err(format!(
"output_text missing expected substring: {needle:?}"
));
}
if let Some(pat) = &a.regex {
let re =
Regex::new(pat).map_err(|e| format!("invalid regex in fixture: {pat:?} ({e})"))?;
if !re.is_match(&outcome.output_text) {
return Err(format!("output_text did not match regex: {pat:?}"));
}
}
if let Some(name) = &a.tool_called
&& !outcome.tool_sequence.iter().any(|t| t == name)
{
return Err(format!(
"expected tool `{name}` to be called; got {:?}",
outcome.tool_sequence
));
}
if let Some(expected_reason) = &a.finish_reason {
let got = outcome.finish_reason.as_deref().unwrap_or("");
if got != expected_reason {
return Err(format!(
"finish_reason mismatch: expected {expected_reason:?}, got {got:?}"
));
}
}
}
Ok(())
}
pub fn load_fixture_file(path: impl AsRef<Path>) -> Result<Fixture> {
let path = path.as_ref();
let raw = std::fs::read_to_string(path)
.with_context(|| format!("reading fixture {}", path.display()))?;
let fixture: Fixture =
serde_yml::from_str(&raw).with_context(|| format!("parsing fixture {}", path.display()))?;
Ok(fixture)
}
pub fn load_fixtures_from_dir(dir: impl AsRef<Path>) -> Result<Vec<Fixture>> {
let dir = dir.as_ref();
let mut out = Vec::new();
let mut paths: Vec<PathBuf> = Vec::new();
let entries =
std::fs::read_dir(dir).with_context(|| format!("reading fixture dir {}", dir.display()))?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if !path.is_file() {
continue;
}
match path.extension().and_then(|s| s.to_str()) {
Some("yaml") | Some("yml") => paths.push(path),
_ => {}
}
}
paths.sort();
for p in paths {
out.push(load_fixture_file(&p)?);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn happy_outcome(seq: Vec<&str>, text: &str) -> RunOutcome {
RunOutcome {
output_text: text.to_string(),
tool_sequence: seq.into_iter().map(String::from).collect(),
finish_reason: Some("end_turn".into()),
duration_ms: 5,
}
}
fn contains(s: &str) -> Assertion {
Assertion {
contains: Some(s.into()),
..Default::default()
}
}
fn tool_called(s: &str) -> Assertion {
Assertion {
tool_called: Some(s.into()),
..Default::default()
}
}
fn finish_reason(s: &str) -> Assertion {
Assertion {
finish_reason: Some(s.into()),
..Default::default()
}
}
fn regex_match(s: &str) -> Assertion {
Assertion {
regex: Some(s.into()),
..Default::default()
}
}
#[test]
fn evaluate_passes_when_all_assertions_hold() {
let expected = ExpectedBehavior {
tool_sequence: vec!["read_file".into(), "edit_file".into()],
assertions: vec![
contains("fn bar"),
tool_called("edit_file"),
finish_reason("end_turn"),
],
};
let outcome = happy_outcome(vec!["read_file", "edit_file"], "updated: fn bar() {}");
evaluate(&expected, &outcome).expect("should pass");
}
#[test]
fn evaluate_fails_on_tool_sequence_mismatch() {
let expected = ExpectedBehavior {
tool_sequence: vec!["read_file".into(), "edit_file".into()],
..Default::default()
};
let outcome = happy_outcome(vec!["edit_file"], "");
let err = evaluate(&expected, &outcome).unwrap_err();
assert!(err.contains("tool_sequence mismatch"));
}
#[test]
fn evaluate_fails_on_missing_substring() {
let expected = ExpectedBehavior {
assertions: vec![contains("bar")],
..Default::default()
};
let outcome = happy_outcome(vec![], "only foo here");
assert!(evaluate(&expected, &outcome).is_err());
}
#[test]
fn evaluate_regex_assertion() {
let expected = ExpectedBehavior {
assertions: vec![regex_match(r"^updated:")],
..Default::default()
};
let outcome = happy_outcome(vec![], "updated: ok");
evaluate(&expected, &outcome).expect("matches");
}
#[test]
fn load_fixtures_from_tmpdir_in_sorted_order() {
let dir = tempfile::tempdir().unwrap();
let a = r#"
name: aa
category: test
messages:
- { role: user, content: "hi" }
expected:
assertions:
- contains: "hi"
"#;
let b = r#"
name: bb
category: test
messages:
- { role: user, content: "go" }
expected:
assertions:
- finish_reason: end_turn
"#;
std::fs::write(dir.path().join("a_first.yaml"), a).unwrap();
std::fs::write(dir.path().join("b_second.yml"), b).unwrap();
std::fs::write(dir.path().join("ignore_me.txt"), "").unwrap();
let fixtures = load_fixtures_from_dir(dir.path()).unwrap();
assert_eq!(fixtures.len(), 2);
assert_eq!(fixtures[0].name, "aa");
assert_eq!(fixtures[1].name, "bb");
}
struct StubRunner {
outcome: RunOutcome,
}
#[async_trait]
impl FixtureRunner for StubRunner {
async fn run(&self, _: &Fixture, _: usize) -> Result<RunOutcome> {
Ok(self.outcome.clone())
}
}
#[tokio::test]
async fn fixture_case_bridges_to_trial_result() {
let fixture = Arc::new(Fixture {
name: "f1".into(),
category: "smoke".into(),
model: None,
messages: vec![FixtureMessage {
role: "user".into(),
content: "hi".into(),
}],
expected: ExpectedBehavior {
tool_sequence: vec![],
assertions: vec![contains("hi")],
},
});
let runner = Arc::new(StubRunner {
outcome: happy_outcome(vec![], "hi there"),
});
let case = FixtureCase::new(fixture.clone(), runner);
let r = case.run(0).await.unwrap();
assert!(r.success);
let fixture_bad = Arc::new(Fixture {
expected: ExpectedBehavior {
assertions: vec![contains("BYE")],
..fixture.expected.clone()
},
..(*fixture).clone()
});
let runner = Arc::new(StubRunner {
outcome: happy_outcome(vec![], "hi there"),
});
let case = FixtureCase::new(fixture_bad, runner);
let r = case.run(0).await.unwrap();
assert!(!r.success);
assert!(
r.error
.as_deref()
.unwrap()
.contains("missing expected substring")
);
}
}