Skip to main content

harn_vm/llm/
mock.rs

1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5
6/// LLM replay mode.
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum LlmReplayMode {
9    Off,
10    Record,
11    Replay,
12}
13
14/// Tool recording mode — mirrors LLM replay for tool call results.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ToolRecordingMode {
17    Off,
18    Record,
19    Replay,
20}
21
22pub(crate) struct LlmMock {
23    pub text: String,
24    pub tool_calls: Vec<serde_json::Value>,
25    pub match_pattern: Option<String>, // None = FIFO (consumed), Some = glob (reusable)
26    pub input_tokens: Option<i64>,
27    pub output_tokens: Option<i64>,
28    pub thinking: Option<String>,
29    pub stop_reason: Option<String>,
30    pub model: String,
31}
32
33#[derive(Clone)]
34pub(crate) struct LlmMockCall {
35    pub messages: Vec<serde_json::Value>,
36    pub system: Option<String>,
37    pub tools: Option<Vec<serde_json::Value>>,
38}
39
40thread_local! {
41    static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
42    static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
43    static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
44    static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
45    static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
46    static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
47    static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
48}
49
50pub(crate) fn push_llm_mock(mock: LlmMock) {
51    LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
52}
53
54pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
55    LLM_MOCK_CALLS.with(|v| v.borrow().clone())
56}
57
58pub(crate) fn reset_llm_mock_state() {
59    LLM_MOCKS.with(|v| v.borrow_mut().clear());
60    LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
61}
62
63fn record_llm_mock_call(
64    messages: &[serde_json::Value],
65    system: Option<&str>,
66    native_tools: Option<&[serde_json::Value]>,
67) {
68    LLM_MOCK_CALLS.with(|v| {
69        v.borrow_mut().push(LlmMockCall {
70            messages: messages.to_vec(),
71            system: system.map(|s| s.to_string()),
72            tools: native_tools.map(|t| t.to_vec()),
73        });
74    });
75}
76
77/// Build an LlmResult from a matched mock.
78fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
79    let mut blocks = Vec::new();
80
81    if !mock.text.is_empty() {
82        blocks.push(serde_json::json!({
83            "type": "output_text",
84            "text": mock.text,
85            "visibility": "public",
86        }));
87    }
88
89    let mut tool_calls = Vec::new();
90    for (i, tc) in mock.tool_calls.iter().enumerate() {
91        let id = format!("mock_call_{}", i + 1);
92        let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
93        let arguments = tc
94            .get("arguments")
95            .cloned()
96            .unwrap_or(serde_json::json!({}));
97        tool_calls.push(serde_json::json!({
98            "id": id,
99            "type": "tool_call",
100            "name": name,
101            "arguments": arguments,
102        }));
103        blocks.push(serde_json::json!({
104            "type": "tool_call",
105            "id": id,
106            "name": name,
107            "arguments": arguments,
108            "visibility": "internal",
109        }));
110    }
111
112    LlmResult {
113        text: mock.text.clone(),
114        tool_calls,
115        input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
116        output_tokens: mock.output_tokens.unwrap_or(30),
117        cache_read_tokens: 0,
118        cache_write_tokens: 0,
119        model: mock.model.clone(),
120        provider: "mock".to_string(),
121        thinking: mock.thinking.clone(),
122        stop_reason: mock.stop_reason.clone(),
123        blocks,
124    }
125}
126
127/// Multi-segment glob match: split on `*` and check segments appear in order.
128/// Handles `*`, `prefix*`, `*suffix`, `*contains*`, `pre*mid*suf`, etc.
129fn mock_glob_match(pattern: &str, text: &str) -> bool {
130    if pattern == "*" {
131        return true;
132    }
133    if !pattern.contains('*') {
134        return pattern == text;
135    }
136    let parts: Vec<&str> = pattern.split('*').collect();
137    let mut remaining = text;
138    for (i, part) in parts.iter().enumerate() {
139        if part.is_empty() {
140            continue;
141        }
142        if i == 0 {
143            if !remaining.starts_with(part) {
144                return false;
145            }
146            remaining = &remaining[part.len()..];
147        } else if i == parts.len() - 1 {
148            if !remaining.ends_with(part) {
149                return false;
150            }
151            remaining = "";
152        } else {
153            match remaining.find(part) {
154                Some(pos) => remaining = &remaining[pos + part.len()..],
155                None => return false,
156            }
157        }
158    }
159    true
160}
161
162/// Try to find and return a matching mock response.
163/// Returns Some(LlmResult) if a mock matched, None to fall through to default.
164fn try_match_mock(last_msg: &str) -> Option<LlmResult> {
165    LLM_MOCKS.with(|mocks| {
166        let mut mocks = mocks.borrow_mut();
167
168        // FIFO: first mock without a match pattern (consumed on use).
169        if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
170            let mock = mocks.remove(idx);
171            return Some(build_mock_result(&mock, last_msg.len()));
172        }
173
174        // Pattern match (last registered wins).
175        for mock in mocks.iter().rev() {
176            if let Some(ref pattern) = mock.match_pattern {
177                if mock_glob_match(pattern, last_msg) {
178                    return Some(build_mock_result(mock, last_msg.len()));
179                }
180            }
181        }
182
183        None
184    })
185}
186
187/// Set LLM replay mode (record/replay) and fixture directory.
188pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
189    LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
190    LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
191}
192
193pub(crate) fn get_replay_mode() -> LlmReplayMode {
194    LLM_REPLAY_MODE.with(|v| *v.borrow())
195}
196
197pub(crate) fn get_fixture_dir() -> String {
198    LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
199}
200
201/// Hash a request for fixture file naming using canonical JSON serialization.
202pub(crate) fn fixture_hash(
203    model: &str,
204    messages: &[serde_json::Value],
205    system: Option<&str>,
206) -> String {
207    use std::hash::{Hash, Hasher};
208    let mut hasher = std::collections::hash_map::DefaultHasher::new();
209    model.hash(&mut hasher);
210    // Canonical JSON hashing is stable across Debug-format changes.
211    serde_json::to_string(messages)
212        .unwrap_or_default()
213        .hash(&mut hasher);
214    system.hash(&mut hasher);
215    format!("{:016x}", hasher.finish())
216}
217
218pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
219    let dir = get_fixture_dir();
220    if dir.is_empty() {
221        return;
222    }
223    let _ = std::fs::create_dir_all(&dir);
224    let path = format!("{dir}/{hash}.json");
225    let json = serde_json::json!({
226        "text": result.text,
227        "tool_calls": result.tool_calls,
228        "input_tokens": result.input_tokens,
229        "output_tokens": result.output_tokens,
230        "model": result.model,
231        "provider": result.provider,
232        "blocks": result.blocks,
233    });
234    let _ = std::fs::write(
235        &path,
236        serde_json::to_string_pretty(&json).unwrap_or_default(),
237    );
238}
239
240pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
241    let dir = get_fixture_dir();
242    if dir.is_empty() {
243        return None;
244    }
245    let path = format!("{dir}/{hash}.json");
246    let content = std::fs::read_to_string(&path).ok()?;
247    let json: serde_json::Value = serde_json::from_str(&content).ok()?;
248    Some(LlmResult {
249        text: json["text"].as_str().unwrap_or("").to_string(),
250        tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
251        input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
252        output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
253        cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
254        cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
255        model: json["model"].as_str().unwrap_or("").to_string(),
256        provider: json["provider"].as_str().unwrap_or("mock").to_string(),
257        thinking: json["thinking"].as_str().map(|s| s.to_string()),
258        stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
259        blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
260    })
261}
262
263/// Generate stub argument values for required parameters in a tool schema.
264/// This makes mock tool calls realistic — a real model would always fill
265/// required fields, so the mock should too.
266fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
267    let mut args = serde_json::Map::new();
268    // Anthropic: {name, input_schema: {properties, required}}
269    // OpenAI:    {function: {name, parameters: {properties, required}}}
270    // Harn VM:   {parameters: {name: {type, required}}}  (from tool_define)
271    let input_schema = tool_schema
272        .get("input_schema")
273        .or_else(|| tool_schema.get("inputSchema"))
274        .or_else(|| {
275            tool_schema
276                .get("function")
277                .and_then(|f| f.get("parameters"))
278        })
279        .or_else(|| tool_schema.get("parameters"));
280    let Some(schema) = input_schema else {
281        return serde_json::Value::Object(args);
282    };
283    let required: std::collections::BTreeSet<String> = schema
284        .get("required")
285        .and_then(|r| r.as_array())
286        .map(|arr| {
287            arr.iter()
288                .filter_map(|v| v.as_str().map(|s| s.to_string()))
289                .collect()
290        })
291        .unwrap_or_default();
292    if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
293        for (name, prop) in props {
294            if !required.contains(name) {
295                continue;
296            }
297            let ty = prop
298                .get("type")
299                .and_then(|t| t.as_str())
300                .unwrap_or("string");
301            let placeholder = match ty {
302                "integer" => serde_json::json!(0),
303                "number" => serde_json::json!(0.0),
304                "boolean" => serde_json::json!(false),
305                "array" => serde_json::json!([]),
306                "object" => serde_json::json!({}),
307                _ => serde_json::json!(""),
308            };
309            args.insert(name.clone(), placeholder);
310        }
311    }
312    serde_json::Value::Object(args)
313}
314
315/// Mock LLM provider -- deterministic responses for testing without API keys.
316/// When configurable mocks have been registered via `llm_mock()`, those are
317/// checked first (FIFO queue, then pattern matching). Falls through to the
318/// default deterministic behavior when no mocks match.
319pub(crate) fn mock_llm_response(
320    messages: &[serde_json::Value],
321    system: Option<&str>,
322    native_tools: Option<&[serde_json::Value]>,
323) -> LlmResult {
324    record_llm_mock_call(messages, system, native_tools);
325
326    let last_msg = messages
327        .last()
328        .and_then(|m| m.get("content"))
329        .and_then(|c| c.as_str())
330        .unwrap_or("");
331
332    if let Some(result) = try_match_mock(last_msg) {
333        return result;
334    }
335
336    // Generate a mock tool call for the first tool, filling required
337    // params with placeholders so the call passes schema validation.
338    if let Some(tools) = native_tools {
339        if let Some(first_tool) = tools.first() {
340            let tool_name = first_tool
341                .get("name")
342                .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
343                .and_then(|n| n.as_str())
344                .unwrap_or("unknown");
345            let mock_args = mock_required_args(first_tool);
346            return LlmResult {
347                text: String::new(),
348                tool_calls: vec![serde_json::json!({
349                    "id": "mock_call_1",
350                    "type": "tool_call",
351                    "name": tool_name,
352                    "arguments": mock_args
353                })],
354                input_tokens: last_msg.len() as i64,
355                output_tokens: 20,
356                cache_read_tokens: 0,
357                cache_write_tokens: 0,
358                model: "mock".to_string(),
359                provider: "mock".to_string(),
360                thinking: None,
361                stop_reason: None,
362                blocks: vec![serde_json::json!({
363                    "type": "tool_call",
364                    "id": "mock_call_1",
365                    "name": tool_name,
366                    "arguments": mock_args,
367                    "visibility": "internal",
368                })],
369            };
370        }
371    }
372
373    // Tagged response: <assistant_prose> + <done> when the system
374    // prompt advertises the sentinel (agent_loop compatibility).
375    let done_block = if system.is_some_and(|s| s.contains("##DONE##")) {
376        "\n<done>##DONE##</done>"
377    } else {
378        ""
379    };
380
381    let prose_body = if last_msg.is_empty() {
382        "Mock LLM response".to_string()
383    } else {
384        let word_count = last_msg.split_whitespace().count();
385        format!(
386            "Mock response to {word_count}-word prompt: {}",
387            last_msg.chars().take(100).collect::<String>()
388        )
389    };
390    let response = format!("<assistant_prose>{prose_body}</assistant_prose>{done_block}");
391
392    LlmResult {
393        text: response.clone(),
394        tool_calls: vec![],
395        input_tokens: last_msg.len() as i64,
396        output_tokens: 30,
397        cache_read_tokens: 0,
398        cache_write_tokens: 0,
399        model: "mock".to_string(),
400        provider: "mock".to_string(),
401        thinking: None,
402        stop_reason: None,
403        blocks: vec![serde_json::json!({
404            "type": "output_text",
405            "text": response,
406            "visibility": "public",
407        })],
408    }
409}
410
411pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
412    TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
413}
414
415pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
416    TOOL_RECORDING_MODE.with(|v| *v.borrow())
417}
418
419/// Append a tool call record during recording mode.
420pub(crate) fn record_tool_call(record: ToolCallRecord) {
421    TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
422}
423
424/// Take all recorded tool calls, leaving the buffer empty.
425pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
426    TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
427}
428
429/// Load tool call fixtures for replay mode.
430pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
431    TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
432}
433
434/// Look up a recorded fixture by tool name + args hash.
435pub(crate) fn find_tool_replay_fixture(
436    tool_name: &str,
437    args: &serde_json::Value,
438) -> Option<ToolCallRecord> {
439    let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
440    TOOL_REPLAY_FIXTURES.with(|v| {
441        v.borrow()
442            .iter()
443            .find(|r| r.tool_name == tool_name && r.args_hash == hash)
444            .cloned()
445    })
446}