Skip to main content

fastmcp_rust/testing/
client.rs

1//! Test client for in-process MCP testing.
2//!
3//! Provides a client wrapper that works with MemoryTransport for
4//! testing servers without subprocess spawning.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use asupersync::Cx;
10use fastmcp_core::{McpError, McpResult};
11use fastmcp_protocol::{
12    CallToolParams, CallToolResult, ClientCapabilities, ClientInfo, Content, GetPromptParams,
13    GetPromptResult, InitializeParams, InitializeResult, JsonRpcMessage, JsonRpcRequest,
14    ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams, ListResourceTemplatesResult,
15    ListResourcesParams, ListResourcesResult, ListToolsParams, ListToolsResult, PROTOCOL_VERSION,
16    Prompt, PromptMessage, ReadResourceParams, ReadResourceResult, RequestId, Resource,
17    ResourceContent, ResourceTemplate, ServerCapabilities, ServerInfo, Tool,
18};
19use fastmcp_transport::Transport;
20use fastmcp_transport::memory::MemoryTransport;
21
22/// Test client for in-process MCP testing.
23///
24/// Unlike the production `Client`, this works with `MemoryTransport` for
25/// fast, in-process testing without subprocess spawning.
26///
27/// # Example
28///
29/// ```ignore
30/// use fastmcp_rust::testing::prelude::*;
31///
32/// let (router, client_transport, server_transport) = TestServer::builder()
33///     .with_tool(my_tool)
34///     .build();
35///
36/// // Run server in background thread
37/// std::thread::spawn(move || {
38///     // server loop with server_transport
39/// });
40///
41/// // Create test client
42/// let mut client = TestClient::new(client_transport);
43/// client.initialize().unwrap();
44///
45/// // Test operations
46/// let tools = client.list_tools().unwrap();
47/// assert!(!tools.is_empty());
48/// ```
49pub struct TestClient {
50    /// Transport for communication.
51    transport: MemoryTransport,
52    /// Capability context for cancellation.
53    cx: Cx,
54    /// Client identification info.
55    client_info: ClientInfo,
56    /// Client capabilities.
57    capabilities: ClientCapabilities,
58    /// Server info after initialization.
59    server_info: Option<ServerInfo>,
60    /// Server capabilities after initialization.
61    server_capabilities: Option<ServerCapabilities>,
62    /// Protocol version after initialization.
63    protocol_version: Option<String>,
64    /// Request ID counter.
65    next_id: AtomicU64,
66    /// Whether client has been initialized.
67    initialized: bool,
68}
69
70impl TestClient {
71    /// Creates a new test client with the given transport.
72    ///
73    /// # Example
74    ///
75    /// ```ignore
76    /// let (client_transport, server_transport) = create_memory_transport_pair();
77    /// let client = TestClient::new(client_transport);
78    /// ```
79    #[must_use]
80    pub fn new(transport: MemoryTransport) -> Self {
81        Self {
82            transport,
83            cx: Cx::for_testing(),
84            client_info: ClientInfo {
85                name: "test-client".to_owned(),
86                version: "1.0.0".to_owned(),
87            },
88            capabilities: ClientCapabilities::default(),
89            server_info: None,
90            server_capabilities: None,
91            protocol_version: None,
92            next_id: AtomicU64::new(1),
93            initialized: false,
94        }
95    }
96
97    /// Creates a new test client with custom Cx.
98    #[must_use]
99    pub fn with_cx(transport: MemoryTransport, cx: Cx) -> Self {
100        Self {
101            transport,
102            cx,
103            client_info: ClientInfo {
104                name: "test-client".to_owned(),
105                version: "1.0.0".to_owned(),
106            },
107            capabilities: ClientCapabilities::default(),
108            server_info: None,
109            server_capabilities: None,
110            protocol_version: None,
111            next_id: AtomicU64::new(1),
112            initialized: false,
113        }
114    }
115
116    /// Sets the client info for initialization.
117    #[must_use]
118    pub fn with_client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
119        self.client_info = ClientInfo {
120            name: name.into(),
121            version: version.into(),
122        };
123        self
124    }
125
126    /// Sets the client capabilities for initialization.
127    #[must_use]
128    pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
129        self.capabilities = capabilities;
130        self
131    }
132
133    /// Performs the MCP initialization handshake.
134    ///
135    /// Must be called before any other operations.
136    ///
137    /// # Errors
138    ///
139    /// Returns an error if the initialization fails.
140    pub fn initialize(&mut self) -> McpResult<InitializeResult> {
141        let params = InitializeParams {
142            protocol_version: PROTOCOL_VERSION.to_string(),
143            capabilities: self.capabilities.clone(),
144            client_info: self.client_info.clone(),
145        };
146
147        let result: InitializeResult = self.send_request("initialize", params)?;
148
149        // Store server info
150        self.server_info = Some(result.server_info.clone());
151        self.server_capabilities = Some(result.capabilities.clone());
152        self.protocol_version = Some(result.protocol_version.clone());
153
154        // Send initialized notification
155        self.send_notification("initialized", serde_json::json!({}))?;
156
157        self.initialized = true;
158        Ok(result)
159    }
160
161    /// Returns whether the client has been initialized.
162    #[must_use]
163    pub fn is_initialized(&self) -> bool {
164        self.initialized
165    }
166
167    /// Returns the server info after initialization.
168    #[must_use]
169    pub fn server_info(&self) -> Option<&ServerInfo> {
170        self.server_info.as_ref()
171    }
172
173    /// Returns the server capabilities after initialization.
174    #[must_use]
175    pub fn server_capabilities(&self) -> Option<&ServerCapabilities> {
176        self.server_capabilities.as_ref()
177    }
178
179    /// Returns the protocol version after initialization.
180    #[must_use]
181    pub fn protocol_version(&self) -> Option<&str> {
182        self.protocol_version.as_deref()
183    }
184
185    /// Lists available tools.
186    ///
187    /// # Errors
188    ///
189    /// Returns an error if the request fails.
190    pub fn list_tools(&mut self) -> McpResult<Vec<Tool>> {
191        self.ensure_initialized()?;
192        let params = ListToolsParams::default();
193        let result: ListToolsResult = self.send_request("tools/list", params)?;
194        Ok(result.tools)
195    }
196
197    /// Calls a tool with the given arguments.
198    ///
199    /// # Errors
200    ///
201    /// Returns an error if the tool call fails.
202    pub fn call_tool(
203        &mut self,
204        name: &str,
205        arguments: serde_json::Value,
206    ) -> McpResult<Vec<Content>> {
207        self.ensure_initialized()?;
208        let params = CallToolParams {
209            name: name.to_string(),
210            arguments: Some(arguments),
211            meta: None,
212        };
213        let result: CallToolResult = self.send_request("tools/call", params)?;
214
215        if result.is_error {
216            let error_msg = result
217                .content
218                .first()
219                .and_then(|c| match c {
220                    Content::Text { text } => Some(text.clone()),
221                    _ => None,
222                })
223                .unwrap_or_else(|| "Tool execution failed".to_string());
224            return Err(McpError::tool_error(error_msg));
225        }
226
227        Ok(result.content)
228    }
229
230    /// Lists available resources.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the request fails.
235    pub fn list_resources(&mut self) -> McpResult<Vec<Resource>> {
236        self.ensure_initialized()?;
237        let params = ListResourcesParams::default();
238        let result: ListResourcesResult = self.send_request("resources/list", params)?;
239        Ok(result.resources)
240    }
241
242    /// Lists available resource templates.
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if the request fails.
247    pub fn list_resource_templates(&mut self) -> McpResult<Vec<ResourceTemplate>> {
248        self.ensure_initialized()?;
249        let params = ListResourceTemplatesParams::default();
250        let result: ListResourceTemplatesResult =
251            self.send_request("resources/templates/list", params)?;
252        Ok(result.resource_templates)
253    }
254
255    /// Reads a resource by URI.
256    ///
257    /// # Errors
258    ///
259    /// Returns an error if the resource cannot be read.
260    pub fn read_resource(&mut self, uri: &str) -> McpResult<Vec<ResourceContent>> {
261        self.ensure_initialized()?;
262        let params = ReadResourceParams {
263            uri: uri.to_string(),
264            meta: None,
265        };
266        let result: ReadResourceResult = self.send_request("resources/read", params)?;
267        Ok(result.contents)
268    }
269
270    /// Lists available prompts.
271    ///
272    /// # Errors
273    ///
274    /// Returns an error if the request fails.
275    pub fn list_prompts(&mut self) -> McpResult<Vec<Prompt>> {
276        self.ensure_initialized()?;
277        let params = ListPromptsParams::default();
278        let result: ListPromptsResult = self.send_request("prompts/list", params)?;
279        Ok(result.prompts)
280    }
281
282    /// Gets a prompt with the given arguments.
283    ///
284    /// # Errors
285    ///
286    /// Returns an error if the prompt cannot be retrieved.
287    pub fn get_prompt(
288        &mut self,
289        name: &str,
290        arguments: HashMap<String, String>,
291    ) -> McpResult<Vec<PromptMessage>> {
292        self.ensure_initialized()?;
293        let params = GetPromptParams {
294            name: name.to_string(),
295            arguments: if arguments.is_empty() {
296                None
297            } else {
298                Some(arguments)
299            },
300            meta: None,
301        };
302        let result: GetPromptResult = self.send_request("prompts/get", params)?;
303        Ok(result.messages)
304    }
305
306    /// Sends a raw JSON-RPC request and returns the raw response.
307    ///
308    /// Useful for testing custom or non-standard methods.
309    ///
310    /// # Errors
311    ///
312    /// Returns an error if the request fails.
313    pub fn send_raw_request(
314        &mut self,
315        method: &str,
316        params: serde_json::Value,
317    ) -> McpResult<serde_json::Value> {
318        let id = self.next_request_id();
319        #[allow(clippy::cast_possible_wrap)]
320        let request = JsonRpcRequest::new(method, Some(params), id as i64);
321
322        self.transport
323            .send(&self.cx, &JsonRpcMessage::Request(request))
324            .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
325
326        #[allow(clippy::cast_possible_wrap)]
327        let response = self.recv_response(&RequestId::Number(id as i64))?;
328
329        if let Some(error) = response.error {
330            return Err(McpError::new(
331                fastmcp_core::McpErrorCode::from(error.code),
332                error.message,
333            ));
334        }
335
336        response
337            .result
338            .ok_or_else(|| McpError::internal_error("No result in response"))
339    }
340
341    /// Closes the client connection.
342    pub fn close(mut self) {
343        let _ = self.transport.close();
344    }
345
346    /// Returns a reference to the transport for advanced testing.
347    #[must_use]
348    pub fn transport(&self) -> &MemoryTransport {
349        &self.transport
350    }
351
352    /// Returns a mutable reference to the transport for advanced testing.
353    pub fn transport_mut(&mut self) -> &mut MemoryTransport {
354        &mut self.transport
355    }
356
357    // --- Private helpers ---
358
359    fn ensure_initialized(&self) -> McpResult<()> {
360        if !self.initialized {
361            return Err(McpError::internal_error(
362                "Client not initialized. Call initialize() first.",
363            ));
364        }
365        Ok(())
366    }
367
368    fn next_request_id(&self) -> u64 {
369        self.next_id.fetch_add(1, Ordering::SeqCst)
370    }
371
372    fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
373        &mut self,
374        method: &str,
375        params: P,
376    ) -> McpResult<R> {
377        let id = self.next_request_id();
378        let params_value = serde_json::to_value(params)
379            .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
380
381        #[allow(clippy::cast_possible_wrap)]
382        let request_id = RequestId::Number(id as i64);
383        #[allow(clippy::cast_possible_wrap)]
384        let request = JsonRpcRequest::new(method, Some(params_value), id as i64);
385
386        self.transport
387            .send(&self.cx, &JsonRpcMessage::Request(request))
388            .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
389
390        let response = self.recv_response(&request_id)?;
391
392        if let Some(error) = response.error {
393            return Err(McpError::new(
394                fastmcp_core::McpErrorCode::from(error.code),
395                error.message,
396            ));
397        }
398
399        let result = response
400            .result
401            .ok_or_else(|| McpError::internal_error("No result in response"))?;
402
403        serde_json::from_value(result)
404            .map_err(|e| McpError::internal_error(format!("Failed to deserialize response: {e}")))
405    }
406
407    fn send_notification<P: serde::Serialize>(&mut self, method: &str, params: P) -> McpResult<()> {
408        let params_value = serde_json::to_value(params)
409            .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
410
411        let request = JsonRpcRequest {
412            jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
413            method: method.to_string(),
414            params: Some(params_value),
415            id: None,
416        };
417
418        self.transport
419            .send(&self.cx, &JsonRpcMessage::Request(request))
420            .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
421
422        Ok(())
423    }
424
425    fn recv_response(
426        &mut self,
427        expected_id: &RequestId,
428    ) -> McpResult<fastmcp_protocol::JsonRpcResponse> {
429        loop {
430            let message = self
431                .transport
432                .recv(&self.cx)
433                .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
434
435            match message {
436                JsonRpcMessage::Response(response) => {
437                    if let Some(ref id) = response.id {
438                        if id != expected_id {
439                            continue;
440                        }
441                    }
442                    return Ok(response);
443                }
444                JsonRpcMessage::Request(_request) => {
445                    // Ignore server-initiated requests for now
446                    // (notifications, progress updates, etc.)
447                    continue;
448                }
449            }
450        }
451    }
452}
453
454impl std::fmt::Debug for TestClient {
455    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456        f.debug_struct("TestClient")
457            .field("client_info", &self.client_info)
458            .field("initialized", &self.initialized)
459            .field("server_info", &self.server_info)
460            .finish_non_exhaustive()
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use fastmcp_transport::memory::create_memory_transport_pair;
468
469    #[test]
470    fn test_client_creation() {
471        let (client_transport, _server_transport) = create_memory_transport_pair();
472        let client = TestClient::new(client_transport);
473        assert!(!client.is_initialized());
474    }
475
476    #[test]
477    fn test_client_with_info() {
478        let (client_transport, _server_transport) = create_memory_transport_pair();
479        let client = TestClient::new(client_transport).with_client_info("my-client", "2.0.0");
480        assert_eq!(client.client_info.name, "my-client");
481        assert_eq!(client.client_info.version, "2.0.0");
482    }
483
484    #[test]
485    fn test_not_initialized_error() {
486        let (client_transport, _server_transport) = create_memory_transport_pair();
487        let mut client = TestClient::new(client_transport);
488        let result = client.list_tools();
489        assert!(result.is_err());
490    }
491}