1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum LlmReplayMode {
9 Off,
10 Record,
11 Replay,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ToolRecordingMode {
17 Off,
18 Record,
19 Replay,
20}
21
22pub(crate) struct LlmMock {
25 pub text: String,
26 pub tool_calls: Vec<serde_json::Value>,
27 pub match_pattern: Option<String>, pub input_tokens: Option<i64>,
29 pub output_tokens: Option<i64>,
30 pub thinking: Option<String>,
31 pub stop_reason: Option<String>,
32 pub model: String,
33}
34
35#[derive(Clone)]
36pub(crate) struct LlmMockCall {
37 pub messages: Vec<serde_json::Value>,
38 pub system: Option<String>,
39 pub tools: Option<Vec<serde_json::Value>>,
40}
41
42thread_local! {
43 static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
44 static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
45 static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
46 static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
47 static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
48 static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
49 static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
50}
51
52pub(crate) fn push_llm_mock(mock: LlmMock) {
53 LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
54}
55
56pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
57 LLM_MOCK_CALLS.with(|v| v.borrow().clone())
58}
59
60pub(crate) fn reset_llm_mock_state() {
61 LLM_MOCKS.with(|v| v.borrow_mut().clear());
62 LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
63}
64
65fn record_llm_mock_call(
66 messages: &[serde_json::Value],
67 system: Option<&str>,
68 native_tools: Option<&[serde_json::Value]>,
69) {
70 LLM_MOCK_CALLS.with(|v| {
71 v.borrow_mut().push(LlmMockCall {
72 messages: messages.to_vec(),
73 system: system.map(|s| s.to_string()),
74 tools: native_tools.map(|t| t.to_vec()),
75 });
76 });
77}
78
79fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
81 let mut blocks = Vec::new();
82
83 if !mock.text.is_empty() {
85 blocks.push(serde_json::json!({
86 "type": "output_text",
87 "text": mock.text,
88 "visibility": "public",
89 }));
90 }
91
92 let mut tool_calls = Vec::new();
94 for (i, tc) in mock.tool_calls.iter().enumerate() {
95 let id = format!("mock_call_{}", i + 1);
96 let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
97 let arguments = tc
98 .get("arguments")
99 .cloned()
100 .unwrap_or(serde_json::json!({}));
101 tool_calls.push(serde_json::json!({
102 "id": id,
103 "type": "tool_call",
104 "name": name,
105 "arguments": arguments,
106 }));
107 blocks.push(serde_json::json!({
108 "type": "tool_call",
109 "id": id,
110 "name": name,
111 "arguments": arguments,
112 "visibility": "internal",
113 }));
114 }
115
116 LlmResult {
117 text: mock.text.clone(),
118 tool_calls,
119 input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
120 output_tokens: mock.output_tokens.unwrap_or(30),
121 cache_read_tokens: 0,
122 cache_write_tokens: 0,
123 model: mock.model.clone(),
124 provider: "mock".to_string(),
125 thinking: mock.thinking.clone(),
126 stop_reason: mock.stop_reason.clone(),
127 blocks,
128 }
129}
130
131fn mock_glob_match(pattern: &str, text: &str) -> bool {
134 if pattern == "*" {
135 return true;
136 }
137 if !pattern.contains('*') {
138 return pattern == text;
139 }
140 let parts: Vec<&str> = pattern.split('*').collect();
141 let mut remaining = text;
142 for (i, part) in parts.iter().enumerate() {
143 if part.is_empty() {
144 continue;
145 }
146 if i == 0 {
147 if !remaining.starts_with(part) {
148 return false;
149 }
150 remaining = &remaining[part.len()..];
151 } else if i == parts.len() - 1 {
152 if !remaining.ends_with(part) {
153 return false;
154 }
155 remaining = "";
156 } else {
157 match remaining.find(part) {
158 Some(pos) => remaining = &remaining[pos + part.len()..],
159 None => return false,
160 }
161 }
162 }
163 true
164}
165
166fn try_match_mock(last_msg: &str) -> Option<LlmResult> {
169 LLM_MOCKS.with(|mocks| {
170 let mut mocks = mocks.borrow_mut();
171
172 if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
174 let mock = mocks.remove(idx);
175 return Some(build_mock_result(&mock, last_msg.len()));
176 }
177
178 for mock in mocks.iter().rev() {
180 if let Some(ref pattern) = mock.match_pattern {
181 if mock_glob_match(pattern, last_msg) {
182 return Some(build_mock_result(mock, last_msg.len()));
183 }
184 }
185 }
186
187 None
188 })
189}
190
191pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
193 LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
194 LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
195}
196
197pub(crate) fn get_replay_mode() -> LlmReplayMode {
198 LLM_REPLAY_MODE.with(|v| *v.borrow())
199}
200
201pub(crate) fn get_fixture_dir() -> String {
202 LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
203}
204
205pub(crate) fn fixture_hash(
207 model: &str,
208 messages: &[serde_json::Value],
209 system: Option<&str>,
210) -> String {
211 use std::hash::{Hash, Hasher};
212 let mut hasher = std::collections::hash_map::DefaultHasher::new();
213 model.hash(&mut hasher);
214 serde_json::to_string(messages)
216 .unwrap_or_default()
217 .hash(&mut hasher);
218 system.hash(&mut hasher);
219 format!("{:016x}", hasher.finish())
220}
221
222pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
223 let dir = get_fixture_dir();
224 if dir.is_empty() {
225 return;
226 }
227 let _ = std::fs::create_dir_all(&dir);
228 let path = format!("{dir}/{hash}.json");
229 let json = serde_json::json!({
230 "text": result.text,
231 "tool_calls": result.tool_calls,
232 "input_tokens": result.input_tokens,
233 "output_tokens": result.output_tokens,
234 "model": result.model,
235 "provider": result.provider,
236 "blocks": result.blocks,
237 });
238 let _ = std::fs::write(
239 &path,
240 serde_json::to_string_pretty(&json).unwrap_or_default(),
241 );
242}
243
244pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
245 let dir = get_fixture_dir();
246 if dir.is_empty() {
247 return None;
248 }
249 let path = format!("{dir}/{hash}.json");
250 let content = std::fs::read_to_string(&path).ok()?;
251 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
252 Some(LlmResult {
253 text: json["text"].as_str().unwrap_or("").to_string(),
254 tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
255 input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
256 output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
257 cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
258 cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
259 model: json["model"].as_str().unwrap_or("").to_string(),
260 provider: json["provider"].as_str().unwrap_or("mock").to_string(),
261 thinking: json["thinking"].as_str().map(|s| s.to_string()),
262 stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
263 blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
264 })
265}
266
267fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
271 let mut args = serde_json::Map::new();
272 let input_schema = tool_schema
276 .get("input_schema")
277 .or_else(|| tool_schema.get("inputSchema"))
278 .or_else(|| {
279 tool_schema
280 .get("function")
281 .and_then(|f| f.get("parameters"))
282 })
283 .or_else(|| tool_schema.get("parameters"));
284 let Some(schema) = input_schema else {
285 return serde_json::Value::Object(args);
286 };
287 let required: std::collections::BTreeSet<String> = schema
288 .get("required")
289 .and_then(|r| r.as_array())
290 .map(|arr| {
291 arr.iter()
292 .filter_map(|v| v.as_str().map(|s| s.to_string()))
293 .collect()
294 })
295 .unwrap_or_default();
296 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
297 for (name, prop) in props {
298 if !required.contains(name) {
299 continue;
300 }
301 let ty = prop
302 .get("type")
303 .and_then(|t| t.as_str())
304 .unwrap_or("string");
305 let placeholder = match ty {
306 "integer" => serde_json::json!(0),
307 "number" => serde_json::json!(0.0),
308 "boolean" => serde_json::json!(false),
309 "array" => serde_json::json!([]),
310 "object" => serde_json::json!({}),
311 _ => serde_json::json!(""),
312 };
313 args.insert(name.clone(), placeholder);
314 }
315 }
316 serde_json::Value::Object(args)
317}
318
319pub(crate) fn mock_llm_response(
324 messages: &[serde_json::Value],
325 system: Option<&str>,
326 native_tools: Option<&[serde_json::Value]>,
327) -> LlmResult {
328 record_llm_mock_call(messages, system, native_tools);
330
331 let last_msg = messages
333 .last()
334 .and_then(|m| m.get("content"))
335 .and_then(|c| c.as_str())
336 .unwrap_or("");
337
338 if let Some(result) = try_match_mock(last_msg) {
340 return result;
341 }
342
343 if let Some(tools) = native_tools {
347 if let Some(first_tool) = tools.first() {
348 let tool_name = first_tool
349 .get("name")
350 .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
351 .and_then(|n| n.as_str())
352 .unwrap_or("unknown");
353 let mock_args = mock_required_args(first_tool);
354 return LlmResult {
355 text: String::new(),
356 tool_calls: vec![serde_json::json!({
357 "id": "mock_call_1",
358 "type": "tool_call",
359 "name": tool_name,
360 "arguments": mock_args
361 })],
362 input_tokens: last_msg.len() as i64,
363 output_tokens: 20,
364 cache_read_tokens: 0,
365 cache_write_tokens: 0,
366 model: "mock".to_string(),
367 provider: "mock".to_string(),
368 thinking: None,
369 stop_reason: None,
370 blocks: vec![serde_json::json!({
371 "type": "tool_call",
372 "id": "mock_call_1",
373 "name": tool_name,
374 "arguments": mock_args,
375 "visibility": "internal",
376 })],
377 };
378 }
379 }
380
381 let done_sentinel = if system.is_some_and(|s| s.contains("##DONE##")) {
384 " ##DONE##"
385 } else {
386 ""
387 };
388
389 let response = if last_msg.is_empty() {
390 format!("Mock LLM response{done_sentinel}")
391 } else {
392 let word_count = last_msg.split_whitespace().count();
393 format!(
394 "Mock response to {word_count}-word prompt: {}{done_sentinel}",
395 last_msg.chars().take(100).collect::<String>()
396 )
397 };
398
399 LlmResult {
400 text: response.clone(),
401 tool_calls: vec![],
402 input_tokens: last_msg.len() as i64,
403 output_tokens: 30,
404 cache_read_tokens: 0,
405 cache_write_tokens: 0,
406 model: "mock".to_string(),
407 provider: "mock".to_string(),
408 thinking: None,
409 stop_reason: None,
410 blocks: vec![serde_json::json!({
411 "type": "output_text",
412 "text": response,
413 "visibility": "public",
414 })],
415 }
416}
417
418pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
421 TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
422}
423
424pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
425 TOOL_RECORDING_MODE.with(|v| *v.borrow())
426}
427
428pub(crate) fn record_tool_call(record: ToolCallRecord) {
430 TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
431}
432
433pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
435 TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
436}
437
438pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
440 TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
441}
442
443pub(crate) fn find_tool_replay_fixture(
445 tool_name: &str,
446 args: &serde_json::Value,
447) -> Option<ToolCallRecord> {
448 let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
449 TOOL_REPLAY_FIXTURES.with(|v| {
450 v.borrow()
451 .iter()
452 .find(|r| r.tool_name == tool_name && r.args_hash == hash)
453 .cloned()
454 })
455}