use std::{io::BufReader, path::Path};
use serde::Deserialize;
use crate::{
error::BenchError,
scenario::{DatasetLoader, EvalResult, Evaluator, Scenario, exact_match},
};
#[derive(Debug, Deserialize)]
struct TauBenchTask {
task_id: String,
instruction: String,
expected_actions: Vec<String>,
ground_truth: String,
domain: String,
}
#[derive(Debug)]
pub struct TauBenchLoader;
impl DatasetLoader for TauBenchLoader {
fn name(&self) -> &'static str {
"tau-bench"
}
fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
let file = std::fs::File::open(path)?;
let reader = BufReader::new(file);
let tasks: Vec<TauBenchTask> = serde_json::from_reader(reader)
.map_err(|e| BenchError::InvalidFormat(e.to_string()))?;
let scenarios = tasks
.into_iter()
.map(|t| Scenario {
id: t.task_id,
prompt: t.instruction,
expected: t.ground_truth,
metadata: serde_json::json!({
"domain": t.domain,
"expected_actions": t.expected_actions,
}),
})
.collect();
Ok(scenarios)
}
}
#[derive(Debug)]
pub struct TauBenchEvaluator;
impl Evaluator for TauBenchEvaluator {
fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult {
let passed = exact_match(agent_response, &scenario.expected);
EvalResult {
scenario_id: scenario.id.clone(),
score: if passed { 1.0 } else { 0.0 },
passed,
details: format!("task_completion={}", if passed { "true" } else { "false" }),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const FIXTURE: &str = r#"[
{
"task_id": "retail_001",
"instruction": "Find product X",
"expected_actions": ["search", "select"],
"ground_truth": "Product X found",
"domain": "retail"
},
{
"task_id": "airline_002",
"instruction": "Book flight to Paris",
"expected_actions": ["search_flight", "book"],
"ground_truth": "Flight booked",
"domain": "airline"
}
]"#;
fn load_from_str(json: &str) -> Vec<Scenario> {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("tau_bench.json");
std::fs::write(&path, json).unwrap();
TauBenchLoader.load(&path).unwrap()
}
#[test]
fn load_parses_scenario_count() {
assert_eq!(load_from_str(FIXTURE).len(), 2);
}
#[test]
fn load_builds_correct_ids() {
let scenarios = load_from_str(FIXTURE);
assert_eq!(scenarios[0].id, "retail_001");
assert_eq!(scenarios[1].id, "airline_002");
}
#[test]
fn load_maps_prompt_and_expected() {
let scenarios = load_from_str(FIXTURE);
assert_eq!(scenarios[0].prompt, "Find product X");
assert_eq!(scenarios[0].expected, "Product X found");
}
#[test]
fn load_stores_domain_in_metadata() {
let scenarios = load_from_str(FIXTURE);
assert_eq!(scenarios[0].metadata["domain"], "retail");
assert_eq!(scenarios[1].metadata["domain"], "airline");
}
#[test]
fn load_stores_expected_actions_in_metadata() {
let scenarios = load_from_str(FIXTURE);
assert!(scenarios[0].metadata["expected_actions"].is_array());
}
#[test]
fn evaluator_exact_match_passes() {
let scenarios = load_from_str(FIXTURE);
let result = TauBenchEvaluator.evaluate(&scenarios[0], "Product X found");
assert!(result.passed);
assert!((result.score - 1.0).abs() < f64::EPSILON);
}
#[test]
fn evaluator_wrong_answer_fails() {
let scenarios = load_from_str(FIXTURE);
let result = TauBenchEvaluator.evaluate(&scenarios[0], "Product not found");
assert!(!result.passed);
assert!(result.score < f64::EPSILON);
}
#[test]
fn evaluator_details_format() {
let scenarios = load_from_str(FIXTURE);
let pass_result = TauBenchEvaluator.evaluate(&scenarios[0], "Product X found");
assert_eq!(pass_result.details, "task_completion=true");
let fail_result = TauBenchEvaluator.evaluate(&scenarios[0], "wrong answer");
assert_eq!(fail_result.details, "task_completion=false");
}
#[test]
fn load_invalid_json_returns_error() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("bad.json");
std::fs::write(&path, "not json").unwrap();
assert!(TauBenchLoader.load(&path).is_err());
}
}