1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5use crate::value::{ErrorCategory, VmError};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum LlmReplayMode {
10 Off,
11 Record,
12 Replay,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ToolRecordingMode {
18 Off,
19 Record,
20 Replay,
21}
22
23#[derive(Clone)]
28pub(crate) struct MockError {
29 pub category: ErrorCategory,
30 pub message: String,
31}
32
33pub(crate) struct LlmMock {
34 pub text: String,
35 pub tool_calls: Vec<serde_json::Value>,
36 pub match_pattern: Option<String>, pub input_tokens: Option<i64>,
38 pub output_tokens: Option<i64>,
39 pub thinking: Option<String>,
40 pub stop_reason: Option<String>,
41 pub model: String,
42 pub error: Option<MockError>,
45}
46
47#[derive(Clone)]
48pub(crate) struct LlmMockCall {
49 pub messages: Vec<serde_json::Value>,
50 pub system: Option<String>,
51 pub tools: Option<Vec<serde_json::Value>>,
52}
53
54thread_local! {
55 static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
56 static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
57 static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
58 static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
59 static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
60 static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
61 static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
62}
63
64pub(crate) fn push_llm_mock(mock: LlmMock) {
65 LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
66}
67
68pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
69 LLM_MOCK_CALLS.with(|v| v.borrow().clone())
70}
71
72pub(crate) fn reset_llm_mock_state() {
73 LLM_MOCKS.with(|v| v.borrow_mut().clear());
74 LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
75}
76
77fn record_llm_mock_call(
78 messages: &[serde_json::Value],
79 system: Option<&str>,
80 native_tools: Option<&[serde_json::Value]>,
81) {
82 LLM_MOCK_CALLS.with(|v| {
83 v.borrow_mut().push(LlmMockCall {
84 messages: messages.to_vec(),
85 system: system.map(|s| s.to_string()),
86 tools: native_tools.map(|t| t.to_vec()),
87 });
88 });
89}
90
91fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
93 let mut blocks = Vec::new();
94
95 if !mock.text.is_empty() {
96 blocks.push(serde_json::json!({
97 "type": "output_text",
98 "text": mock.text,
99 "visibility": "public",
100 }));
101 }
102
103 let mut tool_calls = Vec::new();
104 for (i, tc) in mock.tool_calls.iter().enumerate() {
105 let id = format!("mock_call_{}", i + 1);
106 let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
107 let arguments = tc
108 .get("arguments")
109 .cloned()
110 .unwrap_or(serde_json::json!({}));
111 tool_calls.push(serde_json::json!({
112 "id": id,
113 "type": "tool_call",
114 "name": name,
115 "arguments": arguments,
116 }));
117 blocks.push(serde_json::json!({
118 "type": "tool_call",
119 "id": id,
120 "name": name,
121 "arguments": arguments,
122 "visibility": "internal",
123 }));
124 }
125
126 LlmResult {
127 text: mock.text.clone(),
128 tool_calls,
129 input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
130 output_tokens: mock.output_tokens.unwrap_or(30),
131 cache_read_tokens: 0,
132 cache_write_tokens: 0,
133 model: mock.model.clone(),
134 provider: "mock".to_string(),
135 thinking: mock.thinking.clone(),
136 stop_reason: mock.stop_reason.clone(),
137 blocks,
138 }
139}
140
141fn mock_glob_match(pattern: &str, text: &str) -> bool {
144 if pattern == "*" {
145 return true;
146 }
147 if !pattern.contains('*') {
148 return pattern == text;
149 }
150 let parts: Vec<&str> = pattern.split('*').collect();
151 let mut remaining = text;
152 for (i, part) in parts.iter().enumerate() {
153 if part.is_empty() {
154 continue;
155 }
156 if i == 0 {
157 if !remaining.starts_with(part) {
158 return false;
159 }
160 remaining = &remaining[part.len()..];
161 } else if i == parts.len() - 1 {
162 if !remaining.ends_with(part) {
163 return false;
164 }
165 remaining = "";
166 } else {
167 match remaining.find(part) {
168 Some(pos) => remaining = &remaining[pos + part.len()..],
169 None => return false,
170 }
171 }
172 }
173 true
174}
175
176fn mock_error_to_vm_error(err: &MockError) -> VmError {
180 VmError::CategorizedError {
181 message: err.message.clone(),
182 category: err.category.clone(),
183 }
184}
185
186fn try_match_mock(last_msg: &str) -> Option<Result<LlmResult, VmError>> {
190 LLM_MOCKS.with(|mocks| {
191 let mut mocks = mocks.borrow_mut();
192
193 if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
195 let mock = mocks.remove(idx);
196 return Some(match &mock.error {
197 Some(err) => Err(mock_error_to_vm_error(err)),
198 None => Ok(build_mock_result(&mock, last_msg.len())),
199 });
200 }
201
202 for mock in mocks.iter().rev() {
204 if let Some(ref pattern) = mock.match_pattern {
205 if mock_glob_match(pattern, last_msg) {
206 return Some(match &mock.error {
207 Some(err) => Err(mock_error_to_vm_error(err)),
208 None => Ok(build_mock_result(mock, last_msg.len())),
209 });
210 }
211 }
212 }
213
214 None
215 })
216}
217
218pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
220 LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
221 LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
222}
223
224pub(crate) fn get_replay_mode() -> LlmReplayMode {
225 LLM_REPLAY_MODE.with(|v| *v.borrow())
226}
227
228pub(crate) fn get_fixture_dir() -> String {
229 LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
230}
231
232pub(crate) fn fixture_hash(
234 model: &str,
235 messages: &[serde_json::Value],
236 system: Option<&str>,
237) -> String {
238 use std::hash::{Hash, Hasher};
239 let mut hasher = std::collections::hash_map::DefaultHasher::new();
240 model.hash(&mut hasher);
241 serde_json::to_string(messages)
243 .unwrap_or_default()
244 .hash(&mut hasher);
245 system.hash(&mut hasher);
246 format!("{:016x}", hasher.finish())
247}
248
249pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
250 let dir = get_fixture_dir();
251 if dir.is_empty() {
252 return;
253 }
254 let _ = std::fs::create_dir_all(&dir);
255 let path = format!("{dir}/{hash}.json");
256 let json = serde_json::json!({
257 "text": result.text,
258 "tool_calls": result.tool_calls,
259 "input_tokens": result.input_tokens,
260 "output_tokens": result.output_tokens,
261 "model": result.model,
262 "provider": result.provider,
263 "blocks": result.blocks,
264 });
265 let _ = std::fs::write(
266 &path,
267 serde_json::to_string_pretty(&json).unwrap_or_default(),
268 );
269}
270
271pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
272 let dir = get_fixture_dir();
273 if dir.is_empty() {
274 return None;
275 }
276 let path = format!("{dir}/{hash}.json");
277 let content = std::fs::read_to_string(&path).ok()?;
278 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
279 Some(LlmResult {
280 text: json["text"].as_str().unwrap_or("").to_string(),
281 tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
282 input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
283 output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
284 cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
285 cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
286 model: json["model"].as_str().unwrap_or("").to_string(),
287 provider: json["provider"].as_str().unwrap_or("mock").to_string(),
288 thinking: json["thinking"].as_str().map(|s| s.to_string()),
289 stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
290 blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
291 })
292}
293
294fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
298 let mut args = serde_json::Map::new();
299 let input_schema = tool_schema
303 .get("input_schema")
304 .or_else(|| tool_schema.get("inputSchema"))
305 .or_else(|| {
306 tool_schema
307 .get("function")
308 .and_then(|f| f.get("parameters"))
309 })
310 .or_else(|| tool_schema.get("parameters"));
311 let Some(schema) = input_schema else {
312 return serde_json::Value::Object(args);
313 };
314 let required: std::collections::BTreeSet<String> = schema
315 .get("required")
316 .and_then(|r| r.as_array())
317 .map(|arr| {
318 arr.iter()
319 .filter_map(|v| v.as_str().map(|s| s.to_string()))
320 .collect()
321 })
322 .unwrap_or_default();
323 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
324 for (name, prop) in props {
325 if !required.contains(name) {
326 continue;
327 }
328 let ty = prop
329 .get("type")
330 .and_then(|t| t.as_str())
331 .unwrap_or("string");
332 let placeholder = match ty {
333 "integer" => serde_json::json!(0),
334 "number" => serde_json::json!(0.0),
335 "boolean" => serde_json::json!(false),
336 "array" => serde_json::json!([]),
337 "object" => serde_json::json!({}),
338 _ => serde_json::json!(""),
339 };
340 args.insert(name.clone(), placeholder);
341 }
342 }
343 serde_json::Value::Object(args)
344}
345
346pub(crate) fn mock_llm_response(
351 messages: &[serde_json::Value],
352 system: Option<&str>,
353 native_tools: Option<&[serde_json::Value]>,
354) -> Result<LlmResult, VmError> {
355 record_llm_mock_call(messages, system, native_tools);
356
357 let last_msg = messages
358 .last()
359 .and_then(|m| m.get("content"))
360 .and_then(|c| c.as_str())
361 .unwrap_or("");
362
363 if let Some(matched) = try_match_mock(last_msg) {
364 return matched;
365 }
366
367 if let Some(tools) = native_tools {
370 if let Some(first_tool) = tools.first() {
371 let tool_name = first_tool
372 .get("name")
373 .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
374 .and_then(|n| n.as_str())
375 .unwrap_or("unknown");
376 let mock_args = mock_required_args(first_tool);
377 return Ok(LlmResult {
378 text: String::new(),
379 tool_calls: vec![serde_json::json!({
380 "id": "mock_call_1",
381 "type": "tool_call",
382 "name": tool_name,
383 "arguments": mock_args
384 })],
385 input_tokens: last_msg.len() as i64,
386 output_tokens: 20,
387 cache_read_tokens: 0,
388 cache_write_tokens: 0,
389 model: "mock".to_string(),
390 provider: "mock".to_string(),
391 thinking: None,
392 stop_reason: None,
393 blocks: vec![serde_json::json!({
394 "type": "tool_call",
395 "id": "mock_call_1",
396 "name": tool_name,
397 "arguments": mock_args,
398 "visibility": "internal",
399 })],
400 });
401 }
402 }
403
404 let done_block = if system.is_some_and(|s| s.contains("##DONE##")) {
407 "\n<done>##DONE##</done>"
408 } else {
409 ""
410 };
411
412 let prose_body = if last_msg.is_empty() {
413 "Mock LLM response".to_string()
414 } else {
415 let word_count = last_msg.split_whitespace().count();
416 format!(
417 "Mock response to {word_count}-word prompt: {}",
418 last_msg.chars().take(100).collect::<String>()
419 )
420 };
421 let response = format!("<assistant_prose>{prose_body}</assistant_prose>{done_block}");
422
423 Ok(LlmResult {
424 text: response.clone(),
425 tool_calls: vec![],
426 input_tokens: last_msg.len() as i64,
427 output_tokens: 30,
428 cache_read_tokens: 0,
429 cache_write_tokens: 0,
430 model: "mock".to_string(),
431 provider: "mock".to_string(),
432 thinking: None,
433 stop_reason: None,
434 blocks: vec![serde_json::json!({
435 "type": "output_text",
436 "text": response,
437 "visibility": "public",
438 })],
439 })
440}
441
442pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
443 TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
444}
445
446pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
447 TOOL_RECORDING_MODE.with(|v| *v.borrow())
448}
449
450pub(crate) fn record_tool_call(record: ToolCallRecord) {
452 TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
453}
454
455pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
457 TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
458}
459
460pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
462 TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
463}
464
465pub(crate) fn find_tool_replay_fixture(
467 tool_name: &str,
468 args: &serde_json::Value,
469) -> Option<ToolCallRecord> {
470 let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
471 TOOL_REPLAY_FIXTURES.with(|v| {
472 v.borrow()
473 .iter()
474 .find(|r| r.tool_name == tool_name && r.args_hash == hash)
475 .cloned()
476 })
477}