Skip to main content

harn_vm/llm/
mock.rs

1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5use crate::value::{ErrorCategory, VmError};
6
7/// LLM replay mode.
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum LlmReplayMode {
10    Off,
11    Record,
12    Replay,
13}
14
15/// Tool recording mode — mirrors LLM replay for tool call results.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ToolRecordingMode {
18    Off,
19    Record,
20    Replay,
21}
22
23/// Categorized error injected by a mock. When present, the mock
24/// short-circuits the provider call and surfaces as
25/// `VmError::CategorizedError`, so `llm_call` throws and
26/// `llm_call_safe` populates its `error` envelope.
27#[derive(Clone)]
28pub(crate) struct MockError {
29    pub category: ErrorCategory,
30    pub message: String,
31}
32
33pub(crate) struct LlmMock {
34    pub text: String,
35    pub tool_calls: Vec<serde_json::Value>,
36    pub match_pattern: Option<String>, // None = FIFO (consumed), Some = glob (reusable)
37    pub input_tokens: Option<i64>,
38    pub output_tokens: Option<i64>,
39    pub thinking: Option<String>,
40    pub stop_reason: Option<String>,
41    pub model: String,
42    /// When `Some`, this mock synthesizes an error instead of an
43    /// `LlmResult`. `text`/`tool_calls` are ignored for error mocks.
44    pub error: Option<MockError>,
45}
46
47#[derive(Clone)]
48pub(crate) struct LlmMockCall {
49    pub messages: Vec<serde_json::Value>,
50    pub system: Option<String>,
51    pub tools: Option<Vec<serde_json::Value>>,
52}
53
54thread_local! {
55    static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
56    static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
57    static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
58    static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
59    static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
60    static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
61    static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
62}
63
64pub(crate) fn push_llm_mock(mock: LlmMock) {
65    LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
66}
67
68pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
69    LLM_MOCK_CALLS.with(|v| v.borrow().clone())
70}
71
72pub(crate) fn reset_llm_mock_state() {
73    LLM_MOCKS.with(|v| v.borrow_mut().clear());
74    LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
75}
76
77fn record_llm_mock_call(
78    messages: &[serde_json::Value],
79    system: Option<&str>,
80    native_tools: Option<&[serde_json::Value]>,
81) {
82    LLM_MOCK_CALLS.with(|v| {
83        v.borrow_mut().push(LlmMockCall {
84            messages: messages.to_vec(),
85            system: system.map(|s| s.to_string()),
86            tools: native_tools.map(|t| t.to_vec()),
87        });
88    });
89}
90
91/// Build an LlmResult from a matched mock.
92fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
93    let mut blocks = Vec::new();
94
95    if !mock.text.is_empty() {
96        blocks.push(serde_json::json!({
97            "type": "output_text",
98            "text": mock.text,
99            "visibility": "public",
100        }));
101    }
102
103    let mut tool_calls = Vec::new();
104    for (i, tc) in mock.tool_calls.iter().enumerate() {
105        let id = format!("mock_call_{}", i + 1);
106        let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
107        let arguments = tc
108            .get("arguments")
109            .cloned()
110            .unwrap_or(serde_json::json!({}));
111        tool_calls.push(serde_json::json!({
112            "id": id,
113            "type": "tool_call",
114            "name": name,
115            "arguments": arguments,
116        }));
117        blocks.push(serde_json::json!({
118            "type": "tool_call",
119            "id": id,
120            "name": name,
121            "arguments": arguments,
122            "visibility": "internal",
123        }));
124    }
125
126    LlmResult {
127        text: mock.text.clone(),
128        tool_calls,
129        input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
130        output_tokens: mock.output_tokens.unwrap_or(30),
131        cache_read_tokens: 0,
132        cache_write_tokens: 0,
133        model: mock.model.clone(),
134        provider: "mock".to_string(),
135        thinking: mock.thinking.clone(),
136        stop_reason: mock.stop_reason.clone(),
137        blocks,
138    }
139}
140
141/// Multi-segment glob match: split on `*` and check segments appear in order.
142/// Handles `*`, `prefix*`, `*suffix`, `*contains*`, `pre*mid*suf`, etc.
143fn mock_glob_match(pattern: &str, text: &str) -> bool {
144    if pattern == "*" {
145        return true;
146    }
147    if !pattern.contains('*') {
148        return pattern == text;
149    }
150    let parts: Vec<&str> = pattern.split('*').collect();
151    let mut remaining = text;
152    for (i, part) in parts.iter().enumerate() {
153        if part.is_empty() {
154            continue;
155        }
156        if i == 0 {
157            if !remaining.starts_with(part) {
158                return false;
159            }
160            remaining = &remaining[part.len()..];
161        } else if i == parts.len() - 1 {
162            if !remaining.ends_with(part) {
163                return false;
164            }
165            remaining = "";
166        } else {
167            match remaining.find(part) {
168                Some(pos) => remaining = &remaining[pos + part.len()..],
169                None => return false,
170            }
171        }
172    }
173    true
174}
175
176/// Convert a mock's `error` payload into the `VmError` that the
177/// provider path would have raised, so classification, retry, and
178/// `error_category` all behave identically to a real failure.
179fn mock_error_to_vm_error(err: &MockError) -> VmError {
180    VmError::CategorizedError {
181        message: err.message.clone(),
182        category: err.category.clone(),
183    }
184}
185
186/// Try to find and return a matching mock response. Returns
187/// `Some(Ok(LlmResult))` on a text/tool_call match, `Some(Err(VmError))`
188/// on an error-mock match, and `None` to fall through to default.
189fn try_match_mock(last_msg: &str) -> Option<Result<LlmResult, VmError>> {
190    LLM_MOCKS.with(|mocks| {
191        let mut mocks = mocks.borrow_mut();
192
193        // FIFO: first mock without a match pattern (consumed on use).
194        if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
195            let mock = mocks.remove(idx);
196            return Some(match &mock.error {
197                Some(err) => Err(mock_error_to_vm_error(err)),
198                None => Ok(build_mock_result(&mock, last_msg.len())),
199            });
200        }
201
202        // Pattern match (last registered wins).
203        for mock in mocks.iter().rev() {
204            if let Some(ref pattern) = mock.match_pattern {
205                if mock_glob_match(pattern, last_msg) {
206                    return Some(match &mock.error {
207                        Some(err) => Err(mock_error_to_vm_error(err)),
208                        None => Ok(build_mock_result(mock, last_msg.len())),
209                    });
210                }
211            }
212        }
213
214        None
215    })
216}
217
218/// Set LLM replay mode (record/replay) and fixture directory.
219pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
220    LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
221    LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
222}
223
224pub(crate) fn get_replay_mode() -> LlmReplayMode {
225    LLM_REPLAY_MODE.with(|v| *v.borrow())
226}
227
228pub(crate) fn get_fixture_dir() -> String {
229    LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
230}
231
232/// Hash a request for fixture file naming using canonical JSON serialization.
233pub(crate) fn fixture_hash(
234    model: &str,
235    messages: &[serde_json::Value],
236    system: Option<&str>,
237) -> String {
238    use std::hash::{Hash, Hasher};
239    let mut hasher = std::collections::hash_map::DefaultHasher::new();
240    model.hash(&mut hasher);
241    // Canonical JSON hashing is stable across Debug-format changes.
242    serde_json::to_string(messages)
243        .unwrap_or_default()
244        .hash(&mut hasher);
245    system.hash(&mut hasher);
246    format!("{:016x}", hasher.finish())
247}
248
249pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
250    let dir = get_fixture_dir();
251    if dir.is_empty() {
252        return;
253    }
254    let _ = std::fs::create_dir_all(&dir);
255    let path = format!("{dir}/{hash}.json");
256    let json = serde_json::json!({
257        "text": result.text,
258        "tool_calls": result.tool_calls,
259        "input_tokens": result.input_tokens,
260        "output_tokens": result.output_tokens,
261        "model": result.model,
262        "provider": result.provider,
263        "blocks": result.blocks,
264    });
265    let _ = std::fs::write(
266        &path,
267        serde_json::to_string_pretty(&json).unwrap_or_default(),
268    );
269}
270
271pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
272    let dir = get_fixture_dir();
273    if dir.is_empty() {
274        return None;
275    }
276    let path = format!("{dir}/{hash}.json");
277    let content = std::fs::read_to_string(&path).ok()?;
278    let json: serde_json::Value = serde_json::from_str(&content).ok()?;
279    Some(LlmResult {
280        text: json["text"].as_str().unwrap_or("").to_string(),
281        tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
282        input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
283        output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
284        cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
285        cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
286        model: json["model"].as_str().unwrap_or("").to_string(),
287        provider: json["provider"].as_str().unwrap_or("mock").to_string(),
288        thinking: json["thinking"].as_str().map(|s| s.to_string()),
289        stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
290        blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
291    })
292}
293
294/// Generate stub argument values for required parameters in a tool schema.
295/// This makes mock tool calls realistic — a real model would always fill
296/// required fields, so the mock should too.
297fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
298    let mut args = serde_json::Map::new();
299    // Anthropic: {name, input_schema: {properties, required}}
300    // OpenAI:    {function: {name, parameters: {properties, required}}}
301    // Harn VM:   {parameters: {name: {type, required}}}  (from tool_define)
302    let input_schema = tool_schema
303        .get("input_schema")
304        .or_else(|| tool_schema.get("inputSchema"))
305        .or_else(|| {
306            tool_schema
307                .get("function")
308                .and_then(|f| f.get("parameters"))
309        })
310        .or_else(|| tool_schema.get("parameters"));
311    let Some(schema) = input_schema else {
312        return serde_json::Value::Object(args);
313    };
314    let required: std::collections::BTreeSet<String> = schema
315        .get("required")
316        .and_then(|r| r.as_array())
317        .map(|arr| {
318            arr.iter()
319                .filter_map(|v| v.as_str().map(|s| s.to_string()))
320                .collect()
321        })
322        .unwrap_or_default();
323    if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
324        for (name, prop) in props {
325            if !required.contains(name) {
326                continue;
327            }
328            let ty = prop
329                .get("type")
330                .and_then(|t| t.as_str())
331                .unwrap_or("string");
332            let placeholder = match ty {
333                "integer" => serde_json::json!(0),
334                "number" => serde_json::json!(0.0),
335                "boolean" => serde_json::json!(false),
336                "array" => serde_json::json!([]),
337                "object" => serde_json::json!({}),
338                _ => serde_json::json!(""),
339            };
340            args.insert(name.clone(), placeholder);
341        }
342    }
343    serde_json::Value::Object(args)
344}
345
346/// Mock LLM provider -- deterministic responses for testing without API keys.
347/// When configurable mocks have been registered via `llm_mock()`, those are
348/// checked first (FIFO queue, then pattern matching). Falls through to the
349/// default deterministic behavior when no mocks match.
350pub(crate) fn mock_llm_response(
351    messages: &[serde_json::Value],
352    system: Option<&str>,
353    native_tools: Option<&[serde_json::Value]>,
354) -> Result<LlmResult, VmError> {
355    record_llm_mock_call(messages, system, native_tools);
356
357    let last_msg = messages
358        .last()
359        .and_then(|m| m.get("content"))
360        .and_then(|c| c.as_str())
361        .unwrap_or("");
362
363    if let Some(matched) = try_match_mock(last_msg) {
364        return matched;
365    }
366
367    // Generate a mock tool call for the first tool, filling required
368    // params with placeholders so the call passes schema validation.
369    if let Some(tools) = native_tools {
370        if let Some(first_tool) = tools.first() {
371            let tool_name = first_tool
372                .get("name")
373                .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
374                .and_then(|n| n.as_str())
375                .unwrap_or("unknown");
376            let mock_args = mock_required_args(first_tool);
377            return Ok(LlmResult {
378                text: String::new(),
379                tool_calls: vec![serde_json::json!({
380                    "id": "mock_call_1",
381                    "type": "tool_call",
382                    "name": tool_name,
383                    "arguments": mock_args
384                })],
385                input_tokens: last_msg.len() as i64,
386                output_tokens: 20,
387                cache_read_tokens: 0,
388                cache_write_tokens: 0,
389                model: "mock".to_string(),
390                provider: "mock".to_string(),
391                thinking: None,
392                stop_reason: None,
393                blocks: vec![serde_json::json!({
394                    "type": "tool_call",
395                    "id": "mock_call_1",
396                    "name": tool_name,
397                    "arguments": mock_args,
398                    "visibility": "internal",
399                })],
400            });
401        }
402    }
403
404    // Tagged response: <assistant_prose> + <done> when the system
405    // prompt advertises the sentinel (agent_loop compatibility).
406    let done_block = if system.is_some_and(|s| s.contains("##DONE##")) {
407        "\n<done>##DONE##</done>"
408    } else {
409        ""
410    };
411
412    let prose_body = if last_msg.is_empty() {
413        "Mock LLM response".to_string()
414    } else {
415        let word_count = last_msg.split_whitespace().count();
416        format!(
417            "Mock response to {word_count}-word prompt: {}",
418            last_msg.chars().take(100).collect::<String>()
419        )
420    };
421    let response = format!("<assistant_prose>{prose_body}</assistant_prose>{done_block}");
422
423    Ok(LlmResult {
424        text: response.clone(),
425        tool_calls: vec![],
426        input_tokens: last_msg.len() as i64,
427        output_tokens: 30,
428        cache_read_tokens: 0,
429        cache_write_tokens: 0,
430        model: "mock".to_string(),
431        provider: "mock".to_string(),
432        thinking: None,
433        stop_reason: None,
434        blocks: vec![serde_json::json!({
435            "type": "output_text",
436            "text": response,
437            "visibility": "public",
438        })],
439    })
440}
441
442pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
443    TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
444}
445
446pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
447    TOOL_RECORDING_MODE.with(|v| *v.borrow())
448}
449
450/// Append a tool call record during recording mode.
451pub(crate) fn record_tool_call(record: ToolCallRecord) {
452    TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
453}
454
455/// Take all recorded tool calls, leaving the buffer empty.
456pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
457    TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
458}
459
460/// Load tool call fixtures for replay mode.
461pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
462    TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
463}
464
465/// Look up a recorded fixture by tool name + args hash.
466pub(crate) fn find_tool_replay_fixture(
467    tool_name: &str,
468    args: &serde_json::Value,
469) -> Option<ToolCallRecord> {
470    let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
471    TOOL_REPLAY_FIXTURES.with(|v| {
472        v.borrow()
473            .iter()
474            .find(|r| r.tool_name == tool_name && r.args_hash == hash)
475            .cloned()
476    })
477}