1use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use asupersync::Cx;
10use fastmcp_core::{McpError, McpResult};
11use fastmcp_protocol::{
12 CallToolParams, CallToolResult, ClientCapabilities, ClientInfo, Content, GetPromptParams,
13 GetPromptResult, InitializeParams, InitializeResult, JsonRpcMessage, JsonRpcRequest,
14 ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams, ListResourceTemplatesResult,
15 ListResourcesParams, ListResourcesResult, ListToolsParams, ListToolsResult, PROTOCOL_VERSION,
16 Prompt, PromptMessage, ReadResourceParams, ReadResourceResult, RequestId, Resource,
17 ResourceContent, ResourceTemplate, ServerCapabilities, ServerInfo, Tool,
18};
19use fastmcp_transport::Transport;
20use fastmcp_transport::memory::MemoryTransport;
21
22pub struct TestClient {
50 transport: MemoryTransport,
52 cx: Cx,
54 client_info: ClientInfo,
56 capabilities: ClientCapabilities,
58 server_info: Option<ServerInfo>,
60 server_capabilities: Option<ServerCapabilities>,
62 protocol_version: Option<String>,
64 next_id: AtomicU64,
66 initialized: bool,
68}
69
70impl TestClient {
71 #[must_use]
80 pub fn new(transport: MemoryTransport) -> Self {
81 Self {
82 transport,
83 cx: Cx::for_testing(),
84 client_info: ClientInfo {
85 name: "test-client".to_owned(),
86 version: "1.0.0".to_owned(),
87 },
88 capabilities: ClientCapabilities::default(),
89 server_info: None,
90 server_capabilities: None,
91 protocol_version: None,
92 next_id: AtomicU64::new(1),
93 initialized: false,
94 }
95 }
96
97 #[must_use]
99 pub fn with_cx(transport: MemoryTransport, cx: Cx) -> Self {
100 Self {
101 transport,
102 cx,
103 client_info: ClientInfo {
104 name: "test-client".to_owned(),
105 version: "1.0.0".to_owned(),
106 },
107 capabilities: ClientCapabilities::default(),
108 server_info: None,
109 server_capabilities: None,
110 protocol_version: None,
111 next_id: AtomicU64::new(1),
112 initialized: false,
113 }
114 }
115
116 #[must_use]
118 pub fn with_client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
119 self.client_info = ClientInfo {
120 name: name.into(),
121 version: version.into(),
122 };
123 self
124 }
125
126 #[must_use]
128 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
129 self.capabilities = capabilities;
130 self
131 }
132
133 pub fn initialize(&mut self) -> McpResult<InitializeResult> {
141 let params = InitializeParams {
142 protocol_version: PROTOCOL_VERSION.to_string(),
143 capabilities: self.capabilities.clone(),
144 client_info: self.client_info.clone(),
145 };
146
147 let result: InitializeResult = self.send_request("initialize", params)?;
148
149 self.server_info = Some(result.server_info.clone());
151 self.server_capabilities = Some(result.capabilities.clone());
152 self.protocol_version = Some(result.protocol_version.clone());
153
154 self.send_notification("initialized", serde_json::json!({}))?;
156
157 self.initialized = true;
158 Ok(result)
159 }
160
161 #[must_use]
163 pub fn is_initialized(&self) -> bool {
164 self.initialized
165 }
166
167 #[must_use]
169 pub fn server_info(&self) -> Option<&ServerInfo> {
170 self.server_info.as_ref()
171 }
172
173 #[must_use]
175 pub fn server_capabilities(&self) -> Option<&ServerCapabilities> {
176 self.server_capabilities.as_ref()
177 }
178
179 #[must_use]
181 pub fn protocol_version(&self) -> Option<&str> {
182 self.protocol_version.as_deref()
183 }
184
185 pub fn list_tools(&mut self) -> McpResult<Vec<Tool>> {
191 self.ensure_initialized()?;
192 let params = ListToolsParams::default();
193 let result: ListToolsResult = self.send_request("tools/list", params)?;
194 Ok(result.tools)
195 }
196
197 pub fn call_tool(
203 &mut self,
204 name: &str,
205 arguments: serde_json::Value,
206 ) -> McpResult<Vec<Content>> {
207 self.ensure_initialized()?;
208 let params = CallToolParams {
209 name: name.to_string(),
210 arguments: Some(arguments),
211 meta: None,
212 };
213 let result: CallToolResult = self.send_request("tools/call", params)?;
214
215 if result.is_error {
216 let error_msg = result
217 .content
218 .first()
219 .and_then(|c| match c {
220 Content::Text { text } => Some(text.clone()),
221 _ => None,
222 })
223 .unwrap_or_else(|| "Tool execution failed".to_string());
224 return Err(McpError::tool_error(error_msg));
225 }
226
227 Ok(result.content)
228 }
229
230 pub fn list_resources(&mut self) -> McpResult<Vec<Resource>> {
236 self.ensure_initialized()?;
237 let params = ListResourcesParams::default();
238 let result: ListResourcesResult = self.send_request("resources/list", params)?;
239 Ok(result.resources)
240 }
241
242 pub fn list_resource_templates(&mut self) -> McpResult<Vec<ResourceTemplate>> {
248 self.ensure_initialized()?;
249 let params = ListResourceTemplatesParams::default();
250 let result: ListResourceTemplatesResult =
251 self.send_request("resources/templates/list", params)?;
252 Ok(result.resource_templates)
253 }
254
255 pub fn read_resource(&mut self, uri: &str) -> McpResult<Vec<ResourceContent>> {
261 self.ensure_initialized()?;
262 let params = ReadResourceParams {
263 uri: uri.to_string(),
264 meta: None,
265 };
266 let result: ReadResourceResult = self.send_request("resources/read", params)?;
267 Ok(result.contents)
268 }
269
270 pub fn list_prompts(&mut self) -> McpResult<Vec<Prompt>> {
276 self.ensure_initialized()?;
277 let params = ListPromptsParams::default();
278 let result: ListPromptsResult = self.send_request("prompts/list", params)?;
279 Ok(result.prompts)
280 }
281
282 pub fn get_prompt(
288 &mut self,
289 name: &str,
290 arguments: HashMap<String, String>,
291 ) -> McpResult<Vec<PromptMessage>> {
292 self.ensure_initialized()?;
293 let params = GetPromptParams {
294 name: name.to_string(),
295 arguments: if arguments.is_empty() {
296 None
297 } else {
298 Some(arguments)
299 },
300 meta: None,
301 };
302 let result: GetPromptResult = self.send_request("prompts/get", params)?;
303 Ok(result.messages)
304 }
305
306 pub fn send_raw_request(
314 &mut self,
315 method: &str,
316 params: serde_json::Value,
317 ) -> McpResult<serde_json::Value> {
318 let id = self.next_request_id();
319 #[allow(clippy::cast_possible_wrap)]
320 let request = JsonRpcRequest::new(method, Some(params), id as i64);
321
322 self.transport
323 .send(&self.cx, &JsonRpcMessage::Request(request))
324 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
325
326 #[allow(clippy::cast_possible_wrap)]
327 let response = self.recv_response(&RequestId::Number(id as i64))?;
328
329 if let Some(error) = response.error {
330 return Err(McpError::new(
331 fastmcp_core::McpErrorCode::from(error.code),
332 error.message,
333 ));
334 }
335
336 response
337 .result
338 .ok_or_else(|| McpError::internal_error("No result in response"))
339 }
340
341 pub fn close(mut self) {
343 let _ = self.transport.close();
344 }
345
346 #[must_use]
348 pub fn transport(&self) -> &MemoryTransport {
349 &self.transport
350 }
351
352 pub fn transport_mut(&mut self) -> &mut MemoryTransport {
354 &mut self.transport
355 }
356
357 fn ensure_initialized(&self) -> McpResult<()> {
360 if !self.initialized {
361 return Err(McpError::internal_error(
362 "Client not initialized. Call initialize() first.",
363 ));
364 }
365 Ok(())
366 }
367
368 fn next_request_id(&self) -> u64 {
369 self.next_id.fetch_add(1, Ordering::SeqCst)
370 }
371
372 fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
373 &mut self,
374 method: &str,
375 params: P,
376 ) -> McpResult<R> {
377 let id = self.next_request_id();
378 let params_value = serde_json::to_value(params)
379 .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
380
381 #[allow(clippy::cast_possible_wrap)]
382 let request_id = RequestId::Number(id as i64);
383 #[allow(clippy::cast_possible_wrap)]
384 let request = JsonRpcRequest::new(method, Some(params_value), id as i64);
385
386 self.transport
387 .send(&self.cx, &JsonRpcMessage::Request(request))
388 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
389
390 let response = self.recv_response(&request_id)?;
391
392 if let Some(error) = response.error {
393 return Err(McpError::new(
394 fastmcp_core::McpErrorCode::from(error.code),
395 error.message,
396 ));
397 }
398
399 let result = response
400 .result
401 .ok_or_else(|| McpError::internal_error("No result in response"))?;
402
403 serde_json::from_value(result)
404 .map_err(|e| McpError::internal_error(format!("Failed to deserialize response: {e}")))
405 }
406
407 fn send_notification<P: serde::Serialize>(&mut self, method: &str, params: P) -> McpResult<()> {
408 let params_value = serde_json::to_value(params)
409 .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
410
411 let request = JsonRpcRequest {
412 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
413 method: method.to_string(),
414 params: Some(params_value),
415 id: None,
416 };
417
418 self.transport
419 .send(&self.cx, &JsonRpcMessage::Request(request))
420 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
421
422 Ok(())
423 }
424
425 fn recv_response(
426 &mut self,
427 expected_id: &RequestId,
428 ) -> McpResult<fastmcp_protocol::JsonRpcResponse> {
429 loop {
430 let message = self
431 .transport
432 .recv(&self.cx)
433 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
434
435 match message {
436 JsonRpcMessage::Response(response) => {
437 if let Some(ref id) = response.id {
438 if id != expected_id {
439 continue;
440 }
441 }
442 return Ok(response);
443 }
444 JsonRpcMessage::Request(_request) => {
445 continue;
448 }
449 }
450 }
451 }
452}
453
454impl std::fmt::Debug for TestClient {
455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456 f.debug_struct("TestClient")
457 .field("client_info", &self.client_info)
458 .field("initialized", &self.initialized)
459 .field("server_info", &self.server_info)
460 .finish_non_exhaustive()
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use fastmcp_transport::memory::create_memory_transport_pair;
468
469 #[test]
470 fn test_client_creation() {
471 let (client_transport, _server_transport) = create_memory_transport_pair();
472 let client = TestClient::new(client_transport);
473 assert!(!client.is_initialized());
474 }
475
476 #[test]
477 fn test_client_with_info() {
478 let (client_transport, _server_transport) = create_memory_transport_pair();
479 let client = TestClient::new(client_transport).with_client_info("my-client", "2.0.0");
480 assert_eq!(client.client_info.name, "my-client");
481 assert_eq!(client.client_info.version, "2.0.0");
482 }
483
484 #[test]
485 fn test_not_initialized_error() {
486 let (client_transport, _server_transport) = create_memory_transport_pair();
487 let mut client = TestClient::new(client_transport);
488 let result = client.list_tools();
489 assert!(result.is_err());
490 }
491}