Skip to main content

albert_runtime/
mcp_client.rs

1use std::collections::BTreeMap;
2
3use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
4use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum McpClientTransport {
8    Stdio(McpStdioTransport),
9    Sse(McpRemoteTransport),
10    Http(McpRemoteTransport),
11    WebSocket(McpRemoteTransport),
12    Sdk(McpSdkTransport),
13    TernlangAiProxy(McpTernlangAiProxyTransport),
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct McpStdioTransport {
18    pub command: String,
19    pub args: Vec<String>,
20    pub env: BTreeMap<String, String>,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct McpRemoteTransport {
25    pub url: String,
26    pub headers: BTreeMap<String, String>,
27    pub headers_helper: Option<String>,
28    pub auth: McpClientAuth,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct McpSdkTransport {
33    pub name: String,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct McpTernlangAiProxyTransport {
38    pub url: String,
39    pub id: String,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum McpClientAuth {
44    None,
45    OAuth(McpOAuthConfig),
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct McpClientBootstrap {
50    pub server_name: String,
51    pub normalized_name: String,
52    pub tool_prefix: String,
53    pub signature: Option<String>,
54    pub transport: McpClientTransport,
55}
56
57impl McpClientBootstrap {
58    #[must_use]
59    pub fn from_scoped_config(server_name: &str, config: &ScopedMcpServerConfig) -> Self {
60        Self {
61            server_name: server_name.to_string(),
62            normalized_name: normalize_name_for_mcp(server_name),
63            tool_prefix: mcp_tool_prefix(server_name),
64            signature: mcp_server_signature(&config.config),
65            transport: McpClientTransport::from_config(&config.config),
66        }
67    }
68}
69
70impl McpClientTransport {
71    #[must_use]
72    pub fn from_config(config: &McpServerConfig) -> Self {
73        match config {
74            McpServerConfig::Stdio(config) => Self::Stdio(McpStdioTransport {
75                command: config.command.clone(),
76                args: config.args.clone(),
77                env: config.env.clone(),
78            }),
79            McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
80                url: config.url.clone(),
81                headers: config.headers.clone(),
82                headers_helper: config.headers_helper.clone(),
83                auth: McpClientAuth::from_oauth(config.oauth.clone()),
84            }),
85            McpServerConfig::Http(config) => Self::Http(McpRemoteTransport {
86                url: config.url.clone(),
87                headers: config.headers.clone(),
88                headers_helper: config.headers_helper.clone(),
89                auth: McpClientAuth::from_oauth(config.oauth.clone()),
90            }),
91            McpServerConfig::Ws(config) => Self::WebSocket(McpRemoteTransport {
92                url: config.url.clone(),
93                headers: config.headers.clone(),
94                headers_helper: config.headers_helper.clone(),
95                auth: McpClientAuth::None,
96            }),
97            McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport {
98                name: config.name.clone(),
99            }),
100            McpServerConfig::TernlangAiProxy(config) => {
101                Self::TernlangAiProxy(McpTernlangAiProxyTransport {
102                    url: config.url.clone(),
103                    id: config.id.clone(),
104                })
105            }
106        }
107    }
108}
109
110impl McpClientAuth {
111    #[must_use]
112    pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
113        oauth.map_or(Self::None, Self::OAuth)
114    }
115
116    #[must_use]
117    pub const fn requires_user_auth(&self) -> bool {
118        matches!(self, Self::OAuth(_))
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use std::collections::BTreeMap;
125
126    use crate::config::{
127        ConfigSource, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
128        McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
129    };
130
131    use super::{McpClientAuth, McpClientBootstrap, McpClientTransport};
132
133    #[test]
134    fn bootstraps_stdio_servers_into_transport_targets() {
135        let config = ScopedMcpServerConfig {
136            scope: ConfigSource::User,
137            config: McpServerConfig::Stdio(McpStdioServerConfig {
138                command: "uvx".to_string(),
139                args: vec!["mcp-server".to_string()],
140                env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
141            }),
142        };
143
144        let bootstrap = McpClientBootstrap::from_scoped_config("stdio-server", &config);
145        assert_eq!(bootstrap.normalized_name, "stdio-server");
146        assert_eq!(bootstrap.tool_prefix, "mcp__stdio-server__");
147        assert_eq!(
148            bootstrap.signature.as_deref(),
149            Some("stdio:[uvx|mcp-server]")
150        );
151        match bootstrap.transport {
152            McpClientTransport::Stdio(transport) => {
153                assert_eq!(transport.command, "uvx");
154                assert_eq!(transport.args, vec!["mcp-server"]);
155                assert_eq!(
156                    transport.env.get("TOKEN").map(String::as_str),
157                    Some("secret")
158                );
159            }
160            other => panic!("expected stdio transport, got {other:?}"),
161        }
162    }
163
164    #[test]
165    fn bootstraps_remote_servers_with_oauth_auth() {
166        let config = ScopedMcpServerConfig {
167            scope: ConfigSource::Project,
168            config: McpServerConfig::Http(McpRemoteServerConfig {
169                url: "https://vendor.example/mcp".to_string(),
170                headers: BTreeMap::from([("X-Test".to_string(), "1".to_string())]),
171                headers_helper: Some("helper.sh".to_string()),
172                oauth: Some(McpOAuthConfig {
173                    client_id: Some("client-id".to_string()),
174                    callback_port: Some(7777),
175                    auth_server_metadata_url: Some(
176                        "https://issuer.example/.well-known/oauth-authorization-server".to_string(),
177                    ),
178                    xaa: Some(true),
179                }),
180            }),
181        };
182
183        let bootstrap = McpClientBootstrap::from_scoped_config("remote server", &config);
184        assert_eq!(bootstrap.normalized_name, "remote_server");
185        match bootstrap.transport {
186            McpClientTransport::Http(transport) => {
187                assert_eq!(transport.url, "https://vendor.example/mcp");
188                assert_eq!(transport.headers_helper.as_deref(), Some("helper.sh"));
189                assert!(transport.auth.requires_user_auth());
190                match transport.auth {
191                    McpClientAuth::OAuth(oauth) => {
192                        assert_eq!(oauth.client_id.as_deref(), Some("client-id"));
193                    }
194                    other @ McpClientAuth::None => panic!("expected oauth auth, got {other:?}"),
195                }
196            }
197            other => panic!("expected http transport, got {other:?}"),
198        }
199    }
200
201    #[test]
202    fn bootstraps_websocket_and_sdk_transports_without_oauth() {
203        let ws = ScopedMcpServerConfig {
204            scope: ConfigSource::Local,
205            config: McpServerConfig::Ws(McpWebSocketServerConfig {
206                url: "wss://vendor.example/mcp".to_string(),
207                headers: BTreeMap::new(),
208                headers_helper: None,
209            }),
210        };
211        let sdk = ScopedMcpServerConfig {
212            scope: ConfigSource::Local,
213            config: McpServerConfig::Sdk(McpSdkServerConfig {
214                name: "sdk-server".to_string(),
215            }),
216        };
217
218        let ws_bootstrap = McpClientBootstrap::from_scoped_config("ws server", &ws);
219        match ws_bootstrap.transport {
220            McpClientTransport::WebSocket(transport) => {
221                assert_eq!(transport.url, "wss://vendor.example/mcp");
222                assert!(!transport.auth.requires_user_auth());
223            }
224            other => panic!("expected websocket transport, got {other:?}"),
225        }
226
227        let sdk_bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &sdk);
228        assert_eq!(sdk_bootstrap.signature, None);
229        match sdk_bootstrap.transport {
230            McpClientTransport::Sdk(transport) => {
231                assert_eq!(transport.name, "sdk-server");
232            }
233            other => panic!("expected sdk transport, got {other:?}"),
234        }
235    }
236}