zeph-bench 0.21.2

Benchmark harness for evaluating Zeph agent performance on standardized datasets
Documentation
// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Multi-turn user simulator for tau2-bench.
//!
//! [`MultiTurnDriver`] replaces the single-prompt collapse in the MVP loader with a proper
//! conversation loop. Each turn:
//!
//! 1. The driver prompts an LLM simulator to generate the next user message from
//!    the structured task instructions and conversation history.
//! 2. The generated message is sent to the agent via a caller-provided async callback.
//! 3. The agent response is collected and appended to the transcript.
//! 4. Termination is checked: the simulator outputs `[DONE]` (case-insensitive substring
//!    in the last line), the agent calls `transfer_to_human_agents`, or `max_turns` is reached.
//!
//! A single retry is attempted on transient LLM simulator failures (any error from the
//! provider). If the retry also fails, the scenario is treated as a hard failure.
//!
//! # Backward compatibility
//!
//! The single-turn `build_prompt` function in the loader module is retained.
//! Callers that do not need multi-turn simulation continue to use
//! [`crate::runner::BenchRunner::run_dataset_with_env_factory`] unchanged.

use std::future::Future;

use zeph_llm::any::AnyProvider;
use zeph_llm::provider::{LlmProvider as _, Message, Role};

use crate::error::BenchError;

use super::data::StructuredUserInstructions;

/// One turn in a multi-turn conversation transcript.
#[derive(Debug, Clone)]
pub struct Turn {
    /// Role of the speaker.
    pub role: Role,
    /// Message text.
    pub text: String,
}

/// Result of driving a complete multi-turn conversation.
#[derive(Debug)]
pub struct MultiTurnResult {
    /// Full conversation transcript (alternating user/assistant turns).
    pub transcript: Vec<Turn>,
    /// Text of the final assistant response.
    pub final_response: String,
}

/// Drives a multi-turn user scenario through an agent.
///
/// Construct via [`MultiTurnDriver::new`] and call [`drive`][Self::drive] with an
/// async callback that sends a user message to the agent and returns the agent's response.
///
/// # Examples
///
/// ```no_run
/// # async fn example() -> Result<(), zeph_bench::BenchError> {
/// use zeph_bench::loaders::tau2_bench::driver::MultiTurnDriver;
/// use zeph_bench::loaders::tau2_bench::data::StructuredUserInstructions;
/// use zeph_llm::{any::AnyProvider, mock::MockProvider};
///
/// let instructions = StructuredUserInstructions {
///     domain: "retail".into(),
///     reason_for_call: "Cancel my order".into(),
///     task_instructions: "Cancel order #W0001".into(),
///     known_info: Some("Order id: #W0001".into()),
///     unknown_info: None,
/// };
/// let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
///     "Hello, I need to cancel order #W0001".into(),
///     "[DONE]".into(),
/// ]));
/// let driver = MultiTurnDriver::new(instructions, provider, 5);
/// let result = driver
///     .drive(|msg| async move { Ok(format!("Order cancelled: {msg}")) })
///     .await?;
/// assert!(!result.transcript.is_empty());
/// # Ok(())
/// # }
/// ```
pub struct MultiTurnDriver {
    instructions: StructuredUserInstructions,
    simulator_provider: AnyProvider,
    max_turns: usize,
}

impl MultiTurnDriver {
    /// Create a new driver.
    ///
    /// `max_turns` caps the conversation length. When reached, the driver returns
    /// the transcript without a `[DONE]` signal — the evaluator scores whatever the
    /// agent did within the turn budget.
    #[must_use]
    pub fn new(
        instructions: StructuredUserInstructions,
        simulator_provider: AnyProvider,
        max_turns: usize,
    ) -> Self {
        Self {
            instructions,
            simulator_provider,
            max_turns,
        }
    }

    /// Drive the conversation until termination.
    ///
    /// `agent_turn` is called once per turn with the user message generated by the
    /// simulator. It should send the message to the agent and return the agent's response.
    ///
    /// Termination conditions (checked in order):
    /// 1. Simulator outputs `[DONE]` (case-insensitive) anywhere in its last line.
    /// 2. Agent response contains `transfer_to_human_agents` (signals end of session).
    /// 3. `max_turns` is reached.
    ///
    /// # Errors
    ///
    /// Returns [`BenchError`] if the simulator LLM fails on both the initial attempt
    /// and the single retry, or if `agent_turn` returns an error.
    pub async fn drive<F, Fut>(&self, mut agent_turn: F) -> Result<MultiTurnResult, BenchError>
    where
        F: FnMut(String) -> Fut,
        Fut: Future<Output = Result<String, BenchError>>,
    {
        let mut transcript: Vec<Turn> = Vec::new();
        let mut final_response = String::new();

        for turn_idx in 0..self.max_turns {
            let user_msg = self.generate_user_message(&transcript, turn_idx).await?;

            if is_done_signal(&user_msg) {
                break;
            }

            transcript.push(Turn {
                role: Role::User,
                text: user_msg.clone(),
            });

            let agent_response = agent_turn(user_msg).await?;
            final_response.clone_from(&agent_response);

            transcript.push(Turn {
                role: Role::Assistant,
                text: agent_response.clone(),
            });

            if agent_response.contains("transfer_to_human_agents") {
                break;
            }
        }

        Ok(MultiTurnResult {
            transcript,
            final_response,
        })
    }

