1use super::variables::{VarError, expand_env_vars};
2use futures::future::BoxFuture;
3use rmcp::{RoleServer, service::DynService, transport::streamable_http_client::StreamableHttpClientTransportConfig};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::{BTreeMap, HashMap};
8use std::fmt::{Debug, Formatter};
9use std::path::Path;
10
11#[derive(Debug, Clone, Default, Deserialize, Serialize, JsonSchema)]
12pub struct McpConfig {
13 #[serde(alias = "mcpServers")]
14 pub servers: BTreeMap<String, McpServerConfig>,
15}
16
17#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
18#[serde(untagged)]
19pub enum McpServerConfig {
20 Stdio(StdioServerConfig),
21 Http(HttpServerConfig),
22 Sse(SseServerConfig),
23 InMemory(InMemoryServerConfig),
24}
25
26#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
27#[serde(deny_unknown_fields)]
28pub struct StdioServerConfig {
29 #[serde(rename = "type", default)]
30 pub type_: StdioType,
31
32 pub command: String,
33
34 #[serde(default)]
35 pub args: Vec<String>,
36
37 #[serde(default)]
38 pub env: HashMap<String, String>,
39
40 #[serde(default, skip_serializing_if = "is_false")]
41 pub proxy: bool,
42}
43
44#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
45#[serde(deny_unknown_fields)]
46pub struct HttpServerConfig {
47 #[serde(rename = "type")]
48 pub type_: HttpType,
49
50 pub url: String,
51
52 #[serde(default)]
53 pub headers: HashMap<String, String>,
54
55 #[serde(default, skip_serializing_if = "is_false")]
56 pub proxy: bool,
57}
58
59#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
60#[serde(deny_unknown_fields)]
61pub struct SseServerConfig {
62 #[serde(rename = "type")]
63 pub type_: SseType,
64
65 pub url: String,
66
67 #[serde(default)]
68 pub headers: HashMap<String, String>,
69
70 #[serde(default, skip_serializing_if = "is_false")]
71 pub proxy: bool,
72}
73
74#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, PartialEq)]
75#[serde(deny_unknown_fields)]
76pub struct InMemoryServerConfig {
77 #[serde(rename = "type")]
78 pub type_: InMemoryType,
79
80 #[serde(default)]
81 pub args: Vec<String>,
82
83 #[serde(default)]
84 pub input: Option<Value>,
85
86 #[serde(default, skip_serializing_if = "is_false")]
87 pub proxy: bool,
88}
89
90#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, JsonSchema, PartialEq)]
91pub enum StdioType {
92 #[default]
93 #[serde(rename = "stdio")]
94 Stdio,
95}
96
97#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
98pub enum HttpType {
99 #[serde(rename = "http")]
100 Http,
101}
102
103#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
104pub enum SseType {
105 #[serde(rename = "sse")]
106 Sse,
107}
108
109#[derive(Debug, Clone, Copy, Deserialize, Serialize, JsonSchema, PartialEq)]
110pub enum InMemoryType {
111 #[serde(rename = "in-memory")]
112 InMemory,
113}
114
115pub struct McpServer {
116 pub name: String,
117 pub transport: McpTransport,
118 pub proxy: bool,
119}
120
121pub enum McpTransport {
122 Stdio { command: String, args: Vec<String>, env: HashMap<String, String> },
123 Http { config: StreamableHttpClientTransportConfig },
124 InMemory { server: Box<dyn DynService<RoleServer>> },
125}
126
127impl McpServer {
128 pub fn new(name: impl Into<String>, transport: McpTransport, proxy: bool) -> Self {
129 Self { name: name.into(), transport, proxy }
130 }
131}
132
133impl Debug for McpServer {
134 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("McpServer")
136 .field("name", &self.name)
137 .field("transport", &self.transport)
138 .field("proxy", &self.proxy)
139 .finish()
140 }
141}
142
143impl Debug for McpTransport {
144 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
145 match self {
146 McpTransport::Stdio { command, args, env } => {
147 f.debug_struct("Stdio").field("command", command).field("args", args).field("env", env).finish()
148 }
149 McpTransport::Http { config } => f.debug_struct("Http").field("config", config).finish(),
150 McpTransport::InMemory { .. } => f.debug_struct("InMemory").field("server", &"<DynService>").finish(),
151 }
152 }
153}
154
155pub type ServerFactory =
156 Box<dyn Fn(Vec<String>, Option<Value>) -> BoxFuture<'static, Box<dyn DynService<RoleServer>>> + Send + Sync>;
157
158#[derive(Debug, thiserror::Error)]
159pub enum ParseError {
160 #[error("Failed to read config file: {0}")]
161 IoError(#[from] std::io::Error),
162
163 #[error("Invalid JSON: {0}")]
164 JsonError(#[from] serde_json::Error),
165
166 #[error("Variable expansion failed: {0}")]
167 VarError(#[from] VarError),
168
169 #[error("InMemory server factory '{0}' not registered")]
170 FactoryNotFound(String),
171
172 #[error("Invalid nested config in tool-proxy: {0}")]
173 InvalidNestedConfig(String),
174}
175
176impl McpConfig {
177 pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self, ParseError> {
178 let content = std::fs::read_to_string(path)?;
179 Self::from_json(&content)
180 }
181
182 pub fn from_json_files<T: AsRef<Path>>(paths: &[T]) -> Result<Self, ParseError> {
183 let mut merged = BTreeMap::new();
184 for path in paths {
185 let raw = Self::from_json_file(path)?;
186 merged.extend(raw.servers);
187 }
188 Ok(Self { servers: merged })
189 }
190
191 pub fn from_json(json: &str) -> Result<Self, ParseError> {
192 Ok(serde_json::from_str(json)?)
193 }
194
195 pub async fn into_servers(self, factories: &HashMap<String, ServerFactory>) -> Result<Vec<McpServer>, ParseError> {
196 self.into_servers_with_proxy(factories, false).await
197 }
198
199 pub async fn into_servers_with_proxy(
200 self,
201 factories: &HashMap<String, ServerFactory>,
202 force_proxy: bool,
203 ) -> Result<Vec<McpServer>, ParseError> {
204 let mut servers = Vec::with_capacity(self.servers.len());
205 for (name, config) in self.servers {
206 servers.push(config.into_server(name, factories, force_proxy).await?);
207 }
208 Ok(servers)
209 }
210
211 pub fn mark_all_proxy(&mut self) {
212 for server in self.servers.values_mut() {
213 server.set_proxy(true);
214 }
215 }
216}
217
218impl McpServerConfig {
219 pub fn proxy(&self) -> bool {
220 match self {
221 McpServerConfig::Stdio(config) => config.proxy,
222 McpServerConfig::Http(config) => config.proxy,
223 McpServerConfig::Sse(config) => config.proxy,
224 McpServerConfig::InMemory(config) => config.proxy,
225 }
226 }
227
228 pub fn set_proxy(&mut self, value: bool) {
229 match self {
230 McpServerConfig::Stdio(config) => config.proxy = value,
231 McpServerConfig::Http(config) => config.proxy = value,
232 McpServerConfig::Sse(config) => config.proxy = value,
233 McpServerConfig::InMemory(config) => config.proxy = value,
234 }
235 }
236
237 pub async fn into_server(
238 self,
239 name: String,
240 factories: &HashMap<String, ServerFactory>,
241 force_proxy: bool,
242 ) -> Result<McpServer, ParseError> {
243 let proxy = force_proxy || self.proxy();
244 let transport = self.into_transport(name.clone(), factories).await?;
245 Ok(McpServer::new(name, transport, proxy))
246 }
247
248 async fn into_transport(
249 self,
250 name: String,
251 factories: &HashMap<String, ServerFactory>,
252 ) -> Result<McpTransport, ParseError> {
253 match self {
254 McpServerConfig::Stdio(StdioServerConfig { command, args, env, .. }) => Ok(McpTransport::Stdio {
255 command: expand_env_vars(&command)?,
256 args: args.into_iter().map(|a| expand_env_vars(&a)).collect::<Result<Vec<_>, _>>()?,
257 env: env
258 .into_iter()
259 .map(|(k, v)| Ok((k, expand_env_vars(&v)?)))
260 .collect::<Result<HashMap<_, _>, VarError>>()?,
261 }),
262
263 McpServerConfig::Http(HttpServerConfig { url, headers, .. })
264 | McpServerConfig::Sse(SseServerConfig { url, headers, .. }) => {
265 let auth_header = headers.get("Authorization").map(|v| expand_env_vars(v)).transpose()?;
266 let mut config = StreamableHttpClientTransportConfig::with_uri(expand_env_vars(&url)?);
267 if let Some(auth) = auth_header {
268 config = config.auth_header(auth);
269 }
270 Ok(McpTransport::Http { config })
271 }
272
273 McpServerConfig::InMemory(InMemoryServerConfig { args, input, .. }) => {
274 let server_factory = factories.get(&name).ok_or_else(|| ParseError::FactoryNotFound(name.clone()))?;
275 let expanded_args =
276 args.into_iter().map(|a| expand_env_vars(&a)).collect::<Result<Vec<_>, VarError>>()?;
277 let server = server_factory(expanded_args, input).await;
278 Ok(McpTransport::InMemory { server })
279 }
280 }
281 }
282}
283
284#[allow(clippy::trivially_copy_pass_by_ref)]
285fn is_false(value: &bool) -> bool {
286 !*value
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use std::fs;
293 use tempfile::tempdir;
294
295 fn write_config(dir: &Path, name: &str, json: &str) -> std::path::PathBuf {
296 let path = dir.join(name);
297 fs::write(&path, json).unwrap();
298 path
299 }
300
301 fn stdio_config(command: &str) -> String {
302 format!(r#"{{"servers": {{"coding": {{"type": "stdio", "command": "{command}"}}}}}}"#)
303 }
304
305 #[test]
306 fn from_json_accepts_mcp_servers_key() {
307 let config = McpConfig::from_json(r#"{"mcpServers": {"alpha": {"type": "stdio", "command": "a"}}}"#).unwrap();
308 assert_eq!(config.servers.len(), 1);
309 assert!(config.servers.contains_key("alpha"));
310 }
311
312 #[test]
313 fn from_json_defaults_missing_type_to_stdio() {
314 let config = McpConfig::from_json(
315 r#"{"mcpServers": {"devtools": {"command": "npx", "args": ["-y", "chrome-devtools-mcp"]}}}"#,
316 )
317 .unwrap();
318 match config.servers.get("devtools").unwrap() {
319 McpServerConfig::Stdio(StdioServerConfig { command, args, proxy, .. }) => {
320 assert_eq!(command, "npx");
321 assert_eq!(args, &["-y", "chrome-devtools-mcp"]);
322 assert!(!proxy);
323 }
324 other => panic!("expected Stdio server, got {other:?}"),
325 }
326 }
327
328 #[test]
329 fn from_json_accepts_server_proxy_true() {
330 let config =
331 McpConfig::from_json(r#"{"servers": {"playwright": {"type": "stdio", "command": "npx", "proxy": true}}}"#)
332 .unwrap();
333 assert!(config.servers.get("playwright").unwrap().proxy());
334 }
335
336 #[test]
337 fn from_json_rejects_proxy_server_type() {
338 let result = McpConfig::from_json(r#"{"servers":{"tools":{"type":"proxy","servers":{}}}}"#);
339 assert!(result.is_err());
340 }
341
342 #[test]
343 fn false_proxy_omits_during_serialization() {
344 let config =
345 McpConfig::from_json(r#"{"servers": {"coding": {"type": "stdio", "command": "a", "proxy": false}}}"#)
346 .unwrap();
347 let serialized = serde_json::to_string(&config).unwrap();
348 assert!(!serialized.contains("proxy"));
349 }
350
351 #[test]
352 fn true_proxy_serializes() {
353 let config =
354 McpConfig::from_json(r#"{"servers": {"coding": {"type": "stdio", "command": "a", "proxy": true}}}"#)
355 .unwrap();
356 let serialized = serde_json::to_string(&config).unwrap();
357 assert!(serialized.contains("proxy"));
358 }
359
360 #[test]
361 fn from_json_rejects_unknown_type() {
362 let result = McpConfig::from_json(r#"{"servers": {"bad": {"type": "htp", "url": "https://example.com"}}}"#);
363 assert!(result.is_err());
364 }
365
366 #[test]
367 fn from_json_files_empty_returns_empty_servers() {
368 let result = McpConfig::from_json_files::<&str>(&[]).unwrap();
369 assert!(result.servers.is_empty());
370 }
371
372 #[test]
373 fn from_json_files_single_file_matches_from_json_file() {
374 let dir = tempdir().unwrap();
375 let path = write_config(dir.path(), "a.json", &stdio_config("ls"));
376
377 let single = McpConfig::from_json_file(&path).unwrap();
378 let multi = McpConfig::from_json_files(&[&path]).unwrap();
379
380 assert_eq!(single.servers.len(), multi.servers.len());
381 assert!(multi.servers.contains_key("coding"));
382 }
383
384 #[test]
385 fn from_json_files_merges_disjoint_servers() {
386 let dir = tempdir().unwrap();
387 let a = write_config(dir.path(), "a.json", r#"{"servers": {"alpha": {"type": "stdio", "command": "a"}}}"#);
388 let b = write_config(dir.path(), "b.json", r#"{"servers": {"beta": {"type": "stdio", "command": "b"}}}"#);
389
390 let merged = McpConfig::from_json_files(&[a, b]).unwrap();
391 assert_eq!(merged.servers.len(), 2);
392 assert!(merged.servers.contains_key("alpha"));
393 assert!(merged.servers.contains_key("beta"));
394 }
395
396 #[test]
397 fn from_json_files_last_file_wins_on_collision_including_proxy() {
398 let dir = tempdir().unwrap();
399 let a = write_config(
400 dir.path(),
401 "a.json",
402 r#"{"servers":{"coding":{"type":"stdio","command":"from_a","proxy":true}}}"#,
403 );
404 let b = write_config(dir.path(), "b.json", r#"{"servers":{"coding":{"type":"stdio","command":"from_b"}}}"#);
405
406 let merged_ab = McpConfig::from_json_files(&[&a, &b]).unwrap();
407 match merged_ab.servers.get("coding").unwrap() {
408 McpServerConfig::Stdio(StdioServerConfig { command, proxy, .. }) => {
409 assert_eq!(command, "from_b");
410 assert!(!proxy);
411 }
412 other => panic!("expected Stdio, got {other:?}"),
413 }
414
415 let merged_ba = McpConfig::from_json_files(&[&b, &a]).unwrap();
416 match merged_ba.servers.get("coding").unwrap() {
417 McpServerConfig::Stdio(StdioServerConfig { command, proxy, .. }) => {
418 assert_eq!(command, "from_a");
419 assert!(*proxy);
420 }
421 other => panic!("expected Stdio, got {other:?}"),
422 }
423 }
424
425 #[test]
426 fn mark_all_proxy_sets_every_server() {
427 let mut config = McpConfig::from_json(
428 r#"{"servers":{"a":{"type":"stdio","command":"a"},"b":{"type":"http","url":"https://example.com"}}}"#,
429 )
430 .unwrap();
431 config.mark_all_proxy();
432 assert!(config.servers.values().all(McpServerConfig::proxy));
433 }
434
435 #[test]
436 fn from_json_files_propagates_io_error_on_missing_file() {
437 let dir = tempdir().unwrap();
438 let missing = dir.path().join("does-not-exist.json");
439 let result = McpConfig::from_json_files(&[missing]);
440 assert!(matches!(result, Err(ParseError::IoError(_))));
441 }
442
443 #[test]
444 fn from_json_files_propagates_json_error_on_invalid_file() {
445 let dir = tempdir().unwrap();
446 let bad = write_config(dir.path(), "bad.json", "not valid json");
447 let result = McpConfig::from_json_files(&[bad]);
448 assert!(matches!(result, Err(ParseError::JsonError(_))));
449 }
450
451 #[tokio::test]
452 async fn into_servers_preserves_proxy_flags() {
453 let json = r#"{
454 "servers": {
455 "github": {"type": "stdio", "command": "g"},
456 "playwright": {"type": "stdio", "command": "p", "proxy": true}
457 }
458 }"#;
459 let config = McpConfig::from_json(json).unwrap();
460 let servers = config.into_servers(&HashMap::new()).await.unwrap();
461
462 assert_eq!(servers.len(), 2);
463 assert!(!servers.iter().find(|s| s.name == "github").unwrap().proxy);
464 assert!(servers.iter().find(|s| s.name == "playwright").unwrap().proxy);
465 }
466
467 #[tokio::test]
468 async fn into_servers_with_proxy_forces_proxy_flags() {
469 let config =
470 McpConfig::from_json(r#"{"servers":{"github":{"type":"stdio","command":"g","proxy":false}}}"#).unwrap();
471 let servers = config.into_servers_with_proxy(&HashMap::new(), true).await.unwrap();
472 assert!(servers[0].proxy);
473 }
474}