use std::path::{Path, PathBuf};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
use super::{CaseResult, TaskQualityAdapter, TaskQualityResult};
#[derive(Debug, Clone)]
pub struct Tau2Config {
pub dataset_path: PathBuf,
pub max_scenarios: Option<usize>,
pub max_turns_per_scenario: usize,
pub scenario_timeout_secs: u64,
pub verbose_scoring: bool,
}
impl Default for Tau2Config {
fn default() -> Self {
Self {
dataset_path: PathBuf::from("./tau2-bench/scenarios"),
max_scenarios: None,
max_turns_per_scenario: 20,
scenario_timeout_secs: 120,
verbose_scoring: false,
}
}
}
impl Tau2Config {
pub fn builder() -> Tau2ConfigBuilder {
Tau2ConfigBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct Tau2ConfigBuilder {
config: Tau2Config,
}
impl Tau2ConfigBuilder {
pub fn dataset_path(mut self, path: impl Into<PathBuf>) -> Self {
self.config.dataset_path = path.into();
self
}
pub fn max_scenarios(mut self, max: usize) -> Self {
self.config.max_scenarios = Some(max);
self
}
pub fn max_turns_per_scenario(mut self, max: usize) -> Self {
self.config.max_turns_per_scenario = max;
self
}
pub fn scenario_timeout_secs(mut self, secs: u64) -> Self {
self.config.scenario_timeout_secs = secs;
self
}
pub fn verbose_scoring(mut self, verbose: bool) -> Self {
self.config.verbose_scoring = verbose;
self
}
pub fn build(self) -> Tau2Config {
self.config
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tau2Scenario {
pub id: String,
pub description: String,
pub domain: String,
pub initial_context: String,
pub user_request: String,
pub available_actions: Vec<Tau2Action>,
pub expected_actions: Vec<Tau2ExpectedAction>,
pub success_criteria: Tau2SuccessCriteria,
#[serde(default = "default_max_turns")]
pub max_turns: usize,
}
fn default_max_turns() -> usize {
20
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tau2Action {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
#[serde(default)]
pub has_side_effects: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tau2ExpectedAction {
pub action_name: String,
pub expected_args: serde_json::Value,
#[serde(default = "default_order_matters")]
pub order_matters: bool,
}
fn default_order_matters() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tau2SuccessCriteria {
pub mode: ScoringMode,
#[serde(default = "default_pass_threshold")]
pub pass_threshold: f64,
#[serde(default)]
pub required_keywords: Vec<String>,
}
fn default_pass_threshold() -> f64 {
0.5
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ScoringMode {
Exact,
Partial,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tau2Response {
pub success: bool,
pub data: serde_json::Value,
pub message: String,
#[serde(default)]
pub scenario_complete: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tau2Report {
pub suite: String,
pub model: String,
pub total_scenarios: usize,
pub passed_scenarios: usize,
pub accuracy: f64,
pub scenarios: Vec<Tau2ScenarioResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tau2ScenarioResult {
pub scenario_id: String,
pub passed: bool,
pub score: f64,
pub actions_taken: usize,
pub turns_used: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub failure_reason: Option<String>,
}
pub struct Tau2Adapter {
config: Tau2Config,
}
impl Tau2Adapter {
pub fn new(config: Tau2Config) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self { config: Tau2Config::default() }
}
pub fn config(&self) -> &Tau2Config {
&self.config
}
async fn load_scenarios(&self) -> crate::Result<Vec<Tau2Scenario>> {
let dataset_path = &self.config.dataset_path;
if !dataset_path.exists() {
return Err(crate::BenchError::WorkloadNotFound {
path: dataset_path.display().to_string(),
});
}
info!(path = %dataset_path.display(), "loading τ²-bench scenarios");
let scenarios = load_scenarios_from_path(dataset_path).await?;
let scenarios = match self.config.max_scenarios {
Some(max) => scenarios.into_iter().take(max).collect(),
None => scenarios,
};
info!(count = scenarios.len(), "loaded τ²-bench scenarios");
Ok(scenarios)
}
async fn execute_scenario(
&self,
scenario: &Tau2Scenario,
model: &str,
) -> crate::Result<Tau2ScenarioResult> {
debug!(
scenario_id = %scenario.id,
domain = %scenario.domain,
"executing τ²-bench scenario"
);
let agent_actions = self.run_agent_session(scenario, model).await?;
let score = self.score_scenario(scenario, &agent_actions);
let passed = score >= scenario.success_criteria.pass_threshold;
let failure_reason = if !passed {
Some(format!(
"Score {score:.2} below threshold {:.2}",
scenario.success_criteria.pass_threshold
))
} else {
None
};
Ok(Tau2ScenarioResult {
scenario_id: scenario.id.clone(),
passed,
score,
actions_taken: agent_actions.len(),
turns_used: agent_actions.len(),
failure_reason,
})
}
async fn run_agent_session(
&self,
scenario: &Tau2Scenario,
model: &str,
) -> crate::Result<Vec<AgentAction>> {
debug!(
model = model,
scenario_id = %scenario.id,
max_turns = self.config.max_turns_per_scenario,
"agent session placeholder — real LLM execution not yet wired"
);
Ok(Vec::new())
}
fn score_scenario(&self, scenario: &Tau2Scenario, agent_actions: &[AgentAction]) -> f64 {
if scenario.expected_actions.is_empty() {
return if scenario.success_criteria.required_keywords.is_empty() { 1.0 } else { 0.0 };
}
match scenario.success_criteria.mode {
ScoringMode::Exact => self.score_exact(scenario, agent_actions),
ScoringMode::Partial => self.score_partial(scenario, agent_actions),
}
}
fn score_exact(&self, scenario: &Tau2Scenario, agent_actions: &[AgentAction]) -> f64 {
let expected = &scenario.expected_actions;
if agent_actions.len() != expected.len() {
return 0.0;
}
for (agent_action, expected_action) in agent_actions.iter().zip(expected.iter()) {
if agent_action.name != expected_action.action_name {
return 0.0;
}
if !partial_json_match(&expected_action.expected_args, &agent_action.arguments) {
return 0.0;
}
}
1.0
}
fn score_partial(&self, scenario: &Tau2Scenario, agent_actions: &[AgentAction]) -> f64 {
let expected = &scenario.expected_actions;
if expected.is_empty() {
return 1.0;
}
let mut correct_count = 0usize;
for expected_action in expected {
let matched = agent_actions.iter().any(|a| {
a.name == expected_action.action_name
&& partial_json_match(&expected_action.expected_args, &a.arguments)
});
if matched {
correct_count += 1;
}
}
correct_count as f64 / expected.len() as f64
}
pub fn generate_report(&self, model: &str, results: &[Tau2ScenarioResult]) -> Tau2Report {
let total_scenarios = results.len();
let passed_scenarios = results.iter().filter(|r| r.passed).count();
let accuracy = if total_scenarios > 0 {
passed_scenarios as f64 / total_scenarios as f64
} else {
0.0
};
Tau2Report {
suite: "tau2-bench".to_string(),
model: model.to_string(),
total_scenarios,
passed_scenarios,
accuracy,
scenarios: results.to_vec(),
}
}
}
impl Default for Tau2Adapter {
fn default() -> Self {
Self::with_defaults()
}
}
#[async_trait]
impl TaskQualityAdapter for Tau2Adapter {
fn name(&self) -> &str {
"tau2"
}
async fn run(&self, model: &str) -> crate::Result<TaskQualityResult> {
info!(model = model, "starting τ²-bench task quality evaluation");
let scenarios = self.load_scenarios().await?;
if scenarios.is_empty() {
warn!("no τ²-bench scenarios found — returning empty result");
return Ok(TaskQualityResult {
adapter_name: self.name().to_string(),
model: model.to_string(),
total_cases: 0,
passed_cases: 0,
accuracy: 0.0,
cases: Vec::new(),
});
}
let mut scenario_results = Vec::with_capacity(scenarios.len());
for scenario in &scenarios {
match self.execute_scenario(scenario, model).await {
Ok(result) => scenario_results.push(result),
Err(e) => {
warn!(
scenario_id = %scenario.id,
error = %e,
"scenario execution failed — marking as failed"
);
scenario_results.push(Tau2ScenarioResult {
scenario_id: scenario.id.clone(),
passed: false,
score: 0.0,
actions_taken: 0,
turns_used: 0,
failure_reason: Some(format!("Execution error: {e}")),
});
}
}
}
let report = self.generate_report(model, &scenario_results);
debug!(
accuracy = report.accuracy,
passed = report.passed_scenarios,
total = report.total_scenarios,
"τ²-bench evaluation complete"
);
let cases = scenario_results
.iter()
.map(|r| CaseResult {
case_id: r.scenario_id.clone(),
passed: r.passed,
score: r.score,
details: r.failure_reason.clone(),
})
.collect();
let total_cases = scenario_results.len();
let passed_cases = scenario_results.iter().filter(|r| r.passed).count();
let accuracy = if total_cases > 0 { passed_cases as f64 / total_cases as f64 } else { 0.0 };
Ok(TaskQualityResult {
adapter_name: self.name().to_string(),
model: model.to_string(),
total_cases,
passed_cases,
accuracy,
cases,
})
}
}
#[derive(Debug, Clone)]
struct AgentAction {
name: String,
arguments: serde_json::Value,
}
async fn load_scenarios_from_path(path: &Path) -> crate::Result<Vec<Tau2Scenario>> {
let mut scenarios = Vec::new();
if path.is_file() {
let content = tokio::fs::read_to_string(path).await.map_err(|e| {
crate::BenchError::Io(std::io::Error::new(
e.kind(),
format!("failed to read scenario file {}: {e}", path.display()),
))
})?;
let scenario: Tau2Scenario =
serde_json::from_str(&content).map_err(|e| crate::BenchError::WorkloadValidation {
field: "scenario".to_string(),
reason: format!("failed to parse τ²-bench scenario {}: {e}", path.display()),
})?;
scenarios.push(scenario);
} else if path.is_dir() {
let mut entries = tokio::fs::read_dir(path).await?;
while let Some(entry) = entries.next_entry().await? {
let entry_path = entry.path();
if entry_path.extension().and_then(|e| e.to_str()) == Some("json") {
let content = tokio::fs::read_to_string(&entry_path).await?;
match serde_json::from_str::<Tau2Scenario>(&content) {
Ok(scenario) => scenarios.push(scenario),
Err(e) => {
warn!(
path = %entry_path.display(),
error = %e,
"skipping invalid τ²-bench scenario file"
);
}
}
}
}
scenarios.sort_by(|a, b| a.id.cmp(&b.id));
}
Ok(scenarios)
}
fn partial_json_match(expected: &serde_json::Value, actual: &serde_json::Value) -> bool {
match (expected, actual) {
(serde_json::Value::Object(exp_map), serde_json::Value::Object(act_map)) => {
for (key, exp_value) in exp_map {
match act_map.get(key) {
Some(act_value) => {
if !partial_json_match(exp_value, act_value) {
return false;
}
}
None => return false,
}
}
true
}
(serde_json::Value::Array(exp_arr), serde_json::Value::Array(act_arr)) => {
if exp_arr.len() != act_arr.len() {
return false;
}
exp_arr.iter().zip(act_arr.iter()).all(|(e, a)| partial_json_match(e, a))
}
_ => expected == actual,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tau2_config_builder() {
let config = Tau2Config::builder()
.dataset_path("/tmp/tau2-scenarios")
.max_scenarios(10)
.max_turns_per_scenario(15)
.scenario_timeout_secs(60)
.verbose_scoring(true)
.build();
assert_eq!(config.dataset_path, PathBuf::from("/tmp/tau2-scenarios"));
assert_eq!(config.max_scenarios, Some(10));
assert_eq!(config.max_turns_per_scenario, 15);
assert_eq!(config.scenario_timeout_secs, 60);
assert!(config.verbose_scoring);
}
#[test]
fn test_tau2_config_defaults() {
let config = Tau2Config::default();
assert_eq!(config.dataset_path, PathBuf::from("./tau2-bench/scenarios"));
assert_eq!(config.max_scenarios, None);
assert_eq!(config.max_turns_per_scenario, 20);
assert_eq!(config.scenario_timeout_secs, 120);
assert!(!config.verbose_scoring);
}
#[test]
fn test_adapter_name() {
let adapter = Tau2Adapter::with_defaults();
assert_eq!(adapter.name(), "tau2");
}
#[test]
fn test_partial_json_match_exact() {
let expected = serde_json::json!({"key": "value"});
let actual = serde_json::json!({"key": "value", "extra": 42});
assert!(partial_json_match(&expected, &actual));
}
#[test]
fn test_partial_json_match_missing_key() {
let expected = serde_json::json!({"key": "value", "required": true});
let actual = serde_json::json!({"key": "value"});
assert!(!partial_json_match(&expected, &actual));
}
#[test]
fn test_partial_json_match_nested() {
let expected = serde_json::json!({"nested": {"inner": "val"}});
let actual = serde_json::json!({"nested": {"inner": "val", "extra": 1}, "top": true});
assert!(partial_json_match(&expected, &actual));
}
#[test]
fn test_partial_json_match_array() {
let expected = serde_json::json!([1, 2, 3]);
let actual = serde_json::json!([1, 2, 3]);
assert!(partial_json_match(&expected, &actual));
let actual_diff = serde_json::json!([1, 2, 4]);
assert!(!partial_json_match(&expected, &actual_diff));
}
#[test]
fn test_scoring_exact_match() {
let adapter = Tau2Adapter::with_defaults();
let scenario = make_test_scenario(ScoringMode::Exact);
let actions = vec![
AgentAction {
name: "lookup_customer".to_string(),
arguments: serde_json::json!({"customer_id": "C123"}),
},
AgentAction {
name: "update_record".to_string(),
arguments: serde_json::json!({"field": "email", "value": "new@example.com"}),
},
];
let score = adapter.score_scenario(&scenario, &actions);
assert_eq!(score, 1.0);
}
#[test]
fn test_scoring_exact_wrong_order() {
let adapter = Tau2Adapter::with_defaults();
let scenario = make_test_scenario(ScoringMode::Exact);
let actions = vec![
AgentAction {
name: "update_record".to_string(),
arguments: serde_json::json!({"field": "email", "value": "new@example.com"}),
},
AgentAction {
name: "lookup_customer".to_string(),
arguments: serde_json::json!({"customer_id": "C123"}),
},
];
let score = adapter.score_scenario(&scenario, &actions);
assert_eq!(score, 0.0);
}
#[test]
fn test_scoring_partial_credit() {
let adapter = Tau2Adapter::with_defaults();
let scenario = make_test_scenario(ScoringMode::Partial);
let actions = vec![AgentAction {
name: "lookup_customer".to_string(),
arguments: serde_json::json!({"customer_id": "C123"}),
}];
let score = adapter.score_scenario(&scenario, &actions);
assert_eq!(score, 0.5); }
#[test]
fn test_scoring_empty_actions() {
let adapter = Tau2Adapter::with_defaults();
let scenario = make_test_scenario(ScoringMode::Partial);
let actions: Vec<AgentAction> = Vec::new();
let score = adapter.score_scenario(&scenario, &actions);
assert_eq!(score, 0.0);
}
#[test]
fn test_generate_report() {
let adapter = Tau2Adapter::with_defaults();
let results = vec![
Tau2ScenarioResult {
scenario_id: "s1".to_string(),
passed: true,
score: 1.0,
actions_taken: 2,
turns_used: 2,
failure_reason: None,
},
Tau2ScenarioResult {
scenario_id: "s2".to_string(),
passed: false,
score: 0.3,
actions_taken: 1,
turns_used: 3,
failure_reason: Some("Score 0.30 below threshold 0.50".to_string()),
},
];
let report = adapter.generate_report("gemini-2.5-flash", &results);
assert_eq!(report.suite, "tau2-bench");
assert_eq!(report.model, "gemini-2.5-flash");
assert_eq!(report.total_scenarios, 2);
assert_eq!(report.passed_scenarios, 1);
assert_eq!(report.accuracy, 0.5);
}
#[test]
fn test_scenario_serialization_roundtrip() {
let scenario = make_test_scenario(ScoringMode::Partial);
let json = serde_json::to_string(&scenario).unwrap();
let deserialized: Tau2Scenario = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, scenario.id);
assert_eq!(deserialized.domain, scenario.domain);
assert_eq!(deserialized.available_actions.len(), 2);
}
fn make_test_scenario(mode: ScoringMode) -> Tau2Scenario {
Tau2Scenario {
id: "test-scenario-1".to_string(),
description: "Test customer service scenario".to_string(),
domain: "customer-service".to_string(),
initial_context: "You are a customer service agent.".to_string(),
user_request: "Update my email address.".to_string(),
available_actions: vec![
Tau2Action {
name: "lookup_customer".to_string(),
description: "Look up customer by ID".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"customer_id": {"type": "string"}
},
"required": ["customer_id"]
}),
has_side_effects: false,
},
Tau2Action {
name: "update_record".to_string(),
description: "Update a customer record field".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"field": {"type": "string"},
"value": {"type": "string"}
},
"required": ["field", "value"]
}),
has_side_effects: true,
},
],
expected_actions: vec![
Tau2ExpectedAction {
action_name: "lookup_customer".to_string(),
expected_args: serde_json::json!({"customer_id": "C123"}),
order_matters: true,
},
Tau2ExpectedAction {
action_name: "update_record".to_string(),
expected_args: serde_json::json!({"field": "email", "value": "new@example.com"}),
order_matters: true,
},
],
success_criteria: Tau2SuccessCriteria {
mode,
pass_threshold: 0.5,
required_keywords: Vec::new(),
},
max_turns: 10,
}
}
}