1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5use crate::value::{ErrorCategory, VmError};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum LlmReplayMode {
10 Off,
11 Record,
12 Replay,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ToolRecordingMode {
18 Off,
19 Record,
20 Replay,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24enum CliLlmMockMode {
25 Off,
26 Replay,
27 Record,
28}
29
30#[derive(Clone)]
35pub struct MockError {
36 pub category: ErrorCategory,
37 pub message: String,
38 pub retry_after_ms: Option<u64>,
44}
45
46#[derive(Clone)]
47pub struct LlmMock {
48 pub text: String,
49 pub tool_calls: Vec<serde_json::Value>,
50 pub match_pattern: Option<String>, pub consume_on_match: bool,
52 pub input_tokens: Option<i64>,
53 pub output_tokens: Option<i64>,
54 pub cache_read_tokens: Option<i64>,
55 pub cache_write_tokens: Option<i64>,
56 pub thinking: Option<String>,
57 pub stop_reason: Option<String>,
58 pub model: String,
59 pub provider: Option<String>,
60 pub blocks: Option<Vec<serde_json::Value>>,
61 pub error: Option<MockError>,
64}
65
66#[derive(Clone)]
67pub(crate) struct LlmMockCall {
68 pub messages: Vec<serde_json::Value>,
69 pub system: Option<String>,
70 pub tools: Option<Vec<serde_json::Value>>,
71}
72
73thread_local! {
74 static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
75 static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
76 static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
77 static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
78 static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
79 static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
80 static CLI_LLM_MOCK_MODE: RefCell<CliLlmMockMode> = const { RefCell::new(CliLlmMockMode::Off) };
81 static CLI_LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
82 static CLI_LLM_RECORDINGS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
83 static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
84 static LLM_MOCK_SCOPES: RefCell<Vec<(Vec<LlmMock>, Vec<LlmMockCall>)>> =
85 const { RefCell::new(Vec::new()) };
86}
87
88pub(crate) fn push_llm_mock(mock: LlmMock) {
89 LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
90}
91
92pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
93 LLM_MOCK_CALLS.with(|v| v.borrow().clone())
94}
95
96pub(crate) fn reset_llm_mock_state() {
97 LLM_MOCKS.with(|v| v.borrow_mut().clear());
98 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
99 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
100 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
101 LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
102 LLM_MOCK_SCOPES.with(|v| v.borrow_mut().clear());
103}
104
105pub(crate) fn push_llm_mock_scope() {
110 let mocks = LLM_MOCKS.with(|v| std::mem::take(&mut *v.borrow_mut()));
111 let calls = LLM_MOCK_CALLS.with(|v| std::mem::take(&mut *v.borrow_mut()));
112 LLM_MOCK_SCOPES.with(|v| v.borrow_mut().push((mocks, calls)));
113}
114
115pub(crate) fn pop_llm_mock_scope() -> bool {
121 let entry = LLM_MOCK_SCOPES.with(|v| v.borrow_mut().pop());
122 match entry {
123 Some((mocks, calls)) => {
124 LLM_MOCKS.with(|v| *v.borrow_mut() = mocks);
125 LLM_MOCK_CALLS.with(|v| *v.borrow_mut() = calls);
126 true
127 }
128 None => false,
129 }
130}
131
132pub fn clear_cli_llm_mock_mode() {
133 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
134 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
135 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
136}
137
138pub fn install_cli_llm_mocks(mocks: Vec<LlmMock>) {
139 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Replay);
140 CLI_LLM_MOCKS.with(|v| *v.borrow_mut() = mocks);
141 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
142}
143
144pub fn enable_cli_llm_mock_recording() {
145 CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Record);
146 CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
147 CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
148}
149
150pub fn take_cli_llm_recordings() -> Vec<LlmMock> {
151 CLI_LLM_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
152}
153
154pub(crate) fn cli_llm_mock_replay_active() -> bool {
155 CLI_LLM_MOCK_MODE.with(|v| *v.borrow() == CliLlmMockMode::Replay)
156}
157
158fn record_llm_mock_call(
159 messages: &[serde_json::Value],
160 system: Option<&str>,
161 native_tools: Option<&[serde_json::Value]>,
162) {
163 LLM_MOCK_CALLS.with(|v| {
164 v.borrow_mut().push(LlmMockCall {
165 messages: messages.to_vec(),
166 system: system.map(|s| s.to_string()),
167 tools: native_tools.map(|t| t.to_vec()),
168 });
169 });
170}
171
172fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
174 let (tool_calls, blocks) = if let Some(blocks) = &mock.blocks {
175 (mock.tool_calls.clone(), blocks.clone())
176 } else {
177 let mut blocks = Vec::new();
178
179 if !mock.text.is_empty() {
180 blocks.push(serde_json::json!({
181 "type": "output_text",
182 "text": mock.text,
183 "visibility": "public",
184 }));
185 }
186
187 let mut tool_calls = Vec::new();
188 for (i, tc) in mock.tool_calls.iter().enumerate() {
189 let id = format!("mock_call_{}", i + 1);
190 let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
191 let arguments = tc
192 .get("arguments")
193 .cloned()
194 .unwrap_or(serde_json::json!({}));
195 tool_calls.push(serde_json::json!({
196 "id": id,
197 "type": "tool_call",
198 "name": name,
199 "arguments": arguments,
200 }));
201 blocks.push(serde_json::json!({
202 "type": "tool_call",
203 "id": id,
204 "name": name,
205 "arguments": arguments,
206 "visibility": "internal",
207 }));
208 }
209
210 (tool_calls, blocks)
211 };
212
213 LlmResult {
214 text: mock.text.clone(),
215 tool_calls,
216 input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
217 output_tokens: mock.output_tokens.unwrap_or(30),
218 cache_read_tokens: mock.cache_read_tokens.unwrap_or(0),
219 cache_write_tokens: mock.cache_write_tokens.unwrap_or(0),
220 model: mock.model.clone(),
221 provider: mock.provider.clone().unwrap_or_else(|| "mock".to_string()),
222 thinking: mock.thinking.clone(),
223 stop_reason: mock.stop_reason.clone(),
224 blocks,
225 }
226}
227
228fn mock_glob_match(pattern: &str, text: &str) -> bool {
231 if pattern == "*" {
232 return true;
233 }
234 if !pattern.contains('*') {
235 return pattern == text;
236 }
237 let parts: Vec<&str> = pattern.split('*').collect();
238 let mut remaining = text;
239 for (i, part) in parts.iter().enumerate() {
240 if part.is_empty() {
241 continue;
242 }
243 if i == 0 {
244 if !remaining.starts_with(part) {
245 return false;
246 }
247 remaining = &remaining[part.len()..];
248 } else if i == parts.len() - 1 {
249 if !remaining.ends_with(part) {
250 return false;
251 }
252 remaining = "";
253 } else {
254 match remaining.find(part) {
255 Some(pos) => remaining = &remaining[pos + part.len()..],
256 None => return false,
257 }
258 }
259 }
260 true
261}
262
263fn collect_mock_match_strings(value: &serde_json::Value, out: &mut Vec<String>) {
264 match value {
265 serde_json::Value::String(text) if !text.is_empty() => out.push(text.clone()),
266 serde_json::Value::String(_) => {}
267 serde_json::Value::Array(items) => {
268 for item in items {
269 collect_mock_match_strings(item, out);
270 }
271 }
272 serde_json::Value::Object(map) => {
273 for value in map.values() {
274 collect_mock_match_strings(value, out);
275 }
276 }
277 _ => {}
278 }
279}
280
281fn mock_match_text(messages: &[serde_json::Value]) -> String {
282 let mut parts = Vec::new();
283 for message in messages {
284 collect_mock_match_strings(message, &mut parts);
285 }
286 parts.join("\n")
287}
288
289fn mock_last_prompt_text(messages: &[serde_json::Value]) -> String {
290 for message in messages.iter().rev() {
291 let Some(content) = message.get("content") else {
292 continue;
293 };
294 let mut parts = Vec::new();
295 collect_mock_match_strings(content, &mut parts);
296 let text = parts.join("\n");
297 if !text.trim().is_empty() {
298 return text;
299 }
300 }
301 String::new()
302}
303
304fn mock_error_to_vm_error(err: &MockError) -> VmError {
308 let message = match err.retry_after_ms {
314 Some(ms) => {
315 let secs = (ms as f64 / 1000.0).max(0.0);
316 let sep = if err.message.is_empty() || err.message.ends_with('\n') {
317 ""
318 } else {
319 "\n"
320 };
321 format!("{}{sep}retry-after: {secs}\n", err.message)
322 }
323 None => err.message.clone(),
324 };
325 VmError::CategorizedError {
326 message,
327 category: err.category.clone(),
328 }
329}
330
331fn try_match_mock_queue(
335 mocks: &mut Vec<LlmMock>,
336 match_text: &str,
337) -> Option<Result<LlmResult, VmError>> {
338 if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
339 let mock = mocks.remove(idx);
340 return Some(match &mock.error {
341 Some(err) => Err(mock_error_to_vm_error(err)),
342 None => Ok(build_mock_result(&mock, match_text.len())),
343 });
344 }
345
346 for idx in 0..mocks.len() {
347 let mock = &mocks[idx];
348 if let Some(ref pattern) = mock.match_pattern {
349 if mock_glob_match(pattern, match_text) {
350 if mock.consume_on_match {
351 let mock = mocks.remove(idx);
352 return Some(match &mock.error {
353 Some(err) => Err(mock_error_to_vm_error(err)),
354 None => Ok(build_mock_result(&mock, match_text.len())),
355 });
356 }
357 return Some(match &mock.error {
358 Some(err) => Err(mock_error_to_vm_error(err)),
359 None => Ok(build_mock_result(mock, match_text.len())),
360 });
361 }
362 }
363 }
364
365 None
366}
367
368fn try_match_builtin_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
369 LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
370}
371
372fn try_match_cli_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
373 CLI_LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
374}
375
376pub(crate) fn record_cli_llm_result(result: &LlmResult) {
377 if !CLI_LLM_MOCK_MODE.with(|mode| *mode.borrow() == CliLlmMockMode::Record) {
378 return;
379 }
380 CLI_LLM_RECORDINGS.with(|recordings| {
381 recordings.borrow_mut().push(LlmMock {
382 text: result.text.clone(),
383 tool_calls: result.tool_calls.clone(),
384 match_pattern: None,
385 consume_on_match: false,
386 input_tokens: Some(result.input_tokens),
387 output_tokens: Some(result.output_tokens),
388 cache_read_tokens: Some(result.cache_read_tokens),
389 cache_write_tokens: Some(result.cache_write_tokens),
390 thinking: result.thinking.clone(),
391 stop_reason: result.stop_reason.clone(),
392 model: result.model.clone(),
393 provider: Some(result.provider.clone()),
394 blocks: Some(result.blocks.clone()),
395 error: None,
396 });
397 });
398}
399
400fn unmatched_cli_prompt_error(match_text: &str) -> VmError {
401 let mut snippet: String = match_text.chars().take(200).collect();
402 if match_text.chars().count() > 200 {
403 snippet.push_str("...");
404 }
405 VmError::Runtime(format!("No --llm-mock fixture matched prompt: {snippet:?}"))
406}
407
408pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
410 LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
411 LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
412}
413
414pub(crate) fn get_replay_mode() -> LlmReplayMode {
415 LLM_REPLAY_MODE.with(|v| *v.borrow())
416}
417
418pub(crate) fn get_fixture_dir() -> String {
419 LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
420}
421
422pub(crate) fn fixture_hash(
424 model: &str,
425 messages: &[serde_json::Value],
426 system: Option<&str>,
427) -> String {
428 use std::hash::{Hash, Hasher};
429 let mut hasher = std::collections::hash_map::DefaultHasher::new();
430 model.hash(&mut hasher);
431 serde_json::to_string(messages)
433 .unwrap_or_default()
434 .hash(&mut hasher);
435 system.hash(&mut hasher);
436 format!("{:016x}", hasher.finish())
437}
438
439pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
440 let dir = get_fixture_dir();
441 if dir.is_empty() {
442 return;
443 }
444 let _ = std::fs::create_dir_all(&dir);
445 let path = format!("{dir}/{hash}.json");
446 let json = serde_json::json!({
447 "text": result.text,
448 "tool_calls": result.tool_calls,
449 "input_tokens": result.input_tokens,
450 "output_tokens": result.output_tokens,
451 "model": result.model,
452 "provider": result.provider,
453 "blocks": result.blocks,
454 });
455 let _ = std::fs::write(
456 &path,
457 serde_json::to_string_pretty(&json).unwrap_or_default(),
458 );
459}
460
461pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
462 let dir = get_fixture_dir();
463 if dir.is_empty() {
464 return None;
465 }
466 let path = format!("{dir}/{hash}.json");
467 let content = std::fs::read_to_string(&path).ok()?;
468 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
469 Some(LlmResult {
470 text: json["text"].as_str().unwrap_or("").to_string(),
471 tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
472 input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
473 output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
474 cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
475 cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
476 model: json["model"].as_str().unwrap_or("").to_string(),
477 provider: json["provider"].as_str().unwrap_or("mock").to_string(),
478 thinking: json["thinking"].as_str().map(|s| s.to_string()),
479 stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
480 blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
481 })
482}
483
484fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
488 let mut args = serde_json::Map::new();
489 let input_schema = tool_schema
493 .get("input_schema")
494 .or_else(|| tool_schema.get("inputSchema"))
495 .or_else(|| {
496 tool_schema
497 .get("function")
498 .and_then(|f| f.get("parameters"))
499 })
500 .or_else(|| tool_schema.get("parameters"));
501 let Some(schema) = input_schema else {
502 return serde_json::Value::Object(args);
503 };
504 let required: std::collections::BTreeSet<String> = schema
505 .get("required")
506 .and_then(|r| r.as_array())
507 .map(|arr| {
508 arr.iter()
509 .filter_map(|v| v.as_str().map(|s| s.to_string()))
510 .collect()
511 })
512 .unwrap_or_default();
513 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
514 for (name, prop) in props {
515 if !required.contains(name) {
516 continue;
517 }
518 let ty = prop
519 .get("type")
520 .and_then(|t| t.as_str())
521 .unwrap_or("string");
522 let placeholder = match ty {
523 "integer" => serde_json::json!(0),
524 "number" => serde_json::json!(0.0),
525 "boolean" => serde_json::json!(false),
526 "array" => serde_json::json!([]),
527 "object" => serde_json::json!({}),
528 _ => serde_json::json!(""),
529 };
530 args.insert(name.clone(), placeholder);
531 }
532 }
533 serde_json::Value::Object(args)
534}
535
536pub(crate) fn mock_llm_response(
541 messages: &[serde_json::Value],
542 system: Option<&str>,
543 native_tools: Option<&[serde_json::Value]>,
544) -> Result<LlmResult, VmError> {
545 record_llm_mock_call(messages, system, native_tools);
546
547 let match_text = mock_match_text(messages);
548 let prompt_text = mock_last_prompt_text(messages);
549
550 if let Some(matched) = try_match_cli_mock(&match_text) {
551 return matched;
552 }
553
554 if let Some(matched) = try_match_builtin_mock(&match_text) {
555 return matched;
556 }
557
558 if cli_llm_mock_replay_active() {
559 return Err(unmatched_cli_prompt_error(&match_text));
560 }
561
562 if let Some(tools) = native_tools {
565 if let Some(first_tool) = tools.first() {
566 let tool_name = first_tool
567 .get("name")
568 .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
569 .and_then(|n| n.as_str())
570 .unwrap_or("unknown");
571 let mock_args = mock_required_args(first_tool);
572 return Ok(LlmResult {
573 text: String::new(),
574 tool_calls: vec![serde_json::json!({
575 "id": "mock_call_1",
576 "type": "tool_call",
577 "name": tool_name,
578 "arguments": mock_args
579 })],
580 input_tokens: prompt_text.len() as i64,
581 output_tokens: 20,
582 cache_read_tokens: 0,
583 cache_write_tokens: 0,
584 model: "mock".to_string(),
585 provider: "mock".to_string(),
586 thinking: None,
587 stop_reason: None,
588 blocks: vec![serde_json::json!({
589 "type": "tool_call",
590 "id": "mock_call_1",
591 "name": tool_name,
592 "arguments": mock_args,
593 "visibility": "internal",
594 })],
595 });
596 }
597 }
598
599 let tagged_done = system.is_some_and(|s| s.contains("<done>"));
604
605 let prose_body = if prompt_text.is_empty() {
606 "Mock LLM response".to_string()
607 } else {
608 let word_count = prompt_text.split_whitespace().count();
609 format!(
610 "Mock response to {word_count}-word prompt: {}",
611 prompt_text.chars().take(100).collect::<String>()
612 )
613 };
614 let response = if tagged_done {
615 format!("<assistant_prose>{prose_body}</assistant_prose>\n<done>##DONE##</done>")
616 } else {
617 prose_body
618 };
619
620 Ok(LlmResult {
621 text: response.clone(),
622 tool_calls: vec![],
623 input_tokens: prompt_text.len() as i64,
624 output_tokens: 30,
625 cache_read_tokens: 0,
626 cache_write_tokens: 0,
627 model: "mock".to_string(),
628 provider: "mock".to_string(),
629 thinking: None,
630 stop_reason: None,
631 blocks: vec![serde_json::json!({
632 "type": "output_text",
633 "text": response,
634 "visibility": "public",
635 })],
636 })
637}
638
639pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
640 TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
641}
642
643pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
644 TOOL_RECORDING_MODE.with(|v| *v.borrow())
645}
646
647pub(crate) fn record_tool_call(record: ToolCallRecord) {
649 TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
650}
651
652pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
654 TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
655}
656
657pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
659 TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
660}
661
662pub(crate) fn find_tool_replay_fixture(
664 tool_name: &str,
665 args: &serde_json::Value,
666) -> Option<ToolCallRecord> {
667 let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
668 TOOL_REPLAY_FIXTURES.with(|v| {
669 v.borrow()
670 .iter()
671 .find(|r| r.tool_name == tool_name && r.args_hash == hash)
672 .cloned()
673 })
674}