mcpkit_testing/
client.rs

1//! Mock client for testing MCP servers.
2//!
3//! This module provides a mock client that can be used to test
4//! MCP server implementations.
5
6use mcpkit_core::capability::{ClientCapabilities, ClientInfo, ServerCapabilities, ServerInfo};
7use mcpkit_core::error::McpError;
8use mcpkit_core::protocol::{Notification, Request, RequestId, Response};
9use mcpkit_core::types::{
10    CallToolResult, GetPromptResult, Prompt, Resource, ResourceContents, Tool,
11};
12use std::collections::HashMap;
13use std::sync::RwLock;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16/// A mock MCP client for testing servers.
17///
18/// The mock client tracks all interactions and provides utilities
19/// for verifying server behavior.
20#[derive(Debug)]
21pub struct MockClient {
22    /// Client info.
23    info: ClientInfo,
24    /// Client capabilities.
25    capabilities: ClientCapabilities,
26    /// Next request ID.
27    next_id: AtomicU64,
28    /// Pending requests.
29    pending: RwLock<HashMap<RequestId, String>>,
30    /// Recorded requests.
31    requests: RwLock<Vec<Request>>,
32    /// Recorded responses.
33    responses: RwLock<Vec<Response>>,
34    /// Recorded notifications sent.
35    notifications_sent: RwLock<Vec<Notification>>,
36    /// Recorded notifications received.
37    notifications_received: RwLock<Vec<Notification>>,
38    /// Server info (after initialize).
39    server_info: RwLock<Option<ServerInfo>>,
40    /// Server capabilities (after initialize).
41    server_capabilities: RwLock<Option<ServerCapabilities>>,
42}
43
44impl Default for MockClient {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl MockClient {
51    /// Create a new mock client.
52    #[must_use]
53    pub fn new() -> Self {
54        Self {
55            info: ClientInfo::new("mock-client", "1.0.0"),
56            capabilities: ClientCapabilities::new(),
57            next_id: AtomicU64::new(1),
58            pending: RwLock::new(HashMap::new()),
59            requests: RwLock::new(Vec::new()),
60            responses: RwLock::new(Vec::new()),
61            notifications_sent: RwLock::new(Vec::new()),
62            notifications_received: RwLock::new(Vec::new()),
63            server_info: RwLock::new(None),
64            server_capabilities: RwLock::new(None),
65        }
66    }
67
68    /// Create a mock client with custom info.
69    pub fn with_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
70        self.info = ClientInfo::new(name, version);
71        self
72    }
73
74    /// Create a mock client with custom capabilities.
75    #[must_use]
76    pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
77        self.capabilities = capabilities;
78        self
79    }
80
81    /// Get the client info.
82    #[must_use]
83    pub fn info(&self) -> &ClientInfo {
84        &self.info
85    }
86
87    /// Get the client capabilities.
88    #[must_use]
89    pub fn capabilities(&self) -> &ClientCapabilities {
90        &self.capabilities
91    }
92
93    /// Get the server info (after initialization).
94    #[must_use]
95    pub fn server_info(&self) -> Option<ServerInfo> {
96        self.server_info.read().ok()?.clone()
97    }
98
99    /// Get the server capabilities (after initialization).
100    #[must_use]
101    pub fn server_capabilities(&self) -> Option<ServerCapabilities> {
102        self.server_capabilities.read().ok()?.clone()
103    }
104
105    /// Create an initialize request.
106    #[must_use]
107    pub fn create_initialize_request(&self) -> Request {
108        let id = self.next_request_id();
109        Request::new("initialize", id).params(serde_json::json!({
110            "protocolVersion": mcpkit_core::PROTOCOL_VERSION,
111            "capabilities": self.capabilities,
112            "clientInfo": self.info
113        }))
114    }
115
116    /// Process an initialize response.
117    pub fn process_initialize_response(&self, response: &Response) -> Result<(), McpError> {
118        if let Some(error) = &response.error {
119            return Err(McpError::InternalMessage {
120                message: format!("Initialize failed: {}", error.message),
121            });
122        }
123
124        if let Some(result) = &response.result {
125            // Extract server info
126            if let Some(server_info) = result.get("serverInfo") {
127                let info: ServerInfo = serde_json::from_value(server_info.clone())?;
128                if let Ok(mut lock) = self.server_info.write() {
129                    *lock = Some(info);
130                }
131            }
132
133            // Extract capabilities
134            if let Some(caps) = result.get("capabilities") {
135                let capabilities: ServerCapabilities = serde_json::from_value(caps.clone())?;
136                if let Ok(mut lock) = self.server_capabilities.write() {
137                    *lock = Some(capabilities);
138                }
139            }
140        }
141
142        Ok(())
143    }
144
145    /// Create an initialized notification.
146    #[must_use]
147    pub fn create_initialized_notification(&self) -> Notification {
148        Notification::new("initialized")
149    }
150
151    /// Create a tools/list request.
152    #[must_use]
153    pub fn create_list_tools_request(&self) -> Request {
154        let id = self.next_request_id();
155        Request::new("tools/list", id)
156    }
157
158    /// Create a tools/call request.
159    #[must_use]
160    pub fn create_call_tool_request(&self, name: &str, arguments: serde_json::Value) -> Request {
161        let id = self.next_request_id();
162        Request::new("tools/call", id).params(serde_json::json!({
163            "name": name,
164            "arguments": arguments
165        }))
166    }
167
168    /// Create a resources/list request.
169    #[must_use]
170    pub fn create_list_resources_request(&self) -> Request {
171        let id = self.next_request_id();
172        Request::new("resources/list", id)
173    }
174
175    /// Create a resources/read request.
176    #[must_use]
177    pub fn create_read_resource_request(&self, uri: &str) -> Request {
178        let id = self.next_request_id();
179        Request::new("resources/read", id).params(serde_json::json!({
180            "uri": uri
181        }))
182    }
183
184    /// Create a prompts/list request.
185    #[must_use]
186    pub fn create_list_prompts_request(&self) -> Request {
187        let id = self.next_request_id();
188        Request::new("prompts/list", id)
189    }
190
191    /// Create a prompts/get request.
192    pub fn create_get_prompt_request(
193        &self,
194        name: &str,
195        arguments: Option<serde_json::Map<String, serde_json::Value>>,
196    ) -> Request {
197        let id = self.next_request_id();
198        let mut params = serde_json::json!({ "name": name });
199        if let Some(args) = arguments {
200            params["arguments"] = serde_json::Value::Object(args);
201        }
202        Request::new("prompts/get", id).params(params)
203    }
204
205    /// Create a ping request.
206    #[must_use]
207    pub fn create_ping_request(&self) -> Request {
208        let id = self.next_request_id();
209        Request::new("ping", id)
210    }
211
212    /// Record a request.
213    pub fn record_request(&self, request: Request) {
214        if let Ok(mut pending) = self.pending.write() {
215            pending.insert(request.id.clone(), request.method.to_string());
216        }
217        if let Ok(mut requests) = self.requests.write() {
218            requests.push(request);
219        }
220    }
221
222    /// Record a response.
223    pub fn record_response(&self, response: Response) {
224        if let Ok(mut pending) = self.pending.write() {
225            pending.remove(&response.id);
226        }
227        if let Ok(mut responses) = self.responses.write() {
228            responses.push(response);
229        }
230    }
231
232    /// Record a sent notification.
233    pub fn record_notification_sent(&self, notification: Notification) {
234        if let Ok(mut notifications) = self.notifications_sent.write() {
235            notifications.push(notification);
236        }
237    }
238
239    /// Record a received notification.
240    pub fn record_notification_received(&self, notification: Notification) {
241        if let Ok(mut notifications) = self.notifications_received.write() {
242            notifications.push(notification);
243        }
244    }
245
246    /// Get all recorded requests.
247    #[must_use]
248    pub fn requests(&self) -> Vec<Request> {
249        self.requests.read().map(|r| r.clone()).unwrap_or_default()
250    }
251
252    /// Get all recorded responses.
253    #[must_use]
254    pub fn responses(&self) -> Vec<Response> {
255        self.responses.read().map(|r| r.clone()).unwrap_or_default()
256    }
257
258    /// Get all sent notifications.
259    #[must_use]
260    pub fn notifications_sent(&self) -> Vec<Notification> {
261        self.notifications_sent
262            .read()
263            .map(|n| n.clone())
264            .unwrap_or_default()
265    }
266
267    /// Get all received notifications.
268    #[must_use]
269    pub fn notifications_received(&self) -> Vec<Notification> {
270        self.notifications_received
271            .read()
272            .map(|n| n.clone())
273            .unwrap_or_default()
274    }
275
276    /// Get pending request count.
277    #[must_use]
278    pub fn pending_count(&self) -> usize {
279        self.pending.read().map(|p| p.len()).unwrap_or(0)
280    }
281
282    /// Get the total request count.
283    #[must_use]
284    pub fn request_count(&self) -> usize {
285        self.requests.read().map(|r| r.len()).unwrap_or(0)
286    }
287
288    /// Get the total response count.
289    #[must_use]
290    pub fn response_count(&self) -> usize {
291        self.responses.read().map(|r| r.len()).unwrap_or(0)
292    }
293
294    /// Clear all recorded data.
295    pub fn clear(&self) {
296        if let Ok(mut pending) = self.pending.write() {
297            pending.clear();
298        }
299        if let Ok(mut requests) = self.requests.write() {
300            requests.clear();
301        }
302        if let Ok(mut responses) = self.responses.write() {
303            responses.clear();
304        }
305        if let Ok(mut notifications) = self.notifications_sent.write() {
306            notifications.clear();
307        }
308        if let Ok(mut notifications) = self.notifications_received.write() {
309            notifications.clear();
310        }
311    }
312
313    /// Parse a tool list response.
314    pub fn parse_tool_list(&self, response: &Response) -> Result<Vec<Tool>, McpError> {
315        if let Some(error) = &response.error {
316            return Err(McpError::InternalMessage {
317                message: error.message.clone(),
318            });
319        }
320
321        let result = response
322            .result
323            .as_ref()
324            .ok_or_else(|| McpError::InternalMessage {
325                message: "No result in response".to_string(),
326            })?;
327
328        let tools = result
329            .get("tools")
330            .ok_or_else(|| McpError::InternalMessage {
331                message: "No tools in response".to_string(),
332            })?;
333
334        Ok(serde_json::from_value(tools.clone())?)
335    }
336
337    /// Parse a tool call response.
338    pub fn parse_tool_call(&self, response: &Response) -> Result<CallToolResult, McpError> {
339        if let Some(error) = &response.error {
340            return Err(McpError::InternalMessage {
341                message: error.message.clone(),
342            });
343        }
344
345        let result = response
346            .result
347            .as_ref()
348            .ok_or_else(|| McpError::InternalMessage {
349                message: "No result in response".to_string(),
350            })?;
351
352        Ok(serde_json::from_value(result.clone())?)
353    }
354
355    /// Parse a resource list response.
356    pub fn parse_resource_list(&self, response: &Response) -> Result<Vec<Resource>, McpError> {
357        if let Some(error) = &response.error {
358            return Err(McpError::InternalMessage {
359                message: error.message.clone(),
360            });
361        }
362
363        let result = response
364            .result
365            .as_ref()
366            .ok_or_else(|| McpError::InternalMessage {
367                message: "No result in response".to_string(),
368            })?;
369
370        let resources = result
371            .get("resources")
372            .ok_or_else(|| McpError::InternalMessage {
373                message: "No resources in response".to_string(),
374            })?;
375
376        Ok(serde_json::from_value(resources.clone())?)
377    }
378
379    /// Parse a resource read response.
380    pub fn parse_resource_read(
381        &self,
382        response: &Response,
383    ) -> Result<Vec<ResourceContents>, McpError> {
384        if let Some(error) = &response.error {
385            return Err(McpError::InternalMessage {
386                message: error.message.clone(),
387            });
388        }
389
390        let result = response
391            .result
392            .as_ref()
393            .ok_or_else(|| McpError::InternalMessage {
394                message: "No result in response".to_string(),
395            })?;
396
397        let contents = result
398            .get("contents")
399            .ok_or_else(|| McpError::InternalMessage {
400                message: "No contents in response".to_string(),
401            })?;
402
403        Ok(serde_json::from_value(contents.clone())?)
404    }
405
406    /// Parse a prompt list response.
407    pub fn parse_prompt_list(&self, response: &Response) -> Result<Vec<Prompt>, McpError> {
408        if let Some(error) = &response.error {
409            return Err(McpError::InternalMessage {
410                message: error.message.clone(),
411            });
412        }
413
414        let result = response
415            .result
416            .as_ref()
417            .ok_or_else(|| McpError::InternalMessage {
418                message: "No result in response".to_string(),
419            })?;
420
421        let prompts = result
422            .get("prompts")
423            .ok_or_else(|| McpError::InternalMessage {
424                message: "No prompts in response".to_string(),
425            })?;
426
427        Ok(serde_json::from_value(prompts.clone())?)
428    }
429
430    /// Parse a prompt get response.
431    pub fn parse_prompt_get(&self, response: &Response) -> Result<GetPromptResult, McpError> {
432        if let Some(error) = &response.error {
433            return Err(McpError::InternalMessage {
434                message: error.message.clone(),
435            });
436        }
437
438        let result = response
439            .result
440            .as_ref()
441            .ok_or_else(|| McpError::InternalMessage {
442                message: "No result in response".to_string(),
443            })?;
444
445        Ok(serde_json::from_value(result.clone())?)
446    }
447
448    fn next_request_id(&self) -> RequestId {
449        RequestId::from(self.next_id.fetch_add(1, Ordering::SeqCst))
450    }
451}
452
453impl Clone for MockClient {
454    fn clone(&self) -> Self {
455        Self {
456            info: self.info.clone(),
457            capabilities: self.capabilities.clone(),
458            next_id: AtomicU64::new(self.next_id.load(Ordering::SeqCst)),
459            pending: RwLock::new(self.pending.read().map(|p| p.clone()).unwrap_or_default()),
460            requests: RwLock::new(self.requests.read().map(|r| r.clone()).unwrap_or_default()),
461            responses: RwLock::new(self.responses.read().map(|r| r.clone()).unwrap_or_default()),
462            notifications_sent: RwLock::new(
463                self.notifications_sent
464                    .read()
465                    .map(|n| n.clone())
466                    .unwrap_or_default(),
467            ),
468            notifications_received: RwLock::new(
469                self.notifications_received
470                    .read()
471                    .map(|n| n.clone())
472                    .unwrap_or_default(),
473            ),
474            server_info: RwLock::new(self.server_info.read().ok().and_then(|s| s.clone())),
475            server_capabilities: RwLock::new(
476                self.server_capabilities.read().ok().and_then(|s| s.clone()),
477            ),
478        }
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_mock_client_creation() {
488        let client = MockClient::new().with_info("test-client", "2.0.0");
489
490        assert_eq!(client.info().name, "test-client");
491        assert_eq!(client.info().version, "2.0.0");
492    }
493
494    #[test]
495    fn test_create_requests() {
496        let client = MockClient::new();
497
498        let init = client.create_initialize_request();
499        assert_eq!(init.method.as_ref(), "initialize");
500
501        let ping = client.create_ping_request();
502        assert_eq!(ping.method.as_ref(), "ping");
503
504        let list_tools = client.create_list_tools_request();
505        assert_eq!(list_tools.method.as_ref(), "tools/list");
506
507        let call_tool = client.create_call_tool_request("test", serde_json::json!({}));
508        assert_eq!(call_tool.method.as_ref(), "tools/call");
509    }
510
511    #[test]
512    fn test_record_requests() {
513        let client = MockClient::new();
514
515        let request = client.create_ping_request();
516        client.record_request(request);
517
518        assert_eq!(client.request_count(), 1);
519        assert_eq!(client.pending_count(), 1);
520
521        let response = Response::success(RequestId::from(1), serde_json::json!({}));
522        client.record_response(response);
523
524        assert_eq!(client.response_count(), 1);
525        assert_eq!(client.pending_count(), 0);
526    }
527
528    #[test]
529    fn test_parse_tool_list() {
530        let client = MockClient::new();
531
532        let response = Response::success(
533            RequestId::from(1),
534            serde_json::json!({
535                "tools": [
536                    {"name": "test", "inputSchema": {"type": "object"}}
537                ]
538            }),
539        );
540
541        let tools = client.parse_tool_list(&response).unwrap();
542        assert_eq!(tools.len(), 1);
543        assert_eq!(tools[0].name, "test");
544    }
545
546    #[test]
547    fn test_parse_resource_list() {
548        let client = MockClient::new();
549
550        let response = Response::success(
551            RequestId::from(1),
552            serde_json::json!({
553                "resources": [
554                    {"uri": "test://resource", "name": "Test"}
555                ]
556            }),
557        );
558
559        let resources = client.parse_resource_list(&response).unwrap();
560        assert_eq!(resources.len(), 1);
561        assert_eq!(resources[0].uri, "test://resource");
562    }
563}