Skip to main content

sage_runtime/
mock.rs

1//! Mock infrastructure for the Sage testing framework (RFC-0012).
2//!
3//! This module provides:
4//! - `MockResponse` - represents either a value or error response
5//! - `MockQueue` - thread-safe queue of mock responses
6//! - `MockLlmClient` - mock implementation of LLM inference
7//! - `MockToolRegistry` - mock implementations for tool calls
8//! - Task-local mock context for tool mocking in tests
9
10use crate::error::{SageError, SageResult};
11use serde::de::DeserializeOwned;
12use std::cell::RefCell;
13use std::future::Future;
14use std::sync::{Arc, Mutex};
15
16// Task-local storage for the mock tool registry.
17// On native: tokio::task_local for per-task isolation.
18// On WASM: thread_local since the browser is single-threaded.
19
20#[cfg(not(target_arch = "wasm32"))]
21tokio::task_local! {
22    static MOCK_TOOL_REGISTRY: RefCell<Option<MockToolRegistry>>;
23}
24
25#[cfg(target_arch = "wasm32")]
26thread_local! {
27    static MOCK_TOOL_REGISTRY: RefCell<Option<MockToolRegistry>> = const { RefCell::new(None) };
28}
29
30/// Run a future with a mock tool registry in scope.
31///
32/// All tool calls made during the execution of the future will check
33/// the registry for mocks before making real calls.
34#[cfg(not(target_arch = "wasm32"))]
35pub async fn with_mock_tools<F, R>(registry: MockToolRegistry, f: F) -> R
36where
37    F: Future<Output = R>,
38{
39    MOCK_TOOL_REGISTRY
40        .scope(RefCell::new(Some(registry)), f)
41        .await
42}
43
44/// Run a future with a mock tool registry in scope (WASM variant).
45#[cfg(target_arch = "wasm32")]
46pub async fn with_mock_tools<F, R>(registry: MockToolRegistry, f: F) -> R
47where
48    F: Future<Output = R>,
49{
50    MOCK_TOOL_REGISTRY.with(|cell| {
51        *cell.borrow_mut() = Some(registry);
52    });
53    let result = f.await;
54    MOCK_TOOL_REGISTRY.with(|cell| {
55        *cell.borrow_mut() = None;
56    });
57    result
58}
59
60/// Try to get a mock response for a tool function call.
61///
62/// Returns `Some(response)` if a mock is registered and available,
63/// `None` if no mock is registered or if called outside a mock context.
64///
65/// This is called by tool clients to intercept calls during tests.
66#[cfg(not(target_arch = "wasm32"))]
67pub fn try_get_mock(tool: &str, function: &str) -> Option<MockResponse> {
68    MOCK_TOOL_REGISTRY
69        .try_with(|cell| {
70            cell.borrow_mut()
71                .as_ref()
72                .and_then(|reg| reg.get(tool, function))
73        })
74        .ok()
75        .flatten()
76}
77
78/// Try to get a mock response for a tool function call (WASM variant).
79#[cfg(target_arch = "wasm32")]
80pub fn try_get_mock(tool: &str, function: &str) -> Option<MockResponse> {
81    MOCK_TOOL_REGISTRY.with(|cell| {
82        cell.borrow()
83            .as_ref()
84            .and_then(|reg| reg.get(tool, function))
85    })
86}
87
88/// A mock response for an `infer` call.
89#[derive(Debug, Clone)]
90pub enum MockResponse {
91    /// A successful response with the given value.
92    Value(serde_json::Value),
93    /// A failure response with the given error message.
94    Fail(String),
95}
96
97impl MockResponse {
98    /// Create a successful mock response from a JSON-serializable value.
99    pub fn value<T: serde::Serialize>(value: T) -> Self {
100        Self::Value(serde_json::to_value(value).expect("failed to serialize mock value"))
101    }
102
103    /// Create a successful mock response from a string.
104    pub fn string(s: impl Into<String>) -> Self {
105        Self::Value(serde_json::Value::String(s.into()))
106    }
107
108    /// Create a failure mock response.
109    pub fn fail(message: impl Into<String>) -> Self {
110        Self::Fail(message.into())
111    }
112}
113
114/// A thread-safe queue of mock responses.
115///
116/// Mock responses are consumed in order - the first `infer` call gets
117/// the first mock, the second gets the second, etc.
118#[derive(Debug, Clone, Default)]
119pub struct MockQueue {
120    responses: Arc<Mutex<Vec<MockResponse>>>,
121}
122
123impl MockQueue {
124    /// Create a new empty mock queue.
125    pub fn new() -> Self {
126        Self::default()
127    }
128
129    /// Create a mock queue with the given responses.
130    pub fn with_responses(responses: Vec<MockResponse>) -> Self {
131        Self {
132            responses: Arc::new(Mutex::new(responses)),
133        }
134    }
135
136    /// Add a mock response to the queue.
137    pub fn push(&self, response: MockResponse) {
138        self.responses.lock().unwrap().push(response);
139    }
140
141    /// Pop the next mock response from the queue.
142    ///
143    /// Returns `None` if the queue is empty.
144    pub fn pop(&self) -> Option<MockResponse> {
145        let mut queue = self.responses.lock().unwrap();
146        if queue.is_empty() {
147            None
148        } else {
149            Some(queue.remove(0))
150        }
151    }
152
153    /// Check if the queue is empty.
154    pub fn is_empty(&self) -> bool {
155        self.responses.lock().unwrap().is_empty()
156    }
157
158    /// Get the number of remaining mock responses.
159    pub fn len(&self) -> usize {
160        self.responses.lock().unwrap().len()
161    }
162}
163
164/// Mock LLM client for testing.
165///
166/// This client uses a `MockQueue` to return pre-configured responses
167/// instead of making real API calls.
168#[derive(Debug, Clone)]
169pub struct MockLlmClient {
170    queue: MockQueue,
171}
172
173impl MockLlmClient {
174    /// Create a new mock client with an empty queue.
175    pub fn new() -> Self {
176        Self {
177            queue: MockQueue::new(),
178        }
179    }
180
181    /// Create a mock client with the given responses.
182    pub fn with_responses(responses: Vec<MockResponse>) -> Self {
183        Self {
184            queue: MockQueue::with_responses(responses),
185        }
186    }
187
188    /// Get a reference to the mock queue for adding responses.
189    pub fn queue(&self) -> &MockQueue {
190        &self.queue
191    }
192
193    /// Call the mock LLM with a prompt and return the raw string response.
194    ///
195    /// Returns an error if no mock responses are queued.
196    pub async fn infer_string(&self, _prompt: &str) -> SageResult<String> {
197        match self.queue.pop() {
198            Some(MockResponse::Value(value)) => {
199                // Convert JSON value to string
200                match value {
201                    serde_json::Value::String(s) => Ok(s),
202                    other => Ok(other.to_string()),
203                }
204            }
205            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
206            None => Err(SageError::Llm(
207                "infer called with no mock available (E054)".to_string(),
208            )),
209        }
210    }
211
212    /// Call the mock LLM with a prompt and parse the response as the given type.
213    ///
214    /// Returns an error if no mock responses are queued.
215    pub async fn infer<T>(&self, _prompt: &str) -> SageResult<T>
216    where
217        T: DeserializeOwned,
218    {
219        match self.queue.pop() {
220            Some(MockResponse::Value(value)) => serde_json::from_value(value)
221                .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
222            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
223            None => Err(SageError::Llm(
224                "infer called with no mock available (E054)".to_string(),
225            )),
226        }
227    }
228
229    /// Call the mock LLM with schema-injected prompt for structured output.
230    ///
231    /// Returns an error if no mock responses are queued.
232    pub async fn infer_structured<T>(&self, _prompt: &str, _schema: &str) -> SageResult<T>
233    where
234        T: DeserializeOwned,
235    {
236        // Same as infer - the schema is ignored for mocks
237        match self.queue.pop() {
238            Some(MockResponse::Value(value)) => serde_json::from_value(value)
239                .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
240            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
241            None => Err(SageError::Llm(
242                "infer called with no mock available (E054)".to_string(),
243            )),
244        }
245    }
246}
247
248impl Default for MockLlmClient {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254/// Mock registry for tool calls.
255///
256/// Stores mock responses for specific tool.function combinations.
257#[derive(Debug, Clone, Default)]
258pub struct MockToolRegistry {
259    mocks: Arc<Mutex<std::collections::HashMap<String, MockQueue>>>,
260}
261
262impl MockToolRegistry {
263    /// Create a new empty mock registry.
264    pub fn new() -> Self {
265        Self::default()
266    }
267
268    /// Register a mock response for a tool function.
269    ///
270    /// The key is in the format "ToolName.function_name".
271    pub fn register(&self, tool: &str, function: &str, response: MockResponse) {
272        let key = format!("{}.{}", tool, function);
273        let mut mocks = self.mocks.lock().unwrap();
274        mocks.entry(key).or_default().push(response);
275    }
276
277    /// Get the next mock response for a tool function.
278    ///
279    /// Returns `None` if no mock is registered for this function.
280    pub fn get(&self, tool: &str, function: &str) -> Option<MockResponse> {
281        let key = format!("{}.{}", tool, function);
282        let mocks = self.mocks.lock().unwrap();
283        mocks.get(&key).and_then(|q| q.pop())
284    }
285
286    /// Check if a mock is registered for a tool function.
287    pub fn has_mock(&self, tool: &str, function: &str) -> bool {
288        let key = format!("{}.{}", tool, function);
289        let mocks = self.mocks.lock().unwrap();
290        mocks.get(&key).is_some_and(|q| !q.is_empty())
291    }
292
293    /// Call a mocked tool function and return the result.
294    ///
295    /// Returns an error if no mock is registered.
296    pub async fn call<T>(&self, tool: &str, function: &str) -> SageResult<T>
297    where
298        T: DeserializeOwned,
299    {
300        match self.get(tool, function) {
301            Some(MockResponse::Value(value)) => serde_json::from_value(value).map_err(|e| {
302                SageError::Tool(format!("failed to deserialize mock tool response: {e}"))
303            }),
304            Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
305            None => Err(SageError::Tool(format!(
306                "no mock registered for {}.{}",
307                tool, function
308            ))),
309        }
310    }
311
312    /// Call a mocked tool function and return the raw string.
313    pub async fn call_string(&self, tool: &str, function: &str) -> SageResult<String> {
314        match self.get(tool, function) {
315            Some(MockResponse::Value(value)) => match value {
316                serde_json::Value::String(s) => Ok(s),
317                other => Ok(other.to_string()),
318            },
319            Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
320            None => Err(SageError::Tool(format!(
321                "no mock registered for {}.{}",
322                tool, function
323            ))),
324        }
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[tokio::test]
333    async fn mock_infer_string_returns_value() {
334        let client = MockLlmClient::with_responses(vec![MockResponse::string("hello world")]);
335        let result = client.infer_string("test").await.unwrap();
336        assert_eq!(result, "hello world");
337    }
338
339    #[tokio::test]
340    async fn mock_infer_string_returns_fail() {
341        let client = MockLlmClient::with_responses(vec![MockResponse::fail("test error")]);
342        let result = client.infer_string("test").await;
343        assert!(result.is_err());
344        assert!(result.unwrap_err().to_string().contains("test error"));
345    }
346
347    #[tokio::test]
348    async fn mock_infer_empty_queue_returns_error() {
349        let client = MockLlmClient::new();
350        let result = client.infer_string("test").await;
351        assert!(result.is_err());
352        assert!(result.unwrap_err().to_string().contains("E054"));
353    }
354
355    #[tokio::test]
356    async fn mock_queue_fifo_order() {
357        let client = MockLlmClient::with_responses(vec![
358            MockResponse::string("first"),
359            MockResponse::string("second"),
360            MockResponse::string("third"),
361        ]);
362
363        assert_eq!(client.infer_string("a").await.unwrap(), "first");
364        assert_eq!(client.infer_string("b").await.unwrap(), "second");
365        assert_eq!(client.infer_string("c").await.unwrap(), "third");
366        assert!(client.infer_string("d").await.is_err());
367    }
368
369    #[tokio::test]
370    async fn mock_infer_typed_value() {
371        #[derive(Debug, serde::Deserialize, PartialEq)]
372        struct Person {
373            name: String,
374            age: i32,
375        }
376
377        let client = MockLlmClient::with_responses(vec![MockResponse::value(
378            serde_json::json!({ "name": "Ward", "age": 42 }),
379        )]);
380
381        let person: Person = client.infer("test").await.unwrap();
382        assert_eq!(person.name, "Ward");
383        assert_eq!(person.age, 42);
384    }
385
386    #[test]
387    fn mock_queue_thread_safe() {
388        use std::thread;
389
390        let queue = MockQueue::with_responses(vec![
391            MockResponse::string("1"),
392            MockResponse::string("2"),
393            MockResponse::string("3"),
394        ]);
395
396        let queue_clone = queue.clone();
397        let handle = thread::spawn(move || {
398            queue_clone.pop();
399            queue_clone.pop();
400        });
401
402        handle.join().unwrap();
403        assert_eq!(queue.len(), 1);
404    }
405
406    #[tokio::test]
407    async fn mock_infer_structured() {
408        #[derive(Debug, serde::Deserialize, PartialEq)]
409        struct Summary {
410            text: String,
411            confidence: f64,
412        }
413
414        let client = MockLlmClient::with_responses(vec![MockResponse::value(serde_json::json!({
415            "text": "A summary",
416            "confidence": 0.95
417        }))]);
418
419        let summary: Summary = client
420            .infer_structured("summarize", "schema")
421            .await
422            .unwrap();
423        assert_eq!(summary.text, "A summary");
424        assert!((summary.confidence - 0.95).abs() < 0.001);
425    }
426
427    #[tokio::test]
428    async fn mock_tool_registry_basic() {
429        let registry = MockToolRegistry::new();
430
431        // Register a mock
432        registry.register("Http", "get", MockResponse::string("mocked response"));
433
434        // Should have mock
435        assert!(registry.has_mock("Http", "get"));
436
437        // Call and get result
438        let result: String = registry.call("Http", "get").await.unwrap();
439        assert_eq!(result, "mocked response");
440
441        // Queue should be empty now
442        assert!(!registry.has_mock("Http", "get"));
443    }
444
445    #[tokio::test]
446    async fn mock_tool_registry_multiple() {
447        let registry = MockToolRegistry::new();
448
449        // Register multiple mocks for same function
450        registry.register("Http", "get", MockResponse::string("first"));
451        registry.register("Http", "get", MockResponse::string("second"));
452
453        // Should get them in order
454        let r1: String = registry.call("Http", "get").await.unwrap();
455        let r2: String = registry.call("Http", "get").await.unwrap();
456
457        assert_eq!(r1, "first");
458        assert_eq!(r2, "second");
459    }
460
461    #[tokio::test]
462    async fn mock_tool_registry_fail() {
463        let registry = MockToolRegistry::new();
464        registry.register("Http", "get", MockResponse::fail("network error"));
465
466        let result: Result<String, _> = registry.call("Http", "get").await;
467        assert!(result.is_err());
468        assert!(result.unwrap_err().to_string().contains("network error"));
469    }
470
471    #[tokio::test]
472    async fn mock_tool_registry_no_mock() {
473        let registry = MockToolRegistry::new();
474
475        let result: Result<String, _> = registry.call("Http", "get").await;
476        assert!(result.is_err());
477        assert!(result
478            .unwrap_err()
479            .to_string()
480            .contains("no mock registered"));
481    }
482}