model_context_protocol/
transport_factory.rs1use std::sync::Arc;
8use std::time::Duration;
9
10use crate::client::http::HttpTransportAdapter;
11use crate::client::stdio::StdioTransportAdapter;
12use crate::transport::{
13 McpServerConnectionConfig, McpTransport, McpTransportError, TransportTypeId,
14};
15
16pub struct TransportFactory;
18
19impl TransportFactory {
20 pub async fn create(
31 config: &McpServerConnectionConfig,
32 ) -> Result<Arc<dyn McpTransport>, McpTransportError> {
33 match config.transport {
34 TransportTypeId::Stdio => Self::create_stdio(config).await,
35 TransportTypeId::Http => Self::create_http(config).await,
36 }
37 }
38
39 async fn create_stdio(
41 config: &McpServerConnectionConfig,
42 ) -> Result<Arc<dyn McpTransport>, McpTransportError> {
43 let command = config.command.as_ref().ok_or_else(|| {
44 McpTransportError::TransportError("Stdio transport requires command".to_string())
45 })?;
46
47 let timeout = Duration::from_secs(config.timeout_secs);
48
49 let transport = StdioTransportAdapter::connect_with_env(
50 command,
51 &config.args,
52 config.env.clone(),
53 Some(config.config.clone()),
54 timeout,
55 )
56 .await?;
57
58 Ok(Arc::new(transport))
59 }
60
61 async fn create_http(
63 config: &McpServerConnectionConfig,
64 ) -> Result<Arc<dyn McpTransport>, McpTransportError> {
65 let url = config.url.as_ref().ok_or_else(|| {
66 McpTransportError::TransportError("HTTP transport requires URL".to_string())
67 })?;
68
69 let timeout = Duration::from_secs(config.timeout_secs);
70 let transport = HttpTransportAdapter::with_timeout(url, timeout)?;
71
72 Ok(Arc::new(transport))
73 }
74
75 pub fn is_supported(transport_type: TransportTypeId) -> bool {
77 matches!(
78 transport_type,
79 TransportTypeId::Stdio | TransportTypeId::Http
80 )
81 }
82
83 pub fn supported_types() -> Vec<TransportTypeId> {
85 vec![TransportTypeId::Stdio, TransportTypeId::Http]
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_is_supported() {
95 assert!(TransportFactory::is_supported(TransportTypeId::Stdio));
96 assert!(TransportFactory::is_supported(TransportTypeId::Http));
97 }
98
99 #[test]
100 fn test_supported_types() {
101 let types = TransportFactory::supported_types();
102 assert!(types.contains(&TransportTypeId::Stdio));
103 assert!(types.contains(&TransportTypeId::Http));
104 }
105}