Skip to main content

turbomcp_cli/
transport.rs

1//! Transport factory and auto-detection
2
3use crate::cli::{Connection, TransportKind};
4use crate::error::{CliError, CliResult};
5use std::collections::HashMap;
6use std::time::Duration;
7use turbomcp_client::Client;
8use turbomcp_protocol::types::Tool;
9
10#[cfg(feature = "stdio")]
11use turbomcp_transport::child_process::{ChildProcessConfig, ChildProcessTransport};
12
13#[cfg(feature = "tcp")]
14use turbomcp_transport::tcp::TcpTransportBuilder;
15
16#[cfg(feature = "unix")]
17use turbomcp_transport::unix::UnixTransportBuilder;
18
19#[cfg(feature = "http")]
20use turbomcp_transport::streamable_http_client::{
21    StreamableHttpClientConfig, StreamableHttpClientTransport,
22};
23
24#[cfg(feature = "websocket")]
25use turbomcp_transport::{WebSocketBidirectionalConfig, WebSocketBidirectionalTransport};
26
27/// Wrapper for unified client operations, hiding transport implementation details
28pub struct UnifiedClient {
29    inner: ClientInner,
30}
31
32enum ClientInner {
33    #[cfg(feature = "stdio")]
34    Stdio(Client<ChildProcessTransport>),
35    #[cfg(feature = "tcp")]
36    Tcp(Client<turbomcp_transport::tcp::TcpTransport>),
37    #[cfg(feature = "unix")]
38    Unix(Client<turbomcp_transport::unix::UnixTransport>),
39    #[cfg(feature = "http")]
40    Http(Client<StreamableHttpClientTransport>),
41    #[cfg(feature = "websocket")]
42    WebSocket(Client<WebSocketBidirectionalTransport>),
43}
44
45impl UnifiedClient {
46    pub async fn initialize(&self) -> CliResult<turbomcp_client::InitializeResult> {
47        match &self.inner {
48            #[cfg(feature = "stdio")]
49            ClientInner::Stdio(client) => Ok(client.initialize().await?),
50            #[cfg(feature = "tcp")]
51            ClientInner::Tcp(client) => Ok(client.initialize().await?),
52            #[cfg(feature = "unix")]
53            ClientInner::Unix(client) => Ok(client.initialize().await?),
54            #[cfg(feature = "http")]
55            ClientInner::Http(client) => Ok(client.initialize().await?),
56            #[cfg(feature = "websocket")]
57            ClientInner::WebSocket(client) => Ok(client.initialize().await?),
58        }
59    }
60
61    pub async fn list_tools(&self) -> CliResult<Vec<Tool>> {
62        match &self.inner {
63            #[cfg(feature = "stdio")]
64            ClientInner::Stdio(client) => Ok(client.list_tools().await?),
65            #[cfg(feature = "tcp")]
66            ClientInner::Tcp(client) => Ok(client.list_tools().await?),
67            #[cfg(feature = "unix")]
68            ClientInner::Unix(client) => Ok(client.list_tools().await?),
69            #[cfg(feature = "http")]
70            ClientInner::Http(client) => Ok(client.list_tools().await?),
71            #[cfg(feature = "websocket")]
72            ClientInner::WebSocket(client) => Ok(client.list_tools().await?),
73        }
74    }
75
76    pub async fn call_tool(
77        &self,
78        name: &str,
79        arguments: Option<HashMap<String, serde_json::Value>>,
80    ) -> CliResult<serde_json::Value> {
81        let result = match &self.inner {
82            #[cfg(feature = "stdio")]
83            ClientInner::Stdio(client) => client.call_tool(name, arguments, None).await?,
84            #[cfg(feature = "tcp")]
85            ClientInner::Tcp(client) => client.call_tool(name, arguments, None).await?,
86            #[cfg(feature = "unix")]
87            ClientInner::Unix(client) => client.call_tool(name, arguments, None).await?,
88            #[cfg(feature = "http")]
89            ClientInner::Http(client) => client.call_tool(name, arguments, None).await?,
90            #[cfg(feature = "websocket")]
91            ClientInner::WebSocket(client) => client.call_tool(name, arguments, None).await?,
92        };
93
94        // Serialize CallToolResult to JSON for CLI display
95        Ok(serde_json::to_value(result)?)
96    }
97
98    pub async fn list_resources(&self) -> CliResult<Vec<turbomcp_protocol::types::Resource>> {
99        match &self.inner {
100            #[cfg(feature = "stdio")]
101            ClientInner::Stdio(client) => Ok(client.list_resources().await?),
102            #[cfg(feature = "tcp")]
103            ClientInner::Tcp(client) => Ok(client.list_resources().await?),
104            #[cfg(feature = "unix")]
105            ClientInner::Unix(client) => Ok(client.list_resources().await?),
106            #[cfg(feature = "http")]
107            ClientInner::Http(client) => Ok(client.list_resources().await?),
108            #[cfg(feature = "websocket")]
109            ClientInner::WebSocket(client) => Ok(client.list_resources().await?),
110        }
111    }
112
113    pub async fn read_resource(
114        &self,
115        uri: &str,
116    ) -> CliResult<turbomcp_protocol::types::ReadResourceResult> {
117        match &self.inner {
118            #[cfg(feature = "stdio")]
119            ClientInner::Stdio(client) => Ok(client.read_resource(uri).await?),
120            #[cfg(feature = "tcp")]
121            ClientInner::Tcp(client) => Ok(client.read_resource(uri).await?),
122            #[cfg(feature = "unix")]
123            ClientInner::Unix(client) => Ok(client.read_resource(uri).await?),
124            #[cfg(feature = "http")]
125            ClientInner::Http(client) => Ok(client.read_resource(uri).await?),
126            #[cfg(feature = "websocket")]
127            ClientInner::WebSocket(client) => Ok(client.read_resource(uri).await?),
128        }
129    }
130
131    pub async fn list_resource_templates(&self) -> CliResult<Vec<String>> {
132        match &self.inner {
133            #[cfg(feature = "stdio")]
134            ClientInner::Stdio(client) => Ok(client.list_resource_templates().await?),
135            #[cfg(feature = "tcp")]
136            ClientInner::Tcp(client) => Ok(client.list_resource_templates().await?),
137            #[cfg(feature = "unix")]
138            ClientInner::Unix(client) => Ok(client.list_resource_templates().await?),
139            #[cfg(feature = "http")]
140            ClientInner::Http(client) => Ok(client.list_resource_templates().await?),
141            #[cfg(feature = "websocket")]
142            ClientInner::WebSocket(client) => Ok(client.list_resource_templates().await?),
143        }
144    }
145
146    pub async fn subscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
147        match &self.inner {
148            #[cfg(feature = "stdio")]
149            ClientInner::Stdio(client) => Ok(client.subscribe(uri).await?),
150            #[cfg(feature = "tcp")]
151            ClientInner::Tcp(client) => Ok(client.subscribe(uri).await?),
152            #[cfg(feature = "unix")]
153            ClientInner::Unix(client) => Ok(client.subscribe(uri).await?),
154            #[cfg(feature = "http")]
155            ClientInner::Http(client) => Ok(client.subscribe(uri).await?),
156            #[cfg(feature = "websocket")]
157            ClientInner::WebSocket(client) => Ok(client.subscribe(uri).await?),
158        }
159    }
160
161    pub async fn unsubscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
162        match &self.inner {
163            #[cfg(feature = "stdio")]
164            ClientInner::Stdio(client) => Ok(client.unsubscribe(uri).await?),
165            #[cfg(feature = "tcp")]
166            ClientInner::Tcp(client) => Ok(client.unsubscribe(uri).await?),
167            #[cfg(feature = "unix")]
168            ClientInner::Unix(client) => Ok(client.unsubscribe(uri).await?),
169            #[cfg(feature = "http")]
170            ClientInner::Http(client) => Ok(client.unsubscribe(uri).await?),
171            #[cfg(feature = "websocket")]
172            ClientInner::WebSocket(client) => Ok(client.unsubscribe(uri).await?),
173        }
174    }
175
176    pub async fn list_prompts(&self) -> CliResult<Vec<turbomcp_protocol::types::Prompt>> {
177        match &self.inner {
178            #[cfg(feature = "stdio")]
179            ClientInner::Stdio(client) => Ok(client.list_prompts().await?),
180            #[cfg(feature = "tcp")]
181            ClientInner::Tcp(client) => Ok(client.list_prompts().await?),
182            #[cfg(feature = "unix")]
183            ClientInner::Unix(client) => Ok(client.list_prompts().await?),
184            #[cfg(feature = "http")]
185            ClientInner::Http(client) => Ok(client.list_prompts().await?),
186            #[cfg(feature = "websocket")]
187            ClientInner::WebSocket(client) => Ok(client.list_prompts().await?),
188        }
189    }
190
191    pub async fn get_prompt(
192        &self,
193        name: &str,
194        arguments: Option<HashMap<String, serde_json::Value>>,
195    ) -> CliResult<turbomcp_protocol::types::GetPromptResult> {
196        match &self.inner {
197            #[cfg(feature = "stdio")]
198            ClientInner::Stdio(client) => Ok(client.get_prompt(name, arguments).await?),
199            #[cfg(feature = "tcp")]
200            ClientInner::Tcp(client) => Ok(client.get_prompt(name, arguments).await?),
201            #[cfg(feature = "unix")]
202            ClientInner::Unix(client) => Ok(client.get_prompt(name, arguments).await?),
203            #[cfg(feature = "http")]
204            ClientInner::Http(client) => Ok(client.get_prompt(name, arguments).await?),
205            #[cfg(feature = "websocket")]
206            ClientInner::WebSocket(client) => Ok(client.get_prompt(name, arguments).await?),
207        }
208    }
209
210    pub async fn complete_prompt(
211        &self,
212        prompt_name: &str,
213        argument_name: &str,
214        argument_value: &str,
215        context: Option<turbomcp_protocol::types::CompletionContext>,
216    ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
217        match &self.inner {
218            #[cfg(feature = "stdio")]
219            ClientInner::Stdio(client) => Ok(client
220                .complete_prompt(prompt_name, argument_name, argument_value, context)
221                .await?),
222            #[cfg(feature = "tcp")]
223            ClientInner::Tcp(client) => Ok(client
224                .complete_prompt(prompt_name, argument_name, argument_value, context)
225                .await?),
226            #[cfg(feature = "unix")]
227            ClientInner::Unix(client) => Ok(client
228                .complete_prompt(prompt_name, argument_name, argument_value, context)
229                .await?),
230            #[cfg(feature = "http")]
231            ClientInner::Http(client) => Ok(client
232                .complete_prompt(prompt_name, argument_name, argument_value, context)
233                .await?),
234            #[cfg(feature = "websocket")]
235            ClientInner::WebSocket(client) => Ok(client
236                .complete_prompt(prompt_name, argument_name, argument_value, context)
237                .await?),
238        }
239    }
240
241    pub async fn complete_resource(
242        &self,
243        resource_uri: &str,
244        argument_name: &str,
245        argument_value: &str,
246        context: Option<turbomcp_protocol::types::CompletionContext>,
247    ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
248        match &self.inner {
249            #[cfg(feature = "stdio")]
250            ClientInner::Stdio(client) => Ok(client
251                .complete_resource(resource_uri, argument_name, argument_value, context)
252                .await?),
253            #[cfg(feature = "tcp")]
254            ClientInner::Tcp(client) => Ok(client
255                .complete_resource(resource_uri, argument_name, argument_value, context)
256                .await?),
257            #[cfg(feature = "unix")]
258            ClientInner::Unix(client) => Ok(client
259                .complete_resource(resource_uri, argument_name, argument_value, context)
260                .await?),
261            #[cfg(feature = "http")]
262            ClientInner::Http(client) => Ok(client
263                .complete_resource(resource_uri, argument_name, argument_value, context)
264                .await?),
265            #[cfg(feature = "websocket")]
266            ClientInner::WebSocket(client) => Ok(client
267                .complete_resource(resource_uri, argument_name, argument_value, context)
268                .await?),
269        }
270    }
271
272    pub async fn ping(&self) -> CliResult<()> {
273        match &self.inner {
274            #[cfg(feature = "stdio")]
275            ClientInner::Stdio(client) => {
276                client.ping().await?;
277                Ok(())
278            }
279            #[cfg(feature = "tcp")]
280            ClientInner::Tcp(client) => {
281                client.ping().await?;
282                Ok(())
283            }
284            #[cfg(feature = "unix")]
285            ClientInner::Unix(client) => {
286                client.ping().await?;
287                Ok(())
288            }
289            #[cfg(feature = "http")]
290            ClientInner::Http(client) => {
291                client.ping().await?;
292                Ok(())
293            }
294            #[cfg(feature = "websocket")]
295            ClientInner::WebSocket(client) => {
296                client.ping().await?;
297                Ok(())
298            }
299        }
300    }
301
302    pub async fn set_log_level(&self, level: turbomcp_protocol::types::LogLevel) -> CliResult<()> {
303        match &self.inner {
304            #[cfg(feature = "stdio")]
305            ClientInner::Stdio(client) => {
306                client.set_log_level(level).await?;
307                Ok(())
308            }
309            #[cfg(feature = "tcp")]
310            ClientInner::Tcp(client) => {
311                client.set_log_level(level).await?;
312                Ok(())
313            }
314            #[cfg(feature = "unix")]
315            ClientInner::Unix(client) => {
316                client.set_log_level(level).await?;
317                Ok(())
318            }
319            #[cfg(feature = "http")]
320            ClientInner::Http(client) => {
321                client.set_log_level(level).await?;
322                Ok(())
323            }
324            #[cfg(feature = "websocket")]
325            ClientInner::WebSocket(client) => {
326                client.set_log_level(level).await?;
327                Ok(())
328            }
329        }
330    }
331}
332
333/// Create a unified client that hides transport type complexity from the executor
334pub async fn create_client(conn: &Connection) -> CliResult<UnifiedClient> {
335    let transport_kind = determine_transport(conn);
336
337    // The --auth / MCP_AUTH bearer token is only consumed by the HTTP transport.
338    // Warn (without echoing the token) when the user supplies it for a transport
339    // that has no notion of authentication so they don't assume it was sent.
340    if conn.auth.is_some() && !matches!(transport_kind, TransportKind::Http | TransportKind::Ws) {
341        eprintln!(
342            "Warning: --auth is currently only honored by the HTTP transport; ignoring for {:?}.",
343            transport_kind
344        );
345    }
346
347    match transport_kind {
348        #[cfg(feature = "stdio")]
349        TransportKind::Stdio => {
350            let transport = create_stdio_transport(conn)?;
351            Ok(UnifiedClient {
352                inner: ClientInner::Stdio(Client::new(transport)),
353            })
354        }
355        #[cfg(not(feature = "stdio"))]
356        TransportKind::Stdio => {
357            Err(CliError::NotSupported(
358                "STDIO transport is not enabled (missing 'stdio' feature)".to_string(),
359            ))
360        }
361        #[cfg(feature = "http")]
362        TransportKind::Http => {
363            let transport = create_http_transport(conn).await?;
364            Ok(UnifiedClient {
365                inner: ClientInner::Http(Client::new(transport)),
366            })
367        }
368        #[cfg(not(feature = "http"))]
369        TransportKind::Http => {
370            Err(CliError::NotSupported(
371                "HTTP transport is not enabled. Rebuild with --features http or --features all"
372                    .to_string(),
373            ))
374        }
375        #[cfg(feature = "websocket")]
376        TransportKind::Ws => {
377            let transport = create_websocket_transport(conn).await?;
378            Ok(UnifiedClient {
379                inner: ClientInner::WebSocket(Client::new(transport)),
380            })
381        }
382        #[cfg(not(feature = "websocket"))]
383        TransportKind::Ws => {
384            Err(CliError::NotSupported(
385                "WebSocket transport is not enabled. Rebuild with --features websocket or --features all"
386                    .to_string(),
387            ))
388        }
389        #[cfg(feature = "tcp")]
390        TransportKind::Tcp => {
391            let transport = create_tcp_transport(conn).await?;
392            Ok(UnifiedClient {
393                inner: ClientInner::Tcp(Client::new(transport)),
394            })
395        }
396        #[cfg(not(feature = "tcp"))]
397        TransportKind::Tcp => {
398            Err(CliError::NotSupported(
399                "TCP transport is not enabled (missing 'tcp' feature)".to_string(),
400            ))
401        }
402        #[cfg(feature = "unix")]
403        TransportKind::Unix => {
404            let transport = create_unix_transport(conn).await?;
405            Ok(UnifiedClient {
406                inner: ClientInner::Unix(Client::new(transport)),
407            })
408        }
409        #[cfg(not(feature = "unix"))]
410        TransportKind::Unix => {
411            Err(CliError::NotSupported(
412                "Unix socket transport is not enabled (missing 'unix' feature)".to_string(),
413            ))
414        }
415    }
416}
417
418/// Determine transport type from connection config
419pub fn determine_transport(conn: &Connection) -> TransportKind {
420    // Use explicit transport if provided
421    if let Some(transport) = &conn.transport {
422        return transport.clone();
423    }
424
425    // Auto-detect based on URL/command patterns
426    let url = &conn.url;
427
428    if conn.command.is_some() {
429        return TransportKind::Stdio;
430    }
431
432    if url.starts_with("tcp://") {
433        return TransportKind::Tcp;
434    }
435
436    if url.starts_with("unix://") || url.starts_with("/") {
437        return TransportKind::Unix;
438    }
439
440    if url.starts_with("ws://") || url.starts_with("wss://") {
441        return TransportKind::Ws;
442    }
443
444    if url.starts_with("http://") || url.starts_with("https://") {
445        return TransportKind::Http;
446    }
447
448    // Default to STDIO for executable paths
449    TransportKind::Stdio
450}
451
452/// Create STDIO transport from connection
453#[cfg(feature = "stdio")]
454fn create_stdio_transport(conn: &Connection) -> CliResult<ChildProcessTransport> {
455    // Use --command if provided, otherwise use --url
456    let command_str = conn.command.as_deref().unwrap_or(&conn.url);
457
458    // Honor shell quoting/escaping so paths with spaces and `bash -c "..."`
459    // wrappers parse correctly. `split_whitespace` would fragment them.
460    let parts = shell_words::split(command_str)
461        .map_err(|e| CliError::InvalidArguments(format!("Invalid --command quoting: {e}")))?;
462    if parts.is_empty() {
463        return Err(CliError::InvalidArguments(
464            "No command specified for STDIO transport".to_string(),
465        ));
466    }
467
468    let command = parts[0].clone();
469    let args: Vec<String> = parts[1..].to_vec();
470
471    // Create config
472    let config = ChildProcessConfig {
473        command,
474        args,
475        working_directory: None,
476        environment: None,
477        startup_timeout: Duration::from_secs(conn.timeout),
478        shutdown_timeout: Duration::from_secs(5),
479        max_message_size: 10 * 1024 * 1024, // 10MB
480        buffer_size: 8192,                  // 8KB buffer
481        kill_on_drop: true,                 // Kill process when client is dropped
482    };
483
484    // Create transport
485    Ok(ChildProcessTransport::new(config))
486}
487
488/// Create TCP transport from connection
489#[cfg(feature = "tcp")]
490async fn create_tcp_transport(
491    conn: &Connection,
492) -> CliResult<turbomcp_transport::tcp::TcpTransport> {
493    let url = &conn.url;
494
495    // Parse TCP URL
496    let addr_str = url
497        .strip_prefix("tcp://")
498        .ok_or_else(|| CliError::InvalidArguments(format!("Invalid TCP URL: {}", url)))?;
499
500    // Parse into SocketAddr
501    let socket_addr: std::net::SocketAddr = addr_str.parse().map_err(|e| {
502        CliError::InvalidArguments(format!("Invalid address '{}': {}", addr_str, e))
503    })?;
504
505    let transport = TcpTransportBuilder::new().remote_addr(socket_addr).build();
506
507    Ok(transport)
508}
509
510/// Create Unix socket transport from connection
511#[cfg(feature = "unix")]
512async fn create_unix_transport(
513    conn: &Connection,
514) -> CliResult<turbomcp_transport::unix::UnixTransport> {
515    let path = conn.url.strip_prefix("unix://").unwrap_or(&conn.url);
516
517    let transport = UnixTransportBuilder::new_client().socket_path(path).build();
518
519    Ok(transport)
520}
521
522/// Create HTTP transport from connection
523#[cfg(feature = "http")]
524async fn create_http_transport(conn: &Connection) -> CliResult<StreamableHttpClientTransport> {
525    let url = &conn.url;
526
527    // Parse HTTP URL (remove http:// or https://)
528    let base_url = if let Some(stripped) = url.strip_prefix("https://") {
529        format!("https://{}", stripped)
530    } else if let Some(stripped) = url.strip_prefix("http://") {
531        format!("http://{}", stripped)
532    } else {
533        url.clone()
534    };
535
536    let config = StreamableHttpClientConfig {
537        base_url,
538        endpoint_path: "/mcp".to_string(),
539        timeout: Duration::from_secs(conn.timeout),
540        auth_token: conn.auth.clone(),
541        ..Default::default()
542    };
543
544    StreamableHttpClientTransport::new(config).map_err(|e| {
545        crate::CliError::Transport(turbomcp_protocol::Error::transport(format!(
546            "Failed to build HTTP transport: {e}"
547        )))
548    })
549}
550
551/// Create WebSocket transport from connection
552#[cfg(feature = "websocket")]
553async fn create_websocket_transport(
554    conn: &Connection,
555) -> CliResult<WebSocketBidirectionalTransport> {
556    let url = &conn.url;
557
558    // Validate URL is a proper WebSocket URL
559    if !url.starts_with("ws://") && !url.starts_with("wss://") {
560        return Err(CliError::InvalidArguments(format!(
561            "Invalid WebSocket URL: {} (must start with ws:// or wss://)",
562            url
563        )));
564    }
565
566    let config = WebSocketBidirectionalConfig::client(url.clone());
567
568    WebSocketBidirectionalTransport::new(config)
569        .await
570        .map_err(|e| CliError::ConnectionFailed(e.to_string()))
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_determine_transport() {
579        // STDIO detection
580        let conn = Connection {
581            transport: None,
582            url: "./my-server".to_string(),
583            command: None,
584            auth: None,
585            timeout: 30,
586        };
587        assert_eq!(determine_transport(&conn), TransportKind::Stdio);
588
589        // Command override
590        let conn = Connection {
591            transport: None,
592            url: "http://localhost".to_string(),
593            command: Some("python server.py".to_string()),
594            auth: None,
595            timeout: 30,
596        };
597        assert_eq!(determine_transport(&conn), TransportKind::Stdio);
598
599        // TCP detection
600        let conn = Connection {
601            transport: None,
602            url: "tcp://localhost:8080".to_string(),
603            command: None,
604            auth: None,
605            timeout: 30,
606        };
607        assert_eq!(determine_transport(&conn), TransportKind::Tcp);
608
609        // Unix detection
610        let conn = Connection {
611            transport: None,
612            url: "/tmp/mcp.sock".to_string(),
613            command: None,
614            auth: None,
615            timeout: 30,
616        };
617        assert_eq!(determine_transport(&conn), TransportKind::Unix);
618
619        // Explicit override
620        let conn = Connection {
621            transport: Some(TransportKind::Tcp),
622            url: "http://localhost".to_string(),
623            command: None,
624            auth: None,
625            timeout: 30,
626        };
627        assert_eq!(determine_transport(&conn), TransportKind::Tcp);
628    }
629}