use std::cell::RefCell;
use std::collections::BTreeSet;
use super::api::LlmResult;
use crate::orchestration::ToolCallRecord;
use crate::value::{ErrorCategory, VmError};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LlmReplayMode {
Off,
Record,
Replay,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolRecordingMode {
Off,
Record,
Replay,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CliLlmMockMode {
Off,
Replay,
Record,
}
#[derive(Clone)]
pub struct MockError {
pub category: ErrorCategory,
pub message: String,
pub retry_after_ms: Option<u64>,
}
#[derive(Clone)]
pub struct LlmMock {
pub text: String,
pub tool_calls: Vec<serde_json::Value>,
pub match_pattern: Option<String>, pub consume_on_match: bool,
pub input_tokens: Option<i64>,
pub output_tokens: Option<i64>,
pub cache_read_tokens: Option<i64>,
pub cache_write_tokens: Option<i64>,
pub thinking: Option<String>,
pub stop_reason: Option<String>,
pub model: String,
pub provider: Option<String>,
pub blocks: Option<Vec<serde_json::Value>>,
pub error: Option<MockError>,
}
#[derive(Clone)]
pub(crate) struct LlmMockCall {
pub messages: Vec<serde_json::Value>,
pub system: Option<String>,
pub tools: Option<Vec<serde_json::Value>>,
pub thinking: serde_json::Value,
}
type LlmMockScope = (Vec<LlmMock>, Vec<LlmMockCall>, BTreeSet<String>);
thread_local! {
static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
static CLI_LLM_MOCK_MODE: RefCell<CliLlmMockMode> = const { RefCell::new(CliLlmMockMode::Off) };
static CLI_LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
static CLI_LLM_RECORDINGS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
static LLM_PROMPT_CACHE: RefCell<BTreeSet<String>> = const { RefCell::new(BTreeSet::new()) };
static LLM_MOCK_SCOPES: RefCell<Vec<LlmMockScope>> = const { RefCell::new(Vec::new()) };
}
pub(crate) fn push_llm_mock(mock: LlmMock) {
LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
}
pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
LLM_MOCK_CALLS.with(|v| v.borrow().clone())
}
pub(crate) fn reset_llm_mock_state() {
LLM_MOCKS.with(|v| v.borrow_mut().clear());
CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
LLM_PROMPT_CACHE.with(|v| v.borrow_mut().clear());
LLM_MOCK_SCOPES.with(|v| v.borrow_mut().clear());
}
pub(crate) fn push_llm_mock_scope() {
let mocks = LLM_MOCKS.with(|v| std::mem::take(&mut *v.borrow_mut()));
let calls = LLM_MOCK_CALLS.with(|v| std::mem::take(&mut *v.borrow_mut()));
let cache = LLM_PROMPT_CACHE.with(|v| std::mem::take(&mut *v.borrow_mut()));
LLM_MOCK_SCOPES.with(|v| v.borrow_mut().push((mocks, calls, cache)));
}
pub(crate) fn pop_llm_mock_scope() -> bool {
let entry = LLM_MOCK_SCOPES.with(|v| v.borrow_mut().pop());
match entry {
Some((mocks, calls, cache)) => {
LLM_MOCKS.with(|v| *v.borrow_mut() = mocks);
LLM_MOCK_CALLS.with(|v| *v.borrow_mut() = calls);
LLM_PROMPT_CACHE.with(|v| *v.borrow_mut() = cache);
true
}
None => false,
}
}
pub fn clear_cli_llm_mock_mode() {
CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
}
pub fn install_cli_llm_mocks(mocks: Vec<LlmMock>) {
CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Replay);
CLI_LLM_MOCKS.with(|v| *v.borrow_mut() = mocks);
CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
}
pub fn enable_cli_llm_mock_recording() {
CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Record);
CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
}
pub fn take_cli_llm_recordings() -> Vec<LlmMock> {
CLI_LLM_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
}
pub(crate) fn cli_llm_mock_replay_active() -> bool {
CLI_LLM_MOCK_MODE.with(|v| *v.borrow() == CliLlmMockMode::Replay)
}
fn record_llm_mock_call(
messages: &[serde_json::Value],
system: Option<&str>,
native_tools: Option<&[serde_json::Value]>,
thinking: &super::api::ThinkingConfig,
) {
LLM_MOCK_CALLS.with(|v| {
v.borrow_mut().push(LlmMockCall {
messages: messages.to_vec(),
system: system.map(|s| s.to_string()),
tools: native_tools.map(|t| t.to_vec()),
thinking: serde_json::to_value(thinking).unwrap_or_else(|_| {
serde_json::json!({
"mode": "disabled"
})
}),
});
});
}
fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
let (tool_calls, blocks) = if let Some(blocks) = &mock.blocks {
(mock.tool_calls.clone(), blocks.clone())
} else {
let mut blocks = Vec::new();
if !mock.text.is_empty() {
blocks.push(serde_json::json!({
"type": "output_text",
"text": mock.text,
"visibility": "public",
}));
}
let mut tool_calls = Vec::new();
for (i, tc) in mock.tool_calls.iter().enumerate() {
let id = format!("mock_call_{}", i + 1);
let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
let arguments = tc
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
tool_calls.push(serde_json::json!({
"id": id,
"type": "tool_call",
"name": name,
"arguments": arguments,
}));
blocks.push(serde_json::json!({
"type": "tool_call",
"id": id,
"name": name,
"arguments": arguments,
"visibility": "internal",
}));
}
(tool_calls, blocks)
};
LlmResult {
text: mock.text.clone(),
tool_calls,
input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
output_tokens: mock.output_tokens.unwrap_or(30),
cache_read_tokens: mock.cache_read_tokens.unwrap_or(0),
cache_write_tokens: mock.cache_write_tokens.unwrap_or(0),
model: mock.model.clone(),
provider: mock.provider.clone().unwrap_or_else(|| "mock".to_string()),
thinking: mock.thinking.clone(),
stop_reason: mock.stop_reason.clone(),
blocks,
}
}
fn mock_glob_match(pattern: &str, text: &str) -> bool {
if pattern == "*" {
return true;
}
if !pattern.contains('*') {
return pattern == text;
}
let parts: Vec<&str> = pattern.split('*').collect();
let mut remaining = text;
for (i, part) in parts.iter().enumerate() {
if part.is_empty() {
continue;
}
if i == 0 {
if !remaining.starts_with(part) {
return false;
}
remaining = &remaining[part.len()..];
} else if i == parts.len() - 1 {
if !remaining.ends_with(part) {
return false;
}
remaining = "";
} else {
match remaining.find(part) {
Some(pos) => remaining = &remaining[pos + part.len()..],
None => return false,
}
}
}
true
}
fn collect_mock_match_strings(value: &serde_json::Value, out: &mut Vec<String>) {
match value {
serde_json::Value::String(text) if !text.is_empty() => out.push(text.clone()),
serde_json::Value::String(_) => {}
serde_json::Value::Array(items) => {
for item in items {
collect_mock_match_strings(item, out);
}
}
serde_json::Value::Object(map) => {
for value in map.values() {
collect_mock_match_strings(value, out);
}
}
_ => {}
}
}
fn mock_match_text(messages: &[serde_json::Value]) -> String {
let mut parts = Vec::new();
for message in messages {
collect_mock_match_strings(message, &mut parts);
}
parts.join("\n")
}
fn mock_last_prompt_text(messages: &[serde_json::Value]) -> String {
for message in messages.iter().rev() {
let Some(content) = message.get("content") else {
continue;
};
let mut parts = Vec::new();
collect_mock_match_strings(content, &mut parts);
let text = parts.join("\n");
if !text.trim().is_empty() {
return text;
}
}
String::new()
}
fn mock_prompt_cache_key(
model: &str,
messages: &[serde_json::Value],
system: Option<&str>,
) -> String {
serde_json::to_string(&serde_json::json!({
"model": model,
"system": system,
"messages": messages,
}))
.unwrap_or_default()
}
fn apply_mock_prompt_cache(result: &mut LlmResult, cache_key: &str) {
if result.cache_read_tokens > 0 || result.cache_write_tokens > 0 {
return;
}
let cache_tokens = result.input_tokens.max(0);
if cache_tokens == 0 {
return;
}
let cache_hit = LLM_PROMPT_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
if cache.contains(cache_key) {
true
} else {
cache.insert(cache_key.to_string());
false
}
});
if cache_hit {
result.cache_read_tokens = cache_tokens;
} else {
result.cache_write_tokens = cache_tokens;
}
}
fn mock_error_to_vm_error(err: &MockError) -> VmError {
let message = match err.retry_after_ms {
Some(ms) => {
let secs = (ms as f64 / 1000.0).max(0.0);
let sep = if err.message.is_empty() || err.message.ends_with('\n') {
""
} else {
"\n"
};
format!("{}{sep}retry-after: {secs}\n", err.message)
}
None => err.message.clone(),
};
VmError::CategorizedError {
message,
category: err.category.clone(),
}
}
fn try_match_mock_queue(
mocks: &mut Vec<LlmMock>,
match_text: &str,
) -> Option<Result<LlmResult, VmError>> {
if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
let mock = mocks.remove(idx);
return Some(match &mock.error {
Some(err) => Err(mock_error_to_vm_error(err)),
None => Ok(build_mock_result(&mock, match_text.len())),
});
}
for idx in 0..mocks.len() {
let mock = &mocks[idx];
if let Some(ref pattern) = mock.match_pattern {
if mock_glob_match(pattern, match_text) {
if mock.consume_on_match {
let mock = mocks.remove(idx);
return Some(match &mock.error {
Some(err) => Err(mock_error_to_vm_error(err)),
None => Ok(build_mock_result(&mock, match_text.len())),
});
}
return Some(match &mock.error {
Some(err) => Err(mock_error_to_vm_error(err)),
None => Ok(build_mock_result(mock, match_text.len())),
});
}
}
}
None
}
fn try_match_builtin_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
}
fn try_match_cli_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
CLI_LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
}
pub(crate) fn record_cli_llm_result(result: &LlmResult) {
if !CLI_LLM_MOCK_MODE.with(|mode| *mode.borrow() == CliLlmMockMode::Record) {
return;
}
CLI_LLM_RECORDINGS.with(|recordings| {
recordings.borrow_mut().push(LlmMock {
text: result.text.clone(),
tool_calls: result.tool_calls.clone(),
match_pattern: None,
consume_on_match: false,
input_tokens: Some(result.input_tokens),
output_tokens: Some(result.output_tokens),
cache_read_tokens: Some(result.cache_read_tokens),
cache_write_tokens: Some(result.cache_write_tokens),
thinking: result.thinking.clone(),
stop_reason: result.stop_reason.clone(),
model: result.model.clone(),
provider: Some(result.provider.clone()),
blocks: Some(result.blocks.clone()),
error: None,
});
});
}
fn unmatched_cli_prompt_error(match_text: &str) -> VmError {
let mut snippet: String = match_text.chars().take(200).collect();
if match_text.chars().count() > 200 {
snippet.push_str("...");
}
VmError::Runtime(format!("No --llm-mock fixture matched prompt: {snippet:?}"))
}
pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
}
pub(crate) fn get_replay_mode() -> LlmReplayMode {
LLM_REPLAY_MODE.with(|v| *v.borrow())
}
pub(crate) fn get_fixture_dir() -> String {
LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
}
pub(crate) fn fixture_hash(
model: &str,
messages: &[serde_json::Value],
system: Option<&str>,
) -> String {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
model.hash(&mut hasher);
serde_json::to_string(messages)
.unwrap_or_default()
.hash(&mut hasher);
system.hash(&mut hasher);
format!("{:016x}", hasher.finish())
}
pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
let dir = get_fixture_dir();
if dir.is_empty() {
return;
}
let _ = std::fs::create_dir_all(&dir);
let path = format!("{dir}/{hash}.json");
let json = serde_json::json!({
"text": result.text,
"tool_calls": result.tool_calls,
"input_tokens": result.input_tokens,
"output_tokens": result.output_tokens,
"cache_read_tokens": result.cache_read_tokens,
"cache_write_tokens": result.cache_write_tokens,
"cache_creation_input_tokens": result.cache_write_tokens,
"model": result.model,
"provider": result.provider,
"thinking": result.thinking,
"stop_reason": result.stop_reason,
"blocks": result.blocks,
});
let _ = std::fs::write(
&path,
serde_json::to_string_pretty(&json).unwrap_or_default(),
);
}
pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
let dir = get_fixture_dir();
if dir.is_empty() {
return None;
}
let path = format!("{dir}/{hash}.json");
let content = std::fs::read_to_string(&path).ok()?;
let json: serde_json::Value = serde_json::from_str(&content).ok()?;
Some(LlmResult {
text: json["text"].as_str().unwrap_or("").to_string(),
tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
cache_write_tokens: json["cache_write_tokens"]
.as_i64()
.or_else(|| json["cache_creation_input_tokens"].as_i64())
.unwrap_or(0),
model: json["model"].as_str().unwrap_or("").to_string(),
provider: json["provider"].as_str().unwrap_or("mock").to_string(),
thinking: json["thinking"].as_str().map(|s| s.to_string()),
stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
})
}
fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
let mut args = serde_json::Map::new();
let input_schema = tool_schema
.get("input_schema")
.or_else(|| tool_schema.get("inputSchema"))
.or_else(|| {
tool_schema
.get("function")
.and_then(|f| f.get("parameters"))
})
.or_else(|| tool_schema.get("parameters"));
let Some(schema) = input_schema else {
return serde_json::Value::Object(args);
};
let required: std::collections::BTreeSet<String> = schema
.get("required")
.and_then(|r| r.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
for (name, prop) in props {
if !required.contains(name) {
continue;
}
let ty = prop
.get("type")
.and_then(|t| t.as_str())
.unwrap_or("string");
let placeholder = match ty {
"integer" => serde_json::json!(0),
"number" => serde_json::json!(0.0),
"boolean" => serde_json::json!(false),
"array" => serde_json::json!([]),
"object" => serde_json::json!({}),
_ => serde_json::json!(""),
};
args.insert(name.clone(), placeholder);
}
}
serde_json::Value::Object(args)
}
pub(crate) fn mock_llm_response(
messages: &[serde_json::Value],
system: Option<&str>,
native_tools: Option<&[serde_json::Value]>,
thinking: &super::api::ThinkingConfig,
model: &str,
cache: bool,
) -> Result<LlmResult, VmError> {
record_llm_mock_call(messages, system, native_tools, thinking);
let match_text = mock_match_text(messages);
let prompt_text = mock_last_prompt_text(messages);
let cache_key = mock_prompt_cache_key(model, messages, system);
if let Some(matched) = try_match_cli_mock(&match_text) {
return matched.map(|mut result| {
if cache {
apply_mock_prompt_cache(&mut result, &cache_key);
}
result
});
}
if let Some(matched) = try_match_builtin_mock(&match_text) {
return matched.map(|mut result| {
if cache {
apply_mock_prompt_cache(&mut result, &cache_key);
}
result
});
}
if cli_llm_mock_replay_active() {
return Err(unmatched_cli_prompt_error(&match_text));
}
if let Some(tools) = native_tools {
if let Some(first_tool) = tools.first() {
let tool_name = first_tool
.get("name")
.or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
.and_then(|n| n.as_str())
.unwrap_or("unknown");
let mock_args = mock_required_args(first_tool);
let mut result = LlmResult {
text: String::new(),
tool_calls: vec![serde_json::json!({
"id": "mock_call_1",
"type": "tool_call",
"name": tool_name,
"arguments": mock_args
})],
input_tokens: prompt_text.len() as i64,
output_tokens: 20,
cache_read_tokens: 0,
cache_write_tokens: 0,
model: model.to_string(),
provider: "mock".to_string(),
thinking: None,
stop_reason: None,
blocks: vec![serde_json::json!({
"type": "tool_call",
"id": "mock_call_1",
"name": tool_name,
"arguments": mock_args,
"visibility": "internal",
})],
};
if cache {
apply_mock_prompt_cache(&mut result, &cache_key);
}
return Ok(result);
}
}
let tagged_done = system.is_some_and(|s| s.contains("<done>"));
let prose_body = if prompt_text.is_empty() {
"Mock LLM response".to_string()
} else {
let word_count = prompt_text.split_whitespace().count();
format!(
"Mock response to {word_count}-word prompt: {}",
prompt_text.chars().take(100).collect::<String>()
)
};
let response = if tagged_done {
format!("<assistant_prose>{prose_body}</assistant_prose>\n<done>##DONE##</done>")
} else {
prose_body
};
let mut result = LlmResult {
text: response.clone(),
tool_calls: vec![],
input_tokens: prompt_text.len() as i64,
output_tokens: 30,
cache_read_tokens: 0,
cache_write_tokens: 0,
model: model.to_string(),
provider: "mock".to_string(),
thinking: None,
stop_reason: None,
blocks: vec![serde_json::json!({
"type": "output_text",
"text": response,
"visibility": "public",
})],
};
if cache {
apply_mock_prompt_cache(&mut result, &cache_key);
}
Ok(result)
}
pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
}
pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
TOOL_RECORDING_MODE.with(|v| *v.borrow())
}
pub(crate) fn record_tool_call(record: ToolCallRecord) {
TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
}
pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
}
pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
}
pub(crate) fn find_tool_replay_fixture(
tool_name: &str,
args: &serde_json::Value,
) -> Option<ToolCallRecord> {
let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
TOOL_REPLAY_FIXTURES.with(|v| {
v.borrow()
.iter()
.find(|r| r.tool_name == tool_name && r.args_hash == hash)
.cloned()
})
}