1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5use crate::value::{ErrorCategory, VmError};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum LlmReplayMode {
10 Off,
11 Record,
12 Replay,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ToolRecordingMode {
18 Off,
19 Record,
20 Replay,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24enum CliLlmMockMode {
25 Off,
26 Replay,
27 Record,
28}
29
30#[derive(Clone)]
35pub struct MockError {
36 pub category: ErrorCategory,
37 pub message: String,
38}
39
40#[derive(Clone)]
41pub struct LlmMock {
42 pub text: String,
43 pub tool_calls: Vec<serde_json::Value>,
44 pub match_pattern: Option<String>, pub consume_on_match: bool,
46 pub input_tokens: Option<i64>,
47 pub output_tokens: Option<i64>,
48 pub cache_read_tokens: Option<i64>,
49 pub cache_write_tokens: Option<i64>,
50 pub thinking: Option<String>,
51 pub stop_reason: Option<String>,
52 pub model: String,
53 pub provider: Option<String>,
54 pub blocks: Option<Vec<serde_json::Value>>,
55 pub error: Option<MockError>,
58}
59
60#[derive(Clone)]
61pub(crate) struct LlmMockCall {
62 pub messages: Vec<serde_json::Value>,
63 pub system: Option<String>,
64 pub tools: Option<Vec<serde_json::Value>>,
65}
66
67thread_local! {
68 static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
69 static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
70 static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
71 static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
72 static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
73 static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
74 static CLI_LLM_MOCK_MODE: RefCell<CliLlmMockMode> = const { RefCell::new(CliLlmMockMode::Off) };
75 static CLI_LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
76 static CLI_LLM_RECORDINGS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
77 static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
78}
79
80pub(crate) fn push_llm_mock(mock: LlmMock) {
81 LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
82}
83
84pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
85 LLM_MOCK_CALLS.with(|v| v.borrow().clone())
86}
87
88pub(crate) fn reset_llm_mock_state() {
89 LLM_MOCKS.with(|v| v.borrow_mut().clear());
90 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
91 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
92 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
93 LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
94}
95
96pub fn clear_cli_llm_mock_mode() {
97 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
98 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
99 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
100}
101
102pub fn install_cli_llm_mocks(mocks: Vec<LlmMock>) {
103 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Replay);
104 CLI_LLM_MOCKS.with(|v| *v.borrow_mut() = mocks);
105 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
106}
107
108pub fn enable_cli_llm_mock_recording() {
109 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Record);
110 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
111 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
112}
113
114pub fn take_cli_llm_recordings() -> Vec<LlmMock> {
115 CLI_LLM_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
116}
117
118pub(crate) fn cli_llm_mock_replay_active() -> bool {
119 CLI_LLM_MOCK_MODE.with(|v| *v.borrow() == CliLlmMockMode::Replay)
120}
121
122fn record_llm_mock_call(
123 messages: &[serde_json::Value],
124 system: Option<&str>,
125 native_tools: Option<&[serde_json::Value]>,
126) {
127 LLM_MOCK_CALLS.with(|v| {
128 v.borrow_mut().push(LlmMockCall {
129 messages: messages.to_vec(),
130 system: system.map(|s| s.to_string()),
131 tools: native_tools.map(|t| t.to_vec()),
132 });
133 });
134}
135
136fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
138 let (tool_calls, blocks) = if let Some(blocks) = &mock.blocks {
139 (mock.tool_calls.clone(), blocks.clone())
140 } else {
141 let mut blocks = Vec::new();
142
143 if !mock.text.is_empty() {
144 blocks.push(serde_json::json!({
145 "type": "output_text",
146 "text": mock.text,
147 "visibility": "public",
148 }));
149 }
150
151 let mut tool_calls = Vec::new();
152 for (i, tc) in mock.tool_calls.iter().enumerate() {
153 let id = format!("mock_call_{}", i + 1);
154 let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
155 let arguments = tc
156 .get("arguments")
157 .cloned()
158 .unwrap_or(serde_json::json!({}));
159 tool_calls.push(serde_json::json!({
160 "id": id,
161 "type": "tool_call",
162 "name": name,
163 "arguments": arguments,
164 }));
165 blocks.push(serde_json::json!({
166 "type": "tool_call",
167 "id": id,
168 "name": name,
169 "arguments": arguments,
170 "visibility": "internal",
171 }));
172 }
173
174 (tool_calls, blocks)
175 };
176
177 LlmResult {
178 text: mock.text.clone(),
179 tool_calls,
180 input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
181 output_tokens: mock.output_tokens.unwrap_or(30),
182 cache_read_tokens: mock.cache_read_tokens.unwrap_or(0),
183 cache_write_tokens: mock.cache_write_tokens.unwrap_or(0),
184 model: mock.model.clone(),
185 provider: mock.provider.clone().unwrap_or_else(|| "mock".to_string()),
186 thinking: mock.thinking.clone(),
187 stop_reason: mock.stop_reason.clone(),
188 blocks,
189 }
190}
191
192fn mock_glob_match(pattern: &str, text: &str) -> bool {
195 if pattern == "*" {
196 return true;
197 }
198 if !pattern.contains('*') {
199 return pattern == text;
200 }
201 let parts: Vec<&str> = pattern.split('*').collect();
202 let mut remaining = text;
203 for (i, part) in parts.iter().enumerate() {
204 if part.is_empty() {
205 continue;
206 }
207 if i == 0 {
208 if !remaining.starts_with(part) {
209 return false;
210 }
211 remaining = &remaining[part.len()..];
212 } else if i == parts.len() - 1 {
213 if !remaining.ends_with(part) {
214 return false;
215 }
216 remaining = "";
217 } else {
218 match remaining.find(part) {
219 Some(pos) => remaining = &remaining[pos + part.len()..],
220 None => return false,
221 }
222 }
223 }
224 true
225}
226
227fn collect_mock_match_strings(value: &serde_json::Value, out: &mut Vec<String>) {
228 match value {
229 serde_json::Value::String(text) if !text.is_empty() => out.push(text.clone()),
230 serde_json::Value::String(_) => {}
231 serde_json::Value::Array(items) => {
232 for item in items {
233 collect_mock_match_strings(item, out);
234 }
235 }
236 serde_json::Value::Object(map) => {
237 for value in map.values() {
238 collect_mock_match_strings(value, out);
239 }
240 }
241 _ => {}
242 }
243}
244
245fn mock_match_text(messages: &[serde_json::Value]) -> String {
246 let mut parts = Vec::new();
247 for message in messages {
248 collect_mock_match_strings(message, &mut parts);
249 }
250 parts.join("\n")
251}
252
253fn mock_last_prompt_text(messages: &[serde_json::Value]) -> String {
254 for message in messages.iter().rev() {
255 let Some(content) = message.get("content") else {
256 continue;
257 };
258 let mut parts = Vec::new();
259 collect_mock_match_strings(content, &mut parts);
260 let text = parts.join("\n");
261 if !text.trim().is_empty() {
262 return text;
263 }
264 }
265 String::new()
266}
267
268fn mock_error_to_vm_error(err: &MockError) -> VmError {
272 VmError::CategorizedError {
273 message: err.message.clone(),
274 category: err.category.clone(),
275 }
276}
277
278fn try_match_mock_queue(
282 mocks: &mut Vec<LlmMock>,
283 match_text: &str,
284) -> Option<Result<LlmResult, VmError>> {
285 if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
286 let mock = mocks.remove(idx);
287 return Some(match &mock.error {
288 Some(err) => Err(mock_error_to_vm_error(err)),
289 None => Ok(build_mock_result(&mock, match_text.len())),
290 });
291 }
292
293 for idx in 0..mocks.len() {
294 let mock = &mocks[idx];
295 if let Some(ref pattern) = mock.match_pattern {
296 if mock_glob_match(pattern, match_text) {
297 if mock.consume_on_match {
298 let mock = mocks.remove(idx);
299 return Some(match &mock.error {
300 Some(err) => Err(mock_error_to_vm_error(err)),
301 None => Ok(build_mock_result(&mock, match_text.len())),
302 });
303 }
304 return Some(match &mock.error {
305 Some(err) => Err(mock_error_to_vm_error(err)),
306 None => Ok(build_mock_result(mock, match_text.len())),
307 });
308 }
309 }
310 }
311
312 None
313}
314
315fn try_match_builtin_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
316 LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
317}
318
319fn try_match_cli_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
320 CLI_LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
321}
322
323pub(crate) fn record_cli_llm_result(result: &LlmResult) {
324 if !CLI_LLM_MOCK_MODE.with(|mode| *mode.borrow() == CliLlmMockMode::Record) {
325 return;
326 }
327 CLI_LLM_RECORDINGS.with(|recordings| {
328 recordings.borrow_mut().push(LlmMock {
329 text: result.text.clone(),
330 tool_calls: result.tool_calls.clone(),
331 match_pattern: None,
332 consume_on_match: false,
333 input_tokens: Some(result.input_tokens),
334 output_tokens: Some(result.output_tokens),
335 cache_read_tokens: Some(result.cache_read_tokens),
336 cache_write_tokens: Some(result.cache_write_tokens),
337 thinking: result.thinking.clone(),
338 stop_reason: result.stop_reason.clone(),
339 model: result.model.clone(),
340 provider: Some(result.provider.clone()),
341 blocks: Some(result.blocks.clone()),
342 error: None,
343 });
344 });
345}
346
347fn unmatched_cli_prompt_error(match_text: &str) -> VmError {
348 let mut snippet: String = match_text.chars().take(200).collect();
349 if match_text.chars().count() > 200 {
350 snippet.push_str("...");
351 }
352 VmError::Runtime(format!("No --llm-mock fixture matched prompt: {snippet:?}"))
353}
354
355pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
357 LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
358 LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
359}
360
361pub(crate) fn get_replay_mode() -> LlmReplayMode {
362 LLM_REPLAY_MODE.with(|v| *v.borrow())
363}
364
365pub(crate) fn get_fixture_dir() -> String {
366 LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
367}
368
369pub(crate) fn fixture_hash(
371 model: &str,
372 messages: &[serde_json::Value],
373 system: Option<&str>,
374) -> String {
375 use std::hash::{Hash, Hasher};
376 let mut hasher = std::collections::hash_map::DefaultHasher::new();
377 model.hash(&mut hasher);
378 serde_json::to_string(messages)
380 .unwrap_or_default()
381 .hash(&mut hasher);
382 system.hash(&mut hasher);
383 format!("{:016x}", hasher.finish())
384}
385
386pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
387 let dir = get_fixture_dir();
388 if dir.is_empty() {
389 return;
390 }
391 let _ = std::fs::create_dir_all(&dir);
392 let path = format!("{dir}/{hash}.json");
393 let json = serde_json::json!({
394 "text": result.text,
395 "tool_calls": result.tool_calls,
396 "input_tokens": result.input_tokens,
397 "output_tokens": result.output_tokens,
398 "model": result.model,
399 "provider": result.provider,
400 "blocks": result.blocks,
401 });
402 let _ = std::fs::write(
403 &path,
404 serde_json::to_string_pretty(&json).unwrap_or_default(),
405 );
406}
407
408pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
409 let dir = get_fixture_dir();
410 if dir.is_empty() {
411 return None;
412 }
413 let path = format!("{dir}/{hash}.json");
414 let content = std::fs::read_to_string(&path).ok()?;
415 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
416 Some(LlmResult {
417 text: json["text"].as_str().unwrap_or("").to_string(),
418 tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
419 input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
420 output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
421 cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
422 cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
423 model: json["model"].as_str().unwrap_or("").to_string(),
424 provider: json["provider"].as_str().unwrap_or("mock").to_string(),
425 thinking: json["thinking"].as_str().map(|s| s.to_string()),
426 stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
427 blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
428 })
429}
430
431fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
435 let mut args = serde_json::Map::new();
436 let input_schema = tool_schema
440 .get("input_schema")
441 .or_else(|| tool_schema.get("inputSchema"))
442 .or_else(|| {
443 tool_schema
444 .get("function")
445 .and_then(|f| f.get("parameters"))
446 })
447 .or_else(|| tool_schema.get("parameters"));
448 let Some(schema) = input_schema else {
449 return serde_json::Value::Object(args);
450 };
451 let required: std::collections::BTreeSet<String> = schema
452 .get("required")
453 .and_then(|r| r.as_array())
454 .map(|arr| {
455 arr.iter()
456 .filter_map(|v| v.as_str().map(|s| s.to_string()))
457 .collect()
458 })
459 .unwrap_or_default();
460 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
461 for (name, prop) in props {
462 if !required.contains(name) {
463 continue;
464 }
465 let ty = prop
466 .get("type")
467 .and_then(|t| t.as_str())
468 .unwrap_or("string");
469 let placeholder = match ty {
470 "integer" => serde_json::json!(0),
471 "number" => serde_json::json!(0.0),
472 "boolean" => serde_json::json!(false),
473 "array" => serde_json::json!([]),
474 "object" => serde_json::json!({}),
475 _ => serde_json::json!(""),
476 };
477 args.insert(name.clone(), placeholder);
478 }
479 }
480 serde_json::Value::Object(args)
481}
482
483pub(crate) fn mock_llm_response(
488 messages: &[serde_json::Value],
489 system: Option<&str>,
490 native_tools: Option<&[serde_json::Value]>,
491) -> Result<LlmResult, VmError> {
492 record_llm_mock_call(messages, system, native_tools);
493
494 let match_text = mock_match_text(messages);
495 let prompt_text = mock_last_prompt_text(messages);
496
497 if let Some(matched) = try_match_cli_mock(&match_text) {
498 return matched;
499 }
500
501 if let Some(matched) = try_match_builtin_mock(&match_text) {
502 return matched;
503 }
504
505 if cli_llm_mock_replay_active() {
506 return Err(unmatched_cli_prompt_error(&match_text));
507 }
508
509 if let Some(tools) = native_tools {
512 if let Some(first_tool) = tools.first() {
513 let tool_name = first_tool
514 .get("name")
515 .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
516 .and_then(|n| n.as_str())
517 .unwrap_or("unknown");
518 let mock_args = mock_required_args(first_tool);
519 return Ok(LlmResult {
520 text: String::new(),
521 tool_calls: vec![serde_json::json!({
522 "id": "mock_call_1",
523 "type": "tool_call",
524 "name": tool_name,
525 "arguments": mock_args
526 })],
527 input_tokens: prompt_text.len() as i64,
528 output_tokens: 20,
529 cache_read_tokens: 0,
530 cache_write_tokens: 0,
531 model: "mock".to_string(),
532 provider: "mock".to_string(),
533 thinking: None,
534 stop_reason: None,
535 blocks: vec![serde_json::json!({
536 "type": "tool_call",
537 "id": "mock_call_1",
538 "name": tool_name,
539 "arguments": mock_args,
540 "visibility": "internal",
541 })],
542 });
543 }
544 }
545
546 let done_block = if system.is_some_and(|s| s.contains("##DONE##")) {
549 "\n<done>##DONE##</done>"
550 } else {
551 ""
552 };
553
554 let prose_body = if prompt_text.is_empty() {
555 "Mock LLM response".to_string()
556 } else {
557 let word_count = prompt_text.split_whitespace().count();
558 format!(
559 "Mock response to {word_count}-word prompt: {}",
560 prompt_text.chars().take(100).collect::<String>()
561 )
562 };
563 let response = format!("<assistant_prose>{prose_body}</assistant_prose>{done_block}");
564
565 Ok(LlmResult {
566 text: response.clone(),
567 tool_calls: vec![],
568 input_tokens: prompt_text.len() as i64,
569 output_tokens: 30,
570 cache_read_tokens: 0,
571 cache_write_tokens: 0,
572 model: "mock".to_string(),
573 provider: "mock".to_string(),
574 thinking: None,
575 stop_reason: None,
576 blocks: vec![serde_json::json!({
577 "type": "output_text",
578 "text": response,
579 "visibility": "public",
580 })],
581 })
582}
583
584pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
585 TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
586}
587
588pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
589 TOOL_RECORDING_MODE.with(|v| *v.borrow())
590}
591
592pub(crate) fn record_tool_call(record: ToolCallRecord) {
594 TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
595}
596
597pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
599 TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
600}
601
602pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
604 TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
605}
606
607pub(crate) fn find_tool_replay_fixture(
609 tool_name: &str,
610 args: &serde_json::Value,
611) -> Option<ToolCallRecord> {
612 let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
613 TOOL_REPLAY_FIXTURES.with(|v| {
614 v.borrow()
615 .iter()
616 .find(|r| r.tool_name == tool_name && r.args_hash == hash)
617 .cloned()
618 })
619}