1use crate::error::{SageError, SageResult};
11use serde::de::DeserializeOwned;
12use std::cell::RefCell;
13use std::future::Future;
14use std::sync::{Arc, Mutex};
15
16#[cfg(not(target_arch = "wasm32"))]
21tokio::task_local! {
22 static MOCK_TOOL_REGISTRY: RefCell<Option<MockToolRegistry>>;
23}
24
25#[cfg(target_arch = "wasm32")]
26thread_local! {
27 static MOCK_TOOL_REGISTRY: RefCell<Option<MockToolRegistry>> = const { RefCell::new(None) };
28}
29
30#[cfg(not(target_arch = "wasm32"))]
35pub async fn with_mock_tools<F, R>(registry: MockToolRegistry, f: F) -> R
36where
37 F: Future<Output = R>,
38{
39 MOCK_TOOL_REGISTRY
40 .scope(RefCell::new(Some(registry)), f)
41 .await
42}
43
44#[cfg(target_arch = "wasm32")]
46pub async fn with_mock_tools<F, R>(registry: MockToolRegistry, f: F) -> R
47where
48 F: Future<Output = R>,
49{
50 MOCK_TOOL_REGISTRY.with(|cell| {
51 *cell.borrow_mut() = Some(registry);
52 });
53 let result = f.await;
54 MOCK_TOOL_REGISTRY.with(|cell| {
55 *cell.borrow_mut() = None;
56 });
57 result
58}
59
60#[cfg(not(target_arch = "wasm32"))]
67pub fn try_get_mock(tool: &str, function: &str) -> Option<MockResponse> {
68 MOCK_TOOL_REGISTRY
69 .try_with(|cell| {
70 cell.borrow_mut()
71 .as_ref()
72 .and_then(|reg| reg.get(tool, function))
73 })
74 .ok()
75 .flatten()
76}
77
78#[cfg(target_arch = "wasm32")]
80pub fn try_get_mock(tool: &str, function: &str) -> Option<MockResponse> {
81 MOCK_TOOL_REGISTRY.with(|cell| {
82 cell.borrow()
83 .as_ref()
84 .and_then(|reg| reg.get(tool, function))
85 })
86}
87
88#[derive(Debug, Clone)]
90pub enum MockResponse {
91 Value(serde_json::Value),
93 Fail(String),
95}
96
97impl MockResponse {
98 pub fn value<T: serde::Serialize>(value: T) -> Self {
100 Self::Value(serde_json::to_value(value).expect("failed to serialize mock value"))
101 }
102
103 pub fn string(s: impl Into<String>) -> Self {
105 Self::Value(serde_json::Value::String(s.into()))
106 }
107
108 pub fn fail(message: impl Into<String>) -> Self {
110 Self::Fail(message.into())
111 }
112}
113
114#[derive(Debug, Clone, Default)]
119pub struct MockQueue {
120 responses: Arc<Mutex<Vec<MockResponse>>>,
121}
122
123impl MockQueue {
124 pub fn new() -> Self {
126 Self::default()
127 }
128
129 pub fn with_responses(responses: Vec<MockResponse>) -> Self {
131 Self {
132 responses: Arc::new(Mutex::new(responses)),
133 }
134 }
135
136 pub fn push(&self, response: MockResponse) {
138 self.responses.lock().unwrap().push(response);
139 }
140
141 pub fn pop(&self) -> Option<MockResponse> {
145 let mut queue = self.responses.lock().unwrap();
146 if queue.is_empty() {
147 None
148 } else {
149 Some(queue.remove(0))
150 }
151 }
152
153 pub fn is_empty(&self) -> bool {
155 self.responses.lock().unwrap().is_empty()
156 }
157
158 pub fn len(&self) -> usize {
160 self.responses.lock().unwrap().len()
161 }
162}
163
164#[derive(Debug, Clone)]
169pub struct MockLlmClient {
170 queue: MockQueue,
171}
172
173impl MockLlmClient {
174 pub fn new() -> Self {
176 Self {
177 queue: MockQueue::new(),
178 }
179 }
180
181 pub fn with_responses(responses: Vec<MockResponse>) -> Self {
183 Self {
184 queue: MockQueue::with_responses(responses),
185 }
186 }
187
188 pub fn queue(&self) -> &MockQueue {
190 &self.queue
191 }
192
193 pub async fn infer_string(&self, _prompt: &str) -> SageResult<String> {
197 match self.queue.pop() {
198 Some(MockResponse::Value(value)) => {
199 match value {
201 serde_json::Value::String(s) => Ok(s),
202 other => Ok(other.to_string()),
203 }
204 }
205 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
206 None => Err(SageError::Llm(
207 "infer called with no mock available (E054)".to_string(),
208 )),
209 }
210 }
211
212 pub async fn infer<T>(&self, _prompt: &str) -> SageResult<T>
216 where
217 T: DeserializeOwned,
218 {
219 match self.queue.pop() {
220 Some(MockResponse::Value(value)) => serde_json::from_value(value)
221 .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
222 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
223 None => Err(SageError::Llm(
224 "infer called with no mock available (E054)".to_string(),
225 )),
226 }
227 }
228
229 pub async fn infer_structured<T>(&self, _prompt: &str, _schema: &str) -> SageResult<T>
233 where
234 T: DeserializeOwned,
235 {
236 match self.queue.pop() {
238 Some(MockResponse::Value(value)) => serde_json::from_value(value)
239 .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
240 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
241 None => Err(SageError::Llm(
242 "infer called with no mock available (E054)".to_string(),
243 )),
244 }
245 }
246}
247
248impl Default for MockLlmClient {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254#[derive(Debug, Clone, Default)]
258pub struct MockToolRegistry {
259 mocks: Arc<Mutex<std::collections::HashMap<String, MockQueue>>>,
260}
261
262impl MockToolRegistry {
263 pub fn new() -> Self {
265 Self::default()
266 }
267
268 pub fn register(&self, tool: &str, function: &str, response: MockResponse) {
272 let key = format!("{}.{}", tool, function);
273 let mut mocks = self.mocks.lock().unwrap();
274 mocks.entry(key).or_default().push(response);
275 }
276
277 pub fn get(&self, tool: &str, function: &str) -> Option<MockResponse> {
281 let key = format!("{}.{}", tool, function);
282 let mocks = self.mocks.lock().unwrap();
283 mocks.get(&key).and_then(|q| q.pop())
284 }
285
286 pub fn has_mock(&self, tool: &str, function: &str) -> bool {
288 let key = format!("{}.{}", tool, function);
289 let mocks = self.mocks.lock().unwrap();
290 mocks.get(&key).is_some_and(|q| !q.is_empty())
291 }
292
293 pub async fn call<T>(&self, tool: &str, function: &str) -> SageResult<T>
297 where
298 T: DeserializeOwned,
299 {
300 match self.get(tool, function) {
301 Some(MockResponse::Value(value)) => serde_json::from_value(value).map_err(|e| {
302 SageError::Tool(format!("failed to deserialize mock tool response: {e}"))
303 }),
304 Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
305 None => Err(SageError::Tool(format!(
306 "no mock registered for {}.{}",
307 tool, function
308 ))),
309 }
310 }
311
312 pub async fn call_string(&self, tool: &str, function: &str) -> SageResult<String> {
314 match self.get(tool, function) {
315 Some(MockResponse::Value(value)) => match value {
316 serde_json::Value::String(s) => Ok(s),
317 other => Ok(other.to_string()),
318 },
319 Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
320 None => Err(SageError::Tool(format!(
321 "no mock registered for {}.{}",
322 tool, function
323 ))),
324 }
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[tokio::test]
333 async fn mock_infer_string_returns_value() {
334 let client = MockLlmClient::with_responses(vec![MockResponse::string("hello world")]);
335 let result = client.infer_string("test").await.unwrap();
336 assert_eq!(result, "hello world");
337 }
338
339 #[tokio::test]
340 async fn mock_infer_string_returns_fail() {
341 let client = MockLlmClient::with_responses(vec![MockResponse::fail("test error")]);
342 let result = client.infer_string("test").await;
343 assert!(result.is_err());
344 assert!(result.unwrap_err().to_string().contains("test error"));
345 }
346
347 #[tokio::test]
348 async fn mock_infer_empty_queue_returns_error() {
349 let client = MockLlmClient::new();
350 let result = client.infer_string("test").await;
351 assert!(result.is_err());
352 assert!(result.unwrap_err().to_string().contains("E054"));
353 }
354
355 #[tokio::test]
356 async fn mock_queue_fifo_order() {
357 let client = MockLlmClient::with_responses(vec![
358 MockResponse::string("first"),
359 MockResponse::string("second"),
360 MockResponse::string("third"),
361 ]);
362
363 assert_eq!(client.infer_string("a").await.unwrap(), "first");
364 assert_eq!(client.infer_string("b").await.unwrap(), "second");
365 assert_eq!(client.infer_string("c").await.unwrap(), "third");
366 assert!(client.infer_string("d").await.is_err());
367 }
368
369 #[tokio::test]
370 async fn mock_infer_typed_value() {
371 #[derive(Debug, serde::Deserialize, PartialEq)]
372 struct Person {
373 name: String,
374 age: i32,
375 }
376
377 let client = MockLlmClient::with_responses(vec![MockResponse::value(
378 serde_json::json!({ "name": "Ward", "age": 42 }),
379 )]);
380
381 let person: Person = client.infer("test").await.unwrap();
382 assert_eq!(person.name, "Ward");
383 assert_eq!(person.age, 42);
384 }
385
386 #[test]
387 fn mock_queue_thread_safe() {
388 use std::thread;
389
390 let queue = MockQueue::with_responses(vec![
391 MockResponse::string("1"),
392 MockResponse::string("2"),
393 MockResponse::string("3"),
394 ]);
395
396 let queue_clone = queue.clone();
397 let handle = thread::spawn(move || {
398 queue_clone.pop();
399 queue_clone.pop();
400 });
401
402 handle.join().unwrap();
403 assert_eq!(queue.len(), 1);
404 }
405
406 #[tokio::test]
407 async fn mock_infer_structured() {
408 #[derive(Debug, serde::Deserialize, PartialEq)]
409 struct Summary {
410 text: String,
411 confidence: f64,
412 }
413
414 let client = MockLlmClient::with_responses(vec![MockResponse::value(serde_json::json!({
415 "text": "A summary",
416 "confidence": 0.95
417 }))]);
418
419 let summary: Summary = client
420 .infer_structured("summarize", "schema")
421 .await
422 .unwrap();
423 assert_eq!(summary.text, "A summary");
424 assert!((summary.confidence - 0.95).abs() < 0.001);
425 }
426
427 #[tokio::test]
428 async fn mock_tool_registry_basic() {
429 let registry = MockToolRegistry::new();
430
431 registry.register("Http", "get", MockResponse::string("mocked response"));
433
434 assert!(registry.has_mock("Http", "get"));
436
437 let result: String = registry.call("Http", "get").await.unwrap();
439 assert_eq!(result, "mocked response");
440
441 assert!(!registry.has_mock("Http", "get"));
443 }
444
445 #[tokio::test]
446 async fn mock_tool_registry_multiple() {
447 let registry = MockToolRegistry::new();
448
449 registry.register("Http", "get", MockResponse::string("first"));
451 registry.register("Http", "get", MockResponse::string("second"));
452
453 let r1: String = registry.call("Http", "get").await.unwrap();
455 let r2: String = registry.call("Http", "get").await.unwrap();
456
457 assert_eq!(r1, "first");
458 assert_eq!(r2, "second");
459 }
460
461 #[tokio::test]
462 async fn mock_tool_registry_fail() {
463 let registry = MockToolRegistry::new();
464 registry.register("Http", "get", MockResponse::fail("network error"));
465
466 let result: Result<String, _> = registry.call("Http", "get").await;
467 assert!(result.is_err());
468 assert!(result.unwrap_err().to_string().contains("network error"));
469 }
470
471 #[tokio::test]
472 async fn mock_tool_registry_no_mock() {
473 let registry = MockToolRegistry::new();
474
475 let result: Result<String, _> = registry.call("Http", "get").await;
476 assert!(result.is_err());
477 assert!(result
478 .unwrap_err()
479 .to_string()
480 .contains("no mock registered"));
481 }
482}