    /// Call the simulator LLM to generate the next user message.
    ///
    /// Retries once on any error before propagating.
    async fn generate_user_message(
        &self,
        history: &[Turn],
        turn_idx: usize,
    ) -> Result<String, BenchError> {
        let messages = self.build_simulator_messages(history, turn_idx);

        match self.call_simulator(&messages).await {
            Ok(msg) => Ok(msg),
            Err(_) => {
                // Single retry on transient failure.
                self.call_simulator(&messages).await.map_err(|e| {
                    BenchError::InvalidFormat(format!("simulator LLM failed after retry: {e}"))
                })
            }
        }
    }

    async fn call_simulator(&self, messages: &[Message]) -> Result<String, BenchError> {
        self.simulator_provider
            .chat(messages)
            .await
            .map_err(|e| BenchError::InvalidFormat(format!("simulator LLM error: {e}")))
    }

    fn build_simulator_messages(&self, history: &[Turn], turn_idx: usize) -> Vec<Message> {
        let i = &self.instructions;
        let mut system = format!(
            "You are a customer calling customer support for the {} domain.\n\
             Reason for call: {}\n\
             Task instructions: {}",
            i.domain, i.reason_for_call, i.task_instructions
        );
        if let Some(known) = &i.known_info {
            system.push_str("\nInformation you have: ");
            system.push_str(known);
        }
        system.push_str(
            "\n\nGenerate the next customer message based on the conversation so far. \
             When the task is complete or the agent has resolved your request, output exactly: [DONE]",
        );

        let mut messages = vec![Message::from_legacy(Role::System, system)];

        for turn in history {
            messages.push(Message::from_legacy(turn.role, turn.text.clone()));
        }

        if turn_idx == 0 {
            messages.push(Message::from_legacy(
                Role::User,
                "Generate the opening customer message.",
            ));
        }

        messages
    }
}

/// Check whether the simulator output contains a `[DONE]` signal.
///
/// Matches case-insensitively anywhere in the last non-empty line of `text`.
fn is_done_signal(text: &str) -> bool {
    let last_line = text
        .lines()
        .map(str::trim)
        .rfind(|l| !l.is_empty())
        .unwrap_or("");
    last_line.to_ascii_lowercase().contains("[done]")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn done_signal_exact() {
        assert!(is_done_signal("[DONE]"));
    }

    #[test]
    fn done_signal_case_insensitive() {
        assert!(is_done_signal("[done]"));
        assert!(is_done_signal("[Done]"));
    }

    #[test]
    fn done_signal_substring_last_line() {
        assert!(is_done_signal("Thank you for your help! [DONE]"));
    }

    #[test]
    fn done_signal_only_in_last_line() {
        assert!(is_done_signal("first line\n[DONE]"));
    }

    #[test]
    fn done_signal_not_in_last_line_returns_false() {
        // [DONE] appears in the first line but NOT in the last line — must return false.
        assert!(!is_done_signal("[DONE]\nthis comes after"));
    }

    #[test]
    fn done_signal_not_present() {
        assert!(!is_done_signal("Hello, I need help with my order."));
    }

    #[test]
    fn done_signal_empty_string() {
        assert!(!is_done_signal(""));
    }

    #[tokio::test]
    async fn drive_terminates_on_done_signal() {
        use zeph_llm::{any::AnyProvider, mock::MockProvider};

        let instructions = StructuredUserInstructions {
            domain: "retail".into(),
            reason_for_call: "Cancel order".into(),
            task_instructions: "Cancel order #W0001".into(),
            known_info: Some("Order id: #W0001".into()),
            unknown_info: None,
        };

        // First call: real user message; second call: [DONE]
        let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
            "I need to cancel order #W0001".into(),
            "[DONE]".into(),
        ]));

        let driver = MultiTurnDriver::new(instructions, provider, 10);
        let result = driver
            .drive(|_msg| async move { Ok("Order cancelled successfully.".into()) })
            .await
            .unwrap();

        // One user turn + one assistant turn before [DONE] on second simulator call.
        assert_eq!(result.transcript.len(), 2);
        assert_eq!(result.final_response, "Order cancelled successfully.");
    }

    #[tokio::test]
    async fn drive_terminates_on_max_turns() {
        use zeph_llm::{any::AnyProvider, mock::MockProvider};

        let instructions = StructuredUserInstructions {
            domain: "retail".into(),
            reason_for_call: "Test".into(),
            task_instructions: "Do nothing".into(),
            known_info: None,
            unknown_info: None,
        };

        // Always returns a non-done message.
        let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
            "Keep going".into(),
            "Keep going".into(),
            "Keep going".into(),
        ]));

        let driver = MultiTurnDriver::new(instructions, provider, 2);
        let result = driver
            .drive(|_msg| async move { Ok("Agent response".into()) })
            .await
            .unwrap();

        // max_turns = 2 → 2 user + 2 assistant turns.
        assert_eq!(result.transcript.len(), 4);
    }

    #[tokio::test]
    async fn drive_terminates_on_human_transfer() {
        use zeph_llm::{any::AnyProvider, mock::MockProvider};

        let instructions = StructuredUserInstructions {
            domain: "airline".into(),
            reason_for_call: "Escalate".into(),
            task_instructions: "Escalate to human".into(),
            known_info: None,
            unknown_info: None,
        };

        let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
            "Please escalate my issue".into(),
        ]));

        let driver = MultiTurnDriver::new(instructions, provider, 10);
        let result = driver
            .drive(|_msg| async move {
                Ok("I will transfer you now. transfer_to_human_agents called.".into())
            })
            .await
            .unwrap();

        // One round-trip, then terminated.
        assert_eq!(result.transcript.len(), 2);
    }
}