Skip to main content

fastmcp_rust/testing/
client.rs

1//! Test client for in-process MCP testing.
2//!
3//! Provides a client wrapper that works with MemoryTransport for
4//! testing servers without subprocess spawning.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use asupersync::Cx;
10use fastmcp_core::{McpError, McpResult};
11use fastmcp_protocol::{
12    CallToolParams, CallToolResult, ClientCapabilities, ClientInfo, Content, GetPromptParams,
13    GetPromptResult, InitializeParams, InitializeResult, JsonRpcMessage, JsonRpcRequest,
14    JsonRpcResponse, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
15    ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListToolsParams,
16    ListToolsResult, PROTOCOL_VERSION, Prompt, PromptMessage, ReadResourceParams,
17    ReadResourceResult, RequestId, Resource, ResourceContent, ResourceTemplate, ServerCapabilities,
18    ServerInfo, Tool,
19};
20use fastmcp_transport::Transport;
21use fastmcp_transport::memory::MemoryTransport;
22
23/// Test client for in-process MCP testing.
24///
25/// Unlike the production `Client`, this works with `MemoryTransport` for
26/// fast, in-process testing without subprocess spawning.
27///
28/// # Example
29///
30/// ```ignore
31/// use fastmcp_rust::testing::prelude::*;
32///
33/// let (router, client_transport, server_transport) = TestServer::builder()
34///     .with_tool(my_tool)
35///     .build();
36/// // Run server in a background thread (omitted here). Prefer using the
37/// // higher-level E2E harness helpers in this crate which join threads on drop.
38///
39/// // Create test client
40/// let mut client = TestClient::new(client_transport);
41/// client.initialize().unwrap();
42///
43/// // Test operations
44/// let tools = client.list_tools().unwrap();
45/// assert!(!tools.is_empty());
46/// ```
47pub struct TestClient {
48    /// Transport for communication.
49    transport: MemoryTransport,
50    /// Capability context for cancellation.
51    cx: Cx,
52    /// Client identification info.
53    client_info: ClientInfo,
54    /// Client capabilities.
55    capabilities: ClientCapabilities,
56    /// Server info after initialization.
57    server_info: Option<ServerInfo>,
58    /// Server capabilities after initialization.
59    server_capabilities: Option<ServerCapabilities>,
60    /// Protocol version after initialization.
61    protocol_version: Option<String>,
62    /// Request ID counter.
63    next_id: AtomicU64,
64    /// Whether client has been initialized.
65    initialized: bool,
66}
67
68impl TestClient {
69    /// Creates a new test client with the given transport.
70    ///
71    /// # Example
72    ///
73    /// ```ignore
74    /// let (client_transport, server_transport) = create_memory_transport_pair();
75    /// let client = TestClient::new(client_transport);
76    /// ```
77    #[must_use]
78    pub fn new(transport: MemoryTransport) -> Self {
79        Self {
80            transport,
81            cx: Cx::for_testing(),
82            client_info: ClientInfo {
83                name: "test-client".to_owned(),
84                version: "1.0.0".to_owned(),
85            },
86            capabilities: ClientCapabilities::default(),
87            server_info: None,
88            server_capabilities: None,
89            protocol_version: None,
90            next_id: AtomicU64::new(1),
91            initialized: false,
92        }
93    }
94
95    /// Creates a new test client with custom Cx.
96    #[must_use]
97    pub fn with_cx(transport: MemoryTransport, cx: Cx) -> Self {
98        Self {
99            transport,
100            cx,
101            client_info: ClientInfo {
102                name: "test-client".to_owned(),
103                version: "1.0.0".to_owned(),
104            },
105            capabilities: ClientCapabilities::default(),
106            server_info: None,
107            server_capabilities: None,
108            protocol_version: None,
109            next_id: AtomicU64::new(1),
110            initialized: false,
111        }
112    }
113
114    /// Sets the client info for initialization.
115    #[must_use]
116    pub fn with_client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
117        self.client_info = ClientInfo {
118            name: name.into(),
119            version: version.into(),
120        };
121        self
122    }
123
124    /// Sets the client capabilities for initialization.
125    #[must_use]
126    pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
127        self.capabilities = capabilities;
128        self
129    }
130
131    /// Performs the MCP initialization handshake.
132    ///
133    /// Must be called before any other operations.
134    ///
135    /// # Errors
136    ///
137    /// Returns an error if the initialization fails.
138    pub fn initialize(&mut self) -> McpResult<InitializeResult> {
139        let params = InitializeParams {
140            protocol_version: PROTOCOL_VERSION.to_string(),
141            capabilities: self.capabilities.clone(),
142            client_info: self.client_info.clone(),
143        };
144
145        let result: InitializeResult = self.send_request("initialize", params)?;
146
147        // Store server info
148        self.server_info = Some(result.server_info.clone());
149        self.server_capabilities = Some(result.capabilities.clone());
150        self.protocol_version = Some(result.protocol_version.clone());
151
152        // Send initialized notification
153        self.send_notification("initialized", serde_json::json!({}))?;
154
155        self.initialized = true;
156        Ok(result)
157    }
158
159    /// Returns whether the client has been initialized.
160    #[must_use]
161    pub fn is_initialized(&self) -> bool {
162        self.initialized
163    }
164
165    /// Returns the server info after initialization.
166    #[must_use]
167    pub fn server_info(&self) -> Option<&ServerInfo> {
168        self.server_info.as_ref()
169    }
170
171    /// Returns the server capabilities after initialization.
172    #[must_use]
173    pub fn server_capabilities(&self) -> Option<&ServerCapabilities> {
174        self.server_capabilities.as_ref()
175    }
176
177    /// Returns the protocol version after initialization.
178    #[must_use]
179    pub fn protocol_version(&self) -> Option<&str> {
180        self.protocol_version.as_deref()
181    }
182
183    /// Lists available tools.
184    ///
185    /// # Errors
186    ///
187    /// Returns an error if the request fails.
188    pub fn list_tools(&mut self) -> McpResult<Vec<Tool>> {
189        self.ensure_initialized()?;
190        let params = ListToolsParams::default();
191        let result: ListToolsResult = self.send_request("tools/list", params)?;
192        Ok(result.tools)
193    }
194
195    /// Calls a tool with the given arguments.
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the tool call fails.
200    pub fn call_tool(
201        &mut self,
202        name: &str,
203        arguments: serde_json::Value,
204    ) -> McpResult<Vec<Content>> {
205        self.ensure_initialized()?;
206        let params = CallToolParams {
207            name: name.to_string(),
208            arguments: Some(arguments),
209            meta: None,
210        };
211        let result: CallToolResult = self.send_request("tools/call", params)?;
212
213        if result.is_error {
214            let error_msg = result
215                .content
216                .first()
217                .and_then(|c| match c {
218                    Content::Text { text } => Some(text.clone()),
219                    _ => None,
220                })
221                .unwrap_or_else(|| "Tool execution failed".to_string());
222            return Err(McpError::tool_error(error_msg));
223        }
224
225        Ok(result.content)
226    }
227
228    /// Lists available resources.
229    ///
230    /// # Errors
231    ///
232    /// Returns an error if the request fails.
233    pub fn list_resources(&mut self) -> McpResult<Vec<Resource>> {
234        self.ensure_initialized()?;
235        let params = ListResourcesParams::default();
236        let result: ListResourcesResult = self.send_request("resources/list", params)?;
237        Ok(result.resources)
238    }
239
240    /// Lists available resource templates.
241    ///
242    /// # Errors
243    ///
244    /// Returns an error if the request fails.
245    pub fn list_resource_templates(&mut self) -> McpResult<Vec<ResourceTemplate>> {
246        self.ensure_initialized()?;
247        let params = ListResourceTemplatesParams::default();
248        let result: ListResourceTemplatesResult =
249            self.send_request("resources/templates/list", params)?;
250        Ok(result.resource_templates)
251    }
252
253    /// Reads a resource by URI.
254    ///
255    /// # Errors
256    ///
257    /// Returns an error if the resource cannot be read.
258    pub fn read_resource(&mut self, uri: &str) -> McpResult<Vec<ResourceContent>> {
259        self.ensure_initialized()?;
260        let params = ReadResourceParams {
261            uri: uri.to_string(),
262            meta: None,
263        };
264        let result: ReadResourceResult = self.send_request("resources/read", params)?;
265        Ok(result.contents)
266    }
267
268    /// Lists available prompts.
269    ///
270    /// # Errors
271    ///
272    /// Returns an error if the request fails.
273    pub fn list_prompts(&mut self) -> McpResult<Vec<Prompt>> {
274        self.ensure_initialized()?;
275        let params = ListPromptsParams::default();
276        let result: ListPromptsResult = self.send_request("prompts/list", params)?;
277        Ok(result.prompts)
278    }
279
280    /// Gets a prompt with the given arguments.
281    ///
282    /// # Errors
283    ///
284    /// Returns an error if the prompt cannot be retrieved.
285    pub fn get_prompt(
286        &mut self,
287        name: &str,
288        arguments: HashMap<String, String>,
289    ) -> McpResult<Vec<PromptMessage>> {
290        self.ensure_initialized()?;
291        let params = GetPromptParams {
292            name: name.to_string(),
293            arguments: if arguments.is_empty() {
294                None
295            } else {
296                Some(arguments)
297            },
298            meta: None,
299        };
300        let result: GetPromptResult = self.send_request("prompts/get", params)?;
301        Ok(result.messages)
302    }
303
304    /// Sends a raw JSON-RPC request and returns the raw response.
305    ///
306    /// Useful for testing custom or non-standard methods.
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if the request fails.
311    pub fn send_raw_request(
312        &mut self,
313        method: &str,
314        params: serde_json::Value,
315    ) -> McpResult<serde_json::Value> {
316        let id = self.next_request_id();
317        #[allow(clippy::cast_possible_wrap)]
318        let request = JsonRpcRequest::new(method, Some(params), id as i64);
319
320        self.transport
321            .send(&self.cx, &JsonRpcMessage::Request(request))
322            .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
323
324        #[allow(clippy::cast_possible_wrap)]
325        let response = self.recv_response(&RequestId::Number(id as i64))?;
326
327        if let Some(error) = response.error {
328            return Err(McpError::new(
329                fastmcp_core::McpErrorCode::from(error.code),
330                error.message,
331            ));
332        }
333
334        response
335            .result
336            .ok_or_else(|| McpError::internal_error("No result in response"))
337    }
338
339    /// Closes the client connection.
340    pub fn close(&mut self) {
341        let _ = self.transport.close();
342    }
343
344    /// Returns a reference to the transport for advanced testing.
345    #[must_use]
346    pub fn transport(&self) -> &MemoryTransport {
347        &self.transport
348    }
349
350    /// Returns a mutable reference to the transport for advanced testing.
351    pub fn transport_mut(&mut self) -> &mut MemoryTransport {
352        &mut self.transport
353    }
354
355    /// Sends a raw JSON-RPC request with already-serialized params.
356    ///
357    /// This is intended for advanced E2E tests that need to inject protocol fields
358    /// not covered by the typed helper methods (for example, auth metadata).
359    ///
360    /// # Errors
361    ///
362    /// Returns an error if the request fails or the response contains an error payload.
363    pub fn send_request_json(
364        &mut self,
365        method: &str,
366        params_value: serde_json::Value,
367    ) -> McpResult<serde_json::Value> {
368        self.ensure_initialized()?;
369
370        let id = self.next_request_id();
371        #[allow(clippy::cast_possible_wrap)]
372        let request_id = RequestId::Number(id as i64);
373        #[allow(clippy::cast_possible_wrap)]
374        let request = JsonRpcRequest::new(method, Some(params_value), id as i64);
375
376        self.transport
377            .send(&self.cx, &JsonRpcMessage::Request(request))
378            .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
379
380        let response = self.recv_response(&request_id)?;
381
382        if let Some(error) = response.error {
383            return Err(McpError::new(
384                fastmcp_core::McpErrorCode::from(error.code),
385                error.message,
386            ));
387        }
388
389        response
390            .result
391            .ok_or_else(|| McpError::internal_error("No result in response"))
392    }
393
394    // --- Private helpers ---
395
396    fn ensure_initialized(&self) -> McpResult<()> {
397        if !self.initialized {
398            return Err(McpError::internal_error(
399                "Client not initialized. Call initialize() first.",
400            ));
401        }
402        Ok(())
403    }
404
405    fn next_request_id(&self) -> u64 {
406        self.next_id.fetch_add(1, Ordering::SeqCst)
407    }
408
409    fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
410        &mut self,
411        method: &str,
412        params: P,
413    ) -> McpResult<R> {
414        let id = self.next_request_id();
415        let params_value = serde_json::to_value(params)
416            .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
417
418        #[allow(clippy::cast_possible_wrap)]
419        let request_id = RequestId::Number(id as i64);
420        #[allow(clippy::cast_possible_wrap)]
421        let request = JsonRpcRequest::new(method, Some(params_value), id as i64);
422
423        self.transport
424            .send(&self.cx, &JsonRpcMessage::Request(request))
425            .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
426
427        let response = self.recv_response(&request_id)?;
428
429        if let Some(error) = response.error {
430            return Err(McpError::new(
431                fastmcp_core::McpErrorCode::from(error.code),
432                error.message,
433            ));
434        }
435
436        let result = response
437            .result
438            .ok_or_else(|| McpError::internal_error("No result in response"))?;
439
440        serde_json::from_value(result)
441            .map_err(|e| McpError::internal_error(format!("Failed to deserialize response: {e}")))
442    }
443
444    fn send_notification<P: serde::Serialize>(&mut self, method: &str, params: P) -> McpResult<()> {
445        let params_value = serde_json::to_value(params)
446            .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
447
448        let request = JsonRpcRequest {
449            jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
450            method: method.to_string(),
451            params: Some(params_value),
452            id: None,
453        };
454
455        self.transport
456            .send(&self.cx, &JsonRpcMessage::Request(request))
457            .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
458
459        Ok(())
460    }
461
462    fn recv_response(
463        &mut self,
464        expected_id: &RequestId,
465    ) -> McpResult<fastmcp_protocol::JsonRpcResponse> {
466        loop {
467            let message = self
468                .transport
469                .recv(&self.cx)
470                .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
471
472            match message {
473                JsonRpcMessage::Response(response) => {
474                    if let Some(ref id) = response.id {
475                        if id != expected_id {
476                            continue;
477                        }
478                    }
479                    return Ok(response);
480                }
481                JsonRpcMessage::Request(request) => {
482                    // Notifications don't require a response.
483                    let Some(id) = request.id.clone() else {
484                        continue;
485                    };
486
487                    // This test client does not implement server-initiated protocols.
488                    // To avoid deadlocks (server awaiting a response), respond with MethodNotFound.
489                    let err = McpError::method_not_found(&request.method);
490                    let response = JsonRpcResponse::error(Some(id), err.into());
491                    self.transport
492                        .send(&self.cx, &JsonRpcMessage::Response(response))
493                        .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
494
495                    continue;
496                }
497            }
498        }
499    }
500}
501
502impl std::fmt::Debug for TestClient {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        f.debug_struct("TestClient")
505            .field("client_info", &self.client_info)
506            .field("initialized", &self.initialized)
507            .field("server_info", &self.server_info)
508            .finish_non_exhaustive()
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use fastmcp_transport::memory::create_memory_transport_pair;
516
517    #[test]
518    fn test_client_creation() {
519        let (client_transport, _server_transport) = create_memory_transport_pair();
520        let client = TestClient::new(client_transport);
521        assert!(!client.is_initialized());
522    }
523
524    #[test]
525    fn test_client_with_info() {
526        let (client_transport, _server_transport) = create_memory_transport_pair();
527        let client = TestClient::new(client_transport).with_client_info("my-client", "2.0.0");
528        assert_eq!(client.client_info.name, "my-client");
529        assert_eq!(client.client_info.version, "2.0.0");
530    }
531
532    #[test]
533    fn test_not_initialized_error() {
534        let (client_transport, _server_transport) = create_memory_transport_pair();
535        let mut client = TestClient::new(client_transport);
536        let result = client.list_tools();
537        assert!(result.is_err());
538    }
539
540    // =========================================================================
541    // Additional coverage tests (bd-8zle)
542    // =========================================================================
543
544    #[test]
545    fn with_cx_sets_custom_cx() {
546        let (ct, _st) = create_memory_transport_pair();
547        let cx = Cx::for_testing();
548        let client = TestClient::with_cx(ct, cx);
549        assert!(!client.is_initialized());
550    }
551
552    #[test]
553    fn with_capabilities_sets_capabilities() {
554        let (ct, _st) = create_memory_transport_pair();
555        let caps = ClientCapabilities {
556            sampling: Some(fastmcp_protocol::SamplingCapability {}),
557            ..Default::default()
558        };
559        let client = TestClient::new(ct).with_capabilities(caps);
560        assert!(client.capabilities.sampling.is_some());
561    }
562
563    #[test]
564    fn pre_init_getters_return_none() {
565        let (ct, _st) = create_memory_transport_pair();
566        let client = TestClient::new(ct);
567        assert!(client.server_info().is_none());
568        assert!(client.server_capabilities().is_none());
569        assert!(client.protocol_version().is_none());
570    }
571
572    #[test]
573    fn debug_output_includes_key_fields() {
574        let (ct, _st) = create_memory_transport_pair();
575        let client = TestClient::new(ct);
576        let debug = format!("{client:?}");
577        assert!(debug.contains("TestClient"));
578        assert!(debug.contains("test-client"));
579        assert!(debug.contains("initialized"));
580    }
581
582    #[test]
583    fn transport_accessors() {
584        let (ct, _st) = create_memory_transport_pair();
585        let mut client = TestClient::new(ct);
586        // Immutable accessor
587        let _ = client.transport();
588        // Mutable accessor
589        let _ = client.transport_mut();
590    }
591
592    #[test]
593    fn close_does_not_panic() {
594        let (ct, _st) = create_memory_transport_pair();
595        let mut client = TestClient::new(ct);
596        client.close();
597    }
598
599    #[test]
600    fn request_id_auto_increments() {
601        let (ct, _st) = create_memory_transport_pair();
602        let client = TestClient::new(ct);
603        let id1 = client.next_request_id();
604        let id2 = client.next_request_id();
605        let id3 = client.next_request_id();
606        assert_eq!(id1, 1);
607        assert_eq!(id2, 2);
608        assert_eq!(id3, 3);
609    }
610
611    #[test]
612    fn ensure_initialized_error_message() {
613        let (ct, _st) = create_memory_transport_pair();
614        let mut client = TestClient::new(ct);
615        let err = client.list_tools().unwrap_err();
616        let msg = format!("{err}");
617        assert!(msg.contains("not initialized"), "error was: {msg}");
618    }
619}