1#![forbid(unsafe_code)]
36#![allow(dead_code)]
37
38mod builder;
39pub mod mcp_config;
40mod session;
41
42pub use builder::ClientBuilder;
43pub use session::ClientSession;
44
45use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
46use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
47use std::time::{Duration, Instant};
48
49use asupersync::Cx;
50use fastmcp_core::{McpError, McpResult};
51use fastmcp_protocol::{
52 CallToolParams, CallToolResult, CancelTaskParams, CancelTaskResult, CancelledParams,
53 ClientCapabilities, ClientInfo, Content, GetPromptParams, GetPromptResult, GetTaskParams,
54 GetTaskResult, InitializeParams, InitializeResult, JsonRpcMessage, JsonRpcRequest,
55 JsonRpcResponse, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
56 ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListTasksParams,
57 ListTasksResult, ListToolsParams, ListToolsResult, LogLevel, LogMessageParams,
58 PROTOCOL_VERSION, ProgressMarker, Prompt, PromptMessage, ReadResourceParams,
59 ReadResourceResult, RequestId, RequestMeta, Resource, ResourceContent, ResourceTemplate,
60 ServerCapabilities, ServerInfo, SetLogLevelParams, SubmitTaskParams, SubmitTaskResult, TaskId,
61 TaskInfo, TaskResult, TaskStatus, Tool,
62};
63
64pub type ProgressCallback<'a> = &'a mut dyn FnMut(f64, Option<f64>, Option<&str>);
68use fastmcp_transport::{StdioTransport, Transport, TransportError};
69
70#[derive(Debug, serde::Deserialize)]
71struct ClientProgressParams {
72 #[serde(rename = "progressTo\x6ben")]
73 marker: ProgressMarker,
74 progress: f64,
75 total: Option<f64>,
76 message: Option<String>,
77}
78
79fn method_not_found_response(request: &JsonRpcRequest) -> Option<JsonRpcMessage> {
80 let id = request.id.clone()?;
81 let error = McpError::method_not_found(&request.method);
82 let response = JsonRpcResponse::error(Some(id), error.into());
83 Some(JsonRpcMessage::Response(response))
84}
85
86pub struct Client {
91 child: Child,
93 transport: StdioTransport<ChildStdout, ChildStdin>,
95 cx: Cx,
97 session: ClientSession,
99 next_id: AtomicU64,
101 timeout_ms: u64,
103 #[allow(dead_code)]
105 auto_initialize: bool,
106 initialized: AtomicBool,
108}
109
110impl Client {
111 pub fn stdio(command: &str, args: &[&str]) -> McpResult<Self> {
122 Self::stdio_with_cx(command, args, Cx::for_request())
123 }
124
125 pub fn stdio_with_cx(command: &str, args: &[&str], cx: Cx) -> McpResult<Self> {
127 let mut child = Command::new(command)
129 .args(args)
130 .stdin(Stdio::piped())
131 .stdout(Stdio::piped())
132 .stderr(Stdio::inherit())
133 .spawn()
134 .map_err(|e| McpError::internal_error(format!("Failed to spawn subprocess: {e}")))?;
135
136 let stdin = child
138 .stdin
139 .take()
140 .ok_or_else(|| McpError::internal_error("Failed to get subprocess stdin"))?;
141 let stdout = child
142 .stdout
143 .take()
144 .ok_or_else(|| McpError::internal_error("Failed to get subprocess stdout"))?;
145
146 let transport = StdioTransport::new(stdout, stdin);
148
149 let client_info = ClientInfo {
151 name: "fastmcp-client".to_owned(),
152 version: env!("CARGO_PKG_VERSION").to_owned(),
153 };
154 let client_capabilities = ClientCapabilities::default();
155
156 let mut client = Self {
158 child,
159 transport,
160 cx,
161 session: ClientSession::new(
162 client_info.clone(),
163 client_capabilities.clone(),
164 ServerInfo {
165 name: String::new(),
166 version: String::new(),
167 },
168 ServerCapabilities::default(),
169 String::new(),
170 ),
171 next_id: AtomicU64::new(1),
172 timeout_ms: 30_000, auto_initialize: false,
174 initialized: AtomicBool::new(false),
175 };
176
177 let init_result = client.initialize(client_info, client_capabilities)?;
179
180 client.session = ClientSession::new(
182 client.session.client_info().clone(),
183 client.session.client_capabilities().clone(),
184 init_result.server_info,
185 init_result.capabilities,
186 init_result.protocol_version,
187 );
188
189 client.send_notification("initialized", serde_json::json!({}))?;
191
192 client.initialized.store(true, Ordering::SeqCst);
194
195 Ok(client)
196 }
197
198 #[must_use]
200 pub fn builder() -> ClientBuilder {
201 ClientBuilder::new()
202 }
203
204 pub(crate) fn from_parts(
208 child: Child,
209 transport: StdioTransport<ChildStdout, ChildStdin>,
210 cx: Cx,
211 session: ClientSession,
212 timeout_ms: u64,
213 ) -> Self {
214 Self {
215 child,
216 transport,
217 cx,
218 session,
219 next_id: AtomicU64::new(2), timeout_ms,
221 auto_initialize: false,
222 initialized: AtomicBool::new(true), }
224 }
225
226 pub(crate) fn from_parts_uninitialized(
230 child: Child,
231 transport: StdioTransport<ChildStdout, ChildStdin>,
232 cx: Cx,
233 session: ClientSession,
234 timeout_ms: u64,
235 ) -> Self {
236 Self {
237 child,
238 transport,
239 cx,
240 session,
241 next_id: AtomicU64::new(1), timeout_ms,
243 auto_initialize: true,
244 initialized: AtomicBool::new(false),
245 }
246 }
247
248 pub fn ensure_initialized(&mut self) -> McpResult<()> {
260 if self.initialized.load(Ordering::SeqCst) {
262 return Ok(());
263 }
264
265 let client_info = self.session.client_info().clone();
267 let capabilities = self.session.client_capabilities().clone();
268 let init_result = self.initialize(client_info, capabilities)?;
269
270 self.session = ClientSession::new(
272 self.session.client_info().clone(),
273 self.session.client_capabilities().clone(),
274 init_result.server_info,
275 init_result.capabilities,
276 init_result.protocol_version,
277 );
278
279 self.send_notification("initialized", serde_json::json!({}))?;
281
282 self.initialized.store(true, Ordering::SeqCst);
284
285 Ok(())
286 }
287
288 #[must_use]
290 pub fn is_initialized(&self) -> bool {
291 self.initialized.load(Ordering::SeqCst)
292 }
293
294 #[must_use]
296 pub fn server_info(&self) -> &ServerInfo {
297 self.session.server_info()
298 }
299
300 #[must_use]
302 pub fn server_capabilities(&self) -> &ServerCapabilities {
303 self.session.server_capabilities()
304 }
305
306 #[must_use]
308 pub fn protocol_version(&self) -> &str {
309 self.session.protocol_version()
310 }
311
312 fn next_request_id(&self) -> u64 {
314 self.next_id.fetch_add(1, Ordering::SeqCst)
315 }
316
317 fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
319 &mut self,
320 method: &str,
321 params: P,
322 ) -> McpResult<R> {
323 let id = self.next_request_id();
324 let params_value = serde_json::to_value(params)
325 .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
326
327 #[allow(clippy::cast_possible_wrap)]
328 let (request_id, request) = {
329 let id_i64 = id as i64;
330 (
331 RequestId::Number(id_i64),
332 JsonRpcRequest::new(method, Some(params_value), id_i64),
333 )
334 };
335
336 self.transport
338 .send(&self.cx, &JsonRpcMessage::Request(request))
339 .map_err(transport_error_to_mcp)?;
340
341 let response = self.recv_response(&request_id)?;
343
344 if let Some(error) = response.error {
346 return Err(McpError::new(
347 fastmcp_core::McpErrorCode::from(error.code),
348 error.message,
349 ));
350 }
351
352 let result = response
354 .result
355 .ok_or_else(|| McpError::internal_error("No result in response"))?;
356
357 serde_json::from_value(result)
358 .map_err(|e| McpError::internal_error(format!("Failed to deserialize response: {e}")))
359 }
360
361 fn send_notification<P: serde::Serialize>(&mut self, method: &str, params: P) -> McpResult<()> {
363 let params_value = serde_json::to_value(params)
364 .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
365
366 let request = JsonRpcRequest {
368 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
369 method: method.to_string(),
370 params: Some(params_value),
371 id: None,
372 };
373
374 self.transport
375 .send(&self.cx, &JsonRpcMessage::Request(request))
376 .map_err(transport_error_to_mcp)?;
377
378 Ok(())
379 }
380
381 pub fn cancel_request(
390 &mut self,
391 request_id: impl Into<RequestId>,
392 reason: Option<String>,
393 await_cleanup: bool,
394 ) -> McpResult<()> {
395 let params = CancelledParams {
396 request_id: request_id.into(),
397 reason,
398 await_cleanup: if await_cleanup { Some(true) } else { None },
399 };
400 self.send_notification("notifications/cancelled", params)
401 }
402
403 fn recv_response(
405 &mut self,
406 expected_id: &RequestId,
407 ) -> McpResult<fastmcp_protocol::JsonRpcResponse> {
408 let deadline = if self.timeout_ms > 0 {
410 Some(Instant::now() + Duration::from_millis(self.timeout_ms))
411 } else {
412 None
413 };
414
415 loop {
416 if let Some(deadline) = deadline {
418 if Instant::now() >= deadline {
419 return Err(McpError::internal_error("Request timed out"));
420 }
421 }
422
423 let message = self
424 .transport
425 .recv(&self.cx)
426 .map_err(transport_error_to_mcp)?;
427
428 match message {
429 JsonRpcMessage::Response(response) => {
430 if let Some(ref id) = response.id {
432 if id != expected_id {
433 continue;
436 }
437 }
438 return Ok(response);
439 }
440 JsonRpcMessage::Request(request) => {
441 if request.method == "notifications/message" {
443 if let Some(params) = request.params.as_ref() {
444 if let Ok(message) =
445 serde_json::from_value::<LogMessageParams>(params.clone())
446 {
447 self.emit_log_message(message);
448 }
449 }
450 }
451
452 if let Some(response) = method_not_found_response(&request) {
453 self.transport
454 .send(&self.cx, &response)
455 .map_err(transport_error_to_mcp)?;
456 }
457 }
458 }
459 }
460 }
461
462 fn initialize(
464 &mut self,
465 client_info: ClientInfo,
466 capabilities: ClientCapabilities,
467 ) -> McpResult<InitializeResult> {
468 let params = InitializeParams {
469 protocol_version: PROTOCOL_VERSION.to_string(),
470 capabilities,
471 client_info,
472 };
473
474 self.send_request("initialize", params)
475 }
476
477 pub fn list_tools(&mut self) -> McpResult<Vec<Tool>> {
483 self.ensure_initialized()?;
484 let mut all = Vec::new();
485 let mut cursor: Option<String> = None;
486 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
487 let mut pages: usize = 0;
488
489 loop {
490 pages += 1;
491 if pages > 10_000 {
492 return Err(McpError::internal_error(
493 "Pagination exceeded 10,000 pages (tools/list)".to_string(),
494 ));
495 }
496 if let Some(cur) = cursor.as_ref() {
497 if !seen.insert(cur.clone()) {
498 return Err(McpError::internal_error(format!(
499 "Pagination cursor repeated (tools/list): {cur}"
500 )));
501 }
502 }
503 let mut params = ListToolsParams::default();
504 params.cursor = cursor.clone();
505 let result: ListToolsResult = self.send_request("tools/list", params)?;
506 all.extend(result.tools);
507 cursor = result.next_cursor;
508 if cursor.is_none() {
509 break;
510 }
511 }
512
513 Ok(all)
514 }
515
516 pub fn call_tool(
522 &mut self,
523 name: &str,
524 arguments: serde_json::Value,
525 ) -> McpResult<Vec<Content>> {
526 self.ensure_initialized()?;
527 let params = CallToolParams {
528 name: name.to_string(),
529 arguments: Some(arguments),
530 meta: None,
531 };
532 let result: CallToolResult = self.send_request("tools/call", params)?;
533
534 if result.is_error {
535 let error_msg = result
537 .content
538 .first()
539 .and_then(|c| match c {
540 Content::Text { text } => Some(text.clone()),
541 _ => None,
542 })
543 .unwrap_or_else(|| "Tool execution failed".to_string());
544 return Err(McpError::tool_error(error_msg));
545 }
546
547 Ok(result.content)
548 }
549
550 pub fn call_tool_with_progress(
565 &mut self,
566 name: &str,
567 arguments: serde_json::Value,
568 on_progress: ProgressCallback<'_>,
569 ) -> McpResult<Vec<Content>> {
570 self.ensure_initialized()?;
571 let request_id = self.next_request_id();
573 #[allow(clippy::cast_possible_wrap)]
574 let progress_marker = ProgressMarker::Number(request_id as i64);
575
576 let params = CallToolParams {
577 name: name.to_string(),
578 arguments: Some(arguments),
579 meta: Some(RequestMeta {
580 progress_marker: Some(progress_marker.clone()),
581 }),
582 };
583
584 let result: CallToolResult = self.send_request_with_progress(
585 "tools/call",
586 params,
587 request_id,
588 &progress_marker,
589 on_progress,
590 )?;
591
592 if result.is_error {
593 let error_msg = result
595 .content
596 .first()
597 .and_then(|c| match c {
598 Content::Text { text } => Some(text.clone()),
599 _ => None,
600 })
601 .unwrap_or_else(|| "Tool execution failed".to_string());
602 return Err(McpError::tool_error(error_msg));
603 }
604
605 Ok(result.content)
606 }
607
608 fn send_request_with_progress<P: serde::Serialize, R: serde::de::DeserializeOwned>(
610 &mut self,
611 method: &str,
612 params: P,
613 request_id: u64,
614 expected_marker: &ProgressMarker,
615 on_progress: ProgressCallback<'_>,
616 ) -> McpResult<R> {
617 let params_value = serde_json::to_value(params)
618 .map_err(|e| McpError::internal_error(format!("Failed to serialize params: {e}")))?;
619
620 #[allow(clippy::cast_possible_wrap)]
621 let request = JsonRpcRequest::new(method, Some(params_value), request_id as i64);
622
623 self.transport
625 .send(&self.cx, &JsonRpcMessage::Request(request))
626 .map_err(transport_error_to_mcp)?;
627
628 let response = self.recv_response_with_progress(expected_marker, on_progress)?;
630
631 if let Some(error) = response.error {
633 return Err(McpError::new(
634 fastmcp_core::McpErrorCode::from(error.code),
635 error.message,
636 ));
637 }
638
639 let result = response
641 .result
642 .ok_or_else(|| McpError::internal_error("No result in response"))?;
643
644 serde_json::from_value(result)
645 .map_err(|e| McpError::internal_error(format!("Failed to deserialize response: {e}")))
646 }
647
648 fn recv_response_with_progress(
650 &mut self,
651 expected_marker: &ProgressMarker,
652 on_progress: ProgressCallback<'_>,
653 ) -> McpResult<fastmcp_protocol::JsonRpcResponse> {
654 let deadline = if self.timeout_ms > 0 {
656 Some(Instant::now() + Duration::from_millis(self.timeout_ms))
657 } else {
658 None
659 };
660
661 loop {
662 if let Some(deadline) = deadline {
664 if Instant::now() >= deadline {
665 return Err(McpError::internal_error("Request timed out"));
666 }
667 }
668
669 let message = self
670 .transport
671 .recv(&self.cx)
672 .map_err(transport_error_to_mcp)?;
673
674 match message {
675 JsonRpcMessage::Response(response) => return Ok(response),
676 JsonRpcMessage::Request(request) => {
677 if request.method == "notifications/progress" {
679 if let Some(params) = request.params.as_ref() {
680 if let Ok(progress) =
681 serde_json::from_value::<ClientProgressParams>(params.clone())
682 {
683 if progress.marker == *expected_marker {
685 on_progress(
686 progress.progress,
687 progress.total,
688 progress.message.as_deref(),
689 );
690 }
691 }
692 }
693 } else if request.method == "notifications/message" {
694 if let Some(params) = request.params.as_ref() {
695 if let Ok(message) =
696 serde_json::from_value::<LogMessageParams>(params.clone())
697 {
698 self.emit_log_message(message);
699 }
700 }
701 }
702
703 if let Some(response) = method_not_found_response(&request) {
704 self.transport
705 .send(&self.cx, &response)
706 .map_err(transport_error_to_mcp)?;
707 }
708 }
710 }
711 }
712 }
713
714 fn emit_log_message(&self, message: LogMessageParams) {
715 let level = match message.level {
716 LogLevel::Debug => log::Level::Debug,
717 LogLevel::Info => log::Level::Info,
718 LogLevel::Warning => log::Level::Warn,
719 LogLevel::Error => log::Level::Error,
720 };
721
722 let target = message.logger.as_deref().unwrap_or("fastmcp_rust::remote");
723 let text = match message.data {
724 serde_json::Value::String(s) => s,
725 other => other.to_string(),
726 };
727
728 log::log!(target: target, level, "{text}");
729 }
730
731 pub fn list_resources(&mut self) -> McpResult<Vec<Resource>> {
737 self.ensure_initialized()?;
738 let mut all = Vec::new();
739 let mut cursor: Option<String> = None;
740 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
741 let mut pages: usize = 0;
742
743 loop {
744 pages += 1;
745 if pages > 10_000 {
746 return Err(McpError::internal_error(
747 "Pagination exceeded 10,000 pages (resources/list)".to_string(),
748 ));
749 }
750 if let Some(cur) = cursor.as_ref() {
751 if !seen.insert(cur.clone()) {
752 return Err(McpError::internal_error(format!(
753 "Pagination cursor repeated (resources/list): {cur}"
754 )));
755 }
756 }
757 let mut params = ListResourcesParams::default();
758 params.cursor = cursor.clone();
759 let result: ListResourcesResult = self.send_request("resources/list", params)?;
760 all.extend(result.resources);
761 cursor = result.next_cursor;
762 if cursor.is_none() {
763 break;
764 }
765 }
766
767 Ok(all)
768 }
769
770 pub fn list_resource_templates(&mut self) -> McpResult<Vec<ResourceTemplate>> {
776 self.ensure_initialized()?;
777 let mut all = Vec::new();
778 let mut cursor: Option<String> = None;
779 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
780 let mut pages: usize = 0;
781
782 loop {
783 pages += 1;
784 if pages > 10_000 {
785 return Err(McpError::internal_error(
786 "Pagination exceeded 10,000 pages (resources/templates/list)".to_string(),
787 ));
788 }
789 if let Some(cur) = cursor.as_ref() {
790 if !seen.insert(cur.clone()) {
791 return Err(McpError::internal_error(format!(
792 "Pagination cursor repeated (resources/templates/list): {cur}"
793 )));
794 }
795 }
796 let mut params = ListResourceTemplatesParams::default();
797 params.cursor = cursor.clone();
798 let result: ListResourceTemplatesResult =
799 self.send_request("resources/templates/list", params)?;
800 all.extend(result.resource_templates);
801 cursor = result.next_cursor;
802 if cursor.is_none() {
803 break;
804 }
805 }
806
807 Ok(all)
808 }
809
810 pub fn set_log_level(&mut self, level: LogLevel) -> McpResult<()> {
816 self.ensure_initialized()?;
817 let params = SetLogLevelParams { level };
818 let _: serde_json::Value = self.send_request("logging/setLevel", params)?;
819 Ok(())
820 }
821
822 pub fn read_resource(&mut self, uri: &str) -> McpResult<Vec<ResourceContent>> {
828 self.ensure_initialized()?;
829 let params = ReadResourceParams {
830 uri: uri.to_string(),
831 meta: None,
832 };
833 let result: ReadResourceResult = self.send_request("resources/read", params)?;
834 Ok(result.contents)
835 }
836
837 pub fn list_prompts(&mut self) -> McpResult<Vec<Prompt>> {
843 self.ensure_initialized()?;
844 let mut all = Vec::new();
845 let mut cursor: Option<String> = None;
846 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
847 let mut pages: usize = 0;
848
849 loop {
850 pages += 1;
851 if pages > 10_000 {
852 return Err(McpError::internal_error(
853 "Pagination exceeded 10,000 pages (prompts/list)".to_string(),
854 ));
855 }
856 if let Some(cur) = cursor.as_ref() {
857 if !seen.insert(cur.clone()) {
858 return Err(McpError::internal_error(format!(
859 "Pagination cursor repeated (prompts/list): {cur}"
860 )));
861 }
862 }
863 let mut params = ListPromptsParams::default();
864 params.cursor = cursor.clone();
865 let result: ListPromptsResult = self.send_request("prompts/list", params)?;
866 all.extend(result.prompts);
867 cursor = result.next_cursor;
868 if cursor.is_none() {
869 break;
870 }
871 }
872
873 Ok(all)
874 }
875
876 pub fn get_prompt(
882 &mut self,
883 name: &str,
884 arguments: std::collections::HashMap<String, String>,
885 ) -> McpResult<Vec<PromptMessage>> {
886 self.ensure_initialized()?;
887 let params = GetPromptParams {
888 name: name.to_string(),
889 arguments: if arguments.is_empty() {
890 None
891 } else {
892 Some(arguments)
893 },
894 meta: None,
895 };
896 let result: GetPromptResult = self.send_request("prompts/get", params)?;
897 Ok(result.messages)
898 }
899
900 pub fn submit_task(
915 &mut self,
916 task_type: &str,
917 input: serde_json::Value,
918 ) -> McpResult<TaskInfo> {
919 self.ensure_initialized()?;
920 let params = SubmitTaskParams {
921 task_type: task_type.to_string(),
922 params: Some(input),
923 };
924 let result: SubmitTaskResult = self.send_request("tasks/submit", params)?;
925 Ok(result.task)
926 }
927
928 pub fn list_tasks(
939 &mut self,
940 status: Option<TaskStatus>,
941 cursor: Option<&str>,
942 limit: Option<u32>,
943 ) -> McpResult<ListTasksResult> {
944 self.ensure_initialized()?;
945 let params = ListTasksParams {
946 cursor: cursor.map(ToString::to_string),
947 limit,
948 status,
949 };
950 self.send_request("tasks/list", params)
951 }
952
953 pub fn list_tasks_all(&mut self, status: Option<TaskStatus>) -> McpResult<Vec<TaskInfo>> {
959 self.ensure_initialized()?;
960 let mut all = Vec::new();
961 let mut cursor: Option<String> = None;
962 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
963 let mut pages: usize = 0;
964
965 loop {
966 pages += 1;
967 if pages > 10_000 {
968 return Err(McpError::internal_error(
969 "Pagination exceeded 10,000 pages (tasks/list)".to_string(),
970 ));
971 }
972 if let Some(cur) = cursor.as_ref() {
973 if !seen.insert(cur.clone()) {
974 return Err(McpError::internal_error(format!(
975 "Pagination cursor repeated (tasks/list): {cur}"
976 )));
977 }
978 }
979 let result = self.list_tasks(status, cursor.as_deref(), Some(200))?;
980 all.extend(result.tasks);
981 cursor = result.next_cursor;
982 if cursor.is_none() {
983 break;
984 }
985 }
986
987 Ok(all)
988 }
989
990 pub fn get_task(&mut self, task_id: &str) -> McpResult<GetTaskResult> {
1000 self.ensure_initialized()?;
1001 let params = GetTaskParams {
1002 id: TaskId::from_string(task_id),
1003 };
1004 self.send_request("tasks/get", params)
1005 }
1006
1007 pub fn cancel_task(&mut self, task_id: &str) -> McpResult<TaskInfo> {
1017 self.cancel_task_with_reason(task_id, None)
1018 }
1019
1020 pub fn cancel_task_with_reason(
1031 &mut self,
1032 task_id: &str,
1033 reason: Option<&str>,
1034 ) -> McpResult<TaskInfo> {
1035 self.ensure_initialized()?;
1036 let params = CancelTaskParams {
1037 id: TaskId::from_string(task_id),
1038 reason: reason.map(ToString::to_string),
1039 };
1040 let result: CancelTaskResult = self.send_request("tasks/cancel", params)?;
1041 Ok(result.task)
1042 }
1043
1044 pub fn wait_for_task(
1058 &mut self,
1059 task_id: &str,
1060 poll_interval: Duration,
1061 ) -> McpResult<TaskResult> {
1062 loop {
1063 let result = self.get_task(task_id)?;
1064
1065 if result.task.status.is_terminal() {
1067 if let Some(task_result) = result.result {
1069 return Ok(task_result);
1070 }
1071
1072 return Ok(TaskResult {
1074 id: result.task.id,
1075 success: result.task.status == TaskStatus::Completed,
1076 data: None,
1077 error: result.task.error,
1078 });
1079 }
1080
1081 std::thread::sleep(poll_interval);
1083 }
1084 }
1085
1086 pub fn wait_for_task_with_progress<F>(
1100 &mut self,
1101 task_id: &str,
1102 poll_interval: Duration,
1103 mut on_progress: F,
1104 ) -> McpResult<TaskResult>
1105 where
1106 F: FnMut(f64, Option<&str>),
1107 {
1108 loop {
1109 let result = self.get_task(task_id)?;
1110
1111 if let Some(progress) = result.task.progress {
1113 on_progress(progress, result.task.message.as_deref());
1114 }
1115
1116 if result.task.status.is_terminal() {
1118 if let Some(task_result) = result.result {
1120 return Ok(task_result);
1121 }
1122
1123 return Ok(TaskResult {
1125 id: result.task.id,
1126 success: result.task.status == TaskStatus::Completed,
1127 data: None,
1128 error: result.task.error,
1129 });
1130 }
1131
1132 std::thread::sleep(poll_interval);
1134 }
1135 }
1136
1137 pub fn close(mut self) {
1139 let _ = self.transport.close();
1141
1142 let _ = self.child.kill();
1144 let _ = self.child.wait();
1145 }
1146}
1147
1148impl Drop for Client {
1149 fn drop(&mut self) {
1150 let _ = self.transport.close();
1153 let _ = self.child.kill();
1154 let _ = self.child.wait();
1155 }
1156}
1157
1158fn transport_error_to_mcp(e: TransportError) -> McpError {
1160 match e {
1161 TransportError::Cancelled => McpError::request_cancelled(),
1162 TransportError::Closed => McpError::internal_error("Transport closed"),
1163 TransportError::Timeout => McpError::internal_error("Request timed out"),
1164 TransportError::Io(io_err) => McpError::internal_error(format!("I/O error: {io_err}")),
1165 TransportError::Codec(codec_err) => {
1166 McpError::internal_error(format!("Codec error: {codec_err}"))
1167 }
1168 }
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173 use super::*;
1174 use std::collections::HashMap;
1175 use std::process::{Command, Stdio};
1176
1177 fn make_closed_client(initialized: bool) -> Client {
1178 let rustc = std::env::var("RUSTC").unwrap_or_else(|_| "rustc".to_string());
1179 let mut child = Command::new(rustc)
1180 .arg("--version")
1181 .stdin(Stdio::piped())
1182 .stdout(Stdio::piped())
1183 .stderr(Stdio::null())
1184 .spawn()
1185 .expect("spawn rustc --version");
1186
1187 let stdin = child.stdin.take().expect("child stdin");
1188 let stdout = child.stdout.take().expect("child stdout");
1189 let transport = StdioTransport::new(stdout, stdin);
1190 let session = ClientSession::new(
1191 ClientInfo {
1192 name: "test-client".to_string(),
1193 version: "0.1.0".to_string(),
1194 },
1195 ClientCapabilities::default(),
1196 ServerInfo {
1197 name: "test-server".to_string(),
1198 version: "1.0.0".to_string(),
1199 },
1200 ServerCapabilities::default(),
1201 PROTOCOL_VERSION.to_string(),
1202 );
1203
1204 if initialized {
1205 Client::from_parts(child, transport, Cx::for_request(), session, 100)
1206 } else {
1207 Client::from_parts_uninitialized(child, transport, Cx::for_request(), session, 100)
1208 }
1209 }
1210
1211 #[test]
1216 fn method_not_found_response_for_request() {
1217 let request = JsonRpcRequest::new("sampling/createMessage", None, "req-1");
1218 let response = method_not_found_response(&request);
1219 assert!(response.is_some());
1220 if let Some(JsonRpcMessage::Response(resp)) = response {
1221 assert!(matches!(
1222 resp.error.as_ref(),
1223 Some(error)
1224 if error.code == i32::from(fastmcp_core::McpErrorCode::MethodNotFound)
1225 ));
1226 assert_eq!(resp.id, Some(RequestId::String("req-1".to_string())));
1227 } else {
1228 assert!(matches!(response, Some(JsonRpcMessage::Response(_))));
1229 }
1230 }
1231
1232 #[test]
1233 fn method_not_found_response_for_notification() {
1234 let request = JsonRpcRequest::notification("notifications/message", None);
1235 let response = method_not_found_response(&request);
1236 assert!(response.is_none());
1237 }
1238
1239 #[test]
1240 fn method_not_found_response_with_numeric_id() {
1241 let request = JsonRpcRequest::new("unknown/method", None, 42i64);
1242 let response = method_not_found_response(&request);
1243 assert!(response.is_some());
1244 if let Some(JsonRpcMessage::Response(resp)) = response {
1245 assert_eq!(resp.id, Some(RequestId::Number(42)));
1246 let error = resp.error.as_ref().unwrap();
1247 assert_eq!(
1248 error.code,
1249 i32::from(fastmcp_core::McpErrorCode::MethodNotFound)
1250 );
1251 assert!(error.message.contains("unknown/method"));
1252 }
1253 }
1254
1255 #[test]
1256 fn method_not_found_response_with_params() {
1257 let params = serde_json::json!({"key": "value"});
1258 let request = JsonRpcRequest::new("roots/list", Some(params), "req-99");
1259 let response = method_not_found_response(&request);
1260 assert!(response.is_some());
1261 if let Some(JsonRpcMessage::Response(resp)) = response {
1262 let error = resp.error.as_ref().unwrap();
1263 assert!(error.message.contains("roots/list"));
1264 }
1265 }
1266
1267 #[test]
1272 fn transport_error_cancelled_maps_to_request_cancelled() {
1273 let err = transport_error_to_mcp(TransportError::Cancelled);
1274 assert_eq!(err.code, fastmcp_core::McpErrorCode::RequestCancelled);
1275 }
1276
1277 #[test]
1278 fn transport_error_closed_maps_to_internal() {
1279 let err = transport_error_to_mcp(TransportError::Closed);
1280 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1281 assert!(err.message.contains("closed"));
1282 }
1283
1284 #[test]
1285 fn transport_error_timeout_maps_to_internal() {
1286 let err = transport_error_to_mcp(TransportError::Timeout);
1287 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1288 assert!(err.message.contains("timed out"));
1289 }
1290
1291 #[test]
1292 fn transport_error_io_maps_to_internal() {
1293 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "pipe broken");
1294 let err = transport_error_to_mcp(TransportError::Io(io_err));
1295 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1296 assert!(err.message.contains("I/O error"));
1297 }
1298
1299 #[test]
1300 fn transport_error_codec_maps_to_internal() {
1301 use fastmcp_transport::CodecError;
1302 let codec_err = CodecError::MessageTooLarge(999_999);
1303 let err = transport_error_to_mcp(TransportError::Codec(codec_err));
1304 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1305 assert!(err.message.contains("Codec error"));
1306 }
1307
1308 #[test]
1313 fn client_progress_params_deserialization() {
1314 let json = serde_json::json!({
1315 "progressToken": 42,
1316 "progress": 0.5,
1317 "total": 1.0,
1318 "message": "Halfway done"
1319 });
1320 let params: ClientProgressParams = serde_json::from_value(json).unwrap();
1321 assert_eq!(params.marker, ProgressMarker::Number(42));
1322 assert!((params.progress - 0.5).abs() < f64::EPSILON);
1323 assert!((params.total.unwrap() - 1.0).abs() < f64::EPSILON);
1324 assert_eq!(params.message.as_deref(), Some("Halfway done"));
1325 }
1326
1327 #[test]
1328 fn client_progress_params_minimal() {
1329 let json = serde_json::json!({
1330 "progressToken": "tok-1",
1331 "progress": 0.0
1332 });
1333 let params: ClientProgressParams = serde_json::from_value(json).unwrap();
1334 assert_eq!(params.marker, ProgressMarker::String("tok-1".to_string()));
1335 assert!(params.total.is_none());
1336 assert!(params.message.is_none());
1337 }
1338
1339 #[test]
1340 fn client_from_parts_accessors_and_request_counter() {
1341 let client = make_closed_client(true);
1342 assert!(client.is_initialized());
1343 assert_eq!(client.server_info().name, "test-server");
1344 let caps_json = serde_json::to_value(client.server_capabilities()).expect("caps json");
1345 assert_eq!(caps_json, serde_json::json!({}));
1346 assert_eq!(client.protocol_version(), PROTOCOL_VERSION);
1347 assert_eq!(client.next_request_id(), 2);
1348 assert_eq!(client.next_request_id(), 3);
1349 }
1350
1351 #[test]
1352 fn ensure_initialized_noop_when_already_initialized() {
1353 let mut client = make_closed_client(true);
1354 assert!(client.ensure_initialized().is_ok());
1355 assert!(client.is_initialized());
1356 }
1357
1358 #[test]
1359 fn ensure_initialized_fails_for_uninitialized_closed_transport() {
1360 let mut client = make_closed_client(false);
1361 std::thread::sleep(Duration::from_millis(50));
1362 let err = client
1363 .ensure_initialized()
1364 .expect_err("expected init failure");
1365 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1366 assert!(!client.is_initialized());
1367 }
1368
1369 #[test]
1370 fn client_core_api_methods_error_cleanly_on_closed_transport() {
1371 let mut client = make_closed_client(true);
1372 std::thread::sleep(Duration::from_millis(50));
1373
1374 let _ = client.cancel_request(7i64, Some("stop".to_string()), true);
1375 assert!(client.list_tools().is_err());
1376 assert!(
1377 client
1378 .call_tool("echo", serde_json::json!({"text": "hi"}))
1379 .is_err()
1380 );
1381
1382 let mut progress_events: Vec<(f64, Option<f64>, Option<String>)> = Vec::new();
1383 let mut on_progress = |p: f64, total: Option<f64>, msg: Option<&str>| {
1384 progress_events.push((p, total, msg.map(ToString::to_string)));
1385 };
1386 assert!(
1387 client
1388 .call_tool_with_progress(
1389 "echo",
1390 serde_json::json!({"text": "hi"}),
1391 &mut on_progress
1392 )
1393 .is_err()
1394 );
1395 assert!(progress_events.is_empty());
1396
1397 assert!(client.list_resources().is_err());
1398 assert!(client.list_resource_templates().is_err());
1399 assert!(client.set_log_level(LogLevel::Debug).is_err());
1400 assert!(client.read_resource("resource://test").is_err());
1401 assert!(client.list_prompts().is_err());
1402
1403 let mut args = HashMap::new();
1404 args.insert("name".to_string(), "world".to_string());
1405 assert!(client.get_prompt("greeting", args).is_err());
1406
1407 assert!(
1408 client
1409 .submit_task("data_export", serde_json::json!({"batch": 1}))
1410 .is_err()
1411 );
1412 assert!(
1413 client
1414 .list_tasks(Some(TaskStatus::Running), Some("c1"), Some(10))
1415 .is_err()
1416 );
1417 assert!(client.list_tasks_all(None).is_err());
1418 assert!(client.get_task("task-1").is_err());
1419 assert!(client.cancel_task("task-1").is_err());
1420 assert!(
1421 client
1422 .cancel_task_with_reason("task-1", Some("no longer needed"))
1423 .is_err()
1424 );
1425 assert!(
1426 client
1427 .wait_for_task("task-1", Duration::from_millis(1))
1428 .is_err()
1429 );
1430
1431 let mut task_progress = Vec::new();
1432 let mut on_task_progress = |p: f64, msg: Option<&str>| {
1433 task_progress.push((p, msg.map(ToString::to_string)));
1434 };
1435 assert!(
1436 client
1437 .wait_for_task_with_progress(
1438 "task-1",
1439 Duration::from_millis(1),
1440 &mut on_task_progress
1441 )
1442 .is_err()
1443 );
1444 assert!(task_progress.is_empty());
1445 }
1446
1447 #[test]
1448 fn close_handles_already_exited_subprocess() {
1449 let client = make_closed_client(true);
1450 std::thread::sleep(Duration::from_millis(50));
1451 client.close();
1452 }
1453
1454 #[test]
1459 fn client_builder_returns_client_builder() {
1460 let _builder = Client::builder();
1461 }
1463
1464 #[test]
1465 fn client_stdio_fails_for_nonexistent_command() {
1466 let result = Client::stdio("definitely-not-a-real-command-xyz", &[]);
1467 assert!(result.is_err());
1468 let err = result.err().expect("should be error");
1469 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1470 assert!(err.message.contains("spawn"));
1471 }
1472
1473 #[test]
1474 fn client_stdio_with_cx_fails_when_cancelled() {
1475 let cx = Cx::for_request();
1476 cx.set_cancel_requested(true);
1477 let result = Client::stdio_with_cx("echo", &["hello"], cx);
1478 assert!(result.is_err());
1480 }
1481
1482 #[test]
1487 fn uninitialized_client_is_not_initialized() {
1488 let client = make_closed_client(false);
1489 assert!(!client.is_initialized());
1490 }
1491
1492 #[test]
1493 fn uninitialized_client_server_info_is_empty() {
1494 let client = make_closed_client(false);
1495 assert_eq!(client.server_info().name, "test-server");
1496 assert_eq!(client.server_info().version, "1.0.0");
1497 }
1498
1499 #[test]
1500 fn uninitialized_client_request_id_starts_at_one() {
1501 let client = make_closed_client(false);
1502 assert_eq!(client.next_request_id(), 1);
1503 assert_eq!(client.next_request_id(), 2);
1504 }
1505
1506 #[test]
1507 fn initialized_client_request_id_starts_at_two() {
1508 let client = make_closed_client(true);
1509 assert_eq!(client.next_request_id(), 2);
1511 assert_eq!(client.next_request_id(), 3);
1512 }
1513
1514 #[test]
1519 fn uninitialized_client_list_tools_fails_on_init() {
1520 let mut client = make_closed_client(false);
1521 std::thread::sleep(Duration::from_millis(50));
1522 let err = client.list_tools().expect_err("should fail");
1523 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1524 }
1525
1526 #[test]
1527 fn uninitialized_client_call_tool_fails_on_init() {
1528 let mut client = make_closed_client(false);
1529 std::thread::sleep(Duration::from_millis(50));
1530 let err = client
1531 .call_tool("echo", serde_json::json!({"text": "hi"}))
1532 .expect_err("should fail");
1533 assert_eq!(err.code, fastmcp_core::McpErrorCode::InternalError);
1534 }
1535
1536 #[test]
1537 fn uninitialized_client_list_resources_fails_on_init() {
1538 let mut client = make_closed_client(false);
1539 std::thread::sleep(Duration::from_millis(50));
1540 assert!(client.list_resources().is_err());
1541 }
1542
1543 #[test]
1544 fn uninitialized_client_list_prompts_fails_on_init() {
1545 let mut client = make_closed_client(false);
1546 std::thread::sleep(Duration::from_millis(50));
1547 assert!(client.list_prompts().is_err());
1548 }
1549
1550 #[test]
1555 fn drop_cleans_up_subprocess() {
1556 let client = make_closed_client(true);
1558 std::thread::sleep(Duration::from_millis(50));
1559 drop(client);
1560 }
1562
1563 #[test]
1564 fn client_progress_params_debug() {
1565 let params = ClientProgressParams {
1566 marker: ProgressMarker::Number(1),
1567 progress: 0.5,
1568 total: Some(1.0),
1569 message: Some("half".into()),
1570 };
1571 let debug = format!("{:?}", params);
1572 assert!(debug.contains("progress"));
1573 }
1574
1575 #[test]
1576 fn transport_error_to_mcp_preserves_io_details() {
1577 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "socket vanished");
1578 let mcp_err = transport_error_to_mcp(TransportError::Io(io_err));
1579 assert!(mcp_err.message.contains("socket vanished"));
1580 }
1581
1582 #[test]
1583 fn method_not_found_response_error_message_includes_method() {
1584 let request = JsonRpcRequest::new("totally/custom/method", None, 1i64);
1585 let response = method_not_found_response(&request).unwrap();
1586 if let JsonRpcMessage::Response(resp) = response {
1587 let error = resp.error.unwrap();
1588 assert!(error.message.contains("totally/custom/method"));
1589 }
1590 }
1591
1592 #[test]
1593 fn client_server_capabilities_default_is_empty() {
1594 let client = make_closed_client(true);
1595 let caps = client.server_capabilities();
1596 assert!(caps.tools.is_none());
1598 assert!(caps.resources.is_none());
1599 assert!(caps.prompts.is_none());
1600 }
1601}