1#[cfg(test)]
2use crate::llm::types::ChatMessage;
3use crate::llm::types::{ChatRequest, ChatResponse, FunctionCall, ToolCall, Usage};
4use crate::llm::{ChatClient, LlmError};
5use async_trait::async_trait;
6#[cfg(test)]
7use serde_json::json;
8use serde_json::Value;
9use std::sync::{Arc, Mutex};
10use tokio::sync::mpsc;
11#[derive(Clone)]
19pub struct MockLlmClient {
20 responses: Arc<Mutex<Vec<MockResponse>>>,
21 calls: Arc<Mutex<Vec<ChatRequest>>>,
22 error_on_call: Arc<Mutex<Option<usize>>>, }
24
25#[derive(Clone, Debug)]
27pub struct MockResponse {
28 pub content: String,
29 pub tool_calls: Vec<MockToolCall>,
30 pub finish_reason: String,
31}
32
33#[derive(Clone, Debug)]
34pub struct MockToolCall {
35 pub name: String,
36 pub arguments: Value,
37}
38
39impl Default for MockLlmClient {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl MockLlmClient {
46 pub fn new() -> Self {
48 Self {
49 responses: Arc::new(Mutex::new(Vec::new())),
50 calls: Arc::new(Mutex::new(Vec::new())),
51 error_on_call: Arc::new(Mutex::new(None)),
52 }
53 }
54
55 pub fn with_responses_vec(responses: Vec<&str>) -> Self {
57 Self {
58 responses: Arc::new(Mutex::new(
59 responses.iter().map(|r| MockResponse::text(r)).collect(),
60 )),
61 calls: Arc::new(Mutex::new(Vec::new())),
62 error_on_call: Arc::new(Mutex::new(None)),
63 }
64 }
65
66 pub fn from_mock_responses(responses: Vec<MockResponse>) -> Self {
68 Self {
69 responses: Arc::new(Mutex::new(responses)),
70 calls: Arc::new(Mutex::new(Vec::new())),
71 error_on_call: Arc::new(Mutex::new(None)),
72 }
73 }
74
75 pub fn from_tool_call(tool_name: &str, args: Value) -> Self {
77 let response = MockResponse::with_tool_call(tool_name, args);
78 Self::from_mock_responses(vec![response])
79 }
80
81 pub fn with_tool_then_text(tool_name: &str, args: Value, final_response: &str) -> Self {
83 Self::from_mock_responses(vec![
84 MockResponse::with_tool_call(tool_name, args),
85 MockResponse::text(final_response),
86 ])
87 }
88
89 pub fn with_response(self, text: &str) -> Self {
91 self.responses
92 .lock()
93 .unwrap()
94 .push(MockResponse::text(text));
95 self
96 }
97
98 pub fn with_tool_call(self, tool_name: &str, args: Value) -> Self {
100 self.responses
101 .lock()
102 .unwrap()
103 .push(MockResponse::with_tool_call(tool_name, args));
104 self
105 }
106
107 pub fn error_on_call(self, call_index: usize) -> Self {
109 *self.error_on_call.lock().unwrap() = Some(call_index);
110 self
111 }
112
113 pub fn fail_on_call(&self, call_index: usize) {
115 *self.error_on_call.lock().unwrap() = Some(call_index);
116 }
117
118 pub fn call_count(&self) -> usize {
120 self.calls.lock().unwrap().len()
121 }
122
123 pub fn get_calls(&self) -> Vec<ChatRequest> {
125 self.calls.lock().unwrap().clone()
126 }
127
128 pub fn last_call(&self) -> Option<ChatRequest> {
130 self.calls.lock().unwrap().last().cloned()
131 }
132
133 pub fn clear_calls(&self) {
135 self.calls.lock().unwrap().clear();
136 }
137}
138
139impl MockResponse {
140 pub fn text(content: &str) -> Self {
142 Self {
143 content: content.to_string(),
144 tool_calls: vec![],
145 finish_reason: "stop".to_string(),
146 }
147 }
148
149 pub fn with_tool_call(tool_name: &str, arguments: Value) -> Self {
151 Self {
152 content: String::new(),
153 tool_calls: vec![MockToolCall {
154 name: tool_name.to_string(),
155 arguments,
156 }],
157 finish_reason: "tool_calls".to_string(),
158 }
159 }
160
161 pub fn with_tool_calls(tool_calls: Vec<(&str, Value)>) -> Self {
163 Self {
164 content: String::new(),
165 tool_calls: tool_calls
166 .into_iter()
167 .map(|(name, args)| MockToolCall {
168 name: name.to_string(),
169 arguments: args,
170 })
171 .collect(),
172 finish_reason: "tool_calls".to_string(),
173 }
174 }
175}
176
177#[async_trait]
178impl ChatClient for MockLlmClient {
179 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, LlmError> {
180 self.calls.lock().unwrap().push(request.clone());
182
183 let call_index = self.calls.lock().unwrap().len() - 1;
185 if let Some(fail_index) = *self.error_on_call.lock().unwrap() {
186 if call_index == fail_index {
187 return Err(LlmError::NetworkError("Mock network error".to_string()));
188 }
189 }
190
191 let mut responses = self.responses.lock().unwrap();
193 if responses.is_empty() {
194 return Ok(ChatResponse {
195 content: "No more mock responses available".to_string(),
196 model: "mock-model".to_string(),
197 tool_calls: None,
198 finish_reason: Some("stop".to_string()),
199 usage: Some(Usage {
200 prompt_tokens: 10,
201 completion_tokens: 5,
202 total_tokens: 15,
203 }),
204 });
205 }
206
207 let mock_response = responses.remove(0);
208
209 let tool_calls = if mock_response.tool_calls.is_empty() {
211 None
212 } else {
213 Some(
214 mock_response
215 .tool_calls
216 .iter()
217 .enumerate()
218 .map(|(i, tc)| ToolCall {
219 id: format!("call_{}", i),
220 r#type: "function".to_string(),
221 function: FunctionCall {
222 name: tc.name.clone(),
223 arguments: serde_json::to_string(&tc.arguments).unwrap(),
224 },
225 })
226 .collect(),
227 )
228 };
229
230 Ok(ChatResponse {
231 content: mock_response.content,
232 model: "mock-model".to_string(),
233 tool_calls,
234 finish_reason: Some(mock_response.finish_reason),
235 usage: Some(Usage {
236 prompt_tokens: 10,
237 completion_tokens: 5,
238 total_tokens: 15,
239 }),
240 })
241 }
242
243 async fn chat_stream(
244 &self,
245 request: ChatRequest,
246 tx: mpsc::Sender<String>,
247 ) -> Result<ChatResponse, LlmError> {
248 let response = self.chat(request).await?;
250
251 for word in response.content.split_whitespace() {
253 let _ = tx.send(format!("{} ", word)).await;
254 }
255
256 Ok(response)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[tokio::test]
265 async fn test_mock_client_simple_response() {
266 let client = MockLlmClient::with_responses_vec(vec!["Hello, world!"]);
267
268 let request = ChatRequest::new(vec![ChatMessage::user("Hi")]);
269
270 let response = client.chat(request).await.unwrap();
271 assert_eq!(response.content, "Hello, world!");
272 assert_eq!(client.call_count(), 1);
273 }
274
275 #[tokio::test]
276 async fn test_mock_client_multiple_responses() {
277 let client = MockLlmClient::with_responses_vec(vec!["First", "Second", "Third"]);
278
279 let request = ChatRequest::new(vec![ChatMessage::user("Hi")]);
280
281 let r1 = client.chat(request.clone()).await.unwrap();
282 assert_eq!(r1.content, "First");
283
284 let r2 = client.chat(request.clone()).await.unwrap();
285 assert_eq!(r2.content, "Second");
286
287 let r3 = client.chat(request.clone()).await.unwrap();
288 assert_eq!(r3.content, "Third");
289
290 assert_eq!(client.call_count(), 3);
291 }
292
293 #[tokio::test]
294 async fn test_mock_client_tool_call() {
295 let client = MockLlmClient::from_tool_call(
296 "calculator",
297 json!({"operation": "add", "a": 5, "b": 3}),
298 );
299
300 let request = ChatRequest::new(vec![ChatMessage::user("What is 5 + 3?")]);
301
302 let response = client.chat(request).await.unwrap();
303 assert!(response.tool_calls.is_some());
304
305 let tool_calls = response.tool_calls.unwrap();
306 assert_eq!(tool_calls.len(), 1);
307 assert_eq!(tool_calls[0].function.name, "calculator");
308
309 let args: Value = serde_json::from_str(&tool_calls[0].function.arguments).unwrap();
310 assert_eq!(args["operation"], "add");
311 assert_eq!(args["a"], 5);
312 assert_eq!(args["b"], 3);
313 }
314
315 #[tokio::test]
316 async fn test_mock_client_error_injection() {
317 let client = MockLlmClient::with_responses_vec(vec!["First", "Second", "Third"]);
318 client.fail_on_call(1); let request = ChatRequest::new(vec![ChatMessage::user("Hi")]);
321
322 let r1 = client.chat(request.clone()).await;
324 assert!(r1.is_ok());
325
326 let r2 = client.chat(request.clone()).await;
328 assert!(r2.is_err());
329
330 let r3 = client.chat(request.clone()).await;
332 assert!(r3.is_ok());
333 }
334
335 #[tokio::test]
336 async fn test_mock_client_call_tracking() {
337 let client = MockLlmClient::with_responses_vec(vec!["Response 1", "Response 2"]);
338
339 let req1 = ChatRequest::new(vec![ChatMessage::user("Question 1")]);
340 let req2 = ChatRequest::new(vec![ChatMessage::user("Question 2")]);
341
342 client.chat(req1).await.unwrap();
343 client.chat(req2).await.unwrap();
344
345 let calls = client.get_calls();
346 assert_eq!(calls.len(), 2);
347 assert_eq!(calls[0].messages[0].content, "Question 1");
348 assert_eq!(calls[1].messages[0].content, "Question 2");
349 }
350}