1use std::collections::HashMap;
22use std::path::PathBuf;
23use std::process::{Child, Command, Stdio};
24
25struct ChildGuard(Option<Child>);
28
29impl ChildGuard {
30 fn new(child: Child) -> Self {
31 Self(Some(child))
32 }
33
34 fn disarm(mut self) -> Child {
36 self.0.take().expect("ChildGuard already disarmed")
37 }
38}
39
40impl Drop for ChildGuard {
41 fn drop(&mut self) {
42 if let Some(mut child) = self.0.take() {
43 let _ = child.kill();
45 let _ = child.wait();
46 }
47 }
48}
49
50use asupersync::Cx;
51use fastmcp_core::{McpError, McpResult};
52use fastmcp_protocol::{
53 ClientCapabilities, ClientInfo, InitializeParams, InitializeResult, JsonRpcMessage,
54 JsonRpcRequest, PROTOCOL_VERSION,
55};
56use fastmcp_transport::{StdioTransport, Transport};
57
58use crate::{Client, ClientSession};
59
60#[derive(Debug, Clone)]
65pub struct ClientBuilder {
66 client_info: ClientInfo,
68 timeout_ms: u64,
70 max_retries: u32,
72 retry_delay_ms: u64,
74 working_dir: Option<PathBuf>,
76 env_vars: HashMap<String, String>,
78 inherit_env: bool,
80 capabilities: ClientCapabilities,
82 auto_initialize: bool,
84}
85
86impl ClientBuilder {
87 #[must_use]
97 pub fn new() -> Self {
98 Self {
99 client_info: ClientInfo {
100 name: "fastmcp-client".to_owned(),
101 version: env!("CARGO_PKG_VERSION").to_owned(),
102 },
103 timeout_ms: 30_000,
104 max_retries: 0,
105 retry_delay_ms: 1_000,
106 working_dir: None,
107 env_vars: HashMap::new(),
108 inherit_env: true,
109 capabilities: ClientCapabilities::default(),
110 auto_initialize: false,
111 }
112 }
113
114 #[must_use]
118 pub fn client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
119 self.client_info = ClientInfo {
120 name: name.into(),
121 version: version.into(),
122 };
123 self
124 }
125
126 #[must_use]
131 pub fn timeout_ms(mut self, timeout: u64) -> Self {
132 self.timeout_ms = timeout;
133 self
134 }
135
136 #[must_use]
141 pub fn max_retries(mut self, retries: u32) -> Self {
142 self.max_retries = retries;
143 self
144 }
145
146 #[must_use]
150 pub fn retry_delay_ms(mut self, delay: u64) -> Self {
151 self.retry_delay_ms = delay;
152 self
153 }
154
155 #[must_use]
159 pub fn working_dir(mut self, path: impl Into<PathBuf>) -> Self {
160 self.working_dir = Some(path.into());
161 self
162 }
163
164 #[must_use]
168 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
169 self.env_vars.insert(key.into(), value.into());
170 self
171 }
172
173 #[must_use]
175 pub fn envs<I, K, V>(mut self, vars: I) -> Self
176 where
177 I: IntoIterator<Item = (K, V)>,
178 K: Into<String>,
179 V: Into<String>,
180 {
181 for (key, value) in vars {
182 self.env_vars.insert(key.into(), value.into());
183 }
184 self
185 }
186
187 #[must_use]
194 pub fn inherit_env(mut self, inherit: bool) -> Self {
195 self.inherit_env = inherit;
196 self
197 }
198
199 #[must_use]
201 pub fn capabilities(mut self, capabilities: ClientCapabilities) -> Self {
202 self.capabilities = capabilities;
203 self
204 }
205
206 #[must_use]
226 pub fn auto_initialize(mut self, enabled: bool) -> Self {
227 self.auto_initialize = enabled;
228 self
229 }
230
231 pub fn connect_stdio(self, command: &str, args: &[&str]) -> McpResult<Client> {
248 self.connect_stdio_with_cx(command, args, &Cx::for_testing())
249 }
250
251 pub fn connect_stdio_with_cx(self, command: &str, args: &[&str], cx: &Cx) -> McpResult<Client> {
256 let mut last_error = None;
257 let attempts = self.max_retries + 1;
258
259 for attempt in 0..attempts {
260 if attempt > 0 {
261 std::thread::sleep(std::time::Duration::from_millis(self.retry_delay_ms));
263 }
264
265 match self.try_connect(command, args, cx) {
266 Ok(client) => return Ok(client),
267 Err(e) => {
268 last_error = Some(e);
269 }
270 }
271 }
272
273 Err(last_error.unwrap_or_else(|| McpError::internal_error("Connection failed")))
275 }
276
277 fn try_connect(&self, command: &str, args: &[&str], cx: &Cx) -> McpResult<Client> {
279 let mut cmd = Command::new(command);
281 cmd.args(args)
282 .stdin(Stdio::piped())
283 .stdout(Stdio::piped())
284 .stderr(Stdio::inherit());
285
286 if let Some(ref dir) = self.working_dir {
288 cmd.current_dir(dir);
289 }
290
291 if !self.inherit_env {
293 cmd.env_clear();
294 }
295 for (key, value) in &self.env_vars {
296 cmd.env(key, value);
297 }
298
299 let mut child = cmd
301 .spawn()
302 .map_err(|e| McpError::internal_error(format!("Failed to spawn subprocess: {e}")))?;
303
304 let stdin = child
306 .stdin
307 .take()
308 .ok_or_else(|| McpError::internal_error("Failed to get subprocess stdin"))?;
309 let stdout = child
310 .stdout
311 .take()
312 .ok_or_else(|| McpError::internal_error("Failed to get subprocess stdout"))?;
313
314 let transport = StdioTransport::new(stdout, stdin);
316
317 if self.auto_initialize {
318 Ok(self.create_uninitialized_client(child, transport, cx))
320 } else {
321 self.initialize_client(child, transport, cx)
323 }
324 }
325
326 fn create_uninitialized_client(
328 &self,
329 child: Child,
330 transport: StdioTransport<std::process::ChildStdout, std::process::ChildStdin>,
331 cx: &Cx,
332 ) -> Client {
333 let session = ClientSession::new(
335 self.client_info.clone(),
336 self.capabilities.clone(),
337 fastmcp_protocol::ServerInfo {
338 name: String::new(),
339 version: String::new(),
340 },
341 fastmcp_protocol::ServerCapabilities::default(),
342 String::new(),
343 );
344
345 Client::from_parts_uninitialized(child, transport, cx.clone(), session, self.timeout_ms)
346 }
347
348 fn initialize_client(
350 &self,
351 child: Child,
352 mut transport: StdioTransport<std::process::ChildStdout, std::process::ChildStdin>,
353 cx: &Cx,
354 ) -> McpResult<Client> {
355 let child_guard = ChildGuard::new(child);
358
359 let init_params = InitializeParams {
361 protocol_version: PROTOCOL_VERSION.to_string(),
362 capabilities: self.capabilities.clone(),
363 client_info: self.client_info.clone(),
364 };
365
366 let init_request = JsonRpcRequest::new(
367 "initialize",
368 Some(serde_json::to_value(&init_params).map_err(|e| {
369 McpError::internal_error(format!("Failed to serialize params: {e}"))
370 })?),
371 1i64,
372 );
373
374 transport
375 .send(cx, &JsonRpcMessage::Request(init_request))
376 .map_err(|e| McpError::internal_error(format!("Failed to send initialize: {e}")))?;
377
378 let response = loop {
380 let msg = transport.recv(cx).map_err(|e| {
381 McpError::internal_error(format!("Failed to receive response: {e}"))
382 })?;
383
384 match msg {
385 JsonRpcMessage::Response(resp) => break resp,
386 JsonRpcMessage::Request(_) => {
387 }
389 }
390 };
391
392 if let Some(error) = response.error {
394 return Err(McpError::new(
395 fastmcp_core::McpErrorCode::Custom(error.code),
396 error.message,
397 ));
398 }
399
400 let result_value = response
402 .result
403 .ok_or_else(|| McpError::internal_error("No result in initialize response"))?;
404
405 let init_result: InitializeResult = serde_json::from_value(result_value).map_err(|e| {
406 McpError::internal_error(format!("Failed to parse initialize result: {e}"))
407 })?;
408
409 let initialized_request = JsonRpcRequest {
411 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
412 method: "initialized".to_string(),
413 params: Some(serde_json::json!({})),
414 id: None,
415 };
416
417 transport
418 .send(cx, &JsonRpcMessage::Request(initialized_request))
419 .map_err(|e| McpError::internal_error(format!("Failed to send initialized: {e}")))?;
420
421 let session = ClientSession::new(
423 self.client_info.clone(),
424 self.capabilities.clone(),
425 init_result.server_info,
426 init_result.capabilities,
427 init_result.protocol_version,
428 );
429
430 Ok(Client::from_parts(
432 child_guard.disarm(),
433 transport,
434 cx.clone(),
435 session,
436 self.timeout_ms,
437 ))
438 }
439}
440
441impl Default for ClientBuilder {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_builder_defaults() {
453 let builder = ClientBuilder::new();
454 assert_eq!(builder.client_info.name, "fastmcp-client");
455 assert_eq!(builder.timeout_ms, 30_000);
456 assert_eq!(builder.max_retries, 0);
457 assert_eq!(builder.retry_delay_ms, 1_000);
458 assert!(builder.inherit_env);
459 assert!(builder.working_dir.is_none());
460 assert!(builder.env_vars.is_empty());
461 assert!(!builder.auto_initialize);
462 }
463
464 #[test]
465 fn test_builder_fluent_api() {
466 let builder = ClientBuilder::new()
467 .client_info("test-client", "2.0.0")
468 .timeout_ms(60_000)
469 .max_retries(3)
470 .retry_delay_ms(500)
471 .working_dir("/tmp")
472 .env("FOO", "bar")
473 .env("BAZ", "qux")
474 .inherit_env(false);
475
476 assert_eq!(builder.client_info.name, "test-client");
477 assert_eq!(builder.client_info.version, "2.0.0");
478 assert_eq!(builder.timeout_ms, 60_000);
479 assert_eq!(builder.max_retries, 3);
480 assert_eq!(builder.retry_delay_ms, 500);
481 assert_eq!(builder.working_dir, Some(PathBuf::from("/tmp")));
482 assert_eq!(builder.env_vars.get("FOO"), Some(&"bar".to_string()));
483 assert_eq!(builder.env_vars.get("BAZ"), Some(&"qux".to_string()));
484 assert!(!builder.inherit_env);
485 }
486
487 #[test]
488 fn test_builder_envs() {
489 let vars = [("KEY1", "value1"), ("KEY2", "value2")];
490 let builder = ClientBuilder::new().envs(vars);
491
492 assert_eq!(builder.env_vars.get("KEY1"), Some(&"value1".to_string()));
493 assert_eq!(builder.env_vars.get("KEY2"), Some(&"value2".to_string()));
494 }
495
496 #[test]
497 fn test_builder_clone() {
498 let builder1 = ClientBuilder::new()
499 .client_info("test", "1.0")
500 .timeout_ms(5000);
501
502 let builder2 = builder1.clone();
503
504 assert_eq!(builder2.client_info.name, "test");
505 assert_eq!(builder2.timeout_ms, 5000);
506 }
507
508 #[test]
509 fn test_builder_auto_initialize() {
510 let builder = ClientBuilder::new().auto_initialize(true);
511 assert!(builder.auto_initialize);
512
513 let builder = ClientBuilder::new().auto_initialize(false);
514 assert!(!builder.auto_initialize);
515 }
516
517 #[test]
518 fn test_builder_capabilities() {
519 let caps = ClientCapabilities {
520 sampling: Some(fastmcp_protocol::SamplingCapability {}),
521 elicitation: None,
522 roots: None,
523 };
524 let builder = ClientBuilder::new().capabilities(caps);
525 assert!(builder.capabilities.sampling.is_some());
526 assert!(builder.capabilities.elicitation.is_none());
527 assert!(builder.capabilities.roots.is_none());
528 }
529
530 #[test]
531 fn test_builder_default_trait() {
532 let builder = ClientBuilder::default();
533 assert_eq!(builder.client_info.name, "fastmcp-client");
534 assert_eq!(builder.timeout_ms, 30_000);
535 assert_eq!(builder.max_retries, 0);
536 assert!(!builder.auto_initialize);
537 }
538
539 #[test]
540 fn test_builder_env_override() {
541 let builder = ClientBuilder::new()
542 .env("KEY", "first")
543 .env("KEY", "second");
544 assert_eq!(builder.env_vars.get("KEY"), Some(&"second".to_string()));
545 }
546
547 #[test]
548 fn test_builder_envs_combined_with_env() {
549 let builder = ClientBuilder::new()
550 .env("A", "1")
551 .envs([("B", "2"), ("C", "3")])
552 .env("D", "4");
553 assert_eq!(builder.env_vars.len(), 4);
554 assert_eq!(builder.env_vars.get("A"), Some(&"1".to_string()));
555 assert_eq!(builder.env_vars.get("B"), Some(&"2".to_string()));
556 assert_eq!(builder.env_vars.get("C"), Some(&"3".to_string()));
557 assert_eq!(builder.env_vars.get("D"), Some(&"4".to_string()));
558 }
559}