1use std::collections::HashMap;
22use std::path::PathBuf;
23use std::process::{Child, Command, Stdio};
24use std::time::Duration;
25
26struct ChildGuard(Option<Child>);
29
30impl ChildGuard {
31 fn new(child: Child) -> Self {
32 Self(Some(child))
33 }
34
35 fn disarm(mut self) -> Child {
37 self.0.take().expect("ChildGuard already disarmed")
38 }
39}
40
41impl Drop for ChildGuard {
42 fn drop(&mut self) {
43 if let Some(mut child) = self.0.take() {
44 let _ = child.kill();
46 let _ = child.wait();
47 }
48 }
49}
50
51use asupersync::Cx;
52use fastmcp_core::{McpError, McpResult};
53use fastmcp_protocol::{
54 ClientCapabilities, ClientInfo, InitializeParams, InitializeResult, JsonRpcMessage,
55 JsonRpcRequest, PROTOCOL_VERSION,
56};
57use fastmcp_transport::{StdioTransport, Transport};
58
59use crate::{Client, ClientSession};
60
61#[derive(Debug, Clone)]
66pub struct ClientBuilder {
67 client_info: ClientInfo,
69 timeout_ms: u64,
71 max_retries: u32,
73 retry_delay_ms: u64,
75 working_dir: Option<PathBuf>,
77 env_vars: HashMap<String, String>,
79 inherit_env: bool,
81 capabilities: ClientCapabilities,
83 auto_initialize: bool,
85}
86
87impl ClientBuilder {
88 #[must_use]
98 pub fn new() -> Self {
99 Self {
100 client_info: ClientInfo {
101 name: "fastmcp-client".to_owned(),
102 version: env!("CARGO_PKG_VERSION").to_owned(),
103 },
104 timeout_ms: 30_000,
105 max_retries: 0,
106 retry_delay_ms: 1_000,
107 working_dir: None,
108 env_vars: HashMap::new(),
109 inherit_env: true,
110 capabilities: ClientCapabilities::default(),
111 auto_initialize: false,
112 }
113 }
114
115 #[must_use]
119 pub fn client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
120 self.client_info = ClientInfo {
121 name: name.into(),
122 version: version.into(),
123 };
124 self
125 }
126
127 #[must_use]
132 pub fn timeout_ms(mut self, timeout: u64) -> Self {
133 self.timeout_ms = timeout;
134 self
135 }
136
137 #[must_use]
142 pub fn max_retries(mut self, retries: u32) -> Self {
143 self.max_retries = retries;
144 self
145 }
146
147 #[must_use]
151 pub fn retry_delay_ms(mut self, delay: u64) -> Self {
152 self.retry_delay_ms = delay;
153 self
154 }
155
156 #[must_use]
160 pub fn working_dir(mut self, path: impl Into<PathBuf>) -> Self {
161 self.working_dir = Some(path.into());
162 self
163 }
164
165 #[must_use]
169 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
170 self.env_vars.insert(key.into(), value.into());
171 self
172 }
173
174 #[must_use]
176 pub fn envs<I, K, V>(mut self, vars: I) -> Self
177 where
178 I: IntoIterator<Item = (K, V)>,
179 K: Into<String>,
180 V: Into<String>,
181 {
182 for (key, value) in vars {
183 self.env_vars.insert(key.into(), value.into());
184 }
185 self
186 }
187
188 #[must_use]
195 pub fn inherit_env(mut self, inherit: bool) -> Self {
196 self.inherit_env = inherit;
197 self
198 }
199
200 #[must_use]
202 pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
203 self.capabilities = capabilities;
204 self
205 }
206
207 #[must_use]
227 pub fn auto_initialize(mut self, enabled: bool) -> Self {
228 self.auto_initialize = enabled;
229 self
230 }
231
232 pub fn connect_stdio(self, command: &str, args: &[&str]) -> McpResult<Client> {
249 self.connect_stdio_with_cx(command, args, &Cx::for_request())
250 }
251
252 pub fn connect_stdio_with_cx(self, command: &str, args: &[&str], cx: &Cx) -> McpResult<Client> {
257 let mut last_error = None;
258 let attempts = u64::from(self.max_retries) + 1;
260
261 for attempt in 0..attempts {
262 if cx.checkpoint().is_err() {
264 return Err(McpError::request_cancelled());
265 }
266
267 if attempt > 0 {
268 let mut remaining_ms = self.retry_delay_ms;
271 while remaining_ms > 0 {
272 if cx.checkpoint().is_err() {
273 return Err(McpError::request_cancelled());
274 }
275
276 let sleep_ms = remaining_ms.min(25);
277 std::thread::sleep(Duration::from_millis(sleep_ms));
278 remaining_ms = remaining_ms.saturating_sub(sleep_ms);
279 }
280 }
281
282 match self.try_connect(command, args, cx) {
283 Ok(client) => return Ok(client),
284 Err(e) => {
285 last_error = Some(e);
286 }
287 }
288 }
289
290 Err(last_error.unwrap_or_else(|| McpError::internal_error("Connection failed")))
292 }
293
294 fn try_connect(&self, command: &str, args: &[&str], cx: &Cx) -> McpResult<Client> {
296 let mut cmd = Command::new(command);
298 cmd.args(args)
299 .stdin(Stdio::piped())
300 .stdout(Stdio::piped())
301 .stderr(Stdio::inherit());
302
303 if let Some(ref dir) = self.working_dir {
305 cmd.current_dir(dir);
306 }
307
308 if !self.inherit_env {
310 cmd.env_clear();
311 }
312 for (key, value) in &self.env_vars {
313 cmd.env(key, value);
314 }
315
316 let mut child = cmd
318 .spawn()
319 .map_err(|e| McpError::internal_error(format!("Failed to spawn subprocess: {e}")))?;
320
321 let stdin = child
323 .stdin
324 .take()
325 .ok_or_else(|| McpError::internal_error("Failed to get subprocess stdin"))?;
326 let stdout = child
327 .stdout
328 .take()
329 .ok_or_else(|| McpError::internal_error("Failed to get subprocess stdout"))?;
330
331 let transport = StdioTransport::new(stdout, stdin);
333
334 if self.auto_initialize {
335 Ok(self.create_uninitialized_client(child, transport, cx))
337 } else {
338 self.initialize_client(child, transport, cx)
340 }
341 }
342
343 fn create_uninitialized_client(
345 &self,
346 child: Child,
347 transport: StdioTransport<std::process::ChildStdout, std::process::ChildStdin>,
348 cx: &Cx,
349 ) -> Client {
350 let session = ClientSession::new(
352 self.client_info.clone(),
353 self.capabilities.clone(),
354 fastmcp_protocol::ServerInfo {
355 name: String::new(),
356 version: String::new(),
357 },
358 fastmcp_protocol::ServerCapabilities::default(),
359 String::new(),
360 );
361
362 Client::from_parts_uninitialized(child, transport, cx.clone(), session, self.timeout_ms)
363 }
364
365 fn initialize_client(
367 &self,
368 child: Child,
369 mut transport: StdioTransport<std::process::ChildStdout, std::process::ChildStdin>,
370 cx: &Cx,
371 ) -> McpResult<Client> {
372 let child_guard = ChildGuard::new(child);
375
376 let init_params = InitializeParams {
378 protocol_version: PROTOCOL_VERSION.to_string(),
379 capabilities: self.capabilities.clone(),
380 client_info: self.client_info.clone(),
381 };
382
383 let init_request = JsonRpcRequest::new(
384 "initialize",
385 Some(serde_json::to_value(&init_params).map_err(|e| {
386 McpError::internal_error(format!("Failed to serialize params: {e}"))
387 })?),
388 1i64,
389 );
390
391 transport
392 .send(cx, &JsonRpcMessage::Request(init_request))
393 .map_err(|e| McpError::internal_error(format!("Failed to send initialize: {e}")))?;
394
395 let response = loop {
397 let msg = transport.recv(cx).map_err(|e| {
398 McpError::internal_error(format!("Failed to receive response: {e}"))
399 })?;
400
401 match msg {
402 JsonRpcMessage::Response(resp) => break resp,
403 JsonRpcMessage::Request(_) => {
404 }
406 }
407 };
408
409 if let Some(error) = response.error {
411 return Err(McpError::new(
412 fastmcp_core::McpErrorCode::Custom(error.code),
413 error.message,
414 ));
415 }
416
417 let result_value = response
419 .result
420 .ok_or_else(|| McpError::internal_error("No result in initialize response"))?;
421
422 let init_result: InitializeResult = serde_json::from_value(result_value).map_err(|e| {
423 McpError::internal_error(format!("Failed to parse initialize result: {e}"))
424 })?;
425
426 let initialized_request = JsonRpcRequest {
428 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
429 method: "initialized".to_string(),
430 params: Some(serde_json::json!({})),
431 id: None,
432 };
433
434 transport
435 .send(cx, &JsonRpcMessage::Request(initialized_request))
436 .map_err(|e| McpError::internal_error(format!("Failed to send initialized: {e}")))?;
437
438 let session = ClientSession::new(
440 self.client_info.clone(),
441 self.capabilities.clone(),
442 init_result.server_info,
443 init_result.capabilities,
444 init_result.protocol_version,
445 );
446
447 Ok(Client::from_parts(
449 child_guard.disarm(),
450 transport,
451 cx.clone(),
452 session,
453 self.timeout_ms,
454 ))
455 }
456}
457
458impl Default for ClientBuilder {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use fastmcp_core::McpErrorCode;
468
469 #[test]
470 fn test_builder_defaults() {
471 let builder = ClientBuilder::new();
472 assert_eq!(builder.client_info.name, "fastmcp-client");
473 assert_eq!(builder.timeout_ms, 30_000);
474 assert_eq!(builder.max_retries, 0);
475 assert_eq!(builder.retry_delay_ms, 1_000);
476 assert!(builder.inherit_env);
477 assert!(builder.working_dir.is_none());
478 assert!(builder.env_vars.is_empty());
479 assert!(!builder.auto_initialize);
480 }
481
482 #[test]
483 fn test_builder_fluent_api() {
484 let builder = ClientBuilder::new()
485 .client_info("test-client", "2.0.0")
486 .timeout_ms(60_000)
487 .max_retries(3)
488 .retry_delay_ms(500)
489 .working_dir("/tmp")
490 .env("FOO", "bar")
491 .env("BAZ", "qux")
492 .inherit_env(false);
493
494 assert_eq!(builder.client_info.name, "test-client");
495 assert_eq!(builder.client_info.version, "2.0.0");
496 assert_eq!(builder.timeout_ms, 60_000);
497 assert_eq!(builder.max_retries, 3);
498 assert_eq!(builder.retry_delay_ms, 500);
499 assert_eq!(builder.working_dir, Some(PathBuf::from("/tmp")));
500 assert_eq!(builder.env_vars.get("FOO"), Some(&"bar".to_string()));
501 assert_eq!(builder.env_vars.get("BAZ"), Some(&"qux".to_string()));
502 assert!(!builder.inherit_env);
503 }
504
505 #[test]
506 fn test_builder_envs() {
507 let vars = [("KEY1", "value1"), ("KEY2", "value2")];
508 let builder = ClientBuilder::new().envs(vars);
509
510 assert_eq!(builder.env_vars.get("KEY1"), Some(&"value1".to_string()));
511 assert_eq!(builder.env_vars.get("KEY2"), Some(&"value2".to_string()));
512 }
513
514 #[test]
515 fn test_builder_clone() {
516 let builder1 = ClientBuilder::new()
517 .client_info("test", "1.0")
518 .timeout_ms(5000);
519
520 let builder2 = builder1.clone();
521
522 assert_eq!(builder2.client_info.name, "test");
523 assert_eq!(builder2.timeout_ms, 5000);
524 }
525
526 #[test]
527 fn test_builder_auto_initialize() {
528 let builder = ClientBuilder::new().auto_initialize(true);
529 assert!(builder.auto_initialize);
530
531 let builder = ClientBuilder::new().auto_initialize(false);
532 assert!(!builder.auto_initialize);
533 }
534
535 #[test]
536 fn test_builder_capabilities() {
537 let caps = ClientCapabilities {
538 sampling: Some(fastmcp_protocol::SamplingCapability {}),
539 elicitation: None,
540 roots: None,
541 };
542 let builder = ClientBuilder::new().capabilities(caps);
543 assert!(builder.capabilities.sampling.is_some());
544 assert!(builder.capabilities.elicitation.is_none());
545 assert!(builder.capabilities.roots.is_none());
546 }
547
548 #[test]
549 fn test_builder_default_trait() {
550 let builder = ClientBuilder::default();
551 assert_eq!(builder.client_info.name, "fastmcp-client");
552 assert_eq!(builder.timeout_ms, 30_000);
553 assert_eq!(builder.max_retries, 0);
554 assert!(!builder.auto_initialize);
555 }
556
557 #[test]
558 fn test_builder_env_override() {
559 let builder = ClientBuilder::new()
560 .env("KEY", "first")
561 .env("KEY", "second");
562 assert_eq!(builder.env_vars.get("KEY"), Some(&"second".to_string()));
563 }
564
565 #[test]
566 fn test_builder_envs_combined_with_env() {
567 let builder = ClientBuilder::new()
568 .env("A", "1")
569 .envs([("B", "2"), ("C", "3")])
570 .env("D", "4");
571 assert_eq!(builder.env_vars.len(), 4);
572 assert_eq!(builder.env_vars.get("A"), Some(&"1".to_string()));
573 assert_eq!(builder.env_vars.get("B"), Some(&"2".to_string()));
574 assert_eq!(builder.env_vars.get("C"), Some(&"3".to_string()));
575 assert_eq!(builder.env_vars.get("D"), Some(&"4".to_string()));
576 }
577
578 #[test]
579 fn test_connect_stdio_with_cx_respects_cancellation_during_retries() {
580 let cx = Cx::for_request();
581 cx.set_cancel_requested(true);
582 let result = ClientBuilder::new()
583 .max_retries(2)
584 .retry_delay_ms(100)
585 .connect_stdio_with_cx("definitely-not-a-real-command", &[], &cx);
586
587 assert!(
588 result.is_err(),
589 "cancelled context should abort before retry attempts"
590 );
591 let err = result.err().expect("error result");
592 assert_eq!(err.code, McpErrorCode::RequestCancelled);
593 }
594
595 #[test]
596 fn test_connect_stdio_with_cx_max_retries_does_not_overflow() {
597 let cx = Cx::for_request();
598 cx.set_cancel_requested(true);
599
600 let result = ClientBuilder::new()
601 .max_retries(u32::MAX)
602 .retry_delay_ms(1)
603 .connect_stdio_with_cx("definitely-not-a-real-command", &[], &cx);
604
605 assert!(
606 result.is_err(),
607 "cancelled context should return an error, not panic from retry overflow"
608 );
609 let err = result.err().expect("error result");
610 assert_eq!(err.code, McpErrorCode::RequestCancelled);
611 }
612
613 #[test]
614 fn builder_debug_includes_client_info() {
615 let builder = ClientBuilder::new().client_info("dbg-test", "0.1");
616 let debug = format!("{:?}", builder);
617 assert!(debug.contains("dbg-test"));
618 assert!(debug.contains("0.1"));
619 }
620
621 #[test]
622 fn connect_stdio_nonexistent_command_fails() {
623 let result = ClientBuilder::new()
624 .max_retries(0)
625 .connect_stdio("fastmcp_nonexistent_binary_xyz", &["--version"]);
626 assert!(result.is_err());
627 }
628
629 #[test]
630 fn builder_working_dir_last_wins() {
631 let builder = ClientBuilder::new()
632 .working_dir("/first")
633 .working_dir("/second");
634 assert_eq!(builder.working_dir, Some(PathBuf::from("/second")));
635 }
636
637 #[test]
642 fn child_guard_disarm_returns_child() {
643 let child = Command::new("true")
644 .stdin(Stdio::null())
645 .stdout(Stdio::null())
646 .stderr(Stdio::null())
647 .spawn()
648 .expect("failed to spawn 'true'");
649 let guard = ChildGuard::new(child);
650 let mut returned = guard.disarm();
651 let status = returned.wait().expect("wait failed");
653 assert!(status.success());
654 }
655
656 #[test]
657 fn child_guard_drop_kills_child() {
658 let child = Command::new("sleep")
659 .arg("60")
660 .stdin(Stdio::null())
661 .stdout(Stdio::null())
662 .stderr(Stdio::null())
663 .spawn()
664 .expect("failed to spawn 'sleep'");
665 let pid = child.id();
666 {
667 let _guard = ChildGuard::new(child);
668 }
670 let proc_path = format!("/proc/{}/status", pid);
673 assert!(
674 !std::path::Path::new(&proc_path).exists(),
675 "process should no longer exist after drop"
676 );
677 }
678
679 #[test]
680 fn builder_capabilities_default_is_empty() {
681 let builder = ClientBuilder::new();
682 assert!(builder.capabilities.sampling.is_none());
683 assert!(builder.capabilities.elicitation.is_none());
684 assert!(builder.capabilities.roots.is_none());
685 }
686
687 #[test]
688 fn connect_stdio_spawn_failure_error_message() {
689 let result = ClientBuilder::new()
690 .max_retries(0)
691 .connect_stdio("fastmcp_no_such_binary_abc123", &[]);
692 match result {
693 Err(err) => assert!(
694 err.message.contains("spawn"),
695 "error should mention spawn failure: {}",
696 err.message
697 ),
698 Ok(_) => panic!("expected spawn to fail"),
699 }
700 }
701}