Skip to main content

mcp_utils/client/
config.rs

1use super::variables::{VarError, expand_env_vars};
2use futures::future::BoxFuture;
3use rmcp::{RoleServer, service::DynService, transport::streamable_http_client::StreamableHttpClientTransportConfig};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{BTreeMap, HashMap};
8use std::fmt::{Debug, Formatter};
9use std::path::Path;
10
11#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
12pub struct McpConfig {
13    #[serde(alias = "mcpServers")]
14    pub servers: BTreeMap<String, McpServerConfig>,
15}
16
17#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
18#[serde(untagged)]
19pub enum McpServerConfig {
20    Stdio(StdioServerConfig),
21    Http(HttpServerConfig),
22    Sse(SseServerConfig),
23    InMemory(InMemoryServerConfig),
24}
25
26#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
27#[serde(deny_unknown_fields)]
28pub struct StdioServerConfig {
29    #[serde(rename = "type", default)]
30    pub type_: StdioType,
31
32    pub command: String,
33
34    #[serde(default)]
35    pub args: Vec<String>,
36
37    #[serde(default)]
38    pub env: HashMap<String, String>,
39
40    #[serde(default, skip_serializing_if = "is_false")]
41    pub proxy: bool,
42}
43
44#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
45#[serde(deny_unknown_fields)]
46pub struct HttpServerConfig {
47    #[serde(rename = "type")]
48    pub type_: HttpType,
49
50    pub url: String,
51
52    #[serde(default)]
53    pub headers: HashMap<String, String>,
54
55    #[serde(default, skip_serializing_if = "is_false")]
56    pub proxy: bool,
57}
58
59#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
60#[serde(deny_unknown_fields)]
61pub struct SseServerConfig {
62    #[serde(rename = "type")]
63    pub type_: SseType,
64
65    pub url: String,
66
67    #[serde(default)]
68    pub headers: HashMap<String, String>,
69
70    #[serde(default, skip_serializing_if = "is_false")]
71    pub proxy: bool,
72}
73
74#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
75#[serde(deny_unknown_fields)]
76pub struct InMemoryServerConfig {
77    #[serde(rename = "type")]
78    pub type_: InMemoryType,
79
80    #[serde(default)]
81    pub args: Vec<String>,
82
83    #[serde(default)]
84    pub input: Option<Value>,
85
86    #[serde(default, skip_serializing_if = "is_false")]
87    pub proxy: bool,
88}
89
90#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, JsonSchema, PartialEq)]
91pub enum StdioType {
92    #[default]
93    #[serde(rename = "stdio")]
94    Stdio,
95}
96
97#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
98pub enum HttpType {
99    #[serde(rename = "http")]
100    Http,
101}
102
103#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
104pub enum SseType {
105    #[serde(rename = "sse")]
106    Sse,
107}
108
109#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
110pub enum InMemoryType {
111    #[serde(rename = "in-memory")]
112    InMemory,
113}
114
115pub struct McpServer {
116    pub name: String,
117    pub transport: McpTransport,
118    pub proxy: bool,
119}
120
121pub enum McpTransport {
122    Stdio { command: String, args: Vec<String>, env: HashMap<String, String> },
123    Http { config: StreamableHttpClientTransportConfig },
124    InMemory { server: Box<dyn DynService<RoleServer>> },
125}
126
127impl McpServer {
128    pub fn new(name: impl Into<String>, transport: McpTransport, proxy: bool) -> Self {
129        Self { name: name.into(), transport, proxy }
130    }
131}
132
133impl Debug for McpServer {
134    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
135        f.debug_struct("McpServer")
136            .field("name", &self.name)
137            .field("transport", &self.transport)
138            .field("proxy", &self.proxy)
139            .finish()
140    }
141}
142
143impl Debug for McpTransport {
144    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
145        match self {
146            McpTransport::Stdio { command, args, env } => {
147                f.debug_struct("Stdio").field("command", command).field("args", args).field("env", env).finish()
148            }
149            McpTransport::Http { config } => f.debug_struct("Http").field("config", config).finish(),
150            McpTransport::InMemory { .. } => f.debug_struct("InMemory").field("server", &"<DynService>").finish(),
151        }
152    }
153}
154
155pub type ServerFactory =
156    Box<dyn Fn(Vec<String>, Option<Value>) -> BoxFuture<'static, Box<dyn DynService<RoleServer>>> + Send + Sync>;
157
158#[derive(Debug, thiserror::Error)]
159pub enum ParseError {
160    #[error("Failed to read config file: {0}")]
161    IoError(#[from] std::io::Error),
162
163    #[error("Invalid JSON: {0}")]
164    JsonError(#[from] serde_json::Error),
165
166    #[error("Variable expansion failed: {0}")]
167    VarError(#[from] VarError),
168
169    #[error("InMemory server factory '{0}' not registered")]
170    FactoryNotFound(String),
171
172    #[error("Invalid nested config in tool-proxy: {0}")]
173    InvalidNestedConfig(String),
174}
175
176impl McpConfig {
177    pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self, ParseError> {
178        let content = std::fs::read_to_string(path)?;
179        Self::from_json(&content)
180    }
181
182    pub fn from_json_files<T: AsRef<Path>>(paths: &[T]) -> Result<Self, ParseError> {
183        let mut merged = BTreeMap::new();
184        for path in paths {
185            let raw = Self::from_json_file(path)?;
186            merged.extend(raw.servers);
187        }
188        Ok(Self { servers: merged })
189    }
190
191    pub fn from_json(json: &str) -> Result<Self, ParseError> {
192        Ok(serde_json::from_str(json)?)
193    }
194
195    pub async fn into_servers(self, factories: &HashMap<String, ServerFactory>) -> Result<Vec<McpServer>, ParseError> {
196        self.into_servers_with_proxy(factories, false).await
197    }
198
199    pub async fn into_servers_with_proxy(
200        self,
201        factories: &HashMap<String, ServerFactory>,
202        force_proxy: bool,
203    ) -> Result<Vec<McpServer>, ParseError> {
204        let mut servers = Vec::with_capacity(self.servers.len());
205        for (name, config) in self.servers {
206            servers.push(config.into_server(name, factories, force_proxy).await?);
207        }
208        Ok(servers)
209    }
210
211    pub fn mark_all_proxy(&mut self) {
212        for server in self.servers.values_mut() {
213            server.set_proxy(true);
214        }
215    }
216}
217
218impl McpServerConfig {
219    pub fn proxy(&self) -> bool {
220        match self {
221            McpServerConfig::Stdio(config) => config.proxy,
222            McpServerConfig::Http(config) => config.proxy,
223            McpServerConfig::Sse(config) => config.proxy,
224            McpServerConfig::InMemory(config) => config.proxy,
225        }
226    }
227
228    pub fn set_proxy(&mut self, value: bool) {
229        match self {
230            McpServerConfig::Stdio(config) => config.proxy = value,
231            McpServerConfig::Http(config) => config.proxy = value,
232            McpServerConfig::Sse(config) => config.proxy = value,
233            McpServerConfig::InMemory(config) => config.proxy = value,
234        }
235    }
236
237    pub async fn into_server(
238        self,
239        name: String,
240        factories: &HashMap<String, ServerFactory>,
241        force_proxy: bool,
242    ) -> Result<McpServer, ParseError> {
243        let proxy = force_proxy || self.proxy();
244        let transport = self.into_transport(name.clone(), factories).await?;
245        Ok(McpServer::new(name, transport, proxy))
246    }
247
248    async fn into_transport(
249        self,
250        name: String,
251        factories: &HashMap<String, ServerFactory>,
252    ) -> Result<McpTransport, ParseError> {
253        match self {
254            McpServerConfig::Stdio(StdioServerConfig { command, args, env, .. }) => Ok(McpTransport::Stdio {
255                command: expand_env_vars(&command)?,
256                args: args.into_iter().map(|a| expand_env_vars(&a)).collect::<Result<Vec<_>, _>>()?,
257                env: env
258                    .into_iter()
259                    .map(|(k, v)| Ok((k, expand_env_vars(&v)?)))
260                    .collect::<Result<HashMap<_, _>, VarError>>()?,
261            }),
262
263            McpServerConfig::Http(HttpServerConfig { url, headers, .. })
264            | McpServerConfig::Sse(SseServerConfig { url, headers, .. }) => {
265                let auth_header = headers.get("Authorization").map(|v| expand_env_vars(v)).transpose()?;
266                let mut config = StreamableHttpClientTransportConfig::with_uri(expand_env_vars(&url)?);
267                if let Some(auth) = auth_header {
268                    config = config.auth_header(auth);
269                }
270                Ok(McpTransport::Http { config })
271            }
272
273            McpServerConfig::InMemory(InMemoryServerConfig { args, input, .. }) => {
274                let server_factory = factories.get(&name).ok_or_else(|| ParseError::FactoryNotFound(name.clone()))?;
275                let expanded_args =
276                    args.into_iter().map(|a| expand_env_vars(&a)).collect::<Result<Vec<_>, VarError>>()?;
277                let server = server_factory(expanded_args, input).await;
278                Ok(McpTransport::InMemory { server })
279            }
280        }
281    }
282}
283
284#[allow(clippy::trivially_copy_pass_by_ref)]
285fn is_false(value: &bool) -> bool {
286    !*value
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use std::fs;
293    use tempfile::tempdir;
294
295    fn write_config(dir: &Path, name: &str, json: &str) -> std::path::PathBuf {
296        let path = dir.join(name);
297        fs::write(&path, json).unwrap();
298        path
299    }
300
301    fn stdio_config(command: &str) -> String {
302        format!(r#"{{"servers": {{"coding": {{"type": "stdio", "command": "{command}"}}}}}}"#)
303    }
304
305    #[test]
306    fn from_json_accepts_mcp_servers_key() {
307        let config = McpConfig::from_json(r#"{"mcpServers": {"alpha": {"type": "stdio", "command": "a"}}}"#).unwrap();
308        assert_eq!(config.servers.len(), 1);
309        assert!(config.servers.contains_key("alpha"));
310    }
311
312    #[test]
313    fn from_json_defaults_missing_type_to_stdio() {
314        let config = McpConfig::from_json(
315            r#"{"mcpServers": {"devtools": {"command": "npx", "args": ["-y", "chrome-devtools-mcp"]}}}"#,
316        )
317        .unwrap();
318        match config.servers.get("devtools").unwrap() {
319            McpServerConfig::Stdio(StdioServerConfig { command, args, proxy, .. }) => {
320                assert_eq!(command, "npx");
321                assert_eq!(args, &["-y", "chrome-devtools-mcp"]);
322                assert!(!proxy);
323            }
324            other => panic!("expected Stdio server, got {other:?}"),
325        }
326    }
327
328    #[test]
329    fn from_json_accepts_server_proxy_true() {
330        let config =
331            McpConfig::from_json(r#"{"servers": {"playwright": {"type": "stdio", "command": "npx", "proxy": true}}}"#)
332                .unwrap();
333        assert!(config.servers.get("playwright").unwrap().proxy());
334    }
335
336    #[test]
337    fn from_json_rejects_proxy_server_type() {
338        let result = McpConfig::from_json(r#"{"servers":{"tools":{"type":"proxy","servers":{}}}}"#);
339        assert!(result.is_err());
340    }
341
342    #[test]
343    fn false_proxy_omits_during_serialization() {
344        let config =
345            McpConfig::from_json(r#"{"servers": {"coding": {"type": "stdio", "command": "a", "proxy": false}}}"#)
346                .unwrap();
347        let serialized = serde_json::to_string(&config).unwrap();
348        assert!(!serialized.contains("proxy"));
349    }
350
351    #[test]
352    fn true_proxy_serializes() {
353        let config =
354            McpConfig::from_json(r#"{"servers": {"coding": {"type": "stdio", "command": "a", "proxy": true}}}"#)
355                .unwrap();
356        let serialized = serde_json::to_string(&config).unwrap();
357        assert!(serialized.contains("proxy"));
358    }
359
360    #[test]
361    fn from_json_rejects_unknown_type() {
362        let result = McpConfig::from_json(r#"{"servers": {"bad": {"type": "htp", "url": "https://example.com"}}}"#);
363        assert!(result.is_err());
364    }
365
366    #[test]
367    fn from_json_files_empty_returns_empty_servers() {
368        let result = McpConfig::from_json_files::<&str>(&[]).unwrap();
369        assert!(result.servers.is_empty());
370    }
371
372    #[test]
373    fn from_json_files_single_file_matches_from_json_file() {
374        let dir = tempdir().unwrap();
375        let path = write_config(dir.path(), "a.json", &stdio_config("ls"));
376
377        let single = McpConfig::from_json_file(&path).unwrap();
378        let multi = McpConfig::from_json_files(&[&path]).unwrap();
379
380        assert_eq!(single.servers.len(), multi.servers.len());
381        assert!(multi.servers.contains_key("coding"));
382    }
383
384    #[test]
385    fn from_json_files_merges_disjoint_servers() {
386        let dir = tempdir().unwrap();
387        let a = write_config(dir.path(), "a.json", r#"{"servers": {"alpha": {"type": "stdio", "command": "a"}}}"#);
388        let b = write_config(dir.path(), "b.json", r#"{"servers": {"beta": {"type": "stdio", "command": "b"}}}"#);
389
390        let merged = McpConfig::from_json_files(&[a, b]).unwrap();
391        assert_eq!(merged.servers.len(), 2);
392        assert!(merged.servers.contains_key("alpha"));
393        assert!(merged.servers.contains_key("beta"));
394    }
395
396    #[test]
397    fn from_json_files_last_file_wins_on_collision_including_proxy() {
398        let dir = tempdir().unwrap();
399        let a = write_config(
400            dir.path(),
401            "a.json",
402            r#"{"servers":{"coding":{"type":"stdio","command":"from_a","proxy":true}}}"#,
403        );
404        let b = write_config(dir.path(), "b.json", r#"{"servers":{"coding":{"type":"stdio","command":"from_b"}}}"#);
405
406        let merged_ab = McpConfig::from_json_files(&[&a, &b]).unwrap();
407        match merged_ab.servers.get("coding").unwrap() {
408            McpServerConfig::Stdio(StdioServerConfig { command, proxy, .. }) => {
409                assert_eq!(command, "from_b");
410                assert!(!proxy);
411            }
412            other => panic!("expected Stdio, got {other:?}"),
413        }
414
415        let merged_ba = McpConfig::from_json_files(&[&b, &a]).unwrap();
416        match merged_ba.servers.get("coding").unwrap() {
417            McpServerConfig::Stdio(StdioServerConfig { command, proxy, .. }) => {
418                assert_eq!(command, "from_a");
419                assert!(*proxy);
420            }
421            other => panic!("expected Stdio, got {other:?}"),
422        }
423    }
424
425    #[test]
426    fn mark_all_proxy_sets_every_server() {
427        let mut config = McpConfig::from_json(
428            r#"{"servers":{"a":{"type":"stdio","command":"a"},"b":{"type":"http","url":"https://example.com"}}}"#,
429        )
430        .unwrap();
431        config.mark_all_proxy();
432        assert!(config.servers.values().all(McpServerConfig::proxy));
433    }
434
435    #[test]
436    fn from_json_files_propagates_io_error_on_missing_file() {
437        let dir = tempdir().unwrap();
438        let missing = dir.path().join("does-not-exist.json");
439        let result = McpConfig::from_json_files(&[missing]);
440        assert!(matches!(result, Err(ParseError::IoError(_))));
441    }
442
443    #[test]
444    fn from_json_files_propagates_json_error_on_invalid_file() {
445        let dir = tempdir().unwrap();
446        let bad = write_config(dir.path(), "bad.json", "not valid json");
447        let result = McpConfig::from_json_files(&[bad]);
448        assert!(matches!(result, Err(ParseError::JsonError(_))));
449    }
450
451    #[tokio::test]
452    async fn into_servers_preserves_proxy_flags() {
453        let json = r#"{
454            "servers": {
455                "github": {"type": "stdio", "command": "g"},
456                "playwright": {"type": "stdio", "command": "p", "proxy": true}
457            }
458        }"#;
459        let config = McpConfig::from_json(json).unwrap();
460        let servers = config.into_servers(&HashMap::new()).await.unwrap();
461
462        assert_eq!(servers.len(), 2);
463        assert!(!servers.iter().find(|s| s.name == "github").unwrap().proxy);
464        assert!(servers.iter().find(|s| s.name == "playwright").unwrap().proxy);
465    }
466
467    #[tokio::test]
468    async fn into_servers_with_proxy_forces_proxy_flags() {
469        let config =
470            McpConfig::from_json(r#"{"servers":{"github":{"type":"stdio","command":"g","proxy":false}}}"#).unwrap();
471        let servers = config.into_servers_with_proxy(&HashMap::new(), true).await.unwrap();
472        assert!(servers[0].proxy);
473    }
474}