1use std::collections::HashMap;
8
9use serde::de::Deserializer;
10use serde::{Deserialize, Serialize};
11
12use super::admin::{ParamRestrictions, ToolFilter};
13
14#[derive(Debug, Clone, Serialize)]
36pub struct ToolServerConfig {
37 pub name: String,
38 pub transport: ToolServerTransport,
39 #[serde(default)]
43 pub bridge: bool,
44 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub tool_filter: Option<ToolFilter>,
46 #[serde(default)]
47 pub param_restrictions: ParamRestrictions,
48}
49
50impl<'de> Deserialize<'de> for ToolServerConfig {
51 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
52 where
53 D: Deserializer<'de>,
54 {
55 #[derive(Deserialize)]
57 struct Raw {
58 name: String,
59
60 #[serde(default)]
62 transport: Option<ToolServerTransport>,
63
64 #[serde(default)]
66 command: Option<String>,
67 #[serde(default)]
68 args: Vec<String>,
69 #[serde(default)]
70 env: HashMap<String, String>,
71
72 #[serde(default)]
74 url: Option<String>,
75 #[serde(default)]
76 headers: HashMap<String, String>,
77
78 #[serde(default)]
80 bridge: bool,
81
82 #[serde(default)]
84 tool_filter: Option<ToolFilter>,
85 #[serde(default)]
86 param_restrictions: ParamRestrictions,
87 }
88
89 let raw = Raw::deserialize(deserializer)?;
90
91 let transport = if let Some(t) = raw.transport {
92 t
93 } else if let Some(command) = raw.command {
94 ToolServerTransport::Stdio {
95 command,
96 args: raw.args,
97 env: raw.env,
98 }
99 } else if let Some(url) = raw.url {
100 ToolServerTransport::Http {
101 url,
102 headers: raw.headers,
103 }
104 } else {
105 return Err(serde::de::Error::custom(
106 "mcp_servers entry must have `transport`, `command` (stdio), or `url` (http)",
107 ));
108 };
109
110 Ok(ToolServerConfig {
111 name: raw.name,
112 transport,
113 bridge: raw.bridge,
114 tool_filter: raw.tool_filter,
115 param_restrictions: raw.param_restrictions,
116 })
117 }
118}
119
120impl ToolServerConfig {
121 pub fn validate(&self) -> Result<(), String> {
123 if self.name.is_empty() {
124 return Err("server name must not be empty".into());
125 }
126 if self.name.contains('/') {
127 return Err(format!("server name '{}' must not contain '/'", self.name));
128 }
129 if self.name == "sse" {
130 return Err("server name 'sse' is reserved".into());
131 }
132 match &self.transport {
133 ToolServerTransport::Stdio { command, .. } => {
134 if command.is_empty() {
135 return Err(format!(
136 "server '{}': stdio command must not be empty",
137 self.name
138 ));
139 }
140 }
141 ToolServerTransport::Http { url, .. } => {
142 if url.is_empty() {
143 return Err(format!(
144 "server '{}': http url must not be empty",
145 self.name
146 ));
147 }
148 }
149 }
150 Ok(())
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum ToolServerTransport {
158 Stdio {
159 command: String,
160 #[serde(default)]
161 args: Vec<String>,
162 #[serde(default)]
163 env: HashMap<String, String>,
164 },
165 Http {
166 url: String,
167 #[serde(default)]
168 headers: HashMap<String, String>,
169 },
170}
171
172#[derive(Debug, Clone, Default, Serialize, Deserialize)]
176pub struct ToolServerAccessGroups {
177 #[serde(flatten)]
178 groups: HashMap<String, Vec<String>>,
179}
180
181impl ToolServerAccessGroups {
182 pub fn expand_patterns(&self, patterns: &[String]) -> Vec<String> {
191 let mut result = Vec::new();
192 for pattern in patterns {
193 if let Some((prefix, suffix)) = pattern.split_once('/') {
194 if let Some(servers) = self.groups.get(prefix) {
195 for server in servers {
196 result.push(format!("{server}/{suffix}"));
197 }
198 } else {
199 result.push(pattern.clone());
200 }
201 } else if let Some(servers) = self.groups.get(pattern.as_str()) {
202 for server in servers {
203 result.push(format!("{server}/*"));
204 }
205 } else {
206 result.push(pattern.clone());
207 }
208 }
209 result
210 }
211
212 pub fn contains(&self, name: &str) -> bool {
214 self.groups.contains_key(name)
215 }
216
217 pub fn servers(&self, name: &str) -> Option<&[String]> {
219 self.groups.get(name).map(|v| v.as_slice())
220 }
221
222 pub fn as_map(&self) -> &HashMap<String, Vec<String>> {
224 &self.groups
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct AgentConfig {
233 pub name: String,
235
236 pub url: String,
238
239 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
241 pub headers: HashMap<String, String>,
242
243 #[serde(default, skip_serializing_if = "Option::is_none")]
245 pub card_path: Option<String>,
246}
247
248impl AgentConfig {
249 pub fn validate(&self) -> Result<(), String> {
251 if self.name.is_empty() {
252 return Err("agent name cannot be empty".to_string());
253 }
254 if self.name.contains('/') {
255 return Err(format!("agent name '{}' cannot contain '/'", self.name));
256 }
257 if self.url.is_empty() {
258 return Err("agent URL cannot be empty".to_string());
259 }
260 Ok(())
261 }
262
263 pub fn discovery_url(&self) -> String {
265 let base = self.url.trim_end_matches('/');
266 let path = self
267 .card_path
268 .as_deref()
269 .unwrap_or("/.well-known/agent-card.json");
270 format!("{base}{path}")
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 fn test_stdio_config(name: &str, command: &str) -> ToolServerConfig {
281 ToolServerConfig {
282 name: name.into(),
283 transport: ToolServerTransport::Stdio {
284 command: command.into(),
285 args: vec![],
286 env: HashMap::new(),
287 },
288 bridge: false,
289 tool_filter: None,
290 param_restrictions: ParamRestrictions::default(),
291 }
292 }
293
294 #[test]
295 fn validate_rejects_empty_name() {
296 assert!(test_stdio_config("", "echo").validate().is_err());
297 }
298
299 #[test]
300 fn validate_rejects_slash_in_name() {
301 assert!(test_stdio_config("a/b", "echo").validate().is_err());
302 }
303
304 #[test]
305 fn validate_rejects_empty_command() {
306 assert!(test_stdio_config("test", "").validate().is_err());
307 }
308
309 #[test]
310 fn validate_rejects_empty_url() {
311 let config = ToolServerConfig {
312 name: "test".into(),
313 transport: ToolServerTransport::Http {
314 url: String::new(),
315 headers: HashMap::new(),
316 },
317 bridge: false,
318 tool_filter: None,
319 param_restrictions: ParamRestrictions::default(),
320 };
321 assert!(config.validate().is_err());
322 }
323
324 #[test]
325 fn validate_accepts_valid_stdio() {
326 assert!(test_stdio_config("my-server", "npx").validate().is_ok());
327 }
328
329 #[test]
330 fn validate_accepts_valid_http() {
331 let config = ToolServerConfig {
332 name: "remote".into(),
333 transport: ToolServerTransport::Http {
334 url: "http://localhost:3000/mcp".into(),
335 headers: HashMap::new(),
336 },
337 bridge: false,
338 tool_filter: None,
339 param_restrictions: ParamRestrictions::default(),
340 };
341 assert!(config.validate().is_ok());
342 }
343
344 #[test]
345 fn serde_roundtrip_stdio() {
346 let config = ToolServerConfig {
347 name: "test".into(),
348 transport: ToolServerTransport::Stdio {
349 command: "npx".into(),
350 args: vec!["-y".into(), "server".into()],
351 env: HashMap::from([("KEY".into(), "VAL".into())]),
352 },
353 bridge: false,
354 tool_filter: Some(ToolFilter {
355 allow: Some(vec!["tool1".into()]),
356 deny: None,
357 }),
358 param_restrictions: ParamRestrictions::default(),
359 };
360 let json = serde_json::to_string(&config).expect("serialize");
361 let parsed: ToolServerConfig = serde_json::from_str(&json).expect("deserialize");
362 assert_eq!(parsed.name, "test");
363 }
364
365 #[test]
366 fn serde_roundtrip_http() {
367 let config = ToolServerConfig {
368 name: "remote".into(),
369 transport: ToolServerTransport::Http {
370 url: "http://localhost:3000/mcp".into(),
371 headers: HashMap::from([("Authorization".into(), "Bearer tok".into())]),
372 },
373 bridge: false,
374 tool_filter: None,
375 param_restrictions: ParamRestrictions::default(),
376 };
377 let json = serde_json::to_string(&config).expect("serialize");
378 let parsed: ToolServerConfig = serde_json::from_str(&json).expect("deserialize");
379 assert_eq!(parsed.name, "remote");
380 }
381
382 #[test]
385 fn deserialize_flat_stdio() {
386 let json = r#"{
387 "name": "fs",
388 "command": "npx",
389 "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
390 }"#;
391 let config: ToolServerConfig = serde_json::from_str(json).expect("deserialize flat stdio");
392 assert_eq!(config.name, "fs");
393 match &config.transport {
394 ToolServerTransport::Stdio { command, args, .. } => {
395 assert_eq!(command, "npx");
396 assert_eq!(args.len(), 3);
397 }
398 _ => panic!("expected Stdio transport"),
399 }
400 }
401
402 #[test]
403 fn deserialize_flat_http() {
404 let json = r#"{
405 "name": "remote",
406 "url": "http://localhost:3000/mcp",
407 "headers": {"Authorization": "Bearer tok"}
408 }"#;
409 let config: ToolServerConfig = serde_json::from_str(json).expect("deserialize flat http");
410 assert_eq!(config.name, "remote");
411 match &config.transport {
412 ToolServerTransport::Http { url, headers } => {
413 assert_eq!(url, "http://localhost:3000/mcp");
414 assert_eq!(
415 headers.get("Authorization").map(String::as_str),
416 Some("Bearer tok")
417 );
418 }
419 _ => panic!("expected Http transport"),
420 }
421 }
422
423 #[test]
424 fn deserialize_nested_still_works() {
425 let json = r#"{
426 "name": "test",
427 "transport": {
428 "type": "stdio",
429 "command": "echo",
430 "args": ["hello"]
431 }
432 }"#;
433 let config: ToolServerConfig =
434 serde_json::from_str(json).expect("deserialize nested transport");
435 assert_eq!(config.name, "test");
436 match &config.transport {
437 ToolServerTransport::Stdio { command, args, .. } => {
438 assert_eq!(command, "echo");
439 assert_eq!(args, &["hello"]);
440 }
441 _ => panic!("expected Stdio transport"),
442 }
443 }
444
445 #[test]
446 fn deserialize_rejects_missing_transport() {
447 let json = r#"{"name": "bad"}"#;
448 let result = serde_json::from_str::<ToolServerConfig>(json);
449 assert!(result.is_err());
450 }
451
452 #[test]
453 fn deserialize_bridge_flag() {
454 let json = r#"{
455 "name": "my-tools",
456 "command": "my-mcp-server",
457 "bridge": true
458 }"#;
459 let config: ToolServerConfig = serde_json::from_str(json).expect("deserialize bridge flag");
460 assert!(config.bridge);
461 }
462
463 #[test]
464 fn deserialize_bridge_defaults_to_false() {
465 let json = r#"{
466 "name": "my-tools",
467 "command": "my-mcp-server"
468 }"#;
469 let config: ToolServerConfig =
470 serde_json::from_str(json).expect("deserialize without bridge flag");
471 assert!(!config.bridge);
472 }
473
474 #[test]
475 fn validate_rejects_reserved_name_sse() {
476 assert!(test_stdio_config("sse", "echo").validate().is_err());
477 }
478
479 #[test]
482 fn access_groups_expand_patterns() {
483 let groups = ToolServerAccessGroups {
484 groups: HashMap::from([
485 ("dev_tools".into(), vec!["github".into(), "jira".into()]),
486 ("comms".into(), vec!["slack".into(), "email".into()]),
487 ]),
488 };
489 let mut expanded = groups.expand_patterns(&["dev_tools/*".into()]);
490 expanded.sort();
491 assert_eq!(expanded, vec!["github/*", "jira/*"]);
492 }
493
494 #[test]
495 fn access_groups_bare_name_expands_to_wildcard() {
496 let groups = ToolServerAccessGroups {
497 groups: HashMap::from([("dev_tools".into(), vec!["github".into(), "jira".into()])]),
498 };
499 let mut expanded = groups.expand_patterns(&["dev_tools".into()]);
500 expanded.sort();
501 assert_eq!(expanded, vec!["github/*", "jira/*"]);
502 }
503
504 #[test]
505 fn access_groups_non_group_passthrough() {
506 let groups = ToolServerAccessGroups::default();
507 let expanded = groups.expand_patterns(&["direct_server/tool".into()]);
508 assert_eq!(expanded, vec!["direct_server/tool"]);
509 }
510
511 #[test]
512 fn access_groups_serde_roundtrip() {
513 let json = r#"{
514 "dev_tools": ["github", "jira"],
515 "comms": ["slack"]
516 }"#;
517 let groups: ToolServerAccessGroups = serde_json::from_str(json).unwrap_or_default();
518 assert!(groups.contains("dev_tools"));
519 assert_eq!(
520 groups.servers("dev_tools").map(|s: &[String]| s.len()),
521 Some(2)
522 );
523 }
524
525 #[test]
528 fn agent_validate_rejects_empty_name() {
529 let config = AgentConfig {
530 name: String::new(),
531 url: "http://localhost".to_string(),
532 headers: HashMap::new(),
533 card_path: None,
534 };
535 assert!(config.validate().is_err());
536 }
537
538 #[test]
539 fn agent_validate_rejects_slash_in_name() {
540 let config = AgentConfig {
541 name: "my/agent".to_string(),
542 url: "http://localhost".to_string(),
543 headers: HashMap::new(),
544 card_path: None,
545 };
546 assert!(config.validate().is_err());
547 }
548
549 #[test]
550 fn agent_validate_rejects_empty_url() {
551 let config = AgentConfig {
552 name: "agent".to_string(),
553 url: String::new(),
554 headers: HashMap::new(),
555 card_path: None,
556 };
557 assert!(config.validate().is_err());
558 }
559
560 #[test]
561 fn agent_validate_accepts_valid() {
562 let config = AgentConfig {
563 name: "test-agent".to_string(),
564 url: "http://localhost:9000".to_string(),
565 headers: HashMap::new(),
566 card_path: None,
567 };
568 assert!(config.validate().is_ok());
569 }
570
571 #[test]
572 fn agent_discovery_url_default_path() {
573 let config = AgentConfig {
574 name: "agent".to_string(),
575 url: "https://agent.example.com".to_string(),
576 headers: HashMap::new(),
577 card_path: None,
578 };
579 assert_eq!(
580 config.discovery_url(),
581 "https://agent.example.com/.well-known/agent-card.json"
582 );
583 }
584
585 #[test]
586 fn agent_discovery_url_custom_path() {
587 let config = AgentConfig {
588 name: "agent".to_string(),
589 url: "https://agent.example.com/".to_string(),
590 headers: HashMap::new(),
591 card_path: Some("/custom/card.json".to_string()),
592 };
593 assert_eq!(
594 config.discovery_url(),
595 "https://agent.example.com/custom/card.json"
596 );
597 }
598
599 #[test]
600 fn agent_serde_round_trip() {
601 let cfg = AgentConfig {
602 name: "my-agent".to_string(),
603 url: "https://agent.example.com".to_string(),
604 headers: HashMap::from([("Authorization".into(), "Bearer tok".into())]),
605 card_path: Some("/custom/card.json".to_string()),
606 };
607 let json = serde_json::to_string(&cfg).expect("serialize");
608 let parsed: AgentConfig = serde_json::from_str(&json).expect("deserialize");
609 assert_eq!(parsed.name, "my-agent");
610 assert_eq!(parsed.url, "https://agent.example.com");
611 assert_eq!(
612 parsed.headers.get("Authorization").map(String::as_str),
613 Some("Bearer tok")
614 );
615 assert_eq!(parsed.card_path.as_deref(), Some("/custom/card.json"));
616 }
617}