1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum LlmReplayMode {
9 Off,
10 Record,
11 Replay,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ToolRecordingMode {
17 Off,
18 Record,
19 Replay,
20}
21
22thread_local! {
23 static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
24 static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
25 static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
26 static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
27 static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
28}
29
30pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
32 LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
33 LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
34}
35
36pub(crate) fn get_replay_mode() -> LlmReplayMode {
37 LLM_REPLAY_MODE.with(|v| *v.borrow())
38}
39
40pub(crate) fn get_fixture_dir() -> String {
41 LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
42}
43
44pub(crate) fn fixture_hash(
46 model: &str,
47 messages: &[serde_json::Value],
48 system: Option<&str>,
49) -> String {
50 use std::hash::{Hash, Hasher};
51 let mut hasher = std::collections::hash_map::DefaultHasher::new();
52 model.hash(&mut hasher);
53 serde_json::to_string(messages)
55 .unwrap_or_default()
56 .hash(&mut hasher);
57 system.hash(&mut hasher);
58 format!("{:016x}", hasher.finish())
59}
60
61pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
62 let dir = get_fixture_dir();
63 if dir.is_empty() {
64 return;
65 }
66 let _ = std::fs::create_dir_all(&dir);
67 let path = format!("{dir}/{hash}.json");
68 let json = serde_json::json!({
69 "text": result.text,
70 "tool_calls": result.tool_calls,
71 "input_tokens": result.input_tokens,
72 "output_tokens": result.output_tokens,
73 "model": result.model,
74 "provider": result.provider,
75 "blocks": result.blocks,
76 });
77 let _ = std::fs::write(
78 &path,
79 serde_json::to_string_pretty(&json).unwrap_or_default(),
80 );
81}
82
83pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
84 let dir = get_fixture_dir();
85 if dir.is_empty() {
86 return None;
87 }
88 let path = format!("{dir}/{hash}.json");
89 let content = std::fs::read_to_string(&path).ok()?;
90 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
91 Some(LlmResult {
92 text: json["text"].as_str().unwrap_or("").to_string(),
93 tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
94 input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
95 output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
96 cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
97 cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
98 model: json["model"].as_str().unwrap_or("").to_string(),
99 provider: json["provider"].as_str().unwrap_or("mock").to_string(),
100 thinking: json["thinking"].as_str().map(|s| s.to_string()),
101 stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
102 blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
103 })
104}
105
106fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
110 let mut args = serde_json::Map::new();
111 let input_schema = tool_schema
115 .get("input_schema")
116 .or_else(|| tool_schema.get("inputSchema"))
117 .or_else(|| {
118 tool_schema
119 .get("function")
120 .and_then(|f| f.get("parameters"))
121 })
122 .or_else(|| tool_schema.get("parameters"));
123 let Some(schema) = input_schema else {
124 return serde_json::Value::Object(args);
125 };
126 let required: std::collections::BTreeSet<String> = schema
127 .get("required")
128 .and_then(|r| r.as_array())
129 .map(|arr| {
130 arr.iter()
131 .filter_map(|v| v.as_str().map(|s| s.to_string()))
132 .collect()
133 })
134 .unwrap_or_default();
135 if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
136 for (name, prop) in props {
137 if !required.contains(name) {
138 continue;
139 }
140 let ty = prop
141 .get("type")
142 .and_then(|t| t.as_str())
143 .unwrap_or("string");
144 let placeholder = match ty {
145 "integer" => serde_json::json!(0),
146 "number" => serde_json::json!(0.0),
147 "boolean" => serde_json::json!(false),
148 "array" => serde_json::json!([]),
149 "object" => serde_json::json!({}),
150 _ => serde_json::json!(""),
151 };
152 args.insert(name.clone(), placeholder);
153 }
154 }
155 serde_json::Value::Object(args)
156}
157
158pub(crate) fn mock_llm_response(
160 messages: &[serde_json::Value],
161 system: Option<&str>,
162 native_tools: Option<&[serde_json::Value]>,
163) -> LlmResult {
164 let last_msg = messages
166 .last()
167 .and_then(|m| m.get("content"))
168 .and_then(|c| c.as_str())
169 .unwrap_or("");
170
171 if let Some(tools) = native_tools {
175 if let Some(first_tool) = tools.first() {
176 let tool_name = first_tool
177 .get("name")
178 .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
179 .and_then(|n| n.as_str())
180 .unwrap_or("unknown");
181 let mock_args = mock_required_args(first_tool);
182 return LlmResult {
183 text: String::new(),
184 tool_calls: vec![serde_json::json!({
185 "id": "mock_call_1",
186 "type": "tool_call",
187 "name": tool_name,
188 "arguments": mock_args
189 })],
190 input_tokens: last_msg.len() as i64,
191 output_tokens: 20,
192 cache_read_tokens: 0,
193 cache_write_tokens: 0,
194 model: "mock".to_string(),
195 provider: "mock".to_string(),
196 thinking: None,
197 stop_reason: None,
198 blocks: vec![serde_json::json!({
199 "type": "tool_call",
200 "id": "mock_call_1",
201 "name": tool_name,
202 "arguments": mock_args,
203 "visibility": "internal",
204 })],
205 };
206 }
207 }
208
209 let done_sentinel = if system.is_some_and(|s| s.contains("##DONE##")) {
212 " ##DONE##"
213 } else {
214 ""
215 };
216
217 let response = if last_msg.is_empty() {
218 format!("Mock LLM response{done_sentinel}")
219 } else {
220 let word_count = last_msg.split_whitespace().count();
221 format!(
222 "Mock response to {word_count}-word prompt: {}{done_sentinel}",
223 last_msg.chars().take(100).collect::<String>()
224 )
225 };
226
227 LlmResult {
228 text: response.clone(),
229 tool_calls: vec![],
230 input_tokens: last_msg.len() as i64,
231 output_tokens: 30,
232 cache_read_tokens: 0,
233 cache_write_tokens: 0,
234 model: "mock".to_string(),
235 provider: "mock".to_string(),
236 thinking: None,
237 stop_reason: None,
238 blocks: vec![serde_json::json!({
239 "type": "output_text",
240 "text": response,
241 "visibility": "public",
242 })],
243 }
244}
245
246pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
249 TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
250}
251
252pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
253 TOOL_RECORDING_MODE.with(|v| *v.borrow())
254}
255
256pub(crate) fn record_tool_call(record: ToolCallRecord) {
258 TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
259}
260
261pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
263 TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
264}
265
266pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
268 TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
269}
270
271pub(crate) fn find_tool_replay_fixture(
273 tool_name: &str,
274 args: &serde_json::Value,
275) -> Option<ToolCallRecord> {
276 let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
277 TOOL_REPLAY_FIXTURES.with(|v| {
278 v.borrow()
279 .iter()
280 .find(|r| r.tool_name == tool_name && r.args_hash == hash)
281 .cloned()
282 })
283}