Skip to main content

turbomcp_cli/
transport.rs

1//! Transport factory and auto-detection
2
3use crate::cli::{Connection, TransportKind};
4use crate::error::{CliError, CliResult};
5use std::collections::HashMap;
6use std::time::Duration;
7use turbomcp_client::Client;
8use turbomcp_protocol::types::Tool;
9
10#[cfg(feature = "stdio")]
11use turbomcp_transport::child_process::{ChildProcessConfig, ChildProcessTransport};
12
13#[cfg(feature = "tcp")]
14use turbomcp_transport::tcp::TcpTransportBuilder;
15
16#[cfg(feature = "unix")]
17use turbomcp_transport::unix::UnixTransportBuilder;
18
19#[cfg(feature = "http")]
20use turbomcp_transport::streamable_http_client::{
21    StreamableHttpClientConfig, StreamableHttpClientTransport,
22};
23
24#[cfg(feature = "websocket")]
25use turbomcp_transport::{WebSocketBidirectionalConfig, WebSocketBidirectionalTransport};
26
27/// Wrapper for unified client operations, hiding transport implementation details
28pub struct UnifiedClient {
29    inner: ClientInner,
30}
31
32enum ClientInner {
33    #[cfg(feature = "stdio")]
34    Stdio(Client<ChildProcessTransport>),
35    #[cfg(feature = "tcp")]
36    Tcp(Client<turbomcp_transport::tcp::TcpTransport>),
37    #[cfg(feature = "unix")]
38    Unix(Client<turbomcp_transport::unix::UnixTransport>),
39    #[cfg(feature = "http")]
40    Http(Client<StreamableHttpClientTransport>),
41    #[cfg(feature = "websocket")]
42    WebSocket(Client<WebSocketBidirectionalTransport>),
43}
44
45impl UnifiedClient {
46    pub async fn initialize(&self) -> CliResult<turbomcp_client::InitializeResult> {
47        match &self.inner {
48            #[cfg(feature = "stdio")]
49            ClientInner::Stdio(client) => Ok(client.initialize().await?),
50            #[cfg(feature = "tcp")]
51            ClientInner::Tcp(client) => Ok(client.initialize().await?),
52            #[cfg(feature = "unix")]
53            ClientInner::Unix(client) => Ok(client.initialize().await?),
54            #[cfg(feature = "http")]
55            ClientInner::Http(client) => Ok(client.initialize().await?),
56            #[cfg(feature = "websocket")]
57            ClientInner::WebSocket(client) => Ok(client.initialize().await?),
58        }
59    }
60
61    pub async fn list_tools(&self) -> CliResult<Vec<Tool>> {
62        match &self.inner {
63            #[cfg(feature = "stdio")]
64            ClientInner::Stdio(client) => Ok(client.list_tools().await?),
65            #[cfg(feature = "tcp")]
66            ClientInner::Tcp(client) => Ok(client.list_tools().await?),
67            #[cfg(feature = "unix")]
68            ClientInner::Unix(client) => Ok(client.list_tools().await?),
69            #[cfg(feature = "http")]
70            ClientInner::Http(client) => Ok(client.list_tools().await?),
71            #[cfg(feature = "websocket")]
72            ClientInner::WebSocket(client) => Ok(client.list_tools().await?),
73        }
74    }
75
76    pub async fn call_tool(
77        &self,
78        name: &str,
79        arguments: Option<HashMap<String, serde_json::Value>>,
80    ) -> CliResult<serde_json::Value> {
81        let result = match &self.inner {
82            #[cfg(feature = "stdio")]
83            ClientInner::Stdio(client) => client.call_tool(name, arguments, None).await?,
84            #[cfg(feature = "tcp")]
85            ClientInner::Tcp(client) => client.call_tool(name, arguments, None).await?,
86            #[cfg(feature = "unix")]
87            ClientInner::Unix(client) => client.call_tool(name, arguments, None).await?,
88            #[cfg(feature = "http")]
89            ClientInner::Http(client) => client.call_tool(name, arguments, None).await?,
90            #[cfg(feature = "websocket")]
91            ClientInner::WebSocket(client) => client.call_tool(name, arguments, None).await?,
92        };
93
94        // Serialize CallToolResult to JSON for CLI display
95        Ok(serde_json::to_value(result)?)
96    }
97
98    pub async fn list_resources(&self) -> CliResult<Vec<turbomcp_protocol::types::Resource>> {
99        match &self.inner {
100            #[cfg(feature = "stdio")]
101            ClientInner::Stdio(client) => Ok(client.list_resources().await?),
102            #[cfg(feature = "tcp")]
103            ClientInner::Tcp(client) => Ok(client.list_resources().await?),
104            #[cfg(feature = "unix")]
105            ClientInner::Unix(client) => Ok(client.list_resources().await?),
106            #[cfg(feature = "http")]
107            ClientInner::Http(client) => Ok(client.list_resources().await?),
108            #[cfg(feature = "websocket")]
109            ClientInner::WebSocket(client) => Ok(client.list_resources().await?),
110        }
111    }
112
113    pub async fn read_resource(
114        &self,
115        uri: &str,
116    ) -> CliResult<turbomcp_protocol::types::ReadResourceResult> {
117        match &self.inner {
118            #[cfg(feature = "stdio")]
119            ClientInner::Stdio(client) => Ok(client.read_resource(uri).await?),
120            #[cfg(feature = "tcp")]
121            ClientInner::Tcp(client) => Ok(client.read_resource(uri).await?),
122            #[cfg(feature = "unix")]
123            ClientInner::Unix(client) => Ok(client.read_resource(uri).await?),
124            #[cfg(feature = "http")]
125            ClientInner::Http(client) => Ok(client.read_resource(uri).await?),
126            #[cfg(feature = "websocket")]
127            ClientInner::WebSocket(client) => Ok(client.read_resource(uri).await?),
128        }
129    }
130
131    pub async fn list_resource_templates(
132        &self,
133    ) -> CliResult<Vec<turbomcp_protocol::types::ResourceTemplate>> {
134        match &self.inner {
135            #[cfg(feature = "stdio")]
136            ClientInner::Stdio(client) => Ok(client.list_resource_templates().await?),
137            #[cfg(feature = "tcp")]
138            ClientInner::Tcp(client) => Ok(client.list_resource_templates().await?),
139            #[cfg(feature = "unix")]
140            ClientInner::Unix(client) => Ok(client.list_resource_templates().await?),
141            #[cfg(feature = "http")]
142            ClientInner::Http(client) => Ok(client.list_resource_templates().await?),
143            #[cfg(feature = "websocket")]
144            ClientInner::WebSocket(client) => Ok(client.list_resource_templates().await?),
145        }
146    }
147
148    pub async fn subscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
149        match &self.inner {
150            #[cfg(feature = "stdio")]
151            ClientInner::Stdio(client) => Ok(client.subscribe(uri).await?),
152            #[cfg(feature = "tcp")]
153            ClientInner::Tcp(client) => Ok(client.subscribe(uri).await?),
154            #[cfg(feature = "unix")]
155            ClientInner::Unix(client) => Ok(client.subscribe(uri).await?),
156            #[cfg(feature = "http")]
157            ClientInner::Http(client) => Ok(client.subscribe(uri).await?),
158            #[cfg(feature = "websocket")]
159            ClientInner::WebSocket(client) => Ok(client.subscribe(uri).await?),
160        }
161    }
162
163    pub async fn unsubscribe(&self, uri: &str) -> CliResult<turbomcp_protocol::types::EmptyResult> {
164        match &self.inner {
165            #[cfg(feature = "stdio")]
166            ClientInner::Stdio(client) => Ok(client.unsubscribe(uri).await?),
167            #[cfg(feature = "tcp")]
168            ClientInner::Tcp(client) => Ok(client.unsubscribe(uri).await?),
169            #[cfg(feature = "unix")]
170            ClientInner::Unix(client) => Ok(client.unsubscribe(uri).await?),
171            #[cfg(feature = "http")]
172            ClientInner::Http(client) => Ok(client.unsubscribe(uri).await?),
173            #[cfg(feature = "websocket")]
174            ClientInner::WebSocket(client) => Ok(client.unsubscribe(uri).await?),
175        }
176    }
177
178    pub async fn list_prompts(&self) -> CliResult<Vec<turbomcp_protocol::types::Prompt>> {
179        match &self.inner {
180            #[cfg(feature = "stdio")]
181            ClientInner::Stdio(client) => Ok(client.list_prompts().await?),
182            #[cfg(feature = "tcp")]
183            ClientInner::Tcp(client) => Ok(client.list_prompts().await?),
184            #[cfg(feature = "unix")]
185            ClientInner::Unix(client) => Ok(client.list_prompts().await?),
186            #[cfg(feature = "http")]
187            ClientInner::Http(client) => Ok(client.list_prompts().await?),
188            #[cfg(feature = "websocket")]
189            ClientInner::WebSocket(client) => Ok(client.list_prompts().await?),
190        }
191    }
192
193    pub async fn get_prompt(
194        &self,
195        name: &str,
196        arguments: Option<HashMap<String, serde_json::Value>>,
197    ) -> CliResult<turbomcp_protocol::types::GetPromptResult> {
198        match &self.inner {
199            #[cfg(feature = "stdio")]
200            ClientInner::Stdio(client) => Ok(client.get_prompt(name, arguments).await?),
201            #[cfg(feature = "tcp")]
202            ClientInner::Tcp(client) => Ok(client.get_prompt(name, arguments).await?),
203            #[cfg(feature = "unix")]
204            ClientInner::Unix(client) => Ok(client.get_prompt(name, arguments).await?),
205            #[cfg(feature = "http")]
206            ClientInner::Http(client) => Ok(client.get_prompt(name, arguments).await?),
207            #[cfg(feature = "websocket")]
208            ClientInner::WebSocket(client) => Ok(client.get_prompt(name, arguments).await?),
209        }
210    }
211
212    pub async fn complete_prompt(
213        &self,
214        prompt_name: &str,
215        argument_name: &str,
216        argument_value: &str,
217        context: Option<turbomcp_protocol::types::CompletionContext>,
218    ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
219        match &self.inner {
220            #[cfg(feature = "stdio")]
221            ClientInner::Stdio(client) => Ok(client
222                .complete_prompt(prompt_name, argument_name, argument_value, context)
223                .await?),
224            #[cfg(feature = "tcp")]
225            ClientInner::Tcp(client) => Ok(client
226                .complete_prompt(prompt_name, argument_name, argument_value, context)
227                .await?),
228            #[cfg(feature = "unix")]
229            ClientInner::Unix(client) => Ok(client
230                .complete_prompt(prompt_name, argument_name, argument_value, context)
231                .await?),
232            #[cfg(feature = "http")]
233            ClientInner::Http(client) => Ok(client
234                .complete_prompt(prompt_name, argument_name, argument_value, context)
235                .await?),
236            #[cfg(feature = "websocket")]
237            ClientInner::WebSocket(client) => Ok(client
238                .complete_prompt(prompt_name, argument_name, argument_value, context)
239                .await?),
240        }
241    }
242
243    pub async fn complete_resource(
244        &self,
245        resource_uri: &str,
246        argument_name: &str,
247        argument_value: &str,
248        context: Option<turbomcp_protocol::types::CompletionContext>,
249    ) -> CliResult<turbomcp_protocol::types::CompletionResponse> {
250        match &self.inner {
251            #[cfg(feature = "stdio")]
252            ClientInner::Stdio(client) => Ok(client
253                .complete_resource(resource_uri, argument_name, argument_value, context)
254                .await?),
255            #[cfg(feature = "tcp")]
256            ClientInner::Tcp(client) => Ok(client
257                .complete_resource(resource_uri, argument_name, argument_value, context)
258                .await?),
259            #[cfg(feature = "unix")]
260            ClientInner::Unix(client) => Ok(client
261                .complete_resource(resource_uri, argument_name, argument_value, context)
262                .await?),
263            #[cfg(feature = "http")]
264            ClientInner::Http(client) => Ok(client
265                .complete_resource(resource_uri, argument_name, argument_value, context)
266                .await?),
267            #[cfg(feature = "websocket")]
268            ClientInner::WebSocket(client) => Ok(client
269                .complete_resource(resource_uri, argument_name, argument_value, context)
270                .await?),
271        }
272    }
273
274    pub async fn ping(&self) -> CliResult<()> {
275        match &self.inner {
276            #[cfg(feature = "stdio")]
277            ClientInner::Stdio(client) => {
278                client.ping().await?;
279                Ok(())
280            }
281            #[cfg(feature = "tcp")]
282            ClientInner::Tcp(client) => {
283                client.ping().await?;
284                Ok(())
285            }
286            #[cfg(feature = "unix")]
287            ClientInner::Unix(client) => {
288                client.ping().await?;
289                Ok(())
290            }
291            #[cfg(feature = "http")]
292            ClientInner::Http(client) => {
293                client.ping().await?;
294                Ok(())
295            }
296            #[cfg(feature = "websocket")]
297            ClientInner::WebSocket(client) => {
298                client.ping().await?;
299                Ok(())
300            }
301        }
302    }
303
304    pub async fn set_log_level(&self, level: turbomcp_protocol::types::LogLevel) -> CliResult<()> {
305        match &self.inner {
306            #[cfg(feature = "stdio")]
307            ClientInner::Stdio(client) => {
308                client.set_log_level(level).await?;
309                Ok(())
310            }
311            #[cfg(feature = "tcp")]
312            ClientInner::Tcp(client) => {
313                client.set_log_level(level).await?;
314                Ok(())
315            }
316            #[cfg(feature = "unix")]
317            ClientInner::Unix(client) => {
318                client.set_log_level(level).await?;
319                Ok(())
320            }
321            #[cfg(feature = "http")]
322            ClientInner::Http(client) => {
323                client.set_log_level(level).await?;
324                Ok(())
325            }
326            #[cfg(feature = "websocket")]
327            ClientInner::WebSocket(client) => {
328                client.set_log_level(level).await?;
329                Ok(())
330            }
331        }
332    }
333}
334
335/// Create a unified client that hides transport type complexity from the executor
336pub async fn create_client(conn: &Connection) -> CliResult<UnifiedClient> {
337    let transport_kind = determine_transport(conn);
338
339    // The --auth / MCP_AUTH bearer token is only consumed by the HTTP transport.
340    // Warn (without echoing the token) when the user supplies it for a transport
341    // that has no notion of authentication so they don't assume it was sent.
342    if conn.auth.is_some() && !matches!(transport_kind, TransportKind::Http | TransportKind::Ws) {
343        eprintln!(
344            "Warning: --auth is currently only honored by the HTTP transport; ignoring for {:?}.",
345            transport_kind
346        );
347    }
348
349    match transport_kind {
350        #[cfg(feature = "stdio")]
351        TransportKind::Stdio => {
352            let transport = create_stdio_transport(conn)?;
353            Ok(UnifiedClient {
354                inner: ClientInner::Stdio(Client::new(transport)),
355            })
356        }
357        #[cfg(not(feature = "stdio"))]
358        TransportKind::Stdio => {
359            Err(CliError::NotSupported(
360                "STDIO transport is not enabled (missing 'stdio' feature)".to_string(),
361            ))
362        }
363        #[cfg(feature = "http")]
364        TransportKind::Http => {
365            let transport = create_http_transport(conn).await?;
366            Ok(UnifiedClient {
367                inner: ClientInner::Http(Client::new(transport)),
368            })
369        }
370        #[cfg(not(feature = "http"))]
371        TransportKind::Http => {
372            Err(CliError::NotSupported(
373                "HTTP transport is not enabled. Rebuild with --features http or --features all"
374                    .to_string(),
375            ))
376        }
377        #[cfg(feature = "websocket")]
378        TransportKind::Ws => {
379            let transport = create_websocket_transport(conn).await?;
380            Ok(UnifiedClient {
381                inner: ClientInner::WebSocket(Client::new(transport)),
382            })
383        }
384        #[cfg(not(feature = "websocket"))]
385        TransportKind::Ws => {
386            Err(CliError::NotSupported(
387                "WebSocket transport is not enabled. Rebuild with --features websocket or --features all"
388                    .to_string(),
389            ))
390        }
391        #[cfg(feature = "tcp")]
392        TransportKind::Tcp => {
393            let transport = create_tcp_transport(conn).await?;
394            Ok(UnifiedClient {
395                inner: ClientInner::Tcp(Client::new(transport)),
396            })
397        }
398        #[cfg(not(feature = "tcp"))]
399        TransportKind::Tcp => {
400            Err(CliError::NotSupported(
401                "TCP transport is not enabled (missing 'tcp' feature)".to_string(),
402            ))
403        }
404        #[cfg(feature = "unix")]
405        TransportKind::Unix => {
406            let transport = create_unix_transport(conn).await?;
407            Ok(UnifiedClient {
408                inner: ClientInner::Unix(Client::new(transport)),
409            })
410        }
411        #[cfg(not(feature = "unix"))]
412        TransportKind::Unix => {
413            Err(CliError::NotSupported(
414                "Unix socket transport is not enabled (missing 'unix' feature)".to_string(),
415            ))
416        }
417    }
418}
419
420/// Determine transport type from connection config
421pub fn determine_transport(conn: &Connection) -> TransportKind {
422    // Use explicit transport if provided
423    if let Some(transport) = &conn.transport {
424        return transport.clone();
425    }
426
427    // Auto-detect based on URL/command patterns
428    let url = &conn.url;
429
430    if conn.command.is_some() {
431        return TransportKind::Stdio;
432    }
433
434    if url.starts_with("tcp://") {
435        return TransportKind::Tcp;
436    }
437
438    if url.starts_with("unix://") || url.starts_with("/") {
439        return TransportKind::Unix;
440    }
441
442    if url.starts_with("ws://") || url.starts_with("wss://") {
443        return TransportKind::Ws;
444    }
445
446    if url.starts_with("http://") || url.starts_with("https://") {
447        return TransportKind::Http;
448    }
449
450    // Default to STDIO for executable paths
451    TransportKind::Stdio
452}
453
454/// Create STDIO transport from connection
455#[cfg(feature = "stdio")]
456fn create_stdio_transport(conn: &Connection) -> CliResult<ChildProcessTransport> {
457    // Use --command if provided, otherwise use --url
458    let command_str = conn.command.as_deref().unwrap_or(&conn.url);
459
460    // Honor shell quoting/escaping so paths with spaces and `bash -c "..."`
461    // wrappers parse correctly. `split_whitespace` would fragment them.
462    let parts = shell_words::split(command_str)
463        .map_err(|e| CliError::InvalidArguments(format!("Invalid --command quoting: {e}")))?;
464    if parts.is_empty() {
465        return Err(CliError::InvalidArguments(
466            "No command specified for STDIO transport".to_string(),
467        ));
468    }
469
470    let command = parts[0].clone();
471    let args: Vec<String> = parts[1..].to_vec();
472
473    // Create config
474    let config = ChildProcessConfig {
475        command,
476        args,
477        working_directory: None,
478        environment: None,
479        startup_timeout: Duration::from_secs(conn.timeout),
480        shutdown_timeout: Duration::from_secs(5),
481        max_message_size: 10 * 1024 * 1024, // 10MB
482        buffer_size: 8192,                  // 8KB buffer
483        kill_on_drop: true,                 // Kill process when client is dropped
484    };
485
486    // Create transport
487    Ok(ChildProcessTransport::new(config))
488}
489
490/// Create TCP transport from connection
491#[cfg(feature = "tcp")]
492async fn create_tcp_transport(
493    conn: &Connection,
494) -> CliResult<turbomcp_transport::tcp::TcpTransport> {
495    let url = &conn.url;
496
497    // Parse TCP URL
498    let addr_str = url
499        .strip_prefix("tcp://")
500        .ok_or_else(|| CliError::InvalidArguments(format!("Invalid TCP URL: {}", url)))?;
501
502    // Parse into SocketAddr
503    let socket_addr: std::net::SocketAddr = addr_str.parse().map_err(|e| {
504        CliError::InvalidArguments(format!("Invalid address '{}': {}", addr_str, e))
505    })?;
506
507    let transport = TcpTransportBuilder::new().remote_addr(socket_addr).build();
508
509    Ok(transport)
510}
511
512/// Create Unix socket transport from connection
513#[cfg(feature = "unix")]
514async fn create_unix_transport(
515    conn: &Connection,
516) -> CliResult<turbomcp_transport::unix::UnixTransport> {
517    let path = conn.url.strip_prefix("unix://").unwrap_or(&conn.url);
518
519    let transport = UnixTransportBuilder::new_client().socket_path(path).build();
520
521    Ok(transport)
522}
523
524/// Create HTTP transport from connection
525#[cfg(feature = "http")]
526async fn create_http_transport(conn: &Connection) -> CliResult<StreamableHttpClientTransport> {
527    let url = &conn.url;
528
529    // Parse HTTP URL (remove http:// or https://)
530    let base_url = if let Some(stripped) = url.strip_prefix("https://") {
531        format!("https://{}", stripped)
532    } else if let Some(stripped) = url.strip_prefix("http://") {
533        format!("http://{}", stripped)
534    } else {
535        url.clone()
536    };
537
538    let config = StreamableHttpClientConfig {
539        base_url,
540        endpoint_path: "/mcp".to_string(),
541        timeout: Duration::from_secs(conn.timeout),
542        auth_token: conn.auth.clone(),
543        ..Default::default()
544    };
545
546    StreamableHttpClientTransport::new(config).map_err(|e| {
547        crate::CliError::Transport(turbomcp_protocol::Error::transport(format!(
548            "Failed to build HTTP transport: {e}"
549        )))
550    })
551}
552
553/// Create WebSocket transport from connection
554#[cfg(feature = "websocket")]
555async fn create_websocket_transport(
556    conn: &Connection,
557) -> CliResult<WebSocketBidirectionalTransport> {
558    let url = &conn.url;
559
560    // Validate URL is a proper WebSocket URL
561    if !url.starts_with("ws://") && !url.starts_with("wss://") {
562        return Err(CliError::InvalidArguments(format!(
563            "Invalid WebSocket URL: {} (must start with ws:// or wss://)",
564            url
565        )));
566    }
567
568    let config = WebSocketBidirectionalConfig::client(url.clone());
569
570    WebSocketBidirectionalTransport::new(config)
571        .await
572        .map_err(|e| CliError::ConnectionFailed(e.to_string()))
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn test_determine_transport() {
581        // STDIO detection
582        let conn = Connection {
583            transport: None,
584            url: "./my-server".to_string(),
585            command: None,
586            auth: None,
587            timeout: 30,
588        };
589        assert_eq!(determine_transport(&conn), TransportKind::Stdio);
590
591        // Command override
592        let conn = Connection {
593            transport: None,
594            url: "http://localhost".to_string(),
595            command: Some("python server.py".to_string()),
596            auth: None,
597            timeout: 30,
598        };
599        assert_eq!(determine_transport(&conn), TransportKind::Stdio);
600
601        // TCP detection
602        let conn = Connection {
603            transport: None,
604            url: "tcp://localhost:8080".to_string(),
605            command: None,
606            auth: None,
607            timeout: 30,
608        };
609        assert_eq!(determine_transport(&conn), TransportKind::Tcp);
610
611        // Unix detection
612        let conn = Connection {
613            transport: None,
614            url: "/tmp/mcp.sock".to_string(),
615            command: None,
616            auth: None,
617            timeout: 30,
618        };
619        assert_eq!(determine_transport(&conn), TransportKind::Unix);
620
621        // Explicit override
622        let conn = Connection {
623            transport: Some(TransportKind::Tcp),
624            url: "http://localhost".to_string(),
625            command: None,
626            auth: None,
627            timeout: 30,
628        };
629        assert_eq!(determine_transport(&conn), TransportKind::Tcp);
630    }
631}