Skip to main content

harn_vm/llm/
mock.rs

1use std::cell::RefCell;
2
3use super::api::LlmResult;
4use crate::orchestration::ToolCallRecord;
5use crate::value::{ErrorCategory, VmError};
6
7/// LLM replay mode.
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum LlmReplayMode {
10    Off,
11    Record,
12    Replay,
13}
14
15/// Tool recording mode — mirrors LLM replay for tool call results.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ToolRecordingMode {
18    Off,
19    Record,
20    Replay,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24enum CliLlmMockMode {
25    Off,
26    Replay,
27    Record,
28}
29
30/// Categorized error injected by a mock. When present, the mock
31/// short-circuits the provider call and surfaces as
32/// `VmError::CategorizedError`, so `llm_call` throws and
33/// `llm_call_safe` populates its `error` envelope.
34#[derive(Clone)]
35pub struct MockError {
36    pub category: ErrorCategory,
37    pub message: String,
38}
39
40#[derive(Clone)]
41pub struct LlmMock {
42    pub text: String,
43    pub tool_calls: Vec<serde_json::Value>,
44    pub match_pattern: Option<String>, // None = FIFO (consumed), Some = glob (reusable)
45    pub consume_on_match: bool,
46    pub input_tokens: Option<i64>,
47    pub output_tokens: Option<i64>,
48    pub cache_read_tokens: Option<i64>,
49    pub cache_write_tokens: Option<i64>,
50    pub thinking: Option<String>,
51    pub stop_reason: Option<String>,
52    pub model: String,
53    pub provider: Option<String>,
54    pub blocks: Option<Vec<serde_json::Value>>,
55    /// When `Some`, this mock synthesizes an error instead of an
56    /// `LlmResult`. `text`/`tool_calls` are ignored for error mocks.
57    pub error: Option<MockError>,
58}
59
60#[derive(Clone)]
61pub(crate) struct LlmMockCall {
62    pub messages: Vec<serde_json::Value>,
63    pub system: Option<String>,
64    pub tools: Option<Vec<serde_json::Value>>,
65}
66
67thread_local! {
68    static LLM_REPLAY_MODE: RefCell<LlmReplayMode> = const { RefCell::new(LlmReplayMode::Off) };
69    static LLM_FIXTURE_DIR: RefCell<String> = const { RefCell::new(String::new()) };
70    static TOOL_RECORDING_MODE: RefCell<ToolRecordingMode> = const { RefCell::new(ToolRecordingMode::Off) };
71    static TOOL_RECORDINGS: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
72    static TOOL_REPLAY_FIXTURES: RefCell<Vec<ToolCallRecord>> = const { RefCell::new(Vec::new()) };
73    static LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
74    static CLI_LLM_MOCK_MODE: RefCell<CliLlmMockMode> = const { RefCell::new(CliLlmMockMode::Off) };
75    static CLI_LLM_MOCKS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
76    static CLI_LLM_RECORDINGS: RefCell<Vec<LlmMock>> = const { RefCell::new(Vec::new()) };
77    static LLM_MOCK_CALLS: RefCell<Vec<LlmMockCall>> = const { RefCell::new(Vec::new()) };
78}
79
80pub(crate) fn push_llm_mock(mock: LlmMock) {
81    LLM_MOCKS.with(|v| v.borrow_mut().push(mock));
82}
83
84pub(crate) fn get_llm_mock_calls() -> Vec<LlmMockCall> {
85    LLM_MOCK_CALLS.with(|v| v.borrow().clone())
86}
87
88pub(crate) fn reset_llm_mock_state() {
89    LLM_MOCKS.with(|v| v.borrow_mut().clear());
90    CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
91    CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
92    CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
93    LLM_MOCK_CALLS.with(|v| v.borrow_mut().clear());
94}
95
96pub fn clear_cli_llm_mock_mode() {
97    CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Off);
98    CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
99    CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
100}
101
102pub fn install_cli_llm_mocks(mocks: Vec<LlmMock>) {
103    CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Replay);
104    CLI_LLM_MOCKS.with(|v| *v.borrow_mut() = mocks);
105    CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
106}
107
108pub fn enable_cli_llm_mock_recording() {
109    CLI_LLM_MOCK_MODE.with(|v| *v.borrow_mut() = CliLlmMockMode::Record);
110    CLI_LLM_MOCKS.with(|v| v.borrow_mut().clear());
111    CLI_LLM_RECORDINGS.with(|v| v.borrow_mut().clear());
112}
113
114pub fn take_cli_llm_recordings() -> Vec<LlmMock> {
115    CLI_LLM_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
116}
117
118pub(crate) fn cli_llm_mock_replay_active() -> bool {
119    CLI_LLM_MOCK_MODE.with(|v| *v.borrow() == CliLlmMockMode::Replay)
120}
121
122fn record_llm_mock_call(
123    messages: &[serde_json::Value],
124    system: Option<&str>,
125    native_tools: Option<&[serde_json::Value]>,
126) {
127    LLM_MOCK_CALLS.with(|v| {
128        v.borrow_mut().push(LlmMockCall {
129            messages: messages.to_vec(),
130            system: system.map(|s| s.to_string()),
131            tools: native_tools.map(|t| t.to_vec()),
132        });
133    });
134}
135
136/// Build an LlmResult from a matched mock.
137fn build_mock_result(mock: &LlmMock, last_msg_len: usize) -> LlmResult {
138    let (tool_calls, blocks) = if let Some(blocks) = &mock.blocks {
139        (mock.tool_calls.clone(), blocks.clone())
140    } else {
141        let mut blocks = Vec::new();
142
143        if !mock.text.is_empty() {
144            blocks.push(serde_json::json!({
145                "type": "output_text",
146                "text": mock.text,
147                "visibility": "public",
148            }));
149        }
150
151        let mut tool_calls = Vec::new();
152        for (i, tc) in mock.tool_calls.iter().enumerate() {
153            let id = format!("mock_call_{}", i + 1);
154            let name = tc.get("name").and_then(|n| n.as_str()).unwrap_or("unknown");
155            let arguments = tc
156                .get("arguments")
157                .cloned()
158                .unwrap_or(serde_json::json!({}));
159            tool_calls.push(serde_json::json!({
160                "id": id,
161                "type": "tool_call",
162                "name": name,
163                "arguments": arguments,
164            }));
165            blocks.push(serde_json::json!({
166                "type": "tool_call",
167                "id": id,
168                "name": name,
169                "arguments": arguments,
170                "visibility": "internal",
171            }));
172        }
173
174        (tool_calls, blocks)
175    };
176
177    LlmResult {
178        text: mock.text.clone(),
179        tool_calls,
180        input_tokens: mock.input_tokens.unwrap_or(last_msg_len as i64),
181        output_tokens: mock.output_tokens.unwrap_or(30),
182        cache_read_tokens: mock.cache_read_tokens.unwrap_or(0),
183        cache_write_tokens: mock.cache_write_tokens.unwrap_or(0),
184        model: mock.model.clone(),
185        provider: mock.provider.clone().unwrap_or_else(|| "mock".to_string()),
186        thinking: mock.thinking.clone(),
187        stop_reason: mock.stop_reason.clone(),
188        blocks,
189    }
190}
191
192/// Multi-segment glob match: split on `*` and check segments appear in order.
193/// Handles `*`, `prefix*`, `*suffix`, `*contains*`, `pre*mid*suf`, etc.
194fn mock_glob_match(pattern: &str, text: &str) -> bool {
195    if pattern == "*" {
196        return true;
197    }
198    if !pattern.contains('*') {
199        return pattern == text;
200    }
201    let parts: Vec<&str> = pattern.split('*').collect();
202    let mut remaining = text;
203    for (i, part) in parts.iter().enumerate() {
204        if part.is_empty() {
205            continue;
206        }
207        if i == 0 {
208            if !remaining.starts_with(part) {
209                return false;
210            }
211            remaining = &remaining[part.len()..];
212        } else if i == parts.len() - 1 {
213            if !remaining.ends_with(part) {
214                return false;
215            }
216            remaining = "";
217        } else {
218            match remaining.find(part) {
219                Some(pos) => remaining = &remaining[pos + part.len()..],
220                None => return false,
221            }
222        }
223    }
224    true
225}
226
227fn collect_mock_match_strings(value: &serde_json::Value, out: &mut Vec<String>) {
228    match value {
229        serde_json::Value::String(text) if !text.is_empty() => out.push(text.clone()),
230        serde_json::Value::String(_) => {}
231        serde_json::Value::Array(items) => {
232            for item in items {
233                collect_mock_match_strings(item, out);
234            }
235        }
236        serde_json::Value::Object(map) => {
237            for value in map.values() {
238                collect_mock_match_strings(value, out);
239            }
240        }
241        _ => {}
242    }
243}
244
245fn mock_match_text(messages: &[serde_json::Value]) -> String {
246    let mut parts = Vec::new();
247    for message in messages {
248        collect_mock_match_strings(message, &mut parts);
249    }
250    parts.join("\n")
251}
252
253fn mock_last_prompt_text(messages: &[serde_json::Value]) -> String {
254    for message in messages.iter().rev() {
255        let Some(content) = message.get("content") else {
256            continue;
257        };
258        let mut parts = Vec::new();
259        collect_mock_match_strings(content, &mut parts);
260        let text = parts.join("\n");
261        if !text.trim().is_empty() {
262            return text;
263        }
264    }
265    String::new()
266}
267
268/// Convert a mock's `error` payload into the `VmError` that the
269/// provider path would have raised, so classification, retry, and
270/// `error_category` all behave identically to a real failure.
271fn mock_error_to_vm_error(err: &MockError) -> VmError {
272    VmError::CategorizedError {
273        message: err.message.clone(),
274        category: err.category.clone(),
275    }
276}
277
278/// Try to find and return a matching mock response. Returns
279/// `Some(Ok(LlmResult))` on a text/tool_call match, `Some(Err(VmError))`
280/// on an error-mock match, and `None` to fall through to default.
281fn try_match_mock_queue(
282    mocks: &mut Vec<LlmMock>,
283    match_text: &str,
284) -> Option<Result<LlmResult, VmError>> {
285    if let Some(idx) = mocks.iter().position(|m| m.match_pattern.is_none()) {
286        let mock = mocks.remove(idx);
287        return Some(match &mock.error {
288            Some(err) => Err(mock_error_to_vm_error(err)),
289            None => Ok(build_mock_result(&mock, match_text.len())),
290        });
291    }
292
293    for idx in 0..mocks.len() {
294        let mock = &mocks[idx];
295        if let Some(ref pattern) = mock.match_pattern {
296            if mock_glob_match(pattern, match_text) {
297                if mock.consume_on_match {
298                    let mock = mocks.remove(idx);
299                    return Some(match &mock.error {
300                        Some(err) => Err(mock_error_to_vm_error(err)),
301                        None => Ok(build_mock_result(&mock, match_text.len())),
302                    });
303                }
304                return Some(match &mock.error {
305                    Some(err) => Err(mock_error_to_vm_error(err)),
306                    None => Ok(build_mock_result(mock, match_text.len())),
307                });
308            }
309        }
310    }
311
312    None
313}
314
315fn try_match_builtin_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
316    LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
317}
318
319fn try_match_cli_mock(match_text: &str) -> Option<Result<LlmResult, VmError>> {
320    CLI_LLM_MOCKS.with(|mocks| try_match_mock_queue(&mut mocks.borrow_mut(), match_text))
321}
322
323pub(crate) fn record_cli_llm_result(result: &LlmResult) {
324    if !CLI_LLM_MOCK_MODE.with(|mode| *mode.borrow() == CliLlmMockMode::Record) {
325        return;
326    }
327    CLI_LLM_RECORDINGS.with(|recordings| {
328        recordings.borrow_mut().push(LlmMock {
329            text: result.text.clone(),
330            tool_calls: result.tool_calls.clone(),
331            match_pattern: None,
332            consume_on_match: false,
333            input_tokens: Some(result.input_tokens),
334            output_tokens: Some(result.output_tokens),
335            cache_read_tokens: Some(result.cache_read_tokens),
336            cache_write_tokens: Some(result.cache_write_tokens),
337            thinking: result.thinking.clone(),
338            stop_reason: result.stop_reason.clone(),
339            model: result.model.clone(),
340            provider: Some(result.provider.clone()),
341            blocks: Some(result.blocks.clone()),
342            error: None,
343        });
344    });
345}
346
347fn unmatched_cli_prompt_error(match_text: &str) -> VmError {
348    let mut snippet: String = match_text.chars().take(200).collect();
349    if match_text.chars().count() > 200 {
350        snippet.push_str("...");
351    }
352    VmError::Runtime(format!("No --llm-mock fixture matched prompt: {snippet:?}"))
353}
354
355/// Set LLM replay mode (record/replay) and fixture directory.
356pub fn set_replay_mode(mode: LlmReplayMode, fixture_dir: &str) {
357    LLM_REPLAY_MODE.with(|v| *v.borrow_mut() = mode);
358    LLM_FIXTURE_DIR.with(|v| *v.borrow_mut() = fixture_dir.to_string());
359}
360
361pub(crate) fn get_replay_mode() -> LlmReplayMode {
362    LLM_REPLAY_MODE.with(|v| *v.borrow())
363}
364
365pub(crate) fn get_fixture_dir() -> String {
366    LLM_FIXTURE_DIR.with(|v| v.borrow().clone())
367}
368
369/// Hash a request for fixture file naming using canonical JSON serialization.
370pub(crate) fn fixture_hash(
371    model: &str,
372    messages: &[serde_json::Value],
373    system: Option<&str>,
374) -> String {
375    use std::hash::{Hash, Hasher};
376    let mut hasher = std::collections::hash_map::DefaultHasher::new();
377    model.hash(&mut hasher);
378    // Canonical JSON hashing is stable across Debug-format changes.
379    serde_json::to_string(messages)
380        .unwrap_or_default()
381        .hash(&mut hasher);
382    system.hash(&mut hasher);
383    format!("{:016x}", hasher.finish())
384}
385
386pub(crate) fn save_fixture(hash: &str, result: &LlmResult) {
387    let dir = get_fixture_dir();
388    if dir.is_empty() {
389        return;
390    }
391    let _ = std::fs::create_dir_all(&dir);
392    let path = format!("{dir}/{hash}.json");
393    let json = serde_json::json!({
394        "text": result.text,
395        "tool_calls": result.tool_calls,
396        "input_tokens": result.input_tokens,
397        "output_tokens": result.output_tokens,
398        "model": result.model,
399        "provider": result.provider,
400        "blocks": result.blocks,
401    });
402    let _ = std::fs::write(
403        &path,
404        serde_json::to_string_pretty(&json).unwrap_or_default(),
405    );
406}
407
408pub(crate) fn load_fixture(hash: &str) -> Option<LlmResult> {
409    let dir = get_fixture_dir();
410    if dir.is_empty() {
411        return None;
412    }
413    let path = format!("{dir}/{hash}.json");
414    let content = std::fs::read_to_string(&path).ok()?;
415    let json: serde_json::Value = serde_json::from_str(&content).ok()?;
416    Some(LlmResult {
417        text: json["text"].as_str().unwrap_or("").to_string(),
418        tool_calls: json["tool_calls"].as_array().cloned().unwrap_or_default(),
419        input_tokens: json["input_tokens"].as_i64().unwrap_or(0),
420        output_tokens: json["output_tokens"].as_i64().unwrap_or(0),
421        cache_read_tokens: json["cache_read_tokens"].as_i64().unwrap_or(0),
422        cache_write_tokens: json["cache_write_tokens"].as_i64().unwrap_or(0),
423        model: json["model"].as_str().unwrap_or("").to_string(),
424        provider: json["provider"].as_str().unwrap_or("mock").to_string(),
425        thinking: json["thinking"].as_str().map(|s| s.to_string()),
426        stop_reason: json["stop_reason"].as_str().map(|s| s.to_string()),
427        blocks: json["blocks"].as_array().cloned().unwrap_or_default(),
428    })
429}
430
431/// Generate stub argument values for required parameters in a tool schema.
432/// This makes mock tool calls realistic — a real model would always fill
433/// required fields, so the mock should too.
434fn mock_required_args(tool_schema: &serde_json::Value) -> serde_json::Value {
435    let mut args = serde_json::Map::new();
436    // Anthropic: {name, input_schema: {properties, required}}
437    // OpenAI:    {function: {name, parameters: {properties, required}}}
438    // Harn VM:   {parameters: {name: {type, required}}}  (from tool_define)
439    let input_schema = tool_schema
440        .get("input_schema")
441        .or_else(|| tool_schema.get("inputSchema"))
442        .or_else(|| {
443            tool_schema
444                .get("function")
445                .and_then(|f| f.get("parameters"))
446        })
447        .or_else(|| tool_schema.get("parameters"));
448    let Some(schema) = input_schema else {
449        return serde_json::Value::Object(args);
450    };
451    let required: std::collections::BTreeSet<String> = schema
452        .get("required")
453        .and_then(|r| r.as_array())
454        .map(|arr| {
455            arr.iter()
456                .filter_map(|v| v.as_str().map(|s| s.to_string()))
457                .collect()
458        })
459        .unwrap_or_default();
460    if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
461        for (name, prop) in props {
462            if !required.contains(name) {
463                continue;
464            }
465            let ty = prop
466                .get("type")
467                .and_then(|t| t.as_str())
468                .unwrap_or("string");
469            let placeholder = match ty {
470                "integer" => serde_json::json!(0),
471                "number" => serde_json::json!(0.0),
472                "boolean" => serde_json::json!(false),
473                "array" => serde_json::json!([]),
474                "object" => serde_json::json!({}),
475                _ => serde_json::json!(""),
476            };
477            args.insert(name.clone(), placeholder);
478        }
479    }
480    serde_json::Value::Object(args)
481}
482
483/// Mock LLM provider -- deterministic responses for testing without API keys.
484/// When configurable mocks have been registered via `llm_mock()`, those are
485/// checked first (FIFO queue, then pattern matching). Falls through to the
486/// default deterministic behavior when no mocks match.
487pub(crate) fn mock_llm_response(
488    messages: &[serde_json::Value],
489    system: Option<&str>,
490    native_tools: Option<&[serde_json::Value]>,
491) -> Result<LlmResult, VmError> {
492    record_llm_mock_call(messages, system, native_tools);
493
494    let match_text = mock_match_text(messages);
495    let prompt_text = mock_last_prompt_text(messages);
496
497    if let Some(matched) = try_match_cli_mock(&match_text) {
498        return matched;
499    }
500
501    if let Some(matched) = try_match_builtin_mock(&match_text) {
502        return matched;
503    }
504
505    if cli_llm_mock_replay_active() {
506        return Err(unmatched_cli_prompt_error(&match_text));
507    }
508
509    // Generate a mock tool call for the first tool, filling required
510    // params with placeholders so the call passes schema validation.
511    if let Some(tools) = native_tools {
512        if let Some(first_tool) = tools.first() {
513            let tool_name = first_tool
514                .get("name")
515                .or_else(|| first_tool.get("function").and_then(|f| f.get("name")))
516                .and_then(|n| n.as_str())
517                .unwrap_or("unknown");
518            let mock_args = mock_required_args(first_tool);
519            return Ok(LlmResult {
520                text: String::new(),
521                tool_calls: vec![serde_json::json!({
522                        "id": "mock_call_1",
523                        "type": "tool_call",
524                        "name": tool_name,
525                "arguments": mock_args
526                })],
527                input_tokens: prompt_text.len() as i64,
528                output_tokens: 20,
529                cache_read_tokens: 0,
530                cache_write_tokens: 0,
531                model: "mock".to_string(),
532                provider: "mock".to_string(),
533                thinking: None,
534                stop_reason: None,
535                blocks: vec![serde_json::json!({
536                    "type": "tool_call",
537                    "id": "mock_call_1",
538                    "name": tool_name,
539                    "arguments": mock_args,
540                    "visibility": "internal",
541                })],
542            });
543        }
544    }
545
546    // Preserve the historical auto-complete behavior for tagged text-tool
547    // prompts only. Bare `##DONE##` in no-tool/native prompts changes
548    // loop semantics by completing runs that used to exhaust budget unless
549    // a fixture explicitly returned the sentinel.
550    let tagged_done = system.is_some_and(|s| s.contains("<done>"));
551
552    let prose_body = if prompt_text.is_empty() {
553        "Mock LLM response".to_string()
554    } else {
555        let word_count = prompt_text.split_whitespace().count();
556        format!(
557            "Mock response to {word_count}-word prompt: {}",
558            prompt_text.chars().take(100).collect::<String>()
559        )
560    };
561    let response = if tagged_done {
562        format!("<assistant_prose>{prose_body}</assistant_prose>\n<done>##DONE##</done>")
563    } else {
564        prose_body
565    };
566
567    Ok(LlmResult {
568        text: response.clone(),
569        tool_calls: vec![],
570        input_tokens: prompt_text.len() as i64,
571        output_tokens: 30,
572        cache_read_tokens: 0,
573        cache_write_tokens: 0,
574        model: "mock".to_string(),
575        provider: "mock".to_string(),
576        thinking: None,
577        stop_reason: None,
578        blocks: vec![serde_json::json!({
579            "type": "output_text",
580            "text": response,
581            "visibility": "public",
582        })],
583    })
584}
585
586pub fn set_tool_recording_mode(mode: ToolRecordingMode) {
587    TOOL_RECORDING_MODE.with(|v| *v.borrow_mut() = mode);
588}
589
590pub(crate) fn get_tool_recording_mode() -> ToolRecordingMode {
591    TOOL_RECORDING_MODE.with(|v| *v.borrow())
592}
593
594/// Append a tool call record during recording mode.
595pub(crate) fn record_tool_call(record: ToolCallRecord) {
596    TOOL_RECORDINGS.with(|v| v.borrow_mut().push(record));
597}
598
599/// Take all recorded tool calls, leaving the buffer empty.
600pub fn drain_tool_recordings() -> Vec<ToolCallRecord> {
601    TOOL_RECORDINGS.with(|v| std::mem::take(&mut *v.borrow_mut()))
602}
603
604/// Load tool call fixtures for replay mode.
605pub fn load_tool_replay_fixtures(records: Vec<ToolCallRecord>) {
606    TOOL_REPLAY_FIXTURES.with(|v| *v.borrow_mut() = records);
607}
608
609/// Look up a recorded fixture by tool name + args hash.
610pub(crate) fn find_tool_replay_fixture(
611    tool_name: &str,
612    args: &serde_json::Value,
613) -> Option<ToolCallRecord> {
614    let hash = crate::orchestration::tool_fixture_hash(tool_name, args);
615    TOOL_REPLAY_FIXTURES.with(|v| {
616        v.borrow()
617            .iter()
618            .find(|r| r.tool_name == tool_name && r.args_hash == hash)
619            .cloned()
620    })
621}