use std::fmt::Write as _;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use crate::{
error::BenchError,
scenario::{DatasetLoader, Scenario},
};
#[cfg(test)]
use super::data::EvaluationCriteria;
use super::data::{Domain, Task, UserInstructions};
pub struct DataPaths {
pub tasks_json: PathBuf,
pub db_json: PathBuf,
pub split_tasks_json: PathBuf,
}
impl DataPaths {
#[must_use]
pub fn resolve(root: &Path, domain: Domain) -> Self {
let dir = root.join(domain.as_str());
Self {
tasks_json: dir.join("tasks.json"),
db_json: dir.join("db.json"),
split_tasks_json: dir.join("split_tasks.json"),
}
}
}
pub struct Tau2BenchLoader {
pub domain: Domain,
}
impl Tau2BenchLoader {
#[must_use]
pub fn retail() -> Self {
Self {
domain: Domain::Retail,
}
}
#[must_use]
pub fn airline() -> Self {
Self {
domain: Domain::Airline,
}
}
}
impl DatasetLoader for Tau2BenchLoader {
fn name(&self) -> &'static str {
match self.domain {
Domain::Retail => "tau2-bench-retail",
Domain::Airline => "tau2-bench-airline",
}
}
fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
let file = std::fs::File::open(path)
.map_err(|e| BenchError::InvalidFormat(format!("open tasks.json: {e}")))?;
let tasks: Vec<Task> = serde_json::from_reader(BufReader::new(file))
.map_err(|e| BenchError::InvalidFormat(format!("parse tasks.json: {e}")))?;
let loader_name = self.name();
let mut scenarios = Vec::with_capacity(tasks.len());
for task in tasks {
let id = format!("{}_{}", loader_name, task.id);
let prompt = build_prompt(&task);
let metadata = serde_json::json!({
"domain": loader_name,
"tau2_task_id": task.id,
"evaluation_criteria": task.evaluation_criteria,
});
scenarios.push(Scenario::single(id, prompt, "", metadata));
}
Ok(scenarios)
}
}
fn build_prompt(task: &Task) -> String {
match &task.user_scenario.instructions {
UserInstructions::Plain(s) => s.clone(),
UserInstructions::Structured(i) => {
let mut buf = String::new();
writeln!(buf, "You are speaking to a customer support agent.").ok();
writeln!(buf, "\nReason for call:\n{}", i.reason_for_call).ok();
if let Some(known) = &i.known_info {
writeln!(buf, "\nKnown information about you:\n{known}").ok();
}
writeln!(buf, "\nTask instructions:\n{}", i.task_instructions).ok();
buf
}
}
}
pub fn db_json_path(tasks_json: &Path) -> Result<PathBuf, BenchError> {
tasks_json
.parent()
.map(|dir| dir.join("db.json"))
.ok_or_else(|| {
BenchError::InvalidFormat(
"tasks.json must have a parent directory containing db.json".into(),
)
})
}
#[cfg(test)]
mod tests {
use super::*;
const TASKS_FIXTURE: &str = r##"[
{
"id": "0",
"user_scenario": {
"instructions": {
"domain": "retail",
"reason_for_call": "Cancel my order",
"task_instructions": "Cancel order #W0001",
"known_info": "Order id: #W0001"
}
},
"evaluation_criteria": {
"actions": [
{
"action_id": "a1",
"requestor": "assistant",
"name": "cancel_pending_order",
"arguments": {"order_id": "#W0001", "reason": "no_longer_needed"},
"compare_args": ["order_id", "reason"]
}
],
"reward_basis": ["ACTION"]
}
},
{
"id": "1",
"user_scenario": {
"instructions": "Simple plain prompt"
},
"evaluation_criteria": {
"actions": [],
"reward_basis": ["ACTION"]
}
},
{
"id": "2",
"user_scenario": {
"instructions": "DB-only task"
},
"evaluation_criteria": {
"actions": [],
"reward_basis": ["DB"]
}
}
]"##;
fn load_from_str(json: &str, domain: Domain) -> Vec<Scenario> {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("tasks.json");
std::fs::write(&path, json).unwrap();
let loader = Tau2BenchLoader { domain };
loader.load(&path).unwrap()
}
#[test]
fn load_all_tasks_regardless_of_reward_basis() {
let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
assert_eq!(scenarios.len(), 3);
}
#[test]
fn load_builds_correct_ids() {
let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
assert_eq!(scenarios[0].id, "tau2-bench-retail_0");
assert_eq!(scenarios[1].id, "tau2-bench-retail_1");
assert_eq!(scenarios[2].id, "tau2-bench-retail_2");
}
#[test]
fn load_prompt_from_structured_instructions() {
let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
let prompt = scenarios[0].primary_prompt().unwrap();
assert!(prompt.contains("Cancel my order") || prompt.contains("Cancel order"));
}
#[test]
fn load_prompt_from_plain_instructions() {
let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
let prompt = scenarios[1].primary_prompt().unwrap();
assert_eq!(prompt, "Simple plain prompt");
}
#[test]
fn metadata_contains_evaluation_criteria() {
let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
let criteria_value = scenarios[0].metadata.get("evaluation_criteria").unwrap();
let criteria: EvaluationCriteria = serde_json::from_value(criteria_value.clone()).unwrap();
assert_eq!(criteria.actions.len(), 1);
assert_eq!(criteria.actions[0].name, "cancel_pending_order");
}
#[test]
fn metadata_roundtrip_preserves_arguments() {
let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
let criteria_value = scenarios[0].metadata.get("evaluation_criteria").unwrap();
let criteria: EvaluationCriteria = serde_json::from_value(criteria_value.clone()).unwrap();
let arg = criteria.actions[0].arguments.get("order_id").unwrap();
assert_eq!(arg.as_str(), Some("#W0001"));
}
#[test]
fn airline_loader_name() {
let loader = Tau2BenchLoader::airline();
assert_eq!(loader.name(), "tau2-bench-airline");
}
}