use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
use zeph_common::timestamp;
use zeph_core::agent::Agent;
use zeph_core::instructions::InstructionBlock;
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::LlmProvider as _;
use zeph_memory::semantic::SemanticMemory;
use zeph_skills::registry::SkillRegistry;
use zeph_tools::executor::{ToolError, ToolExecutor, ToolOutput};
use crate::channel::BenchmarkChannel;
use crate::error::BenchError;
use crate::loaders::tau2_bench::{ActionTrace, TauBenchEvaluator};
use crate::results::{BenchRun, RunStatus, ScenarioResult};
use crate::scenario::{DatasetLoader, Evaluator, Scenario};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResponseMode {
TerseAnswer,
ToolUse,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum MemoryMode {
#[default]
Off,
On,
}
#[derive(Debug, Clone)]
pub struct BenchMemoryParams {
pub data_dir: PathBuf,
pub embedding_model: String,
pub run_id: String,
pub dataset: String,
}
#[derive(Debug, Default)]
pub struct RunOptions {
pub scenario_filter: Option<String>,
pub completed_ids: HashSet<String>,
pub memory_mode: MemoryMode,
}
struct NoopExecutor;
impl ToolExecutor for NoopExecutor {
async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
Ok(None)
}
}
pub struct BenchRunner {
provider: AnyProvider,
memory_params: Option<BenchMemoryParams>,
}
impl BenchRunner {
#[must_use]
pub fn new(provider: AnyProvider) -> Self {
Self {
provider,
memory_params: None,
}
}
#[must_use]
pub fn with_memory_params(mut self, params: BenchMemoryParams) -> Self {
self.memory_params = Some(params);
self
}
pub async fn run_dataset<L, E>(
&self,
loader: &L,
evaluator: &E,
path: &Path,
opts: RunOptions,
) -> Result<BenchRun, BenchError>
where
L: DatasetLoader,
E: Evaluator,
{
let scenarios = loader.load(path)?;
let filtered = filter_scenarios(&scenarios, &opts, loader.name())?;
let _span = tracing::info_span!(
"bench.run_dataset",
dataset = loader.name(),
scenarios = filtered.len(),
)
.entered();
let model_id = self.provider.model_identifier().to_owned();
let mut run = BenchRun {
dataset: loader.name().to_owned(),
model: model_id,
run_id: uuid(),
started_at: timestamp::utc_now_rfc3339(),
finished_at: String::new(),
status: RunStatus::Running,
results: vec![],
aggregate: crate::results::Aggregate::default(),
};
for scenario in filtered {
let _s = tracing::info_span!("bench.scenario", id = %scenario.id).entered();
let t0 = Instant::now();
let response_text = Box::pin(self.run_one(scenario, opts.memory_mode)).await?;
let elapsed_ms = u64::try_from(t0.elapsed().as_millis()).unwrap_or(u64::MAX);
let eval = evaluator.evaluate(scenario, &response_text);
let excerpt = response_text.chars().take(200).collect::<String>();
run.results.push(ScenarioResult {
scenario_id: scenario.id.clone(),
score: eval.score,
response_excerpt: excerpt,
error: None,
elapsed_ms,
});
run.recompute_aggregate();
}
Ok(run)
}
pub async fn run_dataset_with_env_factory<L, F, X>(
&self,
loader: &L,
env_factory: F,
path: &Path,
opts: RunOptions,
) -> Result<BenchRun, BenchError>
where
L: DatasetLoader,
F: Fn(&Scenario) -> Result<(X, ActionTrace), BenchError>,
X: ToolExecutor + Send + Sync + 'static,
{
let scenarios = loader.load(path)?;
let filtered = filter_scenarios(&scenarios, &opts, loader.name())?;
let _span = tracing::info_span!(
"bench.run_dataset_with_env_factory",
dataset = loader.name(),
scenarios = filtered.len(),
)
.entered();
let model_id = self.provider.model_identifier().to_owned();
let mut run = BenchRun {
dataset: loader.name().to_owned(),
model: model_id,
run_id: uuid(),
started_at: timestamp::utc_now_rfc3339(),
finished_at: String::new(),
status: RunStatus::Running,
results: vec![],
aggregate: crate::results::Aggregate::default(),
};
for scenario in filtered {
let _s = tracing::info_span!("bench.scenario", id = %scenario.id).entered();
let (executor, trace) = env_factory(scenario)?;
let evaluator = TauBenchEvaluator::from_scenario(scenario, trace)?;
let t0 = Instant::now();
let response_text = Box::pin(self.run_one_with_executor(
scenario,
executor,
opts.memory_mode,
ResponseMode::ToolUse,
))
.await?;
let elapsed_ms = u64::try_from(t0.elapsed().as_millis()).unwrap_or(u64::MAX);
let eval = evaluator.evaluate(scenario, &response_text);
let excerpt = response_text.chars().take(200).collect::<String>();
run.results.push(ScenarioResult {
scenario_id: scenario.id.clone(),
score: eval.score,
response_excerpt: excerpt,
error: None,
elapsed_ms,
});
run.recompute_aggregate();
}
Ok(run)
}
async fn run_one(
&self,
scenario: &Scenario,
memory_mode: MemoryMode,
) -> Result<String, BenchError> {
Box::pin(self.run_one_with_executor(
scenario,
NoopExecutor,
memory_mode,
ResponseMode::TerseAnswer,
))
.await
}
#[allow(clippy::too_many_lines)] async fn run_one_with_executor<X: ToolExecutor + Send + Sync + 'static>(
&self,
scenario: &Scenario,
executor: X,
memory_mode: MemoryMode,
mode: ResponseMode,
) -> Result<String, BenchError> {
let _span = tracing::info_span!(
"bench.run_one",
scenario_id = %scenario.id,
mode = ?mode,
)
.entered();
let channel = BenchmarkChannel::from_turns(scenario.turns.clone());
if channel.total() == 0 {
return Err(BenchError::InvalidFormat(format!(
"scenario '{}' has no user turn",
scenario.id
)));
}
let registry = SkillRegistry::empty();
let system_content = match mode {
ResponseMode::TerseAnswer => concat!(
"You are an evaluation assistant. ",
"Answer every question with the shortest possible response. ",
"Give only the final answer — no explanation, no full sentences, ",
"no punctuation unless it is part of the answer. ",
"If the answer is a single word or number, respond with only that word or number."
),
ResponseMode::ToolUse => concat!(
"You are a customer-service agent. ",
"Use the available tools to help the user. ",
"Always call a tool when one applies; do not ask the user to perform actions you can perform yourself. ",
"When you have completed the user's request, respond with a brief confirmation."
),
};
let blocks = vec![InstructionBlock {
source: PathBuf::from("<bench-system-prompt>"),
content: system_content.to_owned(),
}];
let base_agent = Agent::new(self.provider.clone(), channel, registry, None, 1, executor)
.with_instruction_blocks(blocks);
let (mut agent, scenario_db) = if memory_mode == MemoryMode::On
&& let Some(ref params) = self.memory_params
{
let scenario_db = params
.data_dir
.join(format!("bench-{}-{}.db", params.run_id, scenario.id));
debug_assert!(
scenario_db.to_string_lossy().contains("bench-"),
"NFR-001: bench SQLite path must be namespaced with 'bench-'"
);
tracing::debug!(
scenario_id = %scenario.id,
path = %scenario_db.display(),
"bench: memory init start"
);
let memory = Arc::new(
tokio::time::timeout(
std::time::Duration::from_secs(10),
SemanticMemory::with_sqlite_backend(
scenario_db.to_string_lossy().as_ref(),
self.provider.clone(),
¶ms.embedding_model,
0.7,
0.3,
),
)
.await
.map_err(|_| {
BenchError::InvalidFormat(format!(
"SemanticMemory init timed out for scenario '{}'",
scenario.id
))
})?
.map_err(|e| BenchError::InvalidFormat(format!("SemanticMemory init: {e}")))?,
);
tracing::debug!(scenario_id = %scenario.id, "bench: memory init done");
let conv_id = memory
.sqlite()
.create_conversation()
.await
.map_err(|e| BenchError::InvalidFormat(format!("create_conversation: {e}")))?;
let wired_agent = base_agent.with_memory(memory, conv_id, 200, 20, 100_000);
(wired_agent, Some(scenario_db))
} else {
(base_agent, None)
};
let _ = Box::pin(agent.run()).await;
let channel = agent.into_channel();
tracing::debug!(
count = channel.tool_outputs().len(),
"bench: tool outputs captured"
);
let responses = channel.into_responses();
if let Some(ref db_path) = scenario_db {
let _ = std::fs::remove_file(db_path);
}
let raw = responses
.into_iter()
.last()
.map(|r| r.text)
.unwrap_or_default();
Ok(match mode {
ResponseMode::TerseAnswer => post_process_response(&raw),
ResponseMode::ToolUse => raw,
})
}
}
fn filter_scenarios<'a>(
scenarios: &'a [Scenario],
opts: &RunOptions,
loader_name: &str,
) -> Result<Vec<&'a Scenario>, BenchError> {
if let Some(ref filter) = opts.scenario_filter
&& !scenarios.iter().any(|s| &s.id == filter)
{
return Err(BenchError::InvalidFormat(format!(
"scenario '{filter}' not found in dataset '{loader_name}'"
)));
}
Ok(scenarios
.iter()
.filter(|s| {
if opts.completed_ids.contains(&s.id) {
return false;
}
if let Some(ref filter) = opts.scenario_filter {
return &s.id == filter;
}
true
})
.collect())
}
fn post_process_response(raw: &str) -> String {
let first_line = raw
.lines()
.map(str::trim)
.find(|l| !l.is_empty())
.unwrap_or("");
first_line
.trim_matches(|c: char| matches!(c, '*' | '_' | '`' | ' ' | '\t'))
.replace("**", "")
.replace('`', "")
.trim()
.to_owned()
}
fn uuid() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let d = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
format!("bench-{:x}-{:x}", d.as_secs(), d.subsec_nanos())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn run_options_default_is_empty() {
let opts = RunOptions::default();
assert!(opts.scenario_filter.is_none());
assert!(opts.completed_ids.is_empty());
assert_eq!(opts.memory_mode, MemoryMode::Off);
}
#[test]
fn memory_mode_default_is_off() {
assert_eq!(MemoryMode::default(), MemoryMode::Off);
}
#[test]
fn with_memory_params_sets_isolation() {
use zeph_llm::{any::AnyProvider, mock::MockProvider};
let provider = AnyProvider::Mock(MockProvider::with_responses(vec![]));
let params = BenchMemoryParams {
data_dir: std::path::PathBuf::from("/tmp/bench-data"),
embedding_model: "nomic-embed-text".into(),
run_id: "bench-abc".into(),
dataset: "locomo".into(),
};
let runner = BenchRunner::new(provider).with_memory_params(params.clone());
assert!(runner.memory_params.is_some());
let stored = runner.memory_params.unwrap();
assert_eq!(stored.run_id, "bench-abc");
assert_eq!(stored.dataset, "locomo");
}
#[test]
fn nfr_001_sqlite_path_namespaced() {
let params = BenchMemoryParams {
data_dir: std::path::PathBuf::from("/tmp/bench-data"),
embedding_model: "nomic-embed-text".into(),
run_id: "run-xyz".into(),
dataset: "locomo".into(),
};
let scenario_id = "s1_0";
let scenario_db = params
.data_dir
.join(format!("bench-{}-{}.db", params.run_id, scenario_id));
assert!(
scenario_db.to_string_lossy().contains("bench-"),
"NFR-001: SQLite path must contain bench- prefix"
);
}
#[test]
fn now_rfc3339_has_correct_format() {
let ts = timestamp::utc_now_rfc3339();
assert_eq!(ts.len(), 20);
assert!(ts.ends_with('Z'));
assert!(ts.contains('T'));
}
#[test]
fn uuid_generates_non_empty_string() {
let id = uuid();
assert!(id.starts_with("bench-"));
assert!(id.len() > 10);
}
#[test]
fn post_process_takes_first_line() {
let raw = "1945\n\nWorld War II ended in 1945.";
assert_eq!(post_process_response(raw), "1945");
}
#[test]
fn post_process_strips_markdown_bold() {
assert_eq!(post_process_response("**1945**"), "1945");
}
#[test]
fn post_process_strips_backticks() {
assert_eq!(post_process_response("`Au`"), "Au");
}
#[test]
fn post_process_trims_whitespace() {
assert_eq!(post_process_response(" Paris "), "Paris");
}
#[test]
fn post_process_empty_input_returns_empty() {
assert_eq!(post_process_response(""), "");
}
#[test]
fn post_process_skips_empty_leading_lines() {
let raw = "\n\n \nParis";
assert_eq!(post_process_response(raw), "Paris");
}
}