Skip to main content

mcp_utils/client/
config.rs

1use futures::future::BoxFuture;
2use rmcp::{RoleServer, service::DynService, transport::streamable_http_client::StreamableHttpClientTransportConfig};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::{BTreeMap, HashMap};
7use std::fmt::{Debug, Formatter};
8use std::path::Path;
9use utils::is_false;
10use utils::variables::{VarError, Vars};
11
12#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
13pub struct McpConfig {
14    #[serde(alias = "mcpServers")]
15    pub servers: BTreeMap<String, McpServerConfig>,
16}
17
18#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
19#[serde(untagged)]
20pub enum McpServerConfig {
21    Stdio(StdioServerConfig),
22    Http(HttpServerConfig),
23    Sse(SseServerConfig),
24    InMemory(InMemoryServerConfig),
25}
26
27#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
28#[serde(deny_unknown_fields)]
29pub struct StdioServerConfig {
30    #[serde(rename = "type", default)]
31    pub type_: StdioType,
32
33    pub command: String,
34
35    #[serde(default)]
36    pub args: Vec<String>,
37
38    #[serde(default)]
39    pub env: HashMap<String, String>,
40
41    #[serde(default, skip_serializing_if = "is_false")]
42    pub proxy: bool,
43}
44
45#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
46#[serde(deny_unknown_fields)]
47pub struct HttpServerConfig {
48    #[serde(rename = "type")]
49    pub type_: HttpType,
50
51    pub url: String,
52
53    #[serde(default)]
54    pub headers: HashMap<String, String>,
55
56    #[serde(default, skip_serializing_if = "is_false")]
57    pub proxy: bool,
58}
59
60#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
61#[serde(deny_unknown_fields)]
62pub struct SseServerConfig {
63    #[serde(rename = "type")]
64    pub type_: SseType,
65
66    pub url: String,
67
68    #[serde(default)]
69    pub headers: HashMap<String, String>,
70
71    #[serde(default, skip_serializing_if = "is_false")]
72    pub proxy: bool,
73}
74
75#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
76#[serde(deny_unknown_fields)]
77pub struct InMemoryServerConfig {
78    #[serde(rename = "type")]
79    pub type_: InMemoryType,
80
81    #[serde(default)]
82    pub args: Vec<String>,
83
84    #[serde(default)]
85    pub input: Option<Value>,
86
87    #[serde(default, skip_serializing_if = "is_false")]
88    pub proxy: bool,
89}
90
91#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, JsonSchema, PartialEq)]
92pub enum StdioType {
93    #[default]
94    #[serde(rename = "stdio")]
95    Stdio,
96}
97
98#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
99pub enum HttpType {
100    #[serde(rename = "http")]
101    Http,
102}
103
104#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
105pub enum SseType {
106    #[serde(rename = "sse")]
107    Sse,
108}
109
110#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
111pub enum InMemoryType {
112    #[serde(rename = "in-memory")]
113    InMemory,
114}
115
116pub struct McpServer {
117    pub name: String,
118    pub transport: McpTransport,
119    pub proxy: bool,
120}
121
122pub enum McpTransport {
123    Stdio { command: String, args: Vec<String>, env: HashMap<String, String> },
124    Http { config: StreamableHttpClientTransportConfig },
125    InMemory { server: Box<dyn DynService<RoleServer>> },
126}
127
128impl McpServer {
129    pub fn new(name: impl Into<String>, transport: McpTransport, proxy: bool) -> Self {
130        Self { name: name.into(), transport, proxy }
131    }
132
133    /// Clone this server config. Fails for [`McpTransport::InMemory`], whose
134    /// boxed service cannot be duplicated and so cannot be shared across
135    /// independently-spawned MCP managers.
136    pub fn try_clone(&self) -> Result<Self, McpServerCloneError> {
137        let transport = match &self.transport {
138            McpTransport::Stdio { command, args, env } => {
139                McpTransport::Stdio { command: command.clone(), args: args.clone(), env: env.clone() }
140            }
141            McpTransport::Http { config } => McpTransport::Http { config: config.clone() },
142            McpTransport::InMemory { .. } => return Err(McpServerCloneError(self.name.clone())),
143        };
144        Ok(Self { name: self.name.clone(), transport, proxy: self.proxy })
145    }
146}
147
148#[derive(Debug, thiserror::Error)]
149#[error("in-memory MCP server `{0}` cannot be cloned across runtimes")]
150pub struct McpServerCloneError(pub String);
151
152impl Debug for McpServer {
153    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("McpServer")
155            .field("name", &self.name)
156            .field("transport", &self.transport)
157            .field("proxy", &self.proxy)
158            .finish()
159    }
160}
161
162impl Debug for McpTransport {
163    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
164        match self {
165            McpTransport::Stdio { command, args, env } => {
166                f.debug_struct("Stdio").field("command", command).field("args", args).field("env", env).finish()
167            }
168            McpTransport::Http { config } => f.debug_struct("Http").field("config", config).finish(),
169            McpTransport::InMemory { .. } => f.debug_struct("InMemory").field("server", &"<DynService>").finish(),
170        }
171    }
172}
173
174pub type ServerFactory =
175    Box<dyn Fn(Vec<String>, Option<Value>) -> BoxFuture<'static, Box<dyn DynService<RoleServer>>> + Send + Sync>;
176
177#[derive(Debug, thiserror::Error)]
178pub enum ParseError {
179    #[error("Failed to read config file: {0}")]
180    IoError(#[from] std::io::Error),
181
182    #[error("Invalid JSON: {0}")]
183    JsonError(#[from] serde_json::Error),
184
185    #[error("Variable expansion failed: {0}")]
186    VarError(#[from] VarError),
187
188    #[error("InMemory server factory '{0}' not registered")]
189    FactoryNotFound(String),
190
191    #[error("Invalid nested config in tool-proxy: {0}")]
192    InvalidNestedConfig(String),
193}
194
195impl McpConfig {
196    pub fn new(servers: BTreeMap<String, McpServerConfig>) -> Self {
197        Self { servers }
198    }
199
200    pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self, ParseError> {
201        let content = std::fs::read_to_string(path)?;
202        Self::from_json(&content)
203    }
204
205    pub fn from_json_files<T: AsRef<Path>>(paths: &[T]) -> Result<Self, ParseError> {
206        let mut merged = BTreeMap::new();
207        for path in paths {
208            let raw = Self::from_json_file(path)?;
209            merged.extend(raw.servers);
210        }
211        Ok(Self::new(merged))
212    }
213
214    pub fn from_json(json: &str) -> Result<Self, ParseError> {
215        Ok(serde_json::from_str(json)?)
216    }
217
218    pub async fn into_servers(
219        self,
220        factories: &HashMap<String, ServerFactory>,
221        vars: &Vars,
222    ) -> Result<Vec<McpServer>, ParseError> {
223        self.into_servers_with_proxy(factories, vars, false).await
224    }
225
226    pub async fn into_servers_with_proxy(
227        self,
228        factories: &HashMap<String, ServerFactory>,
229        vars: &Vars,
230        force_proxy: bool,
231    ) -> Result<Vec<McpServer>, ParseError> {
232        let mut servers = Vec::with_capacity(self.servers.len());
233        for (name, config) in self.servers {
234            servers.push(config.into_server(name, factories, vars, force_proxy).await?);
235        }
236        Ok(servers)
237    }
238
239    pub fn mark_all_proxy(&mut self) {
240        for server in self.servers.values_mut() {
241            server.set_proxy(true);
242        }
243    }
244}
245
246impl McpServerConfig {
247    pub fn proxy(&self) -> bool {
248        match self {
249            McpServerConfig::Stdio(config) => config.proxy,
250            McpServerConfig::Http(config) => config.proxy,
251            McpServerConfig::Sse(config) => config.proxy,
252            McpServerConfig::InMemory(config) => config.proxy,
253        }
254    }
255
256    pub fn set_proxy(&mut self, value: bool) {
257        match self {
258            McpServerConfig::Stdio(config) => config.proxy = value,
259            McpServerConfig::Http(config) => config.proxy = value,
260            McpServerConfig::Sse(config) => config.proxy = value,
261            McpServerConfig::InMemory(config) => config.proxy = value,
262        }
263    }
264
265    pub async fn into_server(
266        self,
267        name: String,
268        factories: &HashMap<String, ServerFactory>,
269        vars: &Vars,
270        force_proxy: bool,
271    ) -> Result<McpServer, ParseError> {
272        let proxy = force_proxy || self.proxy();
273        let transport = self.into_transport(name.clone(), factories, vars).await?;
274        Ok(McpServer::new(name, transport, proxy))
275    }
276
277    async fn into_transport(
278        self,
279        name: String,
280        factories: &HashMap<String, ServerFactory>,
281        vars: &Vars,
282    ) -> Result<McpTransport, ParseError> {
283        match self {
284            McpServerConfig::Stdio(StdioServerConfig { command, args, env, .. }) => Ok(McpTransport::Stdio {
285                command: vars.expand(&command)?,
286                args: args.into_iter().map(|a| vars.expand(&a)).collect::<Result<Vec<_>, _>>()?,
287                env: env
288                    .into_iter()
289                    .map(|(k, v)| Ok((k, vars.expand(&v)?)))
290                    .collect::<Result<HashMap<_, _>, VarError>>()?,
291            }),
292
293            McpServerConfig::Http(HttpServerConfig { url, headers, .. })
294            | McpServerConfig::Sse(SseServerConfig { url, headers, .. }) => {
295                let auth_header = headers.get("Authorization").map(|v| vars.expand(v)).transpose()?;
296                let mut config = StreamableHttpClientTransportConfig::with_uri(vars.expand(&url)?);
297                if let Some(auth) = auth_header {
298                    config = config.auth_header(auth);
299                }
300                Ok(McpTransport::Http { config })
301            }
302
303            McpServerConfig::InMemory(InMemoryServerConfig { args, input, .. }) => {
304                let server_factory = factories.get(&name).ok_or_else(|| ParseError::FactoryNotFound(name.clone()))?;
305                let expanded_args = args.into_iter().map(|a| vars.expand(&a)).collect::<Result<Vec<_>, VarError>>()?;
306                let server = server_factory(expanded_args, input).await;
307                Ok(McpTransport::InMemory { server })
308            }
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use std::fs;
317    use tempfile::tempdir;
318
319    fn write_config(dir: &Path, name: &str, json: &str) -> std::path::PathBuf {
320        let path = dir.join(name);
321        fs::write(&path, json).unwrap();
322        path
323    }
324
325    fn stdio_config(command: &str) -> String {
326        format!(r#"{{"servers": {{"coding": {{"type": "stdio", "command": "{command}"}}}}}}"#)
327    }
328
329    #[test]
330    fn from_json_accepts_mcp_servers_key() {
331        let config = McpConfig::from_json(r#"{"mcpServers": {"alpha": {"type": "stdio", "command": "a"}}}"#).unwrap();
332        assert_eq!(config.servers.len(), 1);
333        assert!(config.servers.contains_key("alpha"));
334    }
335
336    #[test]
337    fn from_json_defaults_missing_type_to_stdio() {
338        let config = McpConfig::from_json(
339            r#"{"mcpServers": {"devtools": {"command": "npx", "args": ["-y", "chrome-devtools-mcp"]}}}"#,
340        )
341        .unwrap();
342        match config.servers.get("devtools").unwrap() {
343            McpServerConfig::Stdio(StdioServerConfig { command, args, proxy, .. }) => {
344                assert_eq!(command, "npx");
345                assert_eq!(args, &["-y", "chrome-devtools-mcp"]);
346                assert!(!proxy);
347            }
348            other => panic!("expected Stdio server, got {other:?}"),
349        }
350    }
351
352    #[test]
353    fn from_json_accepts_server_proxy_true() {
354        let config =
355            McpConfig::from_json(r#"{"servers": {"playwright": {"type": "stdio", "command": "npx", "proxy": true}}}"#)
356                .unwrap();
357        assert!(config.servers.get("playwright").unwrap().proxy());
358    }
359
360    #[test]
361    fn from_json_rejects_proxy_server_type() {
362        let result = McpConfig::from_json(r#"{"servers":{"tools":{"type":"proxy","servers":{}}}}"#);
363        assert!(result.is_err());
364    }
365
366    #[test]
367    fn false_proxy_omits_during_serialization() {
368        let config =
369            McpConfig::from_json(r#"{"servers": {"coding": {"type": "stdio", "command": "a", "proxy": false}}}"#)
370                .unwrap();
371        let serialized = serde_json::to_string(&config).unwrap();
372        assert!(!serialized.contains("proxy"));
373    }
374
375    #[test]
376    fn true_proxy_serializes() {
377        let config =
378            McpConfig::from_json(r#"{"servers": {"coding": {"type": "stdio", "command": "a", "proxy": true}}}"#)
379                .unwrap();
380        let serialized = serde_json::to_string(&config).unwrap();
381        assert!(serialized.contains("proxy"));
382    }
383
384    #[test]
385    fn from_json_rejects_unknown_type() {
386        let result = McpConfig::from_json(r#"{"servers": {"bad": {"type": "htp", "url": "https://example.com"}}}"#);
387        assert!(result.is_err());
388    }
389
390    #[test]
391    fn from_json_files_empty_returns_empty_servers() {
392        let result = McpConfig::from_json_files::<&str>(&[]).unwrap();
393        assert!(result.servers.is_empty());
394    }
395
396    #[test]
397    fn from_json_files_single_file_matches_from_json_file() {
398        let dir = tempdir().unwrap();
399        let path = write_config(dir.path(), "a.json", &stdio_config("ls"));
400
401        let single = McpConfig::from_json_file(&path).unwrap();
402        let multi = McpConfig::from_json_files(&[&path]).unwrap();
403
404        assert_eq!(single.servers.len(), multi.servers.len());
405        assert!(multi.servers.contains_key("coding"));
406    }
407
408    #[test]
409    fn from_json_files_merges_disjoint_servers() {
410        let dir = tempdir().unwrap();
411        let a = write_config(dir.path(), "a.json", r#"{"servers": {"alpha": {"type": "stdio", "command": "a"}}}"#);
412        let b = write_config(dir.path(), "b.json", r#"{"servers": {"beta": {"type": "stdio", "command": "b"}}}"#);
413
414        let merged = McpConfig::from_json_files(&[a, b]).unwrap();
415        assert_eq!(merged.servers.len(), 2);
416        assert!(merged.servers.contains_key("alpha"));
417        assert!(merged.servers.contains_key("beta"));
418    }
419
420    #[test]
421    fn from_json_files_last_file_wins_on_collision_including_proxy() {
422        let dir = tempdir().unwrap();
423        let a = write_config(
424            dir.path(),
425            "a.json",
426            r#"{"servers":{"coding":{"type":"stdio","command":"from_a","proxy":true}}}"#,
427        );
428        let b = write_config(dir.path(), "b.json", r#"{"servers":{"coding":{"type":"stdio","command":"from_b"}}}"#);
429
430        let merged_ab = McpConfig::from_json_files(&[&a, &b]).unwrap();
431        match merged_ab.servers.get("coding").unwrap() {
432            McpServerConfig::Stdio(StdioServerConfig { command, proxy, .. }) => {
433                assert_eq!(command, "from_b");
434                assert!(!proxy);
435            }
436            other => panic!("expected Stdio, got {other:?}"),
437        }
438
439        let merged_ba = McpConfig::from_json_files(&[&b, &a]).unwrap();
440        match merged_ba.servers.get("coding").unwrap() {
441            McpServerConfig::Stdio(StdioServerConfig { command, proxy, .. }) => {
442                assert_eq!(command, "from_a");
443                assert!(*proxy);
444            }
445            other => panic!("expected Stdio, got {other:?}"),
446        }
447    }
448
449    #[test]
450    fn mark_all_proxy_sets_every_server() {
451        let mut config = McpConfig::from_json(
452            r#"{"servers":{"a":{"type":"stdio","command":"a"},"b":{"type":"http","url":"https://example.com"}}}"#,
453        )
454        .unwrap();
455        config.mark_all_proxy();
456        assert!(config.servers.values().all(McpServerConfig::proxy));
457    }
458
459    #[test]
460    fn from_json_files_propagates_io_error_on_missing_file() {
461        let dir = tempdir().unwrap();
462        let missing = dir.path().join("does-not-exist.json");
463        let result = McpConfig::from_json_files(&[missing]);
464        assert!(matches!(result, Err(ParseError::IoError(_))));
465    }
466
467    #[test]
468    fn from_json_files_propagates_json_error_on_invalid_file() {
469        let dir = tempdir().unwrap();
470        let bad = write_config(dir.path(), "bad.json", "not valid json");
471        let result = McpConfig::from_json_files(&[bad]);
472        assert!(matches!(result, Err(ParseError::JsonError(_))));
473    }
474
475    #[tokio::test]
476    async fn into_servers_preserves_proxy_flags() {
477        let json = r#"{
478            "servers": {
479                "github": {"type": "stdio", "command": "g"},
480                "playwright": {"type": "stdio", "command": "p", "proxy": true}
481            }
482        }"#;
483        let config = McpConfig::from_json(json).unwrap();
484        let servers = config.into_servers(&HashMap::new(), &Vars::new()).await.unwrap();
485
486        assert_eq!(servers.len(), 2);
487        assert!(!servers.iter().find(|s| s.name == "github").unwrap().proxy);
488        assert!(servers.iter().find(|s| s.name == "playwright").unwrap().proxy);
489    }
490
491    #[tokio::test]
492    async fn into_servers_with_proxy_forces_proxy_flags() {
493        let config =
494            McpConfig::from_json(r#"{"servers":{"github":{"type":"stdio","command":"g","proxy":false}}}"#).unwrap();
495        let servers = config.into_servers_with_proxy(&HashMap::new(), &Vars::new(), true).await.unwrap();
496        assert!(servers[0].proxy);
497    }
498
499    #[tokio::test]
500    async fn into_transport_expands_workspace_var_in_stdio_args() {
501        let config = McpConfig::from_json(
502            r#"{"servers":{"coding":{"type":"stdio","command":"server","args":["--root","${WORKSPACE}/src"]}}}"#,
503        )
504        .unwrap();
505        let vars = Vars::new().with("WORKSPACE", "/workspace");
506        let servers = config.into_servers(&HashMap::new(), &vars).await.unwrap();
507
508        match &servers[0].transport {
509            McpTransport::Stdio { args, .. } => {
510                assert_eq!(args, &["--root", "/workspace/src"]);
511            }
512            other => panic!("expected Stdio transport, got {other:?}"),
513        }
514    }
515}