1use crate::error::{SageError, SageResult};
11use serde::de::DeserializeOwned;
12use std::cell::RefCell;
13use std::future::Future;
14use std::sync::{Arc, Mutex};
15
16tokio::task_local! {
19 static MOCK_TOOL_REGISTRY: RefCell<Option<MockToolRegistry>>;
20}
21
22pub async fn with_mock_tools<F, R>(registry: MockToolRegistry, f: F) -> R
37where
38 F: Future<Output = R>,
39{
40 MOCK_TOOL_REGISTRY
41 .scope(RefCell::new(Some(registry)), f)
42 .await
43}
44
45pub fn try_get_mock(tool: &str, function: &str) -> Option<MockResponse> {
52 MOCK_TOOL_REGISTRY
53 .try_with(|cell| {
54 cell.borrow_mut()
55 .as_ref()
56 .and_then(|reg| reg.get(tool, function))
57 })
58 .ok()
59 .flatten()
60}
61
62#[derive(Debug, Clone)]
64pub enum MockResponse {
65 Value(serde_json::Value),
67 Fail(String),
69}
70
71impl MockResponse {
72 pub fn value<T: serde::Serialize>(value: T) -> Self {
74 Self::Value(serde_json::to_value(value).expect("failed to serialize mock value"))
75 }
76
77 pub fn string(s: impl Into<String>) -> Self {
79 Self::Value(serde_json::Value::String(s.into()))
80 }
81
82 pub fn fail(message: impl Into<String>) -> Self {
84 Self::Fail(message.into())
85 }
86}
87
88#[derive(Debug, Clone, Default)]
93pub struct MockQueue {
94 responses: Arc<Mutex<Vec<MockResponse>>>,
95}
96
97impl MockQueue {
98 pub fn new() -> Self {
100 Self::default()
101 }
102
103 pub fn with_responses(responses: Vec<MockResponse>) -> Self {
105 Self {
106 responses: Arc::new(Mutex::new(responses)),
107 }
108 }
109
110 pub fn push(&self, response: MockResponse) {
112 self.responses.lock().unwrap().push(response);
113 }
114
115 pub fn pop(&self) -> Option<MockResponse> {
119 let mut queue = self.responses.lock().unwrap();
120 if queue.is_empty() {
121 None
122 } else {
123 Some(queue.remove(0))
124 }
125 }
126
127 pub fn is_empty(&self) -> bool {
129 self.responses.lock().unwrap().is_empty()
130 }
131
132 pub fn len(&self) -> usize {
134 self.responses.lock().unwrap().len()
135 }
136}
137
138#[derive(Debug, Clone)]
143pub struct MockLlmClient {
144 queue: MockQueue,
145}
146
147impl MockLlmClient {
148 pub fn new() -> Self {
150 Self {
151 queue: MockQueue::new(),
152 }
153 }
154
155 pub fn with_responses(responses: Vec<MockResponse>) -> Self {
157 Self {
158 queue: MockQueue::with_responses(responses),
159 }
160 }
161
162 pub fn queue(&self) -> &MockQueue {
164 &self.queue
165 }
166
167 pub async fn infer_string(&self, _prompt: &str) -> SageResult<String> {
171 match self.queue.pop() {
172 Some(MockResponse::Value(value)) => {
173 match value {
175 serde_json::Value::String(s) => Ok(s),
176 other => Ok(other.to_string()),
177 }
178 }
179 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
180 None => Err(SageError::Llm(
181 "infer called with no mock available (E054)".to_string(),
182 )),
183 }
184 }
185
186 pub async fn infer<T>(&self, _prompt: &str) -> SageResult<T>
190 where
191 T: DeserializeOwned,
192 {
193 match self.queue.pop() {
194 Some(MockResponse::Value(value)) => serde_json::from_value(value)
195 .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
196 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
197 None => Err(SageError::Llm(
198 "infer called with no mock available (E054)".to_string(),
199 )),
200 }
201 }
202
203 pub async fn infer_structured<T>(&self, _prompt: &str, _schema: &str) -> SageResult<T>
207 where
208 T: DeserializeOwned,
209 {
210 match self.queue.pop() {
212 Some(MockResponse::Value(value)) => serde_json::from_value(value)
213 .map_err(|e| SageError::Llm(format!("failed to deserialize mock value: {e}"))),
214 Some(MockResponse::Fail(msg)) => Err(SageError::Llm(msg)),
215 None => Err(SageError::Llm(
216 "infer called with no mock available (E054)".to_string(),
217 )),
218 }
219 }
220}
221
222impl Default for MockLlmClient {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228#[derive(Debug, Clone, Default)]
232pub struct MockToolRegistry {
233 mocks: Arc<Mutex<std::collections::HashMap<String, MockQueue>>>,
234}
235
236impl MockToolRegistry {
237 pub fn new() -> Self {
239 Self::default()
240 }
241
242 pub fn register(&self, tool: &str, function: &str, response: MockResponse) {
246 let key = format!("{}.{}", tool, function);
247 let mut mocks = self.mocks.lock().unwrap();
248 mocks
249 .entry(key)
250 .or_insert_with(MockQueue::new)
251 .push(response);
252 }
253
254 pub fn get(&self, tool: &str, function: &str) -> Option<MockResponse> {
258 let key = format!("{}.{}", tool, function);
259 let mocks = self.mocks.lock().unwrap();
260 mocks.get(&key).and_then(|q| q.pop())
261 }
262
263 pub fn has_mock(&self, tool: &str, function: &str) -> bool {
265 let key = format!("{}.{}", tool, function);
266 let mocks = self.mocks.lock().unwrap();
267 mocks.get(&key).is_some_and(|q| !q.is_empty())
268 }
269
270 pub async fn call<T>(&self, tool: &str, function: &str) -> SageResult<T>
274 where
275 T: DeserializeOwned,
276 {
277 match self.get(tool, function) {
278 Some(MockResponse::Value(value)) => serde_json::from_value(value).map_err(|e| {
279 SageError::Tool(format!("failed to deserialize mock tool response: {e}"))
280 }),
281 Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
282 None => Err(SageError::Tool(format!(
283 "no mock registered for {}.{}",
284 tool, function
285 ))),
286 }
287 }
288
289 pub async fn call_string(&self, tool: &str, function: &str) -> SageResult<String> {
291 match self.get(tool, function) {
292 Some(MockResponse::Value(value)) => match value {
293 serde_json::Value::String(s) => Ok(s),
294 other => Ok(other.to_string()),
295 },
296 Some(MockResponse::Fail(msg)) => Err(SageError::Tool(msg)),
297 None => Err(SageError::Tool(format!(
298 "no mock registered for {}.{}",
299 tool, function
300 ))),
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[tokio::test]
310 async fn mock_infer_string_returns_value() {
311 let client = MockLlmClient::with_responses(vec![MockResponse::string("hello world")]);
312 let result = client.infer_string("test").await.unwrap();
313 assert_eq!(result, "hello world");
314 }
315
316 #[tokio::test]
317 async fn mock_infer_string_returns_fail() {
318 let client = MockLlmClient::with_responses(vec![MockResponse::fail("test error")]);
319 let result = client.infer_string("test").await;
320 assert!(result.is_err());
321 assert!(result.unwrap_err().to_string().contains("test error"));
322 }
323
324 #[tokio::test]
325 async fn mock_infer_empty_queue_returns_error() {
326 let client = MockLlmClient::new();
327 let result = client.infer_string("test").await;
328 assert!(result.is_err());
329 assert!(result.unwrap_err().to_string().contains("E054"));
330 }
331
332 #[tokio::test]
333 async fn mock_queue_fifo_order() {
334 let client = MockLlmClient::with_responses(vec![
335 MockResponse::string("first"),
336 MockResponse::string("second"),
337 MockResponse::string("third"),
338 ]);
339
340 assert_eq!(client.infer_string("a").await.unwrap(), "first");
341 assert_eq!(client.infer_string("b").await.unwrap(), "second");
342 assert_eq!(client.infer_string("c").await.unwrap(), "third");
343 assert!(client.infer_string("d").await.is_err());
344 }
345
346 #[tokio::test]
347 async fn mock_infer_typed_value() {
348 #[derive(Debug, serde::Deserialize, PartialEq)]
349 struct Person {
350 name: String,
351 age: i32,
352 }
353
354 let client = MockLlmClient::with_responses(vec![MockResponse::value(
355 serde_json::json!({ "name": "Ward", "age": 42 }),
356 )]);
357
358 let person: Person = client.infer("test").await.unwrap();
359 assert_eq!(person.name, "Ward");
360 assert_eq!(person.age, 42);
361 }
362
363 #[test]
364 fn mock_queue_thread_safe() {
365 use std::thread;
366
367 let queue = MockQueue::with_responses(vec![
368 MockResponse::string("1"),
369 MockResponse::string("2"),
370 MockResponse::string("3"),
371 ]);
372
373 let queue_clone = queue.clone();
374 let handle = thread::spawn(move || {
375 queue_clone.pop();
376 queue_clone.pop();
377 });
378
379 handle.join().unwrap();
380 assert_eq!(queue.len(), 1);
381 }
382
383 #[tokio::test]
384 async fn mock_infer_structured() {
385 #[derive(Debug, serde::Deserialize, PartialEq)]
386 struct Summary {
387 text: String,
388 confidence: f64,
389 }
390
391 let client = MockLlmClient::with_responses(vec![MockResponse::value(serde_json::json!({
392 "text": "A summary",
393 "confidence": 0.95
394 }))]);
395
396 let summary: Summary = client
397 .infer_structured("summarize", "schema")
398 .await
399 .unwrap();
400 assert_eq!(summary.text, "A summary");
401 assert!((summary.confidence - 0.95).abs() < 0.001);
402 }
403
404 #[tokio::test]
405 async fn mock_tool_registry_basic() {
406 let registry = MockToolRegistry::new();
407
408 registry.register("Http", "get", MockResponse::string("mocked response"));
410
411 assert!(registry.has_mock("Http", "get"));
413
414 let result: String = registry.call("Http", "get").await.unwrap();
416 assert_eq!(result, "mocked response");
417
418 assert!(!registry.has_mock("Http", "get"));
420 }
421
422 #[tokio::test]
423 async fn mock_tool_registry_multiple() {
424 let registry = MockToolRegistry::new();
425
426 registry.register("Http", "get", MockResponse::string("first"));
428 registry.register("Http", "get", MockResponse::string("second"));
429
430 let r1: String = registry.call("Http", "get").await.unwrap();
432 let r2: String = registry.call("Http", "get").await.unwrap();
433
434 assert_eq!(r1, "first");
435 assert_eq!(r2, "second");
436 }
437
438 #[tokio::test]
439 async fn mock_tool_registry_fail() {
440 let registry = MockToolRegistry::new();
441 registry.register("Http", "get", MockResponse::fail("network error"));
442
443 let result: Result<String, _> = registry.call("Http", "get").await;
444 assert!(result.is_err());
445 assert!(result.unwrap_err().to_string().contains("network error"));
446 }
447
448 #[tokio::test]
449 async fn mock_tool_registry_no_mock() {
450 let registry = MockToolRegistry::new();
451
452 let result: Result<String, _> = registry.call("Http", "get").await;
453 assert!(result.is_err());
454 assert!(result
455 .unwrap_err()
456 .to_string()
457 .contains("no mock registered"));
458 }
459}