Skip to main content

agent_runtime/llm/
mock.rs

1#[cfg(test)]
2use crate::llm::types::ChatMessage;
3use crate::llm::types::{ChatRequest, ChatResponse, FunctionCall, ToolCall, Usage};
4use crate::llm::{ChatClient, LlmError};
5use async_trait::async_trait;
6#[cfg(test)]
7use serde_json::json;
8use serde_json::Value;
9use std::sync::{Arc, Mutex};
10use tokio::sync::mpsc;
11///
12/// Supports:
13/// - Predefined responses
14/// - Tool call simulation
15/// - Streaming simulation
16/// - Call tracking
17/// - Error injection
18#[derive(Clone)]
19pub struct MockLlmClient {
20    responses: Arc<Mutex<Vec<MockResponse>>>,
21    calls: Arc<Mutex<Vec<ChatRequest>>>,
22    error_on_call: Arc<Mutex<Option<usize>>>, // Fail on nth call
23}
24
25/// Mock response configuration
26#[derive(Clone, Debug)]
27pub struct MockResponse {
28    pub content: String,
29    pub tool_calls: Vec<MockToolCall>,
30    pub finish_reason: String,
31}
32
33#[derive(Clone, Debug)]
34pub struct MockToolCall {
35    pub name: String,
36    pub arguments: Value,
37}
38
39impl Default for MockLlmClient {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl MockLlmClient {
46    /// Create a new empty mock client (call with_response() to add responses)
47    pub fn new() -> Self {
48        Self {
49            responses: Arc::new(Mutex::new(Vec::new())),
50            calls: Arc::new(Mutex::new(Vec::new())),
51            error_on_call: Arc::new(Mutex::new(None)),
52        }
53    }
54
55    /// Create a new mock client with simple text responses
56    pub fn with_responses_vec(responses: Vec<&str>) -> Self {
57        Self {
58            responses: Arc::new(Mutex::new(
59                responses.iter().map(|r| MockResponse::text(r)).collect(),
60            )),
61            calls: Arc::new(Mutex::new(Vec::new())),
62            error_on_call: Arc::new(Mutex::new(None)),
63        }
64    }
65
66    /// Create a mock client with detailed responses
67    pub fn from_mock_responses(responses: Vec<MockResponse>) -> Self {
68        Self {
69            responses: Arc::new(Mutex::new(responses)),
70            calls: Arc::new(Mutex::new(Vec::new())),
71            error_on_call: Arc::new(Mutex::new(None)),
72        }
73    }
74
75    /// Create a mock client that calls a specific tool
76    pub fn from_tool_call(tool_name: &str, args: Value) -> Self {
77        let response = MockResponse::with_tool_call(tool_name, args);
78        Self::from_mock_responses(vec![response])
79    }
80
81    /// Create a mock client that calls a tool then responds with text
82    pub fn with_tool_then_text(tool_name: &str, args: Value, final_response: &str) -> Self {
83        Self::from_mock_responses(vec![
84            MockResponse::with_tool_call(tool_name, args),
85            MockResponse::text(final_response),
86        ])
87    }
88
89    /// Add a text response to the chain
90    pub fn with_response(self, text: &str) -> Self {
91        self.responses
92            .lock()
93            .unwrap()
94            .push(MockResponse::text(text));
95        self
96    }
97
98    /// Add a tool call response to the chain
99    pub fn with_tool_call(self, tool_name: &str, args: Value) -> Self {
100        self.responses
101            .lock()
102            .unwrap()
103            .push(MockResponse::with_tool_call(tool_name, args));
104        self
105    }
106
107    /// Set the client to error on a specific call index
108    pub fn error_on_call(self, call_index: usize) -> Self {
109        *self.error_on_call.lock().unwrap() = Some(call_index);
110        self
111    }
112
113    /// Set the client to fail on the nth call (0-indexed)
114    pub fn fail_on_call(&self, call_index: usize) {
115        *self.error_on_call.lock().unwrap() = Some(call_index);
116    }
117
118    /// Get the number of calls made
119    pub fn call_count(&self) -> usize {
120        self.calls.lock().unwrap().len()
121    }
122
123    /// Get a copy of all calls made
124    pub fn get_calls(&self) -> Vec<ChatRequest> {
125        self.calls.lock().unwrap().clone()
126    }
127
128    /// Get the last call made
129    pub fn last_call(&self) -> Option<ChatRequest> {
130        self.calls.lock().unwrap().last().cloned()
131    }
132
133    /// Clear call history
134    pub fn clear_calls(&self) {
135        self.calls.lock().unwrap().clear();
136    }
137}
138
139impl MockResponse {
140    /// Simple text response
141    pub fn text(content: &str) -> Self {
142        Self {
143            content: content.to_string(),
144            tool_calls: vec![],
145            finish_reason: "stop".to_string(),
146        }
147    }
148
149    /// Response with a tool call
150    pub fn with_tool_call(tool_name: &str, arguments: Value) -> Self {
151        Self {
152            content: String::new(),
153            tool_calls: vec![MockToolCall {
154                name: tool_name.to_string(),
155                arguments,
156            }],
157            finish_reason: "tool_calls".to_string(),
158        }
159    }
160
161    /// Response with multiple tool calls
162    pub fn with_tool_calls(tool_calls: Vec<(&str, Value)>) -> Self {
163        Self {
164            content: String::new(),
165            tool_calls: tool_calls
166                .into_iter()
167                .map(|(name, args)| MockToolCall {
168                    name: name.to_string(),
169                    arguments: args,
170                })
171                .collect(),
172            finish_reason: "tool_calls".to_string(),
173        }
174    }
175}
176
177#[async_trait]
178impl ChatClient for MockLlmClient {
179    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
180        // Record the call
181        self.calls.lock().unwrap().push(request.clone());
182
183        // Check if we should fail on this call
184        let call_index = self.calls.lock().unwrap().len() - 1;
185        if let Some(fail_index) = *self.error_on_call.lock().unwrap() {
186            if call_index == fail_index {
187                return Err(LlmError::NetworkError("Mock network error".to_string()));
188            }
189        }
190
191        // Get the next response
192        let mut responses = self.responses.lock().unwrap();
193        if responses.is_empty() {
194            return Ok(ChatResponse {
195                content: "No more mock responses available".to_string(),
196                model: "mock-model".to_string(),
197                tool_calls: None,
198                finish_reason: Some("stop".to_string()),
199                usage: Some(Usage {
200                    prompt_tokens: 10,
201                    completion_tokens: 5,
202                    total_tokens: 15,
203                }),
204            });
205        }
206
207        let mock_response = responses.remove(0);
208
209        // Convert mock tool calls to actual tool calls
210        let tool_calls = if mock_response.tool_calls.is_empty() {
211            None
212        } else {
213            Some(
214                mock_response
215                    .tool_calls
216                    .iter()
217                    .enumerate()
218                    .map(|(i, tc)| ToolCall {
219                        id: format!("call_{}", i),
220                        r#type: "function".to_string(),
221                        function: FunctionCall {
222                            name: tc.name.clone(),
223                            arguments: serde_json::to_string(&tc.arguments).unwrap(),
224                        },
225                    })
226                    .collect(),
227            )
228        };
229
230        Ok(ChatResponse {
231            content: mock_response.content,
232            model: "mock-model".to_string(),
233            tool_calls,
234            finish_reason: Some(mock_response.finish_reason),
235            usage: Some(Usage {
236                prompt_tokens: 10,
237                completion_tokens: 5,
238                total_tokens: 15,
239            }),
240        })
241    }
242
243    async fn chat_stream(
244        &self,
245        request: ChatRequest,
246        tx: mpsc::Sender<String>,
247    ) -> Result<ChatResponse, LlmError> {
248        // For streaming, just send the response in chunks
249        let response = self.chat(request).await?;
250
251        // Simulate streaming by sending content word by word
252        for word in response.content.split_whitespace() {
253            let _ = tx.send(format!("{} ", word)).await;
254        }
255
256        Ok(response)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[tokio::test]
265    async fn test_mock_client_simple_response() {
266        let client = MockLlmClient::with_responses_vec(vec!["Hello, world!"]);
267
268        let request = ChatRequest::new(vec![ChatMessage::user("Hi")]);
269
270        let response = client.chat(request).await.unwrap();
271        assert_eq!(response.content, "Hello, world!");
272        assert_eq!(client.call_count(), 1);
273    }
274
275    #[tokio::test]
276    async fn test_mock_client_multiple_responses() {
277        let client = MockLlmClient::with_responses_vec(vec!["First", "Second", "Third"]);
278
279        let request = ChatRequest::new(vec![ChatMessage::user("Hi")]);
280
281        let r1 = client.chat(request.clone()).await.unwrap();
282        assert_eq!(r1.content, "First");
283
284        let r2 = client.chat(request.clone()).await.unwrap();
285        assert_eq!(r2.content, "Second");
286
287        let r3 = client.chat(request.clone()).await.unwrap();
288        assert_eq!(r3.content, "Third");
289
290        assert_eq!(client.call_count(), 3);
291    }
292
293    #[tokio::test]
294    async fn test_mock_client_tool_call() {
295        let client = MockLlmClient::from_tool_call(
296            "calculator",
297            json!({"operation": "add", "a": 5, "b": 3}),
298        );
299
300        let request = ChatRequest::new(vec![ChatMessage::user("What is 5 + 3?")]);
301
302        let response = client.chat(request).await.unwrap();
303        assert!(response.tool_calls.is_some());
304
305        let tool_calls = response.tool_calls.unwrap();
306        assert_eq!(tool_calls.len(), 1);
307        assert_eq!(tool_calls[0].function.name, "calculator");
308
309        let args: Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
310        assert_eq!(args["operation"], "add");
311        assert_eq!(args["a"], 5);
312        assert_eq!(args["b"], 3);
313    }
314
315    #[tokio::test]
316    async fn test_mock_client_error_injection() {
317        let client = MockLlmClient::with_responses_vec(vec!["First", "Second", "Third"]);
318        client.fail_on_call(1); // Fail on second call
319
320        let request = ChatRequest::new(vec![ChatMessage::user("Hi")]);
321
322        // First call succeeds
323        let r1 = client.chat(request.clone()).await;
324        assert!(r1.is_ok());
325
326        // Second call fails
327        let r2 = client.chat(request.clone()).await;
328        assert!(r2.is_err());
329
330        // Third call succeeds
331        let r3 = client.chat(request.clone()).await;
332        assert!(r3.is_ok());
333    }
334
335    #[tokio::test]
336    async fn test_mock_client_call_tracking() {
337        let client = MockLlmClient::with_responses_vec(vec!["Response 1", "Response 2"]);
338
339        let req1 = ChatRequest::new(vec![ChatMessage::user("Question 1")]);
340        let req2 = ChatRequest::new(vec![ChatMessage::user("Question 2")]);
341
342        client.chat(req1).await.unwrap();
343        client.chat(req2).await.unwrap();
344
345        let calls = client.get_calls();
346        assert_eq!(calls.len(), 2);
347        assert_eq!(calls[0].messages[0].content, "Question 1");
348        assert_eq!(calls[1].messages[0].content, "Question 2");
349    }
350}