1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5use crate::value::{ErrorCategory, VmError};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum LlmReplayMode {
10 Off,
11 Record,
12 Replay,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ToolRecordingMode {
18 Off,
19 Record,
20 Replay,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24enum CliLlmMockMode {
25 Off,
26 Replay,
27 Record,
28}
29
30#[derive(Clone)]
35pub struct MockError {
36 pub category: ErrorCategory,
37 pub message: String,
38 pub retry_after_ms: Option<u64>,
44}
45
46#[derive(Clone)]
47pub struct LlmMock {
48 pub text: String,
49 pub tool_calls: Vec<serde_json::Value>,
50 pub match_pattern: Option<String>, pub consume_on_match: bool,
52 pub input_tokens: Option<i64>,
53 pub output_tokens: Option<i64>,
54 pub cache_read_tokens: Option<i64>,
55 pub cache_write_tokens: Option<i64>,
56 pub thinking: Option<String>,
57 pub stop_reason: Option<String>,
58 pub model: String,
59 pub provider: Option<String>,
60 pub blocks: Option<Vec<serde_json::Value>>,
61 pub error: Option<MockError>,
64}
65
66#[derive(Clone)]
67pub(crate) struct LlmMockCall {
68 pub messages: Vec<serde_json::Value>,
69 pub system: Option<String>,
70 pub tools: Option<Vec<serde_json::Value>>,
71}
72
73thread_local! {
74 static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
75 static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
76 static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
77 static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
78 static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
79 static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
80 static CLI_LLM_MOCK_MODE: RefCell<CliLlmMockMode> = const { RefCell::new(CliLlmMockMode::Off) };
81 static CLI_LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
82 static CLI_LLM_RECORDINGS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
83 static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
84}
85
86pub(crate) fn push_llm_mock(mock: LlmMock) {
87 LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
88}
89
90pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
91 LLM_MOCK_CALLS.with(|v| v.borrow().clone())
92}
93
94pub(crate) fn reset_llm_mock_state() {
95 LLM_MOCKS.with(|v| v.borrow_mut().clear());
96 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
97 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
98 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
99 LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
100}
101
102pub fn clear_cli_llm_mock_mode() {
103 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
104 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
105 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
106}
107
108pub fn install_cli_llm_mocks(mocks: Vec<LlmMock>) {
109 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Replay);
110 CLI_LLM_MOCKS.with(|v| *v.borrow_mut() = mocks);
111 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
112}
113
114pub fn enable_cli_llm_mock_recording() {
115 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Record);
116 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
117 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
118}
119
120pub fn take_cli_llm_recordings() -> Vec<LlmMock> {
121 CLI_LLM_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
122}
123
124pub(crate) fn cli_llm_mock_replay_active() -> bool {
125 CLI_LLM_MOCK_MODE.with(|v| *v.borrow() == CliLlmMockMode::Replay)
126}
127
128fn record_llm_mock_call(
129 messages: &[serde_json::Value],
130 system: Option<&str>,
131 native_tools: Option<&[serde_json::Value]>,
132) {
133 LLM_MOCK_CALLS.with(|v| {
134 v.borrow_mut().push(LlmMockCall {
135 messages: messages.to_vec(),
136 system: system.map(|s| s.to_string()),
137 tools: native_tools.map(|t| t.to_vec()),
138 });
139 });
140}
141
142fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
144 let (tool_calls, blocks) = if let Some(blocks) = &mock.blocks {
145 (mock.tool_calls.clone(), blocks.clone())
146 } else {
147 let mut blocks = Vec::new();
148
149 if !mock.text.is_empty() {
150 blocks.push(serde_json::json!({
151 "type": "output_text",
152 "text": mock.text,
153 "visibility": "public",
154 }));
155 }
156
157 let mut tool_calls = Vec::new();
158 for (i, tc) in mock.tool_calls.iter().enumerate() {
159 let id = format!("mock_call_{}", i + 1);
160 let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
161 let arguments = tc
162 .get("arguments")
163 .cloned()
164 .unwrap_or(serde_json::json!({}));
165 tool_calls.push(serde_json::json!({
166 "id": id,
167 "type": "tool_call",
168 "name": name,
169 "arguments": arguments,
170 }));
171 blocks.push(serde_json::json!({
172 "type": "tool_call",
173 "id": id,
174 "name": name,
175 "arguments": arguments,
176 "visibility": "internal",
177 }));
178 }
179
180 (tool_calls, blocks)
181 };
182
183 LlmResult {
184 text: mock.text.clone(),
185 tool_calls,
186 input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
187 output_tokens: mock.output_tokens.unwrap_or(30),
188 cache_read_tokens: mock.cache_read_tokens.unwrap_or(0),
189 cache_write_tokens: mock.cache_write_tokens.unwrap_or(0),
190 model: mock.model.clone(),
191 provider: mock.provider.clone().unwrap_or_else(|| "mock".to_string()),
192 thinking: mock.thinking.clone(),
193 stop_reason: mock.stop_reason.clone(),
194 blocks,
195 }
196}
197
198fn mock_glob_match(pattern: &str, text: &str) -> bool {
201 if pattern == "*" {
202 return true;
203 }
204 if !pattern.contains('*') {
205 return pattern == text;
206 }
207 let parts: Vec<&str> = pattern.split('*').collect();
208 let mut remaining = text;
209 for (i, part) in parts.iter().enumerate() {
210 if part.is_empty() {
211 continue;
212 }
213 if i == 0 {
214 if !remaining.starts_with(part) {
215 return false;
216 }
217 remaining = &remaining[part.len()..];
218 } else if i == parts.len() - 1 {
219 if !remaining.ends_with(part) {
220 return false;
221 }
222 remaining = "";
223 } else {
224 match remaining.find(part) {
225 Some(pos) => remaining = &remaining[pos + part.len()..],
226 None => return false,
227 }
228 }
229 }
230 true
231}
232
233fn collect_mock_match_strings(value: &serde_json::Value, out: &mut Vec<String>) {
234 match value {
235 serde_json::Value::String(text) if !text.is_empty() => out.push(text.clone()),
236 serde_json::Value::String(_) => {}
237 serde_json::Value::Array(items) => {
238 for item in items {
239 collect_mock_match_strings(item, out);
240 }
241 }
242 serde_json::Value::Object(map) => {
243 for value in map.values() {
244 collect_mock_match_strings(value, out);
245 }
246 }
247 _ => {}
248 }
249}
250
251fn mock_match_text(messages: &[serde_json::Value]) -> String {
252 let mut parts = Vec::new();
253 for message in messages {
254 collect_mock_match_strings(message, &mut parts);
255 }
256 parts.join("\n")
257}
258
259fn mock_last_prompt_text(messages: &[serde_json::Value]) -> String {
260 for message in messages.iter().rev() {
261 let Some(content) = message.get("content") else {
262 continue;
263 };
264 let mut parts = Vec::new();
265 collect_mock_match_strings(content, &mut parts);
266 let text = parts.join("\n");
267 if !text.trim().is_empty() {
268 return text;
269 }
270 }
271 String::new()
272}
273
274fn mock_error_to_vm_error(err: &MockError) -> VmError {
278 let message = match err.retry_after_ms {
284 Some(ms) => {
285 let secs = (ms as f64 / 1000.0).max(0.0);
286 let sep = if err.message.is_empty() || err.message.ends_with('\n') {
287 ""
288 } else {
289 "\n"
290 };
291 format!("{}{sep}retry-after: {secs}\n", err.message)
292 }
293 None => err.message.clone(),
294 };
295 VmError::CategorizedError {
296 message,
297 category: err.category.clone(),
298 }
299}
300
301fn try_match_mock_queue(
305 mocks: &mut Vec<LlmMock>,
306 match_text: &str,
307) -> Option<Result<LlmResult, VmError>> {
308 if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
309 let mock = mocks.remove(idx);
310 return Some(match &mock.error {
311 Some(err) => Err(mock_error_to_vm_error(err)),
312 None => Ok(build_mock_result(&mock, match_text.len())),
313 });
314 }
315
316 for idx in 0..mocks.len() {
317 let mock = &mocks[idx];
318 if let Some(ref pattern) = mock.match_pattern {
319 if mock_glob_match(pattern, match_text) {
320 if mock.consume_on_match {
321 let mock = mocks.remove(idx);
322 return Some(match &mock.error {
323 Some(err) => Err(mock_error_to_vm_error(err)),
324 None => Ok(build_mock_result(&mock, match_text.len())),
325 });
326 }
327 return Some(match &mock.error {
328 Some(err) => Err(mock_error_to_vm_error(err)),
329 None => Ok(build_mock_result(mock, match_text.len())),
330 });
331 }
332 }
333 }
334
335 None
336}
337
338fn try_match_builtin_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
339 LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
340}
341
342fn try_match_cli_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
343 CLI_LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
344}
345
346pub(crate) fn record_cli_llm_result(result: &LlmResult) {
347 if !CLI_LLM_MOCK_MODE.with(|mode| *mode.borrow() == CliLlmMockMode::Record) {
348 return;
349 }
350 CLI_LLM_RECORDINGS.with(|recordings| {
351 recordings.borrow_mut().push(LlmMock {
352 text: result.text.clone(),
353 tool_calls: result.tool_calls.clone(),
354 match_pattern: None,
355 consume_on_match: false,
356 input_tokens: Some(result.input_tokens),
357 output_tokens: Some(result.output_tokens),
358 cache_read_tokens: Some(result.cache_read_tokens),
359 cache_write_tokens: Some(result.cache_write_tokens),
360 thinking: result.thinking.clone(),
361 stop_reason: result.stop_reason.clone(),
362 model: result.model.clone(),
363 provider: Some(result.provider.clone()),
364 blocks: Some(result.blocks.clone()),
365 error: None,
366 });
367 });
368}
369
370fn unmatched_cli_prompt_error(match_text: &str) -> VmError {
371 let mut snippet: String = match_text.chars().take(200).collect();
372 if match_text.chars().count() > 200 {
373 snippet.push_str("...");
374 }
375 VmError::Runtime(format!("No --llm-mock fixture matched prompt: {snippet:?}"))
376}
377
378pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
380 LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
381 LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
382}
383
384pub(crate) fn get_replay_mode() -> LlmReplayMode {
385 LLM_REPLAY_MODE.with(|v| *v.borrow())
386}
387
388pub(crate) fn get_fixture_dir() -> String {
389 LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
390}
391
392pub(crate) fn fixture_hash(
394 model: &str,
395 messages: &[serde_json::Value],
396 system: Option<&str>,
397) -> String {
398 use std::hash::{Hash, Hasher};
399 let mut hasher = std::collections::hash_map::DefaultHasher::new();
400 model.hash(&mut hasher);
401 serde_json::to_string(messages)
403 .unwrap_or_default()
404 .hash(&mut hasher);
405 system.hash(&mut hasher);
406 format!("{:016x}", hasher.finish())
407}
408
409pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
410 let dir = get_fixture_dir();
411 if dir.is_empty() {
412 return;
413 }
414 let _ = std::fs::create_dir_all(&dir);
415 let path = format!("{dir}/{hash}.json");
416 let json = serde_json::json!({
417 "text": result.text,
418 "tool_calls": result.tool_calls,
419 "input_tokens": result.input_tokens,
420 "output_tokens": result.output_tokens,
421 "model": result.model,
422 "provider": result.provider,
423 "blocks": result.blocks,
424 });
425 let _ = std::fs::write(
426 &path,
427 serde_json::to_string_pretty(&json).unwrap_or_default(),
428 );
429}
430
431pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
432 let dir = get_fixture_dir();
433 if dir.is_empty() {
434 return None;
435 }
436 let path = format!("{dir}/{hash}.json");
437 let content = std::fs::read_to_string(&path).ok()?;
438 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
439 Some(LlmResult {
440 text: json["text"].as_str().unwrap_or("").to_string(),
441 tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
442 input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
443 output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
444 cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
445 cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
446 model: json["model"].as_str().unwrap_or("").to_string(),
447 provider: json["provider"].as_str().unwrap_or("mock").to_string(),
448 thinking: json["thinking"].as_str().map(|s| s.to_string()),
449 stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
450 blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
451 })
452}
453
454fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
458 let mut args = serde_json::Map::new();
459 let input_schema = tool_schema
463 .get("input_schema")
464 .or_else(|| tool_schema.get("inputSchema"))
465 .or_else(|| {
466 tool_schema
467 .get("function")
468 .and_then(|f| f.get("parameters"))
469 })
470 .or_else(|| tool_schema.get("parameters"));
471 let Some(schema) = input_schema else {
472 return serde_json::Value::Object(args);
473 };
474 let required: std::collections::BTreeSet<String> = schema
475 .get("required")
476 .and_then(|r| r.as_array())
477 .map(|arr| {
478 arr.iter()
479 .filter_map(|v| v.as_str().map(|s| s.to_string()))
480 .collect()
481 })
482 .unwrap_or_default();
483 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
484 for (name, prop) in props {
485 if !required.contains(name) {
486 continue;
487 }
488 let ty = prop
489 .get("type")
490 .and_then(|t| t.as_str())
491 .unwrap_or("string");
492 let placeholder = match ty {
493 "integer" => serde_json::json!(0),
494 "number" => serde_json::json!(0.0),
495 "boolean" => serde_json::json!(false),
496 "array" => serde_json::json!([]),
497 "object" => serde_json::json!({}),
498 _ => serde_json::json!(""),
499 };
500 args.insert(name.clone(), placeholder);
501 }
502 }
503 serde_json::Value::Object(args)
504}
505
506pub(crate) fn mock_llm_response(
511 messages: &[serde_json::Value],
512 system: Option<&str>,
513 native_tools: Option<&[serde_json::Value]>,
514) -> Result<LlmResult, VmError> {
515 record_llm_mock_call(messages, system, native_tools);
516
517 let match_text = mock_match_text(messages);
518 let prompt_text = mock_last_prompt_text(messages);
519
520 if let Some(matched) = try_match_cli_mock(&match_text) {
521 return matched;
522 }
523
524 if let Some(matched) = try_match_builtin_mock(&match_text) {
525 return matched;
526 }
527
528 if cli_llm_mock_replay_active() {
529 return Err(unmatched_cli_prompt_error(&match_text));
530 }
531
532 if let Some(tools) = native_tools {
535 if let Some(first_tool) = tools.first() {
536 let tool_name = first_tool
537 .get("name")
538 .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
539 .and_then(|n| n.as_str())
540 .unwrap_or("unknown");
541 let mock_args = mock_required_args(first_tool);
542 return Ok(LlmResult {
543 text: String::new(),
544 tool_calls: vec![serde_json::json!({
545 "id": "mock_call_1",
546 "type": "tool_call",
547 "name": tool_name,
548 "arguments": mock_args
549 })],
550 input_tokens: prompt_text.len() as i64,
551 output_tokens: 20,
552 cache_read_tokens: 0,
553 cache_write_tokens: 0,
554 model: "mock".to_string(),
555 provider: "mock".to_string(),
556 thinking: None,
557 stop_reason: None,
558 blocks: vec![serde_json::json!({
559 "type": "tool_call",
560 "id": "mock_call_1",
561 "name": tool_name,
562 "arguments": mock_args,
563 "visibility": "internal",
564 })],
565 });
566 }
567 }
568
569 let tagged_done = system.is_some_and(|s| s.contains("<done>"));
574
575 let prose_body = if prompt_text.is_empty() {
576 "Mock LLM response".to_string()
577 } else {
578 let word_count = prompt_text.split_whitespace().count();
579 format!(
580 "Mock response to {word_count}-word prompt: {}",
581 prompt_text.chars().take(100).collect::<String>()
582 )
583 };
584 let response = if tagged_done {
585 format!("<assistant_prose>{prose_body}</assistant_prose>\n<done>##DONE##</done>")
586 } else {
587 prose_body
588 };
589
590 Ok(LlmResult {
591 text: response.clone(),
592 tool_calls: vec![],
593 input_tokens: prompt_text.len() as i64,
594 output_tokens: 30,
595 cache_read_tokens: 0,
596 cache_write_tokens: 0,
597 model: "mock".to_string(),
598 provider: "mock".to_string(),
599 thinking: None,
600 stop_reason: None,
601 blocks: vec![serde_json::json!({
602 "type": "output_text",
603 "text": response,
604 "visibility": "public",
605 })],
606 })
607}
608
609pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
610 TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
611}
612
613pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
614 TOOL_RECORDING_MODE.with(|v| *v.borrow())
615}
616
617pub(crate) fn record_tool_call(record: ToolCallRecord) {
619 TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
620}
621
622pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
624 TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
625}
626
627pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
629 TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
630}
631
632pub(crate) fn find_tool_replay_fixture(
634 tool_name: &str,
635 args: &serde_json::Value,
636) -> Option<ToolCallRecord> {
637 let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
638 TOOL_REPLAY_FIXTURES.with(|v| {
639 v.borrow()
640 .iter()
641 .find(|r| r.tool_name == tool_name && r.args_hash == hash)
642 .cloned()
643 })
644}