Skip to main content

sage_runtime/
mock.rs

1//! Mock infrastructure for the Sage testing framework (RFC-0012).
2//!
3//! This module provides:
4//! - `MockResponse` - represents either a value or error response
5//! - `MockQueue` - thread-safe queue of mock responses
6//! - `MockLlmClient` - mock implementation of LLM inference
7//! - `MockToolRegistry` - mock implementations for tool calls
8
9use crate::error::{SageError, SageResult};
10use serde::de::DeserializeOwned;
11use std::sync::{Arc, Mutex};
12
13/// A mock response for an `infer` call.
14#[derive(Debug, Clone)]
15pub enum MockResponse {
16    /// A successful response with the given value.
17    Value(serde_json::Value),
18    /// A failure response with the given error message.
19    Fail(String),
20}
21
22impl MockResponse {
23    /// Create a successful mock response from a JSON-serializable value.
24    pub fn value<T: serde::Serialize>(value: T) -> Self {
25        Self::Value(serde_json::to_value(value).expect("failed to serialize mock value"))
26    }
27
28    /// Create a successful mock response from a string.
29    pub fn string(s: impl Into<String>) -> Self {
30        Self::Value(serde_json::Value::String(s.into()))
31    }
32
33    /// Create a failure mock response.
34    pub fn fail(message: impl Into<String>) -> Self {
35        Self::Fail(message.into())
36    }
37}
38
39/// A thread-safe queue of mock responses.
40///
41/// Mock responses are consumed in order - the first `infer` call gets
42/// the first mock, the second gets the second, etc.
43#[derive(Debug, Clone, Default)]
44pub struct MockQueue {
45    responses: Arc<Mutex<Vec<MockResponse>>>,
46}
47
48impl MockQueue {
49    /// Create a new empty mock queue.
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Create a mock queue with the given responses.
55    pub fn with_responses(responses: Vec<MockResponse>) -> Self {
56        Self {
57            responses: Arc::new(Mutex::new(responses)),
58        }
59    }
60
61    /// Add a mock response to the queue.
62    pub fn push(&self, response: MockResponse) {
63        self.responses.lock().unwrap().push(response);
64    }
65
66    /// Pop the next mock response from the queue.
67    ///
68    /// Returns `None` if the queue is empty.
69    pub fn pop(&self) -> Option<MockResponse> {
70        let mut queue = self.responses.lock().unwrap();
71        if queue.is_empty() {
72            None
73        } else {
74            Some(queue.remove(0))
75        }
76    }
77
78    /// Check if the queue is empty.
79    pub fn is_empty(&self) -> bool {
80        self.responses.lock().unwrap().is_empty()
81    }
82
83    /// Get the number of remaining mock responses.
84    pub fn len(&self) -> usize {
85        self.responses.lock().unwrap().len()
86    }
87}
88
89/// Mock LLM client for testing.
90///
91/// This client uses a `MockQueue` to return pre-configured responses
92/// instead of making real API calls.
93#[derive(Debug, Clone)]
94pub struct MockLlmClient {
95    queue: MockQueue,
96}
97
98impl MockLlmClient {
99    /// Create a new mock client with an empty queue.
100    pub fn new() -> Self {
101        Self {
102            queue: MockQueue::new(),
103        }
104    }
105
106    /// Create a mock client with the given responses.
107    pub fn with_responses(responses: Vec<MockResponse>) -> Self {
108        Self {
109            queue: MockQueue::with_responses(responses),
110        }
111    }
112
113    /// Get a reference to the mock queue for adding responses.
114    pub fn queue(&self) -> &MockQueue {
115        &self.queue
116    }
117
118    /// Call the mock LLM with a prompt and return the raw string response.
119    ///
120    /// Returns an error if no mock responses are queued.
121    pub async fn infer_string(&self, _prompt: &str) -> SageResult<String> {
122        match self.queue.pop() {
123            Some(MockResponse::Value(value)) => {
124                // Convert JSON value to string
125                match value {
126                    serde_json::Value::String(s) => Ok(s),
127                    other => Ok(other.to_string()),
128                }
129            }
130            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
131            None => Err(SageError::Llm(
132                "infer called with no mock available (E054)".to_string(),
133            )),
134        }
135    }
136
137    /// Call the mock LLM with a prompt and parse the response as the given type.
138    ///
139    /// Returns an error if no mock responses are queued.
140    pub async fn infer<T>(&self, _prompt: &str) -> SageResult<T>
141    where
142        T: DeserializeOwned,
143    {
144        match self.queue.pop() {
145            Some(MockResponse::Value(value)) => serde_json::from_value(value)
146                .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
147            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
148            None => Err(SageError::Llm(
149                "infer called with no mock available (E054)".to_string(),
150            )),
151        }
152    }
153
154    /// Call the mock LLM with schema-injected prompt for structured output.
155    ///
156    /// Returns an error if no mock responses are queued.
157    pub async fn infer_structured<T>(&self, _prompt: &str, _schema: &str) -> SageResult<T>
158    where
159        T: DeserializeOwned,
160    {
161        // Same as infer - the schema is ignored for mocks
162        match self.queue.pop() {
163            Some(MockResponse::Value(value)) => serde_json::from_value(value)
164                .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
165            Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
166            None => Err(SageError::Llm(
167                "infer called with no mock available (E054)".to_string(),
168            )),
169        }
170    }
171}
172
173impl Default for MockLlmClient {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179/// Mock registry for tool calls.
180///
181/// Stores mock responses for specific tool.function combinations.
182#[derive(Debug, Clone, Default)]
183pub struct MockToolRegistry {
184    mocks: Arc<Mutex<std::collections::HashMap<String, MockQueue>>>,
185}
186
187impl MockToolRegistry {
188    /// Create a new empty mock registry.
189    pub fn new() -> Self {
190        Self::default()
191    }
192
193    /// Register a mock response for a tool function.
194    ///
195    /// The key is in the format "ToolName.function_name".
196    pub fn register(&self, tool: &str, function: &str, response: MockResponse) {
197        let key = format!("{}.{}", tool, function);
198        let mut mocks = self.mocks.lock().unwrap();
199        mocks
200            .entry(key)
201            .or_insert_with(MockQueue::new)
202            .push(response);
203    }
204
205    /// Get the next mock response for a tool function.
206    ///
207    /// Returns `None` if no mock is registered for this function.
208    pub fn get(&self, tool: &str, function: &str) -> Option<MockResponse> {
209        let key = format!("{}.{}", tool, function);
210        let mocks = self.mocks.lock().unwrap();
211        mocks.get(&key).and_then(|q| q.pop())
212    }
213
214    /// Check if a mock is registered for a tool function.
215    pub fn has_mock(&self, tool: &str, function: &str) -> bool {
216        let key = format!("{}.{}", tool, function);
217        let mocks = self.mocks.lock().unwrap();
218        mocks.get(&key).is_some_and(|q| !q.is_empty())
219    }
220
221    /// Call a mocked tool function and return the result.
222    ///
223    /// Returns an error if no mock is registered.
224    pub async fn call<T>(&self, tool: &str, function: &str) -> SageResult<T>
225    where
226        T: DeserializeOwned,
227    {
228        match self.get(tool, function) {
229            Some(MockResponse::Value(value)) => serde_json::from_value(value).map_err(|e| {
230                SageError::Tool(format!("failed to deserialize mock tool response: {e}"))
231            }),
232            Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
233            None => Err(SageError::Tool(format!(
234                "no mock registered for {}.{}",
235                tool, function
236            ))),
237        }
238    }
239
240    /// Call a mocked tool function and return the raw string.
241    pub async fn call_string(&self, tool: &str, function: &str) -> SageResult<String> {
242        match self.get(tool, function) {
243            Some(MockResponse::Value(value)) => match value {
244                serde_json::Value::String(s) => Ok(s),
245                other => Ok(other.to_string()),
246            },
247            Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
248            None => Err(SageError::Tool(format!(
249                "no mock registered for {}.{}",
250                tool, function
251            ))),
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[tokio::test]
261    async fn mock_infer_string_returns_value() {
262        let client = MockLlmClient::with_responses(vec![MockResponse::string("hello world")]);
263        let result = client.infer_string("test").await.unwrap();
264        assert_eq!(result, "hello world");
265    }
266
267    #[tokio::test]
268    async fn mock_infer_string_returns_fail() {
269        let client = MockLlmClient::with_responses(vec![MockResponse::fail("test error")]);
270        let result = client.infer_string("test").await;
271        assert!(result.is_err());
272        assert!(result.unwrap_err().to_string().contains("test error"));
273    }
274
275    #[tokio::test]
276    async fn mock_infer_empty_queue_returns_error() {
277        let client = MockLlmClient::new();
278        let result = client.infer_string("test").await;
279        assert!(result.is_err());
280        assert!(result.unwrap_err().to_string().contains("E054"));
281    }
282
283    #[tokio::test]
284    async fn mock_queue_fifo_order() {
285        let client = MockLlmClient::with_responses(vec![
286            MockResponse::string("first"),
287            MockResponse::string("second"),
288            MockResponse::string("third"),
289        ]);
290
291        assert_eq!(client.infer_string("a").await.unwrap(), "first");
292        assert_eq!(client.infer_string("b").await.unwrap(), "second");
293        assert_eq!(client.infer_string("c").await.unwrap(), "third");
294        assert!(client.infer_string("d").await.is_err());
295    }
296
297    #[tokio::test]
298    async fn mock_infer_typed_value() {
299        #[derive(Debug, serde::Deserialize, PartialEq)]
300        struct Person {
301            name: String,
302            age: i32,
303        }
304
305        let client = MockLlmClient::with_responses(vec![MockResponse::value(
306            serde_json::json!({ "name": "Ward", "age": 42 }),
307        )]);
308
309        let person: Person = client.infer("test").await.unwrap();
310        assert_eq!(person.name, "Ward");
311        assert_eq!(person.age, 42);
312    }
313
314    #[test]
315    fn mock_queue_thread_safe() {
316        use std::thread;
317
318        let queue = MockQueue::with_responses(vec![
319            MockResponse::string("1"),
320            MockResponse::string("2"),
321            MockResponse::string("3"),
322        ]);
323
324        let queue_clone = queue.clone();
325        let handle = thread::spawn(move || {
326            queue_clone.pop();
327            queue_clone.pop();
328        });
329
330        handle.join().unwrap();
331        assert_eq!(queue.len(), 1);
332    }
333
334    #[tokio::test]
335    async fn mock_infer_structured() {
336        #[derive(Debug, serde::Deserialize, PartialEq)]
337        struct Summary {
338            text: String,
339            confidence: f64,
340        }
341
342        let client = MockLlmClient::with_responses(vec![MockResponse::value(serde_json::json!({
343            "text": "A summary",
344            "confidence": 0.95
345        }))]);
346
347        let summary: Summary = client
348            .infer_structured("summarize", "schema")
349            .await
350            .unwrap();
351        assert_eq!(summary.text, "A summary");
352        assert!((summary.confidence - 0.95).abs() < 0.001);
353    }
354
355    #[tokio::test]
356    async fn mock_tool_registry_basic() {
357        let registry = MockToolRegistry::new();
358
359        // Register a mock
360        registry.register("Http", "get", MockResponse::string("mocked response"));
361
362        // Should have mock
363        assert!(registry.has_mock("Http", "get"));
364
365        // Call and get result
366        let result: String = registry.call("Http", "get").await.unwrap();
367        assert_eq!(result, "mocked response");
368
369        // Queue should be empty now
370        assert!(!registry.has_mock("Http", "get"));
371    }
372
373    #[tokio::test]
374    async fn mock_tool_registry_multiple() {
375        let registry = MockToolRegistry::new();
376
377        // Register multiple mocks for same function
378        registry.register("Http", "get", MockResponse::string("first"));
379        registry.register("Http", "get", MockResponse::string("second"));
380
381        // Should get them in order
382        let r1: String = registry.call("Http", "get").await.unwrap();
383        let r2: String = registry.call("Http", "get").await.unwrap();
384
385        assert_eq!(r1, "first");
386        assert_eq!(r2, "second");
387    }
388
389    #[tokio::test]
390    async fn mock_tool_registry_fail() {
391        let registry = MockToolRegistry::new();
392        registry.register("Http", "get", MockResponse::fail("network error"));
393
394        let result: Result<String, _> = registry.call("Http", "get").await;
395        assert!(result.is_err());
396        assert!(result.unwrap_err().to_string().contains("network error"));
397    }
398
399    #[tokio::test]
400    async fn mock_tool_registry_no_mock() {
401        let registry = MockToolRegistry::new();
402
403        let result: Result<String, _> = registry.call("Http", "get").await;
404        assert!(result.is_err());
405        assert!(result
406            .unwrap_err()
407            .to_string()
408            .contains("no mock registered"));
409    }
410}