1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum LlmReplayMode {
9 Off,
10 Record,
11 Replay,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ToolRecordingMode {
17 Off,
18 Record,
19 Replay,
20}
21
22pub(crate) struct LlmMock {
23 pub text: String,
24 pub tool_calls: Vec<serde_json::Value>,
25 pub match_pattern: Option<String>, pub input_tokens: Option<i64>,
27 pub output_tokens: Option<i64>,
28 pub thinking: Option<String>,
29 pub stop_reason: Option<String>,
30 pub model: String,
31}
32
33#[derive(Clone)]
34pub(crate) struct LlmMockCall {
35 pub messages: Vec<serde_json::Value>,
36 pub system: Option<String>,
37 pub tools: Option<Vec<serde_json::Value>>,
38}
39
40thread_local! {
41 static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
42 static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
43 static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
44 static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
45 static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
46 static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
47 static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
48}
49
50pub(crate) fn push_llm_mock(mock: LlmMock) {
51 LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
52}
53
54pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
55 LLM_MOCK_CALLS.with(|v| v.borrow().clone())
56}
57
58pub(crate) fn reset_llm_mock_state() {
59 LLM_MOCKS.with(|v| v.borrow_mut().clear());
60 LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
61}
62
63fn record_llm_mock_call(
64 messages: &[serde_json::Value],
65 system: Option<&str>,
66 native_tools: Option<&[serde_json::Value]>,
67) {
68 LLM_MOCK_CALLS.with(|v| {
69 v.borrow_mut().push(LlmMockCall {
70 messages: messages.to_vec(),
71 system: system.map(|s| s.to_string()),
72 tools: native_tools.map(|t| t.to_vec()),
73 });
74 });
75}
76
77fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
79 let mut blocks = Vec::new();
80
81 if !mock.text.is_empty() {
82 blocks.push(serde_json::json!({
83 "type": "output_text",
84 "text": mock.text,
85 "visibility": "public",
86 }));
87 }
88
89 let mut tool_calls = Vec::new();
90 for (i, tc) in mock.tool_calls.iter().enumerate() {
91 let id = format!("mock_call_{}", i + 1);
92 let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
93 let arguments = tc
94 .get("arguments")
95 .cloned()
96 .unwrap_or(serde_json::json!({}));
97 tool_calls.push(serde_json::json!({
98 "id": id,
99 "type": "tool_call",
100 "name": name,
101 "arguments": arguments,
102 }));
103 blocks.push(serde_json::json!({
104 "type": "tool_call",
105 "id": id,
106 "name": name,
107 "arguments": arguments,
108 "visibility": "internal",
109 }));
110 }
111
112 LlmResult {
113 text: mock.text.clone(),
114 tool_calls,
115 input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
116 output_tokens: mock.output_tokens.unwrap_or(30),
117 cache_read_tokens: 0,
118 cache_write_tokens: 0,
119 model: mock.model.clone(),
120 provider: "mock".to_string(),
121 thinking: mock.thinking.clone(),
122 stop_reason: mock.stop_reason.clone(),
123 blocks,
124 }
125}
126
127fn mock_glob_match(pattern: &str, text: &str) -> bool {
130 if pattern == "*" {
131 return true;
132 }
133 if !pattern.contains('*') {
134 return pattern == text;
135 }
136 let parts: Vec<&str> = pattern.split('*').collect();
137 let mut remaining = text;
138 for (i, part) in parts.iter().enumerate() {
139 if part.is_empty() {
140 continue;
141 }
142 if i == 0 {
143 if !remaining.starts_with(part) {
144 return false;
145 }
146 remaining = &remaining[part.len()..];
147 } else if i == parts.len() - 1 {
148 if !remaining.ends_with(part) {
149 return false;
150 }
151 remaining = "";
152 } else {
153 match remaining.find(part) {
154 Some(pos) => remaining = &remaining[pos + part.len()..],
155 None => return false,
156 }
157 }
158 }
159 true
160}
161
162fn try_match_mock(last_msg: &str) -> Option<LlmResult> {
165 LLM_MOCKS.with(|mocks| {
166 let mut mocks = mocks.borrow_mut();
167
168 if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
170 let mock = mocks.remove(idx);
171 return Some(build_mock_result(&mock, last_msg.len()));
172 }
173
174 for mock in mocks.iter().rev() {
176 if let Some(ref pattern) = mock.match_pattern {
177 if mock_glob_match(pattern, last_msg) {
178 return Some(build_mock_result(mock, last_msg.len()));
179 }
180 }
181 }
182
183 None
184 })
185}
186
187pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
189 LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
190 LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
191}
192
193pub(crate) fn get_replay_mode() -> LlmReplayMode {
194 LLM_REPLAY_MODE.with(|v| *v.borrow())
195}
196
197pub(crate) fn get_fixture_dir() -> String {
198 LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
199}
200
201pub(crate) fn fixture_hash(
203 model: &str,
204 messages: &[serde_json::Value],
205 system: Option<&str>,
206) -> String {
207 use std::hash::{Hash, Hasher};
208 let mut hasher = std::collections::hash_map::DefaultHasher::new();
209 model.hash(&mut hasher);
210 serde_json::to_string(messages)
212 .unwrap_or_default()
213 .hash(&mut hasher);
214 system.hash(&mut hasher);
215 format!("{:016x}", hasher.finish())
216}
217
218pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
219 let dir = get_fixture_dir();
220 if dir.is_empty() {
221 return;
222 }
223 let _ = std::fs::create_dir_all(&dir);
224 let path = format!("{dir}/{hash}.json");
225 let json = serde_json::json!({
226 "text": result.text,
227 "tool_calls": result.tool_calls,
228 "input_tokens": result.input_tokens,
229 "output_tokens": result.output_tokens,
230 "model": result.model,
231 "provider": result.provider,
232 "blocks": result.blocks,
233 });
234 let _ = std::fs::write(
235 &path,
236 serde_json::to_string_pretty(&json).unwrap_or_default(),
237 );
238}
239
240pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
241 let dir = get_fixture_dir();
242 if dir.is_empty() {
243 return None;
244 }
245 let path = format!("{dir}/{hash}.json");
246 let content = std::fs::read_to_string(&path).ok()?;
247 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
248 Some(LlmResult {
249 text: json["text"].as_str().unwrap_or("").to_string(),
250 tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
251 input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
252 output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
253 cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
254 cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
255 model: json["model"].as_str().unwrap_or("").to_string(),
256 provider: json["provider"].as_str().unwrap_or("mock").to_string(),
257 thinking: json["thinking"].as_str().map(|s| s.to_string()),
258 stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
259 blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
260 })
261}
262
263fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
267 let mut args = serde_json::Map::new();
268 let input_schema = tool_schema
272 .get("input_schema")
273 .or_else(|| tool_schema.get("inputSchema"))
274 .or_else(|| {
275 tool_schema
276 .get("function")
277 .and_then(|f| f.get("parameters"))
278 })
279 .or_else(|| tool_schema.get("parameters"));
280 let Some(schema) = input_schema else {
281 return serde_json::Value::Object(args);
282 };
283 let required: std::collections::BTreeSet<String> = schema
284 .get("required")
285 .and_then(|r| r.as_array())
286 .map(|arr| {
287 arr.iter()
288 .filter_map(|v| v.as_str().map(|s| s.to_string()))
289 .collect()
290 })
291 .unwrap_or_default();
292 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
293 for (name, prop) in props {
294 if !required.contains(name) {
295 continue;
296 }
297 let ty = prop
298 .get("type")
299 .and_then(|t| t.as_str())
300 .unwrap_or("string");
301 let placeholder = match ty {
302 "integer" => serde_json::json!(0),
303 "number" => serde_json::json!(0.0),
304 "boolean" => serde_json::json!(false),
305 "array" => serde_json::json!([]),
306 "object" => serde_json::json!({}),
307 _ => serde_json::json!(""),
308 };
309 args.insert(name.clone(), placeholder);
310 }
311 }
312 serde_json::Value::Object(args)
313}
314
315pub(crate) fn mock_llm_response(
320 messages: &[serde_json::Value],
321 system: Option<&str>,
322 native_tools: Option<&[serde_json::Value]>,
323) -> LlmResult {
324 record_llm_mock_call(messages, system, native_tools);
325
326 let last_msg = messages
327 .last()
328 .and_then(|m| m.get("content"))
329 .and_then(|c| c.as_str())
330 .unwrap_or("");
331
332 if let Some(result) = try_match_mock(last_msg) {
333 return result;
334 }
335
336 if let Some(tools) = native_tools {
339 if let Some(first_tool) = tools.first() {
340 let tool_name = first_tool
341 .get("name")
342 .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
343 .and_then(|n| n.as_str())
344 .unwrap_or("unknown");
345 let mock_args = mock_required_args(first_tool);
346 return LlmResult {
347 text: String::new(),
348 tool_calls: vec![serde_json::json!({
349 "id": "mock_call_1",
350 "type": "tool_call",
351 "name": tool_name,
352 "arguments": mock_args
353 })],
354 input_tokens: last_msg.len() as i64,
355 output_tokens: 20,
356 cache_read_tokens: 0,
357 cache_write_tokens: 0,
358 model: "mock".to_string(),
359 provider: "mock".to_string(),
360 thinking: None,
361 stop_reason: None,
362 blocks: vec![serde_json::json!({
363 "type": "tool_call",
364 "id": "mock_call_1",
365 "name": tool_name,
366 "arguments": mock_args,
367 "visibility": "internal",
368 })],
369 };
370 }
371 }
372
373 let done_block = if system.is_some_and(|s| s.contains("##DONE##")) {
376 "\n<done>##DONE##</done>"
377 } else {
378 ""
379 };
380
381 let prose_body = if last_msg.is_empty() {
382 "Mock LLM response".to_string()
383 } else {
384 let word_count = last_msg.split_whitespace().count();
385 format!(
386 "Mock response to {word_count}-word prompt: {}",
387 last_msg.chars().take(100).collect::<String>()
388 )
389 };
390 let response = format!("<assistant_prose>{prose_body}</assistant_prose>{done_block}");
391
392 LlmResult {
393 text: response.clone(),
394 tool_calls: vec![],
395 input_tokens: last_msg.len() as i64,
396 output_tokens: 30,
397 cache_read_tokens: 0,
398 cache_write_tokens: 0,
399 model: "mock".to_string(),
400 provider: "mock".to_string(),
401 thinking: None,
402 stop_reason: None,
403 blocks: vec![serde_json::json!({
404 "type": "output_text",
405 "text": response,
406 "visibility": "public",
407 })],
408 }
409}
410
411pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
412 TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
413}
414
415pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
416 TOOL_RECORDING_MODE.with(|v| *v.borrow())
417}
418
419pub(crate) fn record_tool_call(record: ToolCallRecord) {
421 TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
422}
423
424pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
426 TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
427}
428
429pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
431 TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
432}
433
434pub(crate) fn find_tool_replay_fixture(
436 tool_name: &str,
437 args: &serde_json::Value,
438) -> Option<ToolCallRecord> {
439 let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
440 TOOL_REPLAY_FIXTURES.with(|v| {
441 v.borrow()
442 .iter()
443 .find(|r| r.tool_name == tool_name && r.args_hash == hash)
444 .cloned()
445 })
446}