1use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use asupersync::Cx;
10use fastmcp_core::{McpError, McpResult};
11use fastmcp_protocol::{
12 CallToolParams, CallToolResult, ClientCapabilities, ClientInfo, Content, GetPromptParams,
13 GetPromptResult, InitializeParams, InitializeResult, JsonRpcMessage, JsonRpcRequest,
14 JsonRpcResponse, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
15 ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListToolsParams,
16 ListToolsResult, PROTOCOL_VERSION, Prompt, PromptMessage, ReadResourceParams,
17 ReadResourceResult, RequestId, Resource, ResourceContent, ResourceTemplate, ServerCapabilities,
18 ServerInfo, Tool,
19};
20use fastmcp_transport::Transport;
21use fastmcp_transport::memory::MemoryTransport;
22
23pub struct TestClient {
48 transport: MemoryTransport,
50 cx: Cx,
52 client_info: ClientInfo,
54 capabilities: ClientCapabilities,
56 server_info: Option<ServerInfo>,
58 server_capabilities: Option<ServerCapabilities>,
60 protocol_version: Option<String>,
62 next_id: AtomicU64,
64 initialized: bool,
66}
67
68impl TestClient {
69 #[must_use]
78 pub fn new(transport: MemoryTransport) -> Self {
79 Self {
80 transport,
81 cx: Cx::for_testing(),
82 client_info: ClientInfo {
83 name: "test-client".to_owned(),
84 version: "1.0.0".to_owned(),
85 },
86 capabilities: ClientCapabilities::default(),
87 server_info: None,
88 server_capabilities: None,
89 protocol_version: None,
90 next_id: AtomicU64::new(1),
91 initialized: false,
92 }
93 }
94
95 #[must_use]
97 pub fn with_cx(transport: MemoryTransport, cx: Cx) -> Self {
98 Self {
99 transport,
100 cx,
101 client_info: ClientInfo {
102 name: "test-client".to_owned(),
103 version: "1.0.0".to_owned(),
104 },
105 capabilities: ClientCapabilities::default(),
106 server_info: None,
107 server_capabilities: None,
108 protocol_version: None,
109 next_id: AtomicU64::new(1),
110 initialized: false,
111 }
112 }
113
114 #[must_use]
116 pub fn with_client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
117 self.client_info = ClientInfo {
118 name: name.into(),
119 version: version.into(),
120 };
121 self
122 }
123
124 #[must_use]
126 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
127 self.capabilities = capabilities;
128 self
129 }
130
131 pub fn initialize(&mut self) -> McpResult<InitializeResult> {
139 let params = InitializeParams {
140 protocol_version: PROTOCOL_VERSION.to_string(),
141 capabilities: self.capabilities.clone(),
142 client_info: self.client_info.clone(),
143 };
144
145 let result: InitializeResult = self.send_request("initialize", params)?;
146
147 self.server_info = Some(result.server_info.clone());
149 self.server_capabilities = Some(result.capabilities.clone());
150 self.protocol_version = Some(result.protocol_version.clone());
151
152 self.send_notification("initialized", serde_json::json!({}))?;
154
155 self.initialized = true;
156 Ok(result)
157 }
158
159 #[must_use]
161 pub fn is_initialized(&self) -> bool {
162 self.initialized
163 }
164
165 #[must_use]
167 pub fn server_info(&self) -> Option<&ServerInfo> {
168 self.server_info.as_ref()
169 }
170
171 #[must_use]
173 pub fn server_capabilities(&self) -> Option<&ServerCapabilities> {
174 self.server_capabilities.as_ref()
175 }
176
177 #[must_use]
179 pub fn protocol_version(&self) -> Option<&str> {
180 self.protocol_version.as_deref()
181 }
182
183 pub fn list_tools(&mut self) -> McpResult<Vec<Tool>> {
189 self.ensure_initialized()?;
190 let params = ListToolsParams::default();
191 let result: ListToolsResult = self.send_request("tools/list", params)?;
192 Ok(result.tools)
193 }
194
195 pub fn call_tool(
201 &mut self,
202 name: &str,
203 arguments: serde_json::Value,
204 ) -> McpResult<Vec<Content>> {
205 self.ensure_initialized()?;
206 let params = CallToolParams {
207 name: name.to_string(),
208 arguments: Some(arguments),
209 meta: None,
210 };
211 let result: CallToolResult = self.send_request("tools/call", params)?;
212
213 if result.is_error {
214 let error_msg = result
215 .content
216 .first()
217 .and_then(|c| match c {
218 Content::Text { text } => Some(text.clone()),
219 _ => None,
220 })
221 .unwrap_or_else(|| "Tool execution failed".to_string());
222 return Err(McpError::tool_error(error_msg));
223 }
224
225 Ok(result.content)
226 }
227
228 pub fn list_resources(&mut self) -> McpResult<Vec<Resource>> {
234 self.ensure_initialized()?;
235 let params = ListResourcesParams::default();
236 let result: ListResourcesResult = self.send_request("resources/list", params)?;
237 Ok(result.resources)
238 }
239
240 pub fn list_resource_templates(&mut self) -> McpResult<Vec<ResourceTemplate>> {
246 self.ensure_initialized()?;
247 let params = ListResourceTemplatesParams::default();
248 let result: ListResourceTemplatesResult =
249 self.send_request("resources/templates/list", params)?;
250 Ok(result.resource_templates)
251 }
252
253 pub fn read_resource(&mut self, uri: &str) -> McpResult<Vec<ResourceContent>> {
259 self.ensure_initialized()?;
260 let params = ReadResourceParams {
261 uri: uri.to_string(),
262 meta: None,
263 };
264 let result: ReadResourceResult = self.send_request("resources/read", params)?;
265 Ok(result.contents)
266 }
267
268 pub fn list_prompts(&mut self) -> McpResult<Vec<Prompt>> {
274 self.ensure_initialized()?;
275 let params = ListPromptsParams::default();
276 let result: ListPromptsResult = self.send_request("prompts/list", params)?;
277 Ok(result.prompts)
278 }
279
280 pub fn get_prompt(
286 &mut self,
287 name: &str,
288 arguments: HashMap<String, String>,
289 ) -> McpResult<Vec<PromptMessage>> {
290 self.ensure_initialized()?;
291 let params = GetPromptParams {
292 name: name.to_string(),
293 arguments: if arguments.is_empty() {
294 None
295 } else {
296 Some(arguments)
297 },
298 meta: None,
299 };
300 let result: GetPromptResult = self.send_request("prompts/get", params)?;
301 Ok(result.messages)
302 }
303
304 pub fn send_raw_request(
312 &mut self,
313 method: &str,
314 params: serde_json::Value,
315 ) -> McpResult<serde_json::Value> {
316 let id = self.next_request_id();
317 #[allow(clippy::cast_possible_wrap)]
318 let request = JsonRpcRequest::new(method, Some(params), id as i64);
319
320 self.transport
321 .send(&self.cx, &JsonRpcMessage::Request(request))
322 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
323
324 #[allow(clippy::cast_possible_wrap)]
325 let response = self.recv_response(&RequestId::Number(id as i64))?;
326
327 if let Some(error) = response.error {
328 return Err(McpError::new(
329 fastmcp_core::McpErrorCode::from(error.code),
330 error.message,
331 ));
332 }
333
334 response
335 .result
336 .ok_or_else(|| McpError::internal_error("No result in response"))
337 }
338
339 pub fn close(&mut self) {
341 let _ = self.transport.close();
342 }
343
344 #[must_use]
346 pub fn transport(&self) -> &MemoryTransport {
347 &self.transport
348 }
349
350 pub fn transport_mut(&mut self) -> &mut MemoryTransport {
352 &mut self.transport
353 }
354
355 pub fn send_request_json(
364 &mut self,
365 method: &str,
366 params_value: serde_json::Value,
367 ) -> McpResult<serde_json::Value> {
368 self.ensure_initialized()?;
369
370 let id = self.next_request_id();
371 #[allow(clippy::cast_possible_wrap)]
372 let request_id = RequestId::Number(id as i64);
373 #[allow(clippy::cast_possible_wrap)]
374 let request = JsonRpcRequest::new(method, Some(params_value), id as i64);
375
376 self.transport
377 .send(&self.cx, &JsonRpcMessage::Request(request))
378 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
379
380 let response = self.recv_response(&request_id)?;
381
382 if let Some(error) = response.error {
383 return Err(McpError::new(
384 fastmcp_core::McpErrorCode::from(error.code),
385 error.message,
386 ));
387 }
388
389 response
390 .result
391 .ok_or_else(|| McpError::internal_error("No result in response"))
392 }
393
394 fn ensure_initialized(&self) -> McpResult<()> {
397 if !self.initialized {
398 return Err(McpError::internal_error(
399 "Client not initialized. Call initialize() first.",
400 ));
401 }
402 Ok(())
403 }
404
405 fn next_request_id(&self) -> u64 {
406 self.next_id.fetch_add(1, Ordering::SeqCst)
407 }
408
409 fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
410 &mut self,
411 method: &str,
412 params: P,
413 ) -> McpResult<R> {
414 let id = self.next_request_id();
415 let params_value = serde_json::to_value(params)
416 .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
417
418 #[allow(clippy::cast_possible_wrap)]
419 let request_id = RequestId::Number(id as i64);
420 #[allow(clippy::cast_possible_wrap)]
421 let request = JsonRpcRequest::new(method, Some(params_value), id as i64);
422
423 self.transport
424 .send(&self.cx, &JsonRpcMessage::Request(request))
425 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
426
427 let response = self.recv_response(&request_id)?;
428
429 if let Some(error) = response.error {
430 return Err(McpError::new(
431 fastmcp_core::McpErrorCode::from(error.code),
432 error.message,
433 ));
434 }
435
436 let result = response
437 .result
438 .ok_or_else(|| McpError::internal_error("No result in response"))?;
439
440 serde_json::from_value(result)
441 .map_err(|e| McpError::internal_error(format!("Failed to deserialize response: {e}")))
442 }
443
444 fn send_notification<P: serde::Serialize>(&mut self, method: &str, params: P) -> McpResult<()> {
445 let params_value = serde_json::to_value(params)
446 .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
447
448 let request = JsonRpcRequest {
449 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
450 method: method.to_string(),
451 params: Some(params_value),
452 id: None,
453 };
454
455 self.transport
456 .send(&self.cx, &JsonRpcMessage::Request(request))
457 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
458
459 Ok(())
460 }
461
462 fn recv_response(
463 &mut self,
464 expected_id: &RequestId,
465 ) -> McpResult<fastmcp_protocol::JsonRpcResponse> {
466 loop {
467 let message = self
468 .transport
469 .recv(&self.cx)
470 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
471
472 match message {
473 JsonRpcMessage::Response(response) => {
474 if let Some(ref id) = response.id {
475 if id != expected_id {
476 continue;
477 }
478 }
479 return Ok(response);
480 }
481 JsonRpcMessage::Request(request) => {
482 let Some(id) = request.id.clone() else {
484 continue;
485 };
486
487 let err = McpError::method_not_found(&request.method);
490 let response = JsonRpcResponse::error(Some(id), err.into());
491 self.transport
492 .send(&self.cx, &JsonRpcMessage::Response(response))
493 .map_err(|e| McpError::internal_error(format!("Transport error: {e:?}")))?;
494
495 continue;
496 }
497 }
498 }
499 }
500}
501
502impl std::fmt::Debug for TestClient {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 f.debug_struct("TestClient")
505 .field("client_info", &self.client_info)
506 .field("initialized", &self.initialized)
507 .field("server_info", &self.server_info)
508 .finish_non_exhaustive()
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use fastmcp_transport::memory::create_memory_transport_pair;
516
517 #[test]
518 fn test_client_creation() {
519 let (client_transport, _server_transport) = create_memory_transport_pair();
520 let client = TestClient::new(client_transport);
521 assert!(!client.is_initialized());
522 }
523
524 #[test]
525 fn test_client_with_info() {
526 let (client_transport, _server_transport) = create_memory_transport_pair();
527 let client = TestClient::new(client_transport).with_client_info("my-client", "2.0.0");
528 assert_eq!(client.client_info.name, "my-client");
529 assert_eq!(client.client_info.version, "2.0.0");
530 }
531
532 #[test]
533 fn test_not_initialized_error() {
534 let (client_transport, _server_transport) = create_memory_transport_pair();
535 let mut client = TestClient::new(client_transport);
536 let result = client.list_tools();
537 assert!(result.is_err());
538 }
539
540 #[test]
545 fn with_cx_sets_custom_cx() {
546 let (ct, _st) = create_memory_transport_pair();
547 let cx = Cx::for_testing();
548 let client = TestClient::with_cx(ct, cx);
549 assert!(!client.is_initialized());
550 }
551
552 #[test]
553 fn with_capabilities_sets_capabilities() {
554 let (ct, _st) = create_memory_transport_pair();
555 let caps = ClientCapabilities {
556 sampling: Some(fastmcp_protocol::SamplingCapability {}),
557 ..Default::default()
558 };
559 let client = TestClient::new(ct).with_capabilities(caps);
560 assert!(client.capabilities.sampling.is_some());
561 }
562
563 #[test]
564 fn pre_init_getters_return_none() {
565 let (ct, _st) = create_memory_transport_pair();
566 let client = TestClient::new(ct);
567 assert!(client.server_info().is_none());
568 assert!(client.server_capabilities().is_none());
569 assert!(client.protocol_version().is_none());
570 }
571
572 #[test]
573 fn debug_output_includes_key_fields() {
574 let (ct, _st) = create_memory_transport_pair();
575 let client = TestClient::new(ct);
576 let debug = format!("{client:?}");
577 assert!(debug.contains("TestClient"));
578 assert!(debug.contains("test-client"));
579 assert!(debug.contains("initialized"));
580 }
581
582 #[test]
583 fn transport_accessors() {
584 let (ct, _st) = create_memory_transport_pair();
585 let mut client = TestClient::new(ct);
586 let _ = client.transport();
588 let _ = client.transport_mut();
590 }
591
592 #[test]
593 fn close_does_not_panic() {
594 let (ct, _st) = create_memory_transport_pair();
595 let mut client = TestClient::new(ct);
596 client.close();
597 }
598
599 #[test]
600 fn request_id_auto_increments() {
601 let (ct, _st) = create_memory_transport_pair();
602 let client = TestClient::new(ct);
603 let id1 = client.next_request_id();
604 let id2 = client.next_request_id();
605 let id3 = client.next_request_id();
606 assert_eq!(id1, 1);
607 assert_eq!(id2, 2);
608 assert_eq!(id3, 3);
609 }
610
611 #[test]
612 fn ensure_initialized_error_message() {
613 let (ct, _st) = create_memory_transport_pair();
614 let mut client = TestClient::new(ct);
615 let err = client.list_tools().unwrap_err();
616 let msg = format!("{err}");
617 assert!(msg.contains("not initialized"), "error was: {msg}");
618 }
619}