1use mcpkit_core::capability::{ClientCapabilities, ClientInfo, ServerCapabilities, ServerInfo};
7use mcpkit_core::error::McpError;
8use mcpkit_core::protocol::{Notification, Request, RequestId, Response};
9use mcpkit_core::types::{
10 CallToolResult, GetPromptResult, Prompt, Resource, ResourceContents, Tool,
11};
12use std::collections::HashMap;
13use std::sync::RwLock;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16#[derive(Debug)]
21pub struct MockClient {
22 info: ClientInfo,
24 capabilities: ClientCapabilities,
26 next_id: AtomicU64,
28 pending: RwLock<HashMap<RequestId, String>>,
30 requests: RwLock<Vec<Request>>,
32 responses: RwLock<Vec<Response>>,
34 notifications_sent: RwLock<Vec<Notification>>,
36 notifications_received: RwLock<Vec<Notification>>,
38 server_info: RwLock<Option<ServerInfo>>,
40 server_capabilities: RwLock<Option<ServerCapabilities>>,
42}
43
44impl Default for MockClient {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl MockClient {
51 #[must_use]
53 pub fn new() -> Self {
54 Self {
55 info: ClientInfo::new("mock-client", "1.0.0"),
56 capabilities: ClientCapabilities::new(),
57 next_id: AtomicU64::new(1),
58 pending: RwLock::new(HashMap::new()),
59 requests: RwLock::new(Vec::new()),
60 responses: RwLock::new(Vec::new()),
61 notifications_sent: RwLock::new(Vec::new()),
62 notifications_received: RwLock::new(Vec::new()),
63 server_info: RwLock::new(None),
64 server_capabilities: RwLock::new(None),
65 }
66 }
67
68 pub fn with_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
70 self.info = ClientInfo::new(name, version);
71 self
72 }
73
74 #[must_use]
76 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
77 self.capabilities = capabilities;
78 self
79 }
80
81 #[must_use]
83 pub fn info(&self) -> &ClientInfo {
84 &self.info
85 }
86
87 #[must_use]
89 pub fn capabilities(&self) -> &ClientCapabilities {
90 &self.capabilities
91 }
92
93 #[must_use]
95 pub fn server_info(&self) -> Option<ServerInfo> {
96 self.server_info.read().ok()?.clone()
97 }
98
99 #[must_use]
101 pub fn server_capabilities(&self) -> Option<ServerCapabilities> {
102 self.server_capabilities.read().ok()?.clone()
103 }
104
105 #[must_use]
107 pub fn create_initialize_request(&self) -> Request {
108 let id = self.next_request_id();
109 Request::new("initialize", id).params(serde_json::json!({
110 "protocolVersion": mcpkit_core::PROTOCOL_VERSION,
111 "capabilities": self.capabilities,
112 "clientInfo": self.info
113 }))
114 }
115
116 pub fn process_initialize_response(&self, response: &Response) -> Result<(), McpError> {
118 if let Some(error) = &response.error {
119 return Err(McpError::InternalMessage {
120 message: format!("Initialize failed: {}", error.message),
121 });
122 }
123
124 if let Some(result) = &response.result {
125 if let Some(server_info) = result.get("serverInfo") {
127 let info: ServerInfo = serde_json::from_value(server_info.clone())?;
128 if let Ok(mut lock) = self.server_info.write() {
129 *lock = Some(info);
130 }
131 }
132
133 if let Some(caps) = result.get("capabilities") {
135 let capabilities: ServerCapabilities = serde_json::from_value(caps.clone())?;
136 if let Ok(mut lock) = self.server_capabilities.write() {
137 *lock = Some(capabilities);
138 }
139 }
140 }
141
142 Ok(())
143 }
144
145 #[must_use]
147 pub fn create_initialized_notification(&self) -> Notification {
148 Notification::new("initialized")
149 }
150
151 #[must_use]
153 pub fn create_list_tools_request(&self) -> Request {
154 let id = self.next_request_id();
155 Request::new("tools/list", id)
156 }
157
158 #[must_use]
160 pub fn create_call_tool_request(&self, name: &str, arguments: serde_json::Value) -> Request {
161 let id = self.next_request_id();
162 Request::new("tools/call", id).params(serde_json::json!({
163 "name": name,
164 "arguments": arguments
165 }))
166 }
167
168 #[must_use]
170 pub fn create_list_resources_request(&self) -> Request {
171 let id = self.next_request_id();
172 Request::new("resources/list", id)
173 }
174
175 #[must_use]
177 pub fn create_read_resource_request(&self, uri: &str) -> Request {
178 let id = self.next_request_id();
179 Request::new("resources/read", id).params(serde_json::json!({
180 "uri": uri
181 }))
182 }
183
184 #[must_use]
186 pub fn create_list_prompts_request(&self) -> Request {
187 let id = self.next_request_id();
188 Request::new("prompts/list", id)
189 }
190
191 pub fn create_get_prompt_request(
193 &self,
194 name: &str,
195 arguments: Option<serde_json::Map<String, serde_json::Value>>,
196 ) -> Request {
197 let id = self.next_request_id();
198 let mut params = serde_json::json!({ "name": name });
199 if let Some(args) = arguments {
200 params["arguments"] = serde_json::Value::Object(args);
201 }
202 Request::new("prompts/get", id).params(params)
203 }
204
205 #[must_use]
207 pub fn create_ping_request(&self) -> Request {
208 let id = self.next_request_id();
209 Request::new("ping", id)
210 }
211
212 pub fn record_request(&self, request: Request) {
214 if let Ok(mut pending) = self.pending.write() {
215 pending.insert(request.id.clone(), request.method.to_string());
216 }
217 if let Ok(mut requests) = self.requests.write() {
218 requests.push(request);
219 }
220 }
221
222 pub fn record_response(&self, response: Response) {
224 if let Ok(mut pending) = self.pending.write() {
225 pending.remove(&response.id);
226 }
227 if let Ok(mut responses) = self.responses.write() {
228 responses.push(response);
229 }
230 }
231
232 pub fn record_notification_sent(&self, notification: Notification) {
234 if let Ok(mut notifications) = self.notifications_sent.write() {
235 notifications.push(notification);
236 }
237 }
238
239 pub fn record_notification_received(&self, notification: Notification) {
241 if let Ok(mut notifications) = self.notifications_received.write() {
242 notifications.push(notification);
243 }
244 }
245
246 #[must_use]
248 pub fn requests(&self) -> Vec<Request> {
249 self.requests.read().map(|r| r.clone()).unwrap_or_default()
250 }
251
252 #[must_use]
254 pub fn responses(&self) -> Vec<Response> {
255 self.responses.read().map(|r| r.clone()).unwrap_or_default()
256 }
257
258 #[must_use]
260 pub fn notifications_sent(&self) -> Vec<Notification> {
261 self.notifications_sent
262 .read()
263 .map(|n| n.clone())
264 .unwrap_or_default()
265 }
266
267 #[must_use]
269 pub fn notifications_received(&self) -> Vec<Notification> {
270 self.notifications_received
271 .read()
272 .map(|n| n.clone())
273 .unwrap_or_default()
274 }
275
276 #[must_use]
278 pub fn pending_count(&self) -> usize {
279 self.pending.read().map(|p| p.len()).unwrap_or(0)
280 }
281
282 #[must_use]
284 pub fn request_count(&self) -> usize {
285 self.requests.read().map(|r| r.len()).unwrap_or(0)
286 }
287
288 #[must_use]
290 pub fn response_count(&self) -> usize {
291 self.responses.read().map(|r| r.len()).unwrap_or(0)
292 }
293
294 pub fn clear(&self) {
296 if let Ok(mut pending) = self.pending.write() {
297 pending.clear();
298 }
299 if let Ok(mut requests) = self.requests.write() {
300 requests.clear();
301 }
302 if let Ok(mut responses) = self.responses.write() {
303 responses.clear();
304 }
305 if let Ok(mut notifications) = self.notifications_sent.write() {
306 notifications.clear();
307 }
308 if let Ok(mut notifications) = self.notifications_received.write() {
309 notifications.clear();
310 }
311 }
312
313 pub fn parse_tool_list(&self, response: &Response) -> Result<Vec<Tool>, McpError> {
315 if let Some(error) = &response.error {
316 return Err(McpError::InternalMessage {
317 message: error.message.clone(),
318 });
319 }
320
321 let result = response
322 .result
323 .as_ref()
324 .ok_or_else(|| McpError::InternalMessage {
325 message: "No result in response".to_string(),
326 })?;
327
328 let tools = result
329 .get("tools")
330 .ok_or_else(|| McpError::InternalMessage {
331 message: "No tools in response".to_string(),
332 })?;
333
334 Ok(serde_json::from_value(tools.clone())?)
335 }
336
337 pub fn parse_tool_call(&self, response: &Response) -> Result<CallToolResult, McpError> {
339 if let Some(error) = &response.error {
340 return Err(McpError::InternalMessage {
341 message: error.message.clone(),
342 });
343 }
344
345 let result = response
346 .result
347 .as_ref()
348 .ok_or_else(|| McpError::InternalMessage {
349 message: "No result in response".to_string(),
350 })?;
351
352 Ok(serde_json::from_value(result.clone())?)
353 }
354
355 pub fn parse_resource_list(&self, response: &Response) -> Result<Vec<Resource>, McpError> {
357 if let Some(error) = &response.error {
358 return Err(McpError::InternalMessage {
359 message: error.message.clone(),
360 });
361 }
362
363 let result = response
364 .result
365 .as_ref()
366 .ok_or_else(|| McpError::InternalMessage {
367 message: "No result in response".to_string(),
368 })?;
369
370 let resources = result
371 .get("resources")
372 .ok_or_else(|| McpError::InternalMessage {
373 message: "No resources in response".to_string(),
374 })?;
375
376 Ok(serde_json::from_value(resources.clone())?)
377 }
378
379 pub fn parse_resource_read(
381 &self,
382 response: &Response,
383 ) -> Result<Vec<ResourceContents>, McpError> {
384 if let Some(error) = &response.error {
385 return Err(McpError::InternalMessage {
386 message: error.message.clone(),
387 });
388 }
389
390 let result = response
391 .result
392 .as_ref()
393 .ok_or_else(|| McpError::InternalMessage {
394 message: "No result in response".to_string(),
395 })?;
396
397 let contents = result
398 .get("contents")
399 .ok_or_else(|| McpError::InternalMessage {
400 message: "No contents in response".to_string(),
401 })?;
402
403 Ok(serde_json::from_value(contents.clone())?)
404 }
405
406 pub fn parse_prompt_list(&self, response: &Response) -> Result<Vec<Prompt>, McpError> {
408 if let Some(error) = &response.error {
409 return Err(McpError::InternalMessage {
410 message: error.message.clone(),
411 });
412 }
413
414 let result = response
415 .result
416 .as_ref()
417 .ok_or_else(|| McpError::InternalMessage {
418 message: "No result in response".to_string(),
419 })?;
420
421 let prompts = result
422 .get("prompts")
423 .ok_or_else(|| McpError::InternalMessage {
424 message: "No prompts in response".to_string(),
425 })?;
426
427 Ok(serde_json::from_value(prompts.clone())?)
428 }
429
430 pub fn parse_prompt_get(&self, response: &Response) -> Result<GetPromptResult, McpError> {
432 if let Some(error) = &response.error {
433 return Err(McpError::InternalMessage {
434 message: error.message.clone(),
435 });
436 }
437
438 let result = response
439 .result
440 .as_ref()
441 .ok_or_else(|| McpError::InternalMessage {
442 message: "No result in response".to_string(),
443 })?;
444
445 Ok(serde_json::from_value(result.clone())?)
446 }
447
448 fn next_request_id(&self) -> RequestId {
449 RequestId::from(self.next_id.fetch_add(1, Ordering::SeqCst))
450 }
451}
452
453impl Clone for MockClient {
454 fn clone(&self) -> Self {
455 Self {
456 info: self.info.clone(),
457 capabilities: self.capabilities.clone(),
458 next_id: AtomicU64::new(self.next_id.load(Ordering::SeqCst)),
459 pending: RwLock::new(self.pending.read().map(|p| p.clone()).unwrap_or_default()),
460 requests: RwLock::new(self.requests.read().map(|r| r.clone()).unwrap_or_default()),
461 responses: RwLock::new(self.responses.read().map(|r| r.clone()).unwrap_or_default()),
462 notifications_sent: RwLock::new(
463 self.notifications_sent
464 .read()
465 .map(|n| n.clone())
466 .unwrap_or_default(),
467 ),
468 notifications_received: RwLock::new(
469 self.notifications_received
470 .read()
471 .map(|n| n.clone())
472 .unwrap_or_default(),
473 ),
474 server_info: RwLock::new(self.server_info.read().ok().and_then(|s| s.clone())),
475 server_capabilities: RwLock::new(
476 self.server_capabilities.read().ok().and_then(|s| s.clone()),
477 ),
478 }
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_mock_client_creation() {
488 let client = MockClient::new().with_info("test-client", "2.0.0");
489
490 assert_eq!(client.info().name, "test-client");
491 assert_eq!(client.info().version, "2.0.0");
492 }
493
494 #[test]
495 fn test_create_requests() {
496 let client = MockClient::new();
497
498 let init = client.create_initialize_request();
499 assert_eq!(init.method.as_ref(), "initialize");
500
501 let ping = client.create_ping_request();
502 assert_eq!(ping.method.as_ref(), "ping");
503
504 let list_tools = client.create_list_tools_request();
505 assert_eq!(list_tools.method.as_ref(), "tools/list");
506
507 let call_tool = client.create_call_tool_request("test", serde_json::json!({}));
508 assert_eq!(call_tool.method.as_ref(), "tools/call");
509 }
510
511 #[test]
512 fn test_record_requests() {
513 let client = MockClient::new();
514
515 let request = client.create_ping_request();
516 client.record_request(request);
517
518 assert_eq!(client.request_count(), 1);
519 assert_eq!(client.pending_count(), 1);
520
521 let response = Response::success(RequestId::from(1), serde_json::json!({}));
522 client.record_response(response);
523
524 assert_eq!(client.response_count(), 1);
525 assert_eq!(client.pending_count(), 0);
526 }
527
528 #[test]
529 fn test_parse_tool_list() {
530 let client = MockClient::new();
531
532 let response = Response::success(
533 RequestId::from(1),
534 serde_json::json!({
535 "tools": [
536 {"name": "test", "inputSchema": {"type": "object"}}
537 ]
538 }),
539 );
540
541 let tools = client.parse_tool_list(&response).unwrap();
542 assert_eq!(tools.len(), 1);
543 assert_eq!(tools[0].name, "test");
544 }
545
546 #[test]
547 fn test_parse_resource_list() {
548 let client = MockClient::new();
549
550 let response = Response::success(
551 RequestId::from(1),
552 serde_json::json!({
553 "resources": [
554 {"uri": "test://resource", "name": "Test"}
555 ]
556 }),
557 );
558
559 let resources = client.parse_resource_list(&response).unwrap();
560 assert_eq!(resources.len(), 1);
561 assert_eq!(resources[0].uri, "test://resource");
562 }
563}