use std::future::Future;
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::{LlmProvider as _, Message, Role};
use crate::error::BenchError;
use super::data::StructuredUserInstructions;
#[derive(Debug, Clone)]
pub struct Turn {
pub role: Role,
pub text: String,
}
#[derive(Debug)]
pub struct MultiTurnResult {
pub transcript: Vec<Turn>,
pub final_response: String,
}
pub struct MultiTurnDriver {
instructions: StructuredUserInstructions,
simulator_provider: AnyProvider,
max_turns: usize,
}
impl MultiTurnDriver {
#[must_use]
pub fn new(
instructions: StructuredUserInstructions,
simulator_provider: AnyProvider,
max_turns: usize,
) -> Self {
Self {
instructions,
simulator_provider,
max_turns,
}
}
pub async fn drive<F, Fut>(&self, mut agent_turn: F) -> Result<MultiTurnResult, BenchError>
where
F: FnMut(String) -> Fut,
Fut: Future<Output = Result<String, BenchError>>,
{
let mut transcript: Vec<Turn> = Vec::new();
let mut final_response = String::new();
for turn_idx in 0..self.max_turns {
let user_msg = self.generate_user_message(&transcript, turn_idx).await?;
if is_done_signal(&user_msg) {
break;
}
transcript.push(Turn {
role: Role::User,
text: user_msg.clone(),
});
let agent_response = agent_turn(user_msg).await?;
final_response.clone_from(&agent_response);
transcript.push(Turn {
role: Role::Assistant,
text: agent_response.clone(),
});
if agent_response.contains("transfer_to_human_agents") {
break;
}
}
Ok(MultiTurnResult {
transcript,
final_response,
})
}
async fn generate_user_message(
&self,
history: &[Turn],
turn_idx: usize,
) -> Result<String, BenchError> {
let messages = self.build_simulator_messages(history, turn_idx);
match self.call_simulator(&messages).await {
Ok(msg) => Ok(msg),
Err(_) => {
self.call_simulator(&messages).await.map_err(|e| {
BenchError::InvalidFormat(format!("simulator LLM failed after retry: {e}"))
})
}
}
}
async fn call_simulator(&self, messages: &[Message]) -> Result<String, BenchError> {
self.simulator_provider
.chat(messages)
.await
.map_err(|e| BenchError::InvalidFormat(format!("simulator LLM error: {e}")))
}
fn build_simulator_messages(&self, history: &[Turn], turn_idx: usize) -> Vec<Message> {
let i = &self.instructions;
let mut system = format!(
"You are a customer calling customer support for the {} domain.\n\
Reason for call: {}\n\
Task instructions: {}",
i.domain, i.reason_for_call, i.task_instructions
);
if let Some(known) = &i.known_info {
system.push_str("\nInformation you have: ");
system.push_str(known);
}
system.push_str(
"\n\nGenerate the next customer message based on the conversation so far. \
When the task is complete or the agent has resolved your request, output exactly: [DONE]",
);
let mut messages = vec![Message::from_legacy(Role::System, system)];
for turn in history {
messages.push(Message::from_legacy(turn.role, turn.text.clone()));
}
if turn_idx == 0 {
messages.push(Message::from_legacy(
Role::User,
"Generate the opening customer message.",
));
}
messages
}
}
fn is_done_signal(text: &str) -> bool {
let last_line = text
.lines()
.map(str::trim)
.rfind(|l| !l.is_empty())
.unwrap_or("");
last_line.to_ascii_lowercase().contains("[done]")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn done_signal_exact() {
assert!(is_done_signal("[DONE]"));
}
#[test]
fn done_signal_case_insensitive() {
assert!(is_done_signal("[done]"));
assert!(is_done_signal("[Done]"));
}
#[test]
fn done_signal_substring_last_line() {
assert!(is_done_signal("Thank you for your help! [DONE]"));
}
#[test]
fn done_signal_only_in_last_line() {
assert!(is_done_signal("first line\n[DONE]"));
}
#[test]
fn done_signal_not_in_last_line_returns_false() {
assert!(!is_done_signal("[DONE]\nthis comes after"));
}
#[test]
fn done_signal_not_present() {
assert!(!is_done_signal("Hello, I need help with my order."));
}
#[test]
fn done_signal_empty_string() {
assert!(!is_done_signal(""));
}
#[tokio::test]
async fn drive_terminates_on_done_signal() {
use zeph_llm::{any::AnyProvider, mock::MockProvider};
let instructions = StructuredUserInstructions {
domain: "retail".into(),
reason_for_call: "Cancel order".into(),
task_instructions: "Cancel order #W0001".into(),
known_info: Some("Order id: #W0001".into()),
unknown_info: None,
};
let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
"I need to cancel order #W0001".into(),
"[DONE]".into(),
]));
let driver = MultiTurnDriver::new(instructions, provider, 10);
let result = driver
.drive(|_msg| async move { Ok("Order cancelled successfully.".into()) })
.await
.unwrap();
assert_eq!(result.transcript.len(), 2);
assert_eq!(result.final_response, "Order cancelled successfully.");
}
#[tokio::test]
async fn drive_terminates_on_max_turns() {
use zeph_llm::{any::AnyProvider, mock::MockProvider};
let instructions = StructuredUserInstructions {
domain: "retail".into(),
reason_for_call: "Test".into(),
task_instructions: "Do nothing".into(),
known_info: None,
unknown_info: None,
};
let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
"Keep going".into(),
"Keep going".into(),
"Keep going".into(),
]));
let driver = MultiTurnDriver::new(instructions, provider, 2);
let result = driver
.drive(|_msg| async move { Ok("Agent response".into()) })
.await
.unwrap();
assert_eq!(result.transcript.len(), 4);
}
#[tokio::test]
async fn drive_terminates_on_human_transfer() {
use zeph_llm::{any::AnyProvider, mock::MockProvider};
let instructions = StructuredUserInstructions {
domain: "airline".into(),
reason_for_call: "Escalate".into(),
task_instructions: "Escalate to human".into(),
known_info: None,
unknown_info: None,
};
let provider = AnyProvider::Mock(MockProvider::with_responses(vec![
"Please escalate my issue".into(),
]));
let driver = MultiTurnDriver::new(instructions, provider, 10);
let result = driver
.drive(|_msg| async move {
Ok("I will transfer you now. transfer_to_human_agents called.".into())
})
.await
.unwrap();
assert_eq!(result.transcript.len(), 2);
}
}