1use crate::error::{SageError, SageResult};
10use serde::de::DeserializeOwned;
11use std::sync::{Arc, Mutex};
12
13#[derive(Debug, Clone)]
15pub enum MockResponse {
16 Value(serde_json::Value),
18 Fail(String),
20}
21
22impl MockResponse {
23 pub fn value<T: serde::Serialize>(value: T) -> Self {
25 Self::Value(serde_json::to_value(value).expect("failed to serialize mock value"))
26 }
27
28 pub fn string(s: impl Into<String>) -> Self {
30 Self::Value(serde_json::Value::String(s.into()))
31 }
32
33 pub fn fail(message: impl Into<String>) -> Self {
35 Self::Fail(message.into())
36 }
37}
38
39#[derive(Debug, Clone, Default)]
44pub struct MockQueue {
45 responses: Arc<Mutex<Vec<MockResponse>>>,
46}
47
48impl MockQueue {
49 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn with_responses(responses: Vec<MockResponse>) -> Self {
56 Self {
57 responses: Arc::new(Mutex::new(responses)),
58 }
59 }
60
61 pub fn push(&self, response: MockResponse) {
63 self.responses.lock().unwrap().push(response);
64 }
65
66 pub fn pop(&self) -> Option<MockResponse> {
70 let mut queue = self.responses.lock().unwrap();
71 if queue.is_empty() {
72 None
73 } else {
74 Some(queue.remove(0))
75 }
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.responses.lock().unwrap().is_empty()
81 }
82
83 pub fn len(&self) -> usize {
85 self.responses.lock().unwrap().len()
86 }
87}
88
89#[derive(Debug, Clone)]
94pub struct MockLlmClient {
95 queue: MockQueue,
96}
97
98impl MockLlmClient {
99 pub fn new() -> Self {
101 Self {
102 queue: MockQueue::new(),
103 }
104 }
105
106 pub fn with_responses(responses: Vec<MockResponse>) -> Self {
108 Self {
109 queue: MockQueue::with_responses(responses),
110 }
111 }
112
113 pub fn queue(&self) -> &MockQueue {
115 &self.queue
116 }
117
118 pub async fn infer_string(&self, _prompt: &str) -> SageResult<String> {
122 match self.queue.pop() {
123 Some(MockResponse::Value(value)) => {
124 match value {
126 serde_json::Value::String(s) => Ok(s),
127 other => Ok(other.to_string()),
128 }
129 }
130 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
131 None => Err(SageError::Llm(
132 "infer called with no mock available (E054)".to_string(),
133 )),
134 }
135 }
136
137 pub async fn infer<T>(&self, _prompt: &str) -> SageResult<T>
141 where
142 T: DeserializeOwned,
143 {
144 match self.queue.pop() {
145 Some(MockResponse::Value(value)) => serde_json::from_value(value)
146 .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
147 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
148 None => Err(SageError::Llm(
149 "infer called with no mock available (E054)".to_string(),
150 )),
151 }
152 }
153
154 pub async fn infer_structured<T>(&self, _prompt: &str, _schema: &str) -> SageResult<T>
158 where
159 T: DeserializeOwned,
160 {
161 match self.queue.pop() {
163 Some(MockResponse::Value(value)) => serde_json::from_value(value)
164 .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
165 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
166 None => Err(SageError::Llm(
167 "infer called with no mock available (E054)".to_string(),
168 )),
169 }
170 }
171}
172
173impl Default for MockLlmClient {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179#[derive(Debug, Clone, Default)]
183pub struct MockToolRegistry {
184 mocks: Arc<Mutex<std::collections::HashMap<String, MockQueue>>>,
185}
186
187impl MockToolRegistry {
188 pub fn new() -> Self {
190 Self::default()
191 }
192
193 pub fn register(&self, tool: &str, function: &str, response: MockResponse) {
197 let key = format!("{}.{}", tool, function);
198 let mut mocks = self.mocks.lock().unwrap();
199 mocks
200 .entry(key)
201 .or_insert_with(MockQueue::new)
202 .push(response);
203 }
204
205 pub fn get(&self, tool: &str, function: &str) -> Option<MockResponse> {
209 let key = format!("{}.{}", tool, function);
210 let mocks = self.mocks.lock().unwrap();
211 mocks.get(&key).and_then(|q| q.pop())
212 }
213
214 pub fn has_mock(&self, tool: &str, function: &str) -> bool {
216 let key = format!("{}.{}", tool, function);
217 let mocks = self.mocks.lock().unwrap();
218 mocks.get(&key).is_some_and(|q| !q.is_empty())
219 }
220
221 pub async fn call<T>(&self, tool: &str, function: &str) -> SageResult<T>
225 where
226 T: DeserializeOwned,
227 {
228 match self.get(tool, function) {
229 Some(MockResponse::Value(value)) => serde_json::from_value(value).map_err(|e| {
230 SageError::Tool(format!("failed to deserialize mock tool response: {e}"))
231 }),
232 Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
233 None => Err(SageError::Tool(format!(
234 "no mock registered for {}.{}",
235 tool, function
236 ))),
237 }
238 }
239
240 pub async fn call_string(&self, tool: &str, function: &str) -> SageResult<String> {
242 match self.get(tool, function) {
243 Some(MockResponse::Value(value)) => match value {
244 serde_json::Value::String(s) => Ok(s),
245 other => Ok(other.to_string()),
246 },
247 Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
248 None => Err(SageError::Tool(format!(
249 "no mock registered for {}.{}",
250 tool, function
251 ))),
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[tokio::test]
261 async fn mock_infer_string_returns_value() {
262 let client = MockLlmClient::with_responses(vec![MockResponse::string("hello world")]);
263 let result = client.infer_string("test").await.unwrap();
264 assert_eq!(result, "hello world");
265 }
266
267 #[tokio::test]
268 async fn mock_infer_string_returns_fail() {
269 let client = MockLlmClient::with_responses(vec![MockResponse::fail("test error")]);
270 let result = client.infer_string("test").await;
271 assert!(result.is_err());
272 assert!(result.unwrap_err().to_string().contains("test error"));
273 }
274
275 #[tokio::test]
276 async fn mock_infer_empty_queue_returns_error() {
277 let client = MockLlmClient::new();
278 let result = client.infer_string("test").await;
279 assert!(result.is_err());
280 assert!(result.unwrap_err().to_string().contains("E054"));
281 }
282
283 #[tokio::test]
284 async fn mock_queue_fifo_order() {
285 let client = MockLlmClient::with_responses(vec![
286 MockResponse::string("first"),
287 MockResponse::string("second"),
288 MockResponse::string("third"),
289 ]);
290
291 assert_eq!(client.infer_string("a").await.unwrap(), "first");
292 assert_eq!(client.infer_string("b").await.unwrap(), "second");
293 assert_eq!(client.infer_string("c").await.unwrap(), "third");
294 assert!(client.infer_string("d").await.is_err());
295 }
296
297 #[tokio::test]
298 async fn mock_infer_typed_value() {
299 #[derive(Debug, serde::Deserialize, PartialEq)]
300 struct Person {
301 name: String,
302 age: i32,
303 }
304
305 let client = MockLlmClient::with_responses(vec![MockResponse::value(
306 serde_json::json!({ "name": "Ward", "age": 42 }),
307 )]);
308
309 let person: Person = client.infer("test").await.unwrap();
310 assert_eq!(person.name, "Ward");
311 assert_eq!(person.age, 42);
312 }
313
314 #[test]
315 fn mock_queue_thread_safe() {
316 use std::thread;
317
318 let queue = MockQueue::with_responses(vec![
319 MockResponse::string("1"),
320 MockResponse::string("2"),
321 MockResponse::string("3"),
322 ]);
323
324 let queue_clone = queue.clone();
325 let handle = thread::spawn(move || {
326 queue_clone.pop();
327 queue_clone.pop();
328 });
329
330 handle.join().unwrap();
331 assert_eq!(queue.len(), 1);
332 }
333
334 #[tokio::test]
335 async fn mock_infer_structured() {
336 #[derive(Debug, serde::Deserialize, PartialEq)]
337 struct Summary {
338 text: String,
339 confidence: f64,
340 }
341
342 let client = MockLlmClient::with_responses(vec![MockResponse::value(serde_json::json!({
343 "text": "A summary",
344 "confidence": 0.95
345 }))]);
346
347 let summary: Summary = client
348 .infer_structured("summarize", "schema")
349 .await
350 .unwrap();
351 assert_eq!(summary.text, "A summary");
352 assert!((summary.confidence - 0.95).abs() < 0.001);
353 }
354
355 #[tokio::test]
356 async fn mock_tool_registry_basic() {
357 let registry = MockToolRegistry::new();
358
359 registry.register("Http", "get", MockResponse::string("mocked response"));
361
362 assert!(registry.has_mock("Http", "get"));
364
365 let result: String = registry.call("Http", "get").await.unwrap();
367 assert_eq!(result, "mocked response");
368
369 assert!(!registry.has_mock("Http", "get"));
371 }
372
373 #[tokio::test]
374 async fn mock_tool_registry_multiple() {
375 let registry = MockToolRegistry::new();
376
377 registry.register("Http", "get", MockResponse::string("first"));
379 registry.register("Http", "get", MockResponse::string("second"));
380
381 let r1: String = registry.call("Http", "get").await.unwrap();
383 let r2: String = registry.call("Http", "get").await.unwrap();
384
385 assert_eq!(r1, "first");
386 assert_eq!(r2, "second");
387 }
388
389 #[tokio::test]
390 async fn mock_tool_registry_fail() {
391 let registry = MockToolRegistry::new();
392 registry.register("Http", "get", MockResponse::fail("network error"));
393
394 let result: Result<String, _> = registry.call("Http", "get").await;
395 assert!(result.is_err());
396 assert!(result.unwrap_err().to_string().contains("network error"));
397 }
398
399 #[tokio::test]
400 async fn mock_tool_registry_no_mock() {
401 let registry = MockToolRegistry::new();
402
403 let result: Result<String, _> = registry.call("Http", "get").await;
404 assert!(result.is_err());
405 assert!(result
406 .unwrap_err()
407 .to_string()
408 .contains("no mock registered"));
409 }
410}