use async_trait::async_trait;
use futures::StreamExt;
use futures::stream;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::error::{AgentLoopError, Result};
use crate::llm_driver_registry::{
BoxedLlmDriver, DriverRegistry, LlmCallConfig, LlmCompletionMetadata, LlmDriver, LlmMessage,
LlmMessageRole, LlmResponseStream, LlmStreamEvent, ProviderType,
};
use crate::tool_types::ToolCall;
use llmsim::generator::{LoremGenerator, ResponseGenerator};
use llmsim::latency::LatencyProfile;
use llmsim::openai::{ChatCompletionRequest, Message, Role, Usage};
use llmsim::script::auto_tool_call_id;
use llmsim::stream::TokenStreamBuilder;
#[derive(Debug, Clone)]
pub struct LlmSimConfig {
pub response: ResponseConfig,
pub tool_calls: Option<ToolCallConfig>,
pub simulate_latency: bool,
pub model_name: String,
pub response_delay: Option<std::time::Duration>,
pub response_id: Option<String>,
}
impl Default for LlmSimConfig {
fn default() -> Self {
Self {
response: ResponseConfig::Fixed("Hello! I'm a simulated LLM response.".to_string()),
tool_calls: None,
simulate_latency: false,
model_name: "llmsim-model".to_string(),
response_delay: None,
response_id: None,
}
}
}
impl LlmSimConfig {
pub fn fixed(response: impl Into<String>) -> Self {
Self {
response: ResponseConfig::Fixed(response.into()),
..Default::default()
}
}
pub fn echo() -> Self {
Self {
response: ResponseConfig::Echo,
..Default::default()
}
}
pub fn lorem(target_tokens: usize) -> Self {
Self {
response: ResponseConfig::Lorem { target_tokens },
..Default::default()
}
}
pub fn sequence(responses: Vec<String>) -> Self {
Self {
response: ResponseConfig::Sequence(responses),
..Default::default()
}
}
pub fn scripted(turns: Vec<SimTurn>) -> Self {
Self {
response: ResponseConfig::Scripted {
turns,
on_exhausted: OnExhausted::default(),
},
..Default::default()
}
}
pub fn with_on_exhausted(mut self, mode: OnExhausted) -> Self {
if let ResponseConfig::Scripted { on_exhausted, .. } = &mut self.response {
*on_exhausted = mode;
}
self
}
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = Some(ToolCallConfig::Fixed(tool_calls));
self
}
pub fn with_tool_call_sequence(mut self, sequences: Vec<Vec<ToolCall>>) -> Self {
self.tool_calls = Some(ToolCallConfig::Sequence(sequences));
self
}
pub fn with_latency(mut self) -> Self {
self.simulate_latency = true;
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model_name = model.into();
self
}
pub fn with_response_delay(mut self, delay: std::time::Duration) -> Self {
self.response_delay = Some(delay);
self
}
pub fn with_response_id(mut self, id: impl Into<String>) -> Self {
self.response_id = Some(id.into());
self
}
pub fn error(message: impl Into<String>) -> Self {
Self {
response: ResponseConfig::Error(message.into()),
..Default::default()
}
}
pub fn model_not_available() -> Self {
Self {
response: ResponseConfig::ModelNotAvailable,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub enum ResponseConfig {
Fixed(String),
Echo,
Lorem { target_tokens: usize },
Sequence(Vec<String>),
Scripted {
turns: Vec<SimTurn>,
on_exhausted: OnExhausted,
},
Empty,
Error(String),
ModelNotAvailable,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SimTurn {
Assistant(String),
ToolCalls(Vec<SimToolCall>),
Mixed {
text: String,
tool_calls: Vec<SimToolCall>,
},
Error(SimError),
}
#[derive(Debug, Clone, PartialEq)]
pub struct SimToolCall {
pub name: String,
pub arguments: serde_json::Value,
pub id: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SimError {
RateLimit,
Timeout,
InvalidResponse(String),
Other(String),
}
impl SimError {
fn status_code(&self) -> u16 {
match self {
SimError::RateLimit => 429,
SimError::Timeout => 504,
SimError::InvalidResponse(_) => 400,
SimError::Other(_) => 500,
}
}
fn message(&self) -> String {
match self {
SimError::RateLimit => "Rate limit exceeded. Please retry after some time.".to_string(),
SimError::Timeout => "Request timed out".to_string(),
SimError::InvalidResponse(message) | SimError::Other(message) => message.clone(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum OnExhausted {
#[default]
RepeatLast,
Error,
Loop,
}
#[derive(Debug, Clone)]
pub enum ToolCallConfig {
Fixed(Vec<ToolCall>),
Sequence(Vec<Vec<ToolCall>>),
Conditional {
patterns: Vec<ToolCallPattern>,
},
}
#[derive(Debug, Clone)]
pub struct ToolCallPattern {
pub contains: String,
pub tool_calls: Vec<ToolCall>,
}
impl ToolCallPattern {
pub fn new(contains: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
Self {
contains: contains.into(),
tool_calls,
}
}
}
fn materialize_scripted_tool_calls(
turn_index: usize,
calls: Vec<SimToolCall>,
) -> Option<Vec<ToolCall>> {
if calls.is_empty() {
return None;
}
Some(
calls
.into_iter()
.enumerate()
.map(|(call_index, call)| ToolCall {
id: call
.id
.unwrap_or_else(|| auto_tool_call_id(turn_index, call_index)),
name: call.name,
arguments: call.arguments,
})
.collect(),
)
}
#[derive(Clone)]
pub struct LlmSimDriver {
config: LlmSimConfig,
response_counter: Arc<AtomicUsize>,
tool_call_counter: Arc<AtomicUsize>,
}
struct GeneratedTurn {
text: String,
tool_calls: Option<Vec<ToolCall>>,
}
impl LlmSimDriver {
pub fn new(config: LlmSimConfig) -> Self {
Self {
config,
response_counter: Arc::new(AtomicUsize::new(0)),
tool_call_counter: Arc::new(AtomicUsize::new(0)),
}
}
pub fn default_driver() -> Self {
Self::new(LlmSimConfig::default())
}
fn generate_response(&self, messages: &[LlmMessage]) -> String {
match &self.config.response {
ResponseConfig::Fixed(text) => text.clone(),
ResponseConfig::Echo => {
let last_user = messages
.iter()
.rev()
.find(|m| m.role == LlmMessageRole::User)
.map(|m| m.content_as_text())
.unwrap_or_default();
format!("Echo: {}", last_user)
}
ResponseConfig::Lorem { target_tokens } => {
let generator = LoremGenerator::new(*target_tokens);
let request = self.to_chat_request(messages);
generator.generate(&request)
}
ResponseConfig::Sequence(responses) => {
if responses.is_empty() {
return String::new();
}
let idx = self.response_counter.fetch_add(1, Ordering::SeqCst);
responses[idx % responses.len()].clone()
}
ResponseConfig::Empty => String::new(),
ResponseConfig::Error(_)
| ResponseConfig::ModelNotAvailable
| ResponseConfig::Scripted { .. } => {
unreachable!("Special configs handled in chat_completion_stream")
}
}
}
fn get_tool_calls(&self, messages: &[LlmMessage]) -> Option<Vec<ToolCall>> {
match &self.config.tool_calls {
None => None,
Some(ToolCallConfig::Fixed(calls)) => {
if calls.is_empty() {
None
} else {
Some(calls.clone())
}
}
Some(ToolCallConfig::Sequence(sequences)) => {
if sequences.is_empty() {
return None;
}
let idx = self.tool_call_counter.fetch_add(1, Ordering::SeqCst);
let calls = &sequences[idx % sequences.len()];
if calls.is_empty() {
None
} else {
Some(calls.clone())
}
}
Some(ToolCallConfig::Conditional { patterns }) => {
let last_user = messages
.iter()
.rev()
.find(|m| m.role == LlmMessageRole::User)
.map(|m| m.content_as_text())
.unwrap_or_default();
for pattern in patterns {
if last_user.contains(&pattern.contains) {
return if pattern.tool_calls.is_empty() {
None
} else {
Some(pattern.tool_calls.clone())
};
}
}
None
}
}
}
fn generate_turn(&self, messages: &[LlmMessage]) -> Result<GeneratedTurn> {
if let ResponseConfig::Scripted {
turns,
on_exhausted,
} = &self.config.response
{
return self.generate_scripted_turn(turns, *on_exhausted);
}
Ok(GeneratedTurn {
text: self.generate_response(messages),
tool_calls: self.get_tool_calls(messages),
})
}
fn generate_scripted_turn(
&self,
turns: &[SimTurn],
on_exhausted: OnExhausted,
) -> Result<GeneratedTurn> {
if turns.is_empty() {
return Err(AgentLoopError::config(
"llmsim scripted config must contain at least one turn",
));
}
let turn_index = self.response_counter.fetch_add(1, Ordering::SeqCst);
let turn = if turn_index < turns.len() {
turns[turn_index].clone()
} else {
match on_exhausted {
OnExhausted::RepeatLast => turns[turns.len() - 1].clone(),
OnExhausted::Loop => turns[turn_index % turns.len()].clone(),
OnExhausted::Error => {
return Err(AgentLoopError::config("llmsim scripted config exhausted"));
}
}
};
match turn {
SimTurn::Assistant(text) => Ok(GeneratedTurn {
text,
tool_calls: None,
}),
SimTurn::ToolCalls(calls) => Ok(GeneratedTurn {
text: String::new(),
tool_calls: materialize_scripted_tool_calls(turn_index, calls),
}),
SimTurn::Mixed { text, tool_calls } => Ok(GeneratedTurn {
text,
tool_calls: materialize_scripted_tool_calls(turn_index, tool_calls),
}),
SimTurn::Error(error) => Err(AgentLoopError::llm(format!(
"LlmSim scripted error ({}): {}",
error.status_code(),
error.message()
))),
}
}
fn to_chat_request(&self, messages: &[LlmMessage]) -> ChatCompletionRequest {
let sim_messages: Vec<Message> = messages
.iter()
.map(|m| {
let role = match m.role {
LlmMessageRole::System => Role::System,
LlmMessageRole::User => Role::User,
LlmMessageRole::Assistant => Role::Assistant,
LlmMessageRole::Tool => Role::Tool,
};
Message {
role,
content: Some(m.content_as_text()),
name: None,
tool_calls: None,
tool_call_id: m.tool_call_id.clone(),
}
})
.collect();
ChatCompletionRequest {
model: self.config.model_name.clone(),
messages: sim_messages,
temperature: None,
top_p: None,
n: None,
max_tokens: None,
max_completion_tokens: None,
stream: true,
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
tools: None,
tool_choice: None,
seed: None,
response_format: None,
}
}
fn resolve_latency_profile(&self, model_name: &str) -> LatencyProfile {
if self.config.simulate_latency || model_name.contains("-latency") {
LatencyProfile::fast()
} else {
LatencyProfile::instant()
}
}
fn estimate_tokens(text: &str) -> u32 {
(text.len() / 4).max(1) as u32
}
}
#[async_trait]
impl LlmDriver for LlmSimDriver {
async fn chat_completion_stream(
&self,
messages: Vec<LlmMessage>,
config: &LlmCallConfig,
) -> Result<LlmResponseStream> {
if let ResponseConfig::Error(error_msg) = &self.config.response {
return Err(anyhow::anyhow!("LLM error: {}", error_msg).into());
}
if matches!(self.config.response, ResponseConfig::ModelNotAvailable) {
return Err(AgentLoopError::model_not_available(config.model.clone()));
}
let delay = self
.config
.response_delay
.or_else(|| parse_ttft_from_model_name(&config.model));
if let Some(delay) = delay {
tokio::time::sleep(delay).await;
}
let generated_turn = self.generate_turn(&messages)?;
let response_text = generated_turn.text;
let tool_calls = generated_turn.tool_calls;
let model_name = config.model.clone();
let response_id_for_done = self.config.response_id.clone();
let latency_profile = self.resolve_latency_profile(&model_name);
let prompt_tokens: u32 = messages
.iter()
.map(|m| Self::estimate_tokens(&m.content_as_text()))
.sum();
let completion_tokens = Self::estimate_tokens(&response_text);
let usage = Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
};
let chunk_stream = TokenStreamBuilder::new(&model_name, &response_text)
.latency(latency_profile)
.usage(usage)
.build()
.into_chunk_stream();
let tool_calls_tail = tool_calls;
let model_name_done = model_name.clone();
let event_stream = chunk_stream.flat_map(move |chunk| {
let mut events: Vec<Result<LlmStreamEvent>> = Vec::new();
for choice in &chunk.choices {
if let Some(content) = &choice.delta.content
&& !content.is_empty()
{
events.push(Ok(LlmStreamEvent::TextDelta(content.clone())));
}
}
stream::iter(events)
});
let done_events: Vec<Result<LlmStreamEvent>> = {
let mut tail = Vec::new();
if let Some(calls) = tool_calls_tail {
tail.push(Ok(LlmStreamEvent::ToolCalls(calls)));
}
tail.push(Ok(LlmStreamEvent::Done(Box::new(LlmCompletionMetadata {
total_tokens: Some(prompt_tokens + completion_tokens),
prompt_tokens: Some(prompt_tokens),
completion_tokens: Some(completion_tokens),
cache_read_tokens: None,
cache_creation_tokens: None,
provider_cost_usd: None,
model: Some(model_name_done),
finish_reason: Some("stop".to_string()),
retry_metadata: None,
response_id: response_id_for_done,
phase: None,
}))));
tail
};
let full_stream = event_stream.chain(stream::iter(done_events));
Ok(Box::pin(full_stream))
}
}
impl std::fmt::Debug for LlmSimDriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmSimDriver")
.field("model", &self.config.model_name)
.field("simulate_latency", &self.config.simulate_latency)
.finish()
}
}
pub fn register_driver(registry: &mut DriverRegistry) {
registry.register(ProviderType::LlmSim, |_api_key, _base_url| {
Box::new(LlmSimDriver::default_driver()) as BoxedLlmDriver
});
}
pub fn register_driver_with_config(registry: &mut DriverRegistry, config: LlmSimConfig) {
let driver = LlmSimDriver::new(config);
registry.register(ProviderType::LlmSim, move |_api_key, _base_url| {
Box::new(driver.clone()) as BoxedLlmDriver
});
}
fn parse_ttft_from_model_name(model_name: &str) -> Option<std::time::Duration> {
if let Some(idx) = model_name.find("-ttft-") {
let after_ttft = &model_name[idx + 6..]; let ms_str: String = after_ttft
.chars()
.take_while(|c| c.is_ascii_digit())
.collect();
if let Ok(ms) = ms_str.parse::<u64>()
&& ms > 0
{
return Some(std::time::Duration::from_millis(ms));
}
}
None
}
pub fn create_driver(config: LlmSimConfig) -> BoxedLlmDriver {
Box::new(LlmSimDriver::new(config))
}
pub fn auditor_demo_script() -> LlmSimConfig {
let turns = vec![
SimTurn::Mixed {
text: "Starting the audit. Listing EC2 instances first.".to_string(),
tool_calls: vec![SimToolCall {
name: "aws_list_ec2_instances".to_string(),
arguments: serde_json::json!({}),
id: Some("call_demo_ec2".to_string()),
}],
},
SimTurn::Mixed {
text: "EC2 inventory captured. Listing S3 buckets next.".to_string(),
tool_calls: vec![SimToolCall {
name: "aws_list_s3_buckets".to_string(),
arguments: serde_json::json!({}),
id: Some("call_demo_s3".to_string()),
}],
},
SimTurn::Assistant(
"Audit complete: inventoried EC2 instances and S3 buckets. \
See /workspace/.audit.log for the per-tool-call audit trail \
written by the post_tool_use hook bundle."
.to_string(),
),
];
LlmSimConfig::scripted(turns)
}
pub fn guarded_bash_demo_script() -> LlmSimConfig {
let turns = vec![
SimTurn::Mixed {
text: "Step 1: attempting a destructive command.".to_string(),
tool_calls: vec![SimToolCall {
name: "bash".to_string(),
arguments: serde_json::json!({ "commands": "rm -rf /" }),
id: Some("call_demo_rm".to_string()),
}],
},
SimTurn::Mixed {
text: "Step 2: trying a safe command.".to_string(),
tool_calls: vec![SimToolCall {
name: "bash".to_string(),
arguments: serde_json::json!({ "commands": "ls -la /workspace" }),
id: Some("call_demo_ls".to_string()),
}],
},
SimTurn::Assistant(
"Guarded-bash demo complete. The first tool call should be \
blocked by the pre_tool_use hook; the second should succeed."
.to_string(),
),
];
LlmSimConfig::scripted(turns)
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[test]
fn auditor_demo_script_calls_ec2_then_s3_then_summarises() {
let config = auditor_demo_script();
let turns = match &config.response {
ResponseConfig::Scripted { turns, .. } => turns,
other => panic!("expected Scripted, got {other:?}"),
};
assert_eq!(turns.len(), 3, "script has three turns");
match &turns[0] {
SimTurn::Mixed { tool_calls, .. } => {
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].name, "aws_list_ec2_instances");
}
other => panic!("turn 0 should be Mixed, got {other:?}"),
}
match &turns[1] {
SimTurn::Mixed { tool_calls, .. } => {
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].name, "aws_list_s3_buckets");
}
other => panic!("turn 1 should be Mixed, got {other:?}"),
}
match &turns[2] {
SimTurn::Assistant(text) => {
assert!(
text.contains("/workspace/.audit.log"),
"summary mentions the audit log: {text:?}"
);
}
other => panic!("turn 2 should be Assistant, got {other:?}"),
}
}
fn make_config() -> LlmCallConfig {
LlmCallConfig {
model: "test-model".to_string(),
temperature: None,
max_tokens: None,
tools: vec![],
reasoning_effort: None,
metadata: std::collections::HashMap::new(),
previous_response_id: None,
tool_search: None,
prompt_cache: None,
}
}
fn user_message(content: &str) -> LlmMessage {
LlmMessage::text(LlmMessageRole::User, content)
}
fn system_message(content: &str) -> LlmMessage {
LlmMessage::text(LlmMessageRole::System, content)
}
#[tokio::test]
async fn test_fixed_response() {
let driver = LlmSimDriver::new(LlmSimConfig::fixed("Hello, world!"));
let messages = vec![user_message("Hi there")];
let response = driver
.chat_completion(messages, &make_config())
.await
.unwrap();
assert_eq!(response.text, "Hello, world!");
assert!(response.tool_calls.is_none());
}
#[tokio::test]
async fn test_echo_response() {
let driver = LlmSimDriver::new(LlmSimConfig::echo());
let messages = vec![
system_message("You are a helpful assistant"),
user_message("What is 2+2?"),
];
let response = driver
.chat_completion(messages, &make_config())
.await
.unwrap();
assert_eq!(response.text, "Echo: What is 2+2?");
}
#[tokio::test]
async fn test_sequence_response() {
let driver = LlmSimDriver::new(LlmSimConfig::sequence(vec![
"First".to_string(),
"Second".to_string(),
"Third".to_string(),
]));
let messages = vec![user_message("test")];
let r1 = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
assert_eq!(r1.text, "First");
let r2 = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
assert_eq!(r2.text, "Second");
let r3 = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
assert_eq!(r3.text, "Third");
let r4 = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
assert_eq!(r4.text, "First");
}
#[tokio::test]
async fn test_lorem_response() {
let driver = LlmSimDriver::new(LlmSimConfig::lorem(50));
let messages = vec![user_message("Generate text")];
let response = driver
.chat_completion(messages, &make_config())
.await
.unwrap();
assert!(!response.text.is_empty());
assert!(response.text.split_whitespace().count() > 5);
}
#[tokio::test]
async fn test_fixed_tool_calls() {
let tool_call = ToolCall {
id: "call_123".to_string(),
name: "get_weather".to_string(),
arguments: serde_json::json!({"city": "NYC"}),
};
let driver = LlmSimDriver::new(
LlmSimConfig::fixed("Let me check the weather.")
.with_tool_calls(vec![tool_call.clone()]),
);
let messages = vec![user_message("What's the weather?")];
let response = driver
.chat_completion(messages, &make_config())
.await
.unwrap();
assert_eq!(response.text, "Let me check the weather.");
let calls = response.tool_calls.expect("Expected tool calls");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "get_weather");
assert_eq!(calls[0].id, "call_123");
}
#[tokio::test]
async fn test_tool_call_sequence() {
let call1 = ToolCall {
id: "call_1".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"q": "rust"}),
};
let call2 = ToolCall {
id: "call_2".to_string(),
name: "fetch".to_string(),
arguments: serde_json::json!({"url": "https://example.com"}),
};
let driver = LlmSimDriver::new(
LlmSimConfig::fixed("Processing...").with_tool_call_sequence(vec![
vec![call1.clone()],
vec![call2.clone()],
vec![],
]),
);
let messages = vec![user_message("test")];
let r1 = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
let calls1 = r1.tool_calls.expect("Expected tool calls");
assert_eq!(calls1[0].name, "search");
let r2 = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
let calls2 = r2.tool_calls.expect("Expected tool calls");
assert_eq!(calls2[0].name, "fetch");
let r3 = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
assert!(r3.tool_calls.is_none());
}
#[tokio::test]
async fn test_scripted_multi_turn_tool_call_agent_sequence() {
let driver = LlmSimDriver::new(
LlmSimConfig::scripted(vec![
SimTurn::ToolCalls(vec![SimToolCall {
name: "bash".to_string(),
arguments: serde_json::json!({"command": "echo hello > /tmp/x.txt"}),
id: None,
}]),
SimTurn::ToolCalls(vec![SimToolCall {
name: "bash".to_string(),
arguments: serde_json::json!({"command": "sed -i s/hello/world/ /tmp/x.txt"}),
id: None,
}]),
SimTurn::Assistant("done".to_string()),
])
.with_on_exhausted(OnExhausted::Error),
);
let messages = vec![user_message("create /tmp/x.txt then change hello to world")];
let first = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
let first_calls = first.tool_calls.expect("first turn should call bash");
assert_eq!(first.text, "");
assert_eq!(first_calls[0].name, "bash");
assert_eq!(first_calls[0].id, "call_llmsim_0_0");
let second = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
let second_calls = second.tool_calls.expect("second turn should call bash");
assert_eq!(second_calls[0].name, "bash");
assert_eq!(second_calls[0].id, "call_llmsim_1_0");
let final_response = driver
.chat_completion(messages.clone(), &make_config())
.await
.unwrap();
assert_eq!(final_response.text, "done");
assert!(final_response.tool_calls.is_none());
let exhausted = driver
.chat_completion(messages, &make_config())
.await
.unwrap_err();
assert!(matches!(exhausted, AgentLoopError::Configuration(_)));
}
#[tokio::test]
async fn test_scripted_mixed_turn_streams_text_and_tool_calls() {
let driver = LlmSimDriver::new(LlmSimConfig::scripted(vec![SimTurn::Mixed {
text: "Let me check".to_string(),
tool_calls: vec![SimToolCall {
name: "search".to_string(),
arguments: serde_json::json!({"q": "rust"}),
id: Some("call_search".to_string()),
}],
}]));
let mut stream = driver
.chat_completion_stream(vec![user_message("find rust")], &make_config())
.await
.unwrap();
let mut text_parts = Vec::new();
let mut tool_calls = None;
while let Some(event) = stream.next().await {
match event.unwrap() {
LlmStreamEvent::TextDelta(text) => text_parts.push(text),
LlmStreamEvent::ToolCalls(calls) => tool_calls = Some(calls),
LlmStreamEvent::Done(_) => {}
_ => {}
}
}
assert!(!text_parts.is_empty(), "scripted text should stream");
assert_eq!(text_parts.join(""), "Let me check");
let calls = tool_calls.expect("mixed turn should emit tool calls");
assert_eq!(calls[0].id, "call_search");
assert_eq!(calls[0].name, "search");
}
#[tokio::test]
async fn test_scripted_on_exhausted_modes() {
let repeat = LlmSimDriver::new(LlmSimConfig::scripted(vec![
SimTurn::Assistant("one".to_string()),
SimTurn::Assistant("two".to_string()),
]));
let messages = vec![user_message("test")];
assert_eq!(
repeat
.chat_completion(messages.clone(), &make_config())
.await
.unwrap()
.text,
"one"
);
assert_eq!(
repeat
.chat_completion(messages.clone(), &make_config())
.await
.unwrap()
.text,
"two"
);
assert_eq!(
repeat
.chat_completion(messages.clone(), &make_config())
.await
.unwrap()
.text,
"two"
);
let looping = LlmSimDriver::new(
LlmSimConfig::scripted(vec![
SimTurn::Assistant("a".to_string()),
SimTurn::Assistant("b".to_string()),
])
.with_on_exhausted(OnExhausted::Loop),
);
assert_eq!(
looping
.chat_completion(messages.clone(), &make_config())
.await
.unwrap()
.text,
"a"
);
assert_eq!(
looping
.chat_completion(messages.clone(), &make_config())
.await
.unwrap()
.text,
"b"
);
assert_eq!(
looping
.chat_completion(messages, &make_config())
.await
.unwrap()
.text,
"a"
);
}
#[tokio::test]
async fn test_scripted_error_turn() {
let driver = LlmSimDriver::new(LlmSimConfig::scripted(vec![SimTurn::Error(
SimError::RateLimit,
)]));
let err = driver
.chat_completion(vec![user_message("test")], &make_config())
.await
.unwrap_err();
assert!(err.is_rate_limited());
}
#[tokio::test]
async fn test_conditional_tool_calls() {
let weather_call = ToolCall {
id: "call_w".to_string(),
name: "get_weather".to_string(),
arguments: serde_json::json!({}),
};
let search_call = ToolCall {
id: "call_s".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({}),
};
let config = LlmSimConfig {
response: ResponseConfig::Fixed("Response".to_string()),
tool_calls: Some(ToolCallConfig::Conditional {
patterns: vec![
ToolCallPattern::new("weather", vec![weather_call]),
ToolCallPattern::new("search", vec![search_call]),
],
}),
simulate_latency: false,
model_name: "test".to_string(),
response_delay: None,
response_id: None,
};
let driver = LlmSimDriver::new(config);
let r1 = driver
.chat_completion(vec![user_message("What's the weather?")], &make_config())
.await
.unwrap();
let calls1 = r1.tool_calls.expect("Expected weather tool");
assert_eq!(calls1[0].name, "get_weather");
let r2 = driver
.chat_completion(vec![user_message("search for rust")], &make_config())
.await
.unwrap();
let calls2 = r2.tool_calls.expect("Expected search tool");
assert_eq!(calls2[0].name, "search");
let r3 = driver
.chat_completion(vec![user_message("hello world")], &make_config())
.await
.unwrap();
assert!(r3.tool_calls.is_none());
}
#[tokio::test]
async fn test_streaming() {
let driver = LlmSimDriver::new(LlmSimConfig::fixed("Hello world test"));
let messages = vec![user_message("test")];
let mut stream = driver
.chat_completion_stream(messages, &make_config())
.await
.unwrap();
let mut text_parts = Vec::new();
let mut got_done = false;
while let Some(event) = stream.next().await {
match event.unwrap() {
LlmStreamEvent::TextDelta(text) => text_parts.push(text),
LlmStreamEvent::Done(meta) => {
got_done = true;
assert!(meta.total_tokens.is_some());
assert!(meta.model.is_some());
}
_ => {}
}
}
assert!(got_done);
assert!(!text_parts.is_empty());
assert_eq!(text_parts.join(""), "Hello world test");
}
#[tokio::test]
async fn test_metadata() {
let driver = LlmSimDriver::new(LlmSimConfig::fixed("Hi").with_model("custom-model"));
let messages = vec![user_message("test")];
let mut config = make_config();
config.model = "request-model".to_string();
let response = driver.chat_completion(messages, &config).await.unwrap();
assert_eq!(response.metadata.model, Some("request-model".to_string()));
assert!(response.metadata.prompt_tokens.is_some());
assert!(response.metadata.completion_tokens.is_some());
}
#[tokio::test]
async fn test_register_driver() {
let mut registry = DriverRegistry::new();
register_driver(&mut registry);
assert!(registry.has_driver(&ProviderType::LlmSim));
let config = crate::llm_driver_registry::ProviderConfig::new(ProviderType::LlmSim)
.with_api_key("fake-key");
let driver = registry.create_driver(&config);
assert!(driver.is_ok());
}
#[tokio::test]
async fn test_empty_response() {
let config = LlmSimConfig {
response: ResponseConfig::Empty,
tool_calls: None,
simulate_latency: false,
model_name: "test".to_string(),
response_delay: None,
response_id: None,
};
let driver = LlmSimDriver::new(config);
let messages = vec![user_message("test")];
let response = driver
.chat_completion(messages, &make_config())
.await
.unwrap();
assert!(response.text.is_empty());
}
#[test]
fn test_driver_debug() {
let driver = LlmSimDriver::new(LlmSimConfig::fixed("test").with_latency());
let debug = format!("{:?}", driver);
assert!(debug.contains("LlmSimDriver"));
assert!(debug.contains("simulate_latency"));
}
#[test]
fn test_default_config() {
let config = LlmSimConfig::default();
assert!(matches!(config.response, ResponseConfig::Fixed(_)));
assert!(config.tool_calls.is_none());
assert!(!config.simulate_latency);
}
#[test]
fn test_config_builder() {
let tool_call = ToolCall {
id: "call_1".to_string(),
name: "get_weather".to_string(),
arguments: serde_json::json!({"city": "NYC"}),
};
let config = LlmSimConfig::fixed("Result")
.with_tool_calls(vec![tool_call.clone()])
.with_latency()
.with_model("gpt-4")
.with_response_delay(std::time::Duration::from_secs(2));
assert!(config.tool_calls.is_some());
assert!(config.simulate_latency);
assert_eq!(config.model_name, "gpt-4");
assert_eq!(
config.response_delay,
Some(std::time::Duration::from_secs(2))
);
}
#[test]
fn test_parse_ttft_from_model_name() {
use super::parse_ttft_from_model_name;
assert_eq!(
parse_ttft_from_model_name("llmsim-ttft-2000"),
Some(std::time::Duration::from_millis(2000))
);
assert_eq!(
parse_ttft_from_model_name("test-ttft-500-extra"),
Some(std::time::Duration::from_millis(500))
);
assert_eq!(parse_ttft_from_model_name("llmsim-model"), None);
assert_eq!(parse_ttft_from_model_name("llmsim-ttft-0"), None);
assert_eq!(parse_ttft_from_model_name("llmsim-ttft-abc"), None);
}
#[test]
fn test_resolve_latency_profile_from_model_name() {
let driver = LlmSimDriver::new(LlmSimConfig::fixed("test"));
let profile = driver.resolve_latency_profile("llmsim-latency");
assert!(profile.sample_ttft().as_nanos() > 0);
let profile = driver.resolve_latency_profile("llmsim-default");
assert_eq!(profile.sample_ttft().as_nanos(), 0);
let driver = LlmSimDriver::new(LlmSimConfig::fixed("test").with_latency());
let profile = driver.resolve_latency_profile("llmsim-default");
assert!(profile.sample_ttft().as_nanos() > 0);
}
#[tokio::test]
async fn test_latency_streaming_from_model_name() {
let driver = LlmSimDriver::new(LlmSimConfig::fixed("Hello world"));
let messages = vec![user_message("test")];
let mut config = make_config();
config.model = "llmsim-latency".to_string();
let start = std::time::Instant::now();
let mut stream = driver
.chat_completion_stream(messages, &config)
.await
.unwrap();
let mut text_parts = Vec::new();
let mut got_done = false;
while let Some(event) = stream.next().await {
match event.unwrap() {
LlmStreamEvent::TextDelta(text) => text_parts.push(text),
LlmStreamEvent::Done(meta) => {
got_done = true;
assert_eq!(meta.model, Some("llmsim-latency".to_string()));
}
_ => {}
}
}
assert!(got_done);
assert_eq!(text_parts.join(""), "Hello world");
assert!(
start.elapsed().as_millis() > 0,
"latency simulation should introduce delays"
);
}
#[tokio::test]
async fn test_no_latency_streaming_is_instant() {
let driver = LlmSimDriver::new(LlmSimConfig::fixed("Hello world"));
let messages = vec![user_message("test")];
let mut config = make_config();
config.model = "llmsim-default".to_string();
let start = std::time::Instant::now();
let response = driver.chat_completion(messages, &config).await.unwrap();
let elapsed = start.elapsed();
assert_eq!(response.text, "Hello world");
assert!(
elapsed.as_millis() < 50,
"instant mode should have no delays, took {}ms",
elapsed.as_millis()
);
}
}