Skip to main content

bitrouter_core/routers/
upstream.rs

1//! Upstream connection configuration types.
2//!
3//! Transport-neutral data types describing how to connect to upstream tool
4//! servers and agents. Used by both `bitrouter-config` (YAML parsing) and
5//! protocol crates (`bitrouter-mcp`, `bitrouter-a2a`) at runtime.
6
7use std::collections::HashMap;
8
9use serde::de::Deserializer;
10use serde::{Deserialize, Serialize};
11
12use super::admin::{ParamRestrictions, ToolFilter};
13
14// ── Tool server config ──────────────────────────────────────────────
15
16/// Configuration for a single upstream tool server.
17///
18/// Supports two YAML formats:
19///
20/// **Nested** (explicit transport):
21/// ```yaml
22/// - name: my-server
23///   transport:
24///     type: stdio
25///     command: npx
26///     args: ["-y", "server"]
27/// ```
28///
29/// **Flat** (inferred transport — `command` implies stdio, `url` implies http):
30/// ```yaml
31/// - name: my-server
32///   command: npx
33///   args: ["-y", "server"]
34/// ```
35#[derive(Debug, Clone, Serialize)]
36pub struct ToolServerConfig {
37    pub name: String,
38    pub transport: ToolServerTransport,
39    /// When `true`, this server is also exposed as a standalone Streamable HTTP
40    /// endpoint at `POST /mcp/{name}` and `GET /mcp/{name}/sse`, in addition to
41    /// participating in the aggregated `POST /mcp` registry.
42    #[serde(default)]
43    pub bridge: bool,
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub tool_filter: Option<ToolFilter>,
46    #[serde(default)]
47    pub param_restrictions: ParamRestrictions,
48}
49
50impl<'de> Deserialize<'de> for ToolServerConfig {
51    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
52    where
53        D: Deserializer<'de>,
54    {
55        /// Helper that accepts both nested and flat transport layouts.
56        #[derive(Deserialize)]
57        struct Raw {
58            name: String,
59
60            // ── Nested format ──
61            #[serde(default)]
62            transport: Option<ToolServerTransport>,
63
64            // ── Flat stdio fields ──
65            #[serde(default)]
66            command: Option<String>,
67            #[serde(default)]
68            args: Vec<String>,
69            #[serde(default)]
70            env: HashMap<String, String>,
71
72            // ── Flat http fields ──
73            #[serde(default)]
74            url: Option<String>,
75            #[serde(default)]
76            headers: HashMap<String, String>,
77
78            // ── Bridge flag ──
79            #[serde(default)]
80            bridge: bool,
81
82            // ── Common fields ──
83            #[serde(default)]
84            tool_filter: Option<ToolFilter>,
85            #[serde(default)]
86            param_restrictions: ParamRestrictions,
87        }
88
89        let raw = Raw::deserialize(deserializer)?;
90
91        let transport = if let Some(t) = raw.transport {
92            t
93        } else if let Some(command) = raw.command {
94            ToolServerTransport::Stdio {
95                command,
96                args: raw.args,
97                env: raw.env,
98            }
99        } else if let Some(url) = raw.url {
100            ToolServerTransport::Http {
101                url,
102                headers: raw.headers,
103            }
104        } else {
105            return Err(serde::de::Error::custom(
106                "mcp_servers entry must have `transport`, `command` (stdio), or `url` (http)",
107            ));
108        };
109
110        Ok(ToolServerConfig {
111            name: raw.name,
112            transport,
113            bridge: raw.bridge,
114            tool_filter: raw.tool_filter,
115            param_restrictions: raw.param_restrictions,
116        })
117    }
118}
119
120impl ToolServerConfig {
121    /// Validate this configuration, returning an error if it is invalid.
122    pub fn validate(&self) -> Result<(), String> {
123        if self.name.is_empty() {
124            return Err("server name must not be empty".into());
125        }
126        if self.name.contains('/') {
127            return Err(format!("server name '{}' must not contain '/'", self.name));
128        }
129        if self.name == "sse" {
130            return Err("server name 'sse' is reserved".into());
131        }
132        match &self.transport {
133            ToolServerTransport::Stdio { command, .. } => {
134                if command.is_empty() {
135                    return Err(format!(
136                        "server '{}': stdio command must not be empty",
137                        self.name
138                    ));
139                }
140            }
141            ToolServerTransport::Http { url, .. } => {
142                if url.is_empty() {
143                    return Err(format!(
144                        "server '{}': http url must not be empty",
145                        self.name
146                    ));
147                }
148            }
149        }
150        Ok(())
151    }
152}
153
154/// Transport type for connecting to an upstream tool server.
155#[derive(Debug, Clone, Serialize, Deserialize)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum ToolServerTransport {
158    Stdio {
159        command: String,
160        #[serde(default)]
161        args: Vec<String>,
162        #[serde(default)]
163        env: HashMap<String, String>,
164    },
165    Http {
166        url: String,
167        #[serde(default)]
168        headers: HashMap<String, String>,
169    },
170}
171
172/// Named groups of tool servers for access control convenience.
173///
174/// Groups resolve at keygen time — JWT claims stay concrete server patterns.
175#[derive(Debug, Clone, Default, Serialize, Deserialize)]
176pub struct ToolServerAccessGroups {
177    #[serde(flatten)]
178    groups: HashMap<String, Vec<String>>,
179}
180
181impl ToolServerAccessGroups {
182    /// Expand patterns that reference group names into concrete server patterns.
183    ///
184    /// For each input pattern, split on first `/`:
185    /// - If the prefix matches a group name, expand to one pattern per server
186    ///   in the group, preserving the suffix.
187    /// - If the prefix matches a group name and there is no suffix (bare group name),
188    ///   expand to `"server/*"` for each server in the group.
189    /// - Non-group patterns pass through unchanged.
190    pub fn expand_patterns(&self, patterns: &[String]) -> Vec<String> {
191        let mut result = Vec::new();
192        for pattern in patterns {
193            if let Some((prefix, suffix)) = pattern.split_once('/') {
194                if let Some(servers) = self.groups.get(prefix) {
195                    for server in servers {
196                        result.push(format!("{server}/{suffix}"));
197                    }
198                } else {
199                    result.push(pattern.clone());
200                }
201            } else if let Some(servers) = self.groups.get(pattern.as_str()) {
202                for server in servers {
203                    result.push(format!("{server}/*"));
204                }
205            } else {
206                result.push(pattern.clone());
207            }
208        }
209        result
210    }
211
212    /// Check if a group name exists.
213    pub fn contains(&self, name: &str) -> bool {
214        self.groups.contains_key(name)
215    }
216
217    /// Get the servers in a group.
218    pub fn servers(&self, name: &str) -> Option<&[String]> {
219        self.groups.get(name).map(|v| v.as_slice())
220    }
221
222    /// Return all groups as a map.
223    pub fn as_map(&self) -> &HashMap<String, Vec<String>> {
224        &self.groups
225    }
226}
227
228// ── Agent config ────────────────────────────────────────────────────
229
230/// Configuration for an upstream agent to proxy.
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct AgentConfig {
233    /// Display name for this upstream agent.
234    pub name: String,
235
236    /// Base URL of the upstream agent (used for discovery).
237    pub url: String,
238
239    /// Optional HTTP headers to send to upstream (e.g., auth tokens).
240    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
241    pub headers: HashMap<String, String>,
242
243    /// Optional card discovery path override.
244    #[serde(default, skip_serializing_if = "Option::is_none")]
245    pub card_path: Option<String>,
246}
247
248impl AgentConfig {
249    /// Validate the configuration.
250    pub fn validate(&self) -> Result<(), String> {
251        if self.name.is_empty() {
252            return Err("agent name cannot be empty".to_string());
253        }
254        if self.name.contains('/') {
255            return Err(format!("agent name '{}' cannot contain '/'", self.name));
256        }
257        if self.url.is_empty() {
258            return Err("agent URL cannot be empty".to_string());
259        }
260        Ok(())
261    }
262
263    /// Get the discovery URL for this agent.
264    pub fn discovery_url(&self) -> String {
265        let base = self.url.trim_end_matches('/');
266        let path = self
267            .card_path
268            .as_deref()
269            .unwrap_or("/.well-known/agent-card.json");
270        format!("{base}{path}")
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    // ── ToolServerConfig tests ──────────────────────────────────────
279
280    fn test_stdio_config(name: &str, command: &str) -> ToolServerConfig {
281        ToolServerConfig {
282            name: name.into(),
283            transport: ToolServerTransport::Stdio {
284                command: command.into(),
285                args: vec![],
286                env: HashMap::new(),
287            },
288            bridge: false,
289            tool_filter: None,
290            param_restrictions: ParamRestrictions::default(),
291        }
292    }
293
294    #[test]
295    fn validate_rejects_empty_name() {
296        assert!(test_stdio_config("", "echo").validate().is_err());
297    }
298
299    #[test]
300    fn validate_rejects_slash_in_name() {
301        assert!(test_stdio_config("a/b", "echo").validate().is_err());
302    }
303
304    #[test]
305    fn validate_rejects_empty_command() {
306        assert!(test_stdio_config("test", "").validate().is_err());
307    }
308
309    #[test]
310    fn validate_rejects_empty_url() {
311        let config = ToolServerConfig {
312            name: "test".into(),
313            transport: ToolServerTransport::Http {
314                url: String::new(),
315                headers: HashMap::new(),
316            },
317            bridge: false,
318            tool_filter: None,
319            param_restrictions: ParamRestrictions::default(),
320        };
321        assert!(config.validate().is_err());
322    }
323
324    #[test]
325    fn validate_accepts_valid_stdio() {
326        assert!(test_stdio_config("my-server", "npx").validate().is_ok());
327    }
328
329    #[test]
330    fn validate_accepts_valid_http() {
331        let config = ToolServerConfig {
332            name: "remote".into(),
333            transport: ToolServerTransport::Http {
334                url: "http://localhost:3000/mcp".into(),
335                headers: HashMap::new(),
336            },
337            bridge: false,
338            tool_filter: None,
339            param_restrictions: ParamRestrictions::default(),
340        };
341        assert!(config.validate().is_ok());
342    }
343
344    #[test]
345    fn serde_roundtrip_stdio() {
346        let config = ToolServerConfig {
347            name: "test".into(),
348            transport: ToolServerTransport::Stdio {
349                command: "npx".into(),
350                args: vec!["-y".into(), "server".into()],
351                env: HashMap::from([("KEY".into(), "VAL".into())]),
352            },
353            bridge: false,
354            tool_filter: Some(ToolFilter {
355                allow: Some(vec!["tool1".into()]),
356                deny: None,
357            }),
358            param_restrictions: ParamRestrictions::default(),
359        };
360        let json = serde_json::to_string(&config).expect("serialize");
361        let parsed: ToolServerConfig = serde_json::from_str(&json).expect("deserialize");
362        assert_eq!(parsed.name, "test");
363    }
364
365    #[test]
366    fn serde_roundtrip_http() {
367        let config = ToolServerConfig {
368            name: "remote".into(),
369            transport: ToolServerTransport::Http {
370                url: "http://localhost:3000/mcp".into(),
371                headers: HashMap::from([("Authorization".into(), "Bearer tok".into())]),
372            },
373            bridge: false,
374            tool_filter: None,
375            param_restrictions: ParamRestrictions::default(),
376        };
377        let json = serde_json::to_string(&config).expect("serialize");
378        let parsed: ToolServerConfig = serde_json::from_str(&json).expect("deserialize");
379        assert_eq!(parsed.name, "remote");
380    }
381
382    // ── Flat format deserialization tests ─────────────────────────────
383
384    #[test]
385    fn deserialize_flat_stdio() {
386        let json = r#"{
387            "name": "fs",
388            "command": "npx",
389            "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
390        }"#;
391        let config: ToolServerConfig = serde_json::from_str(json).expect("deserialize flat stdio");
392        assert_eq!(config.name, "fs");
393        match &config.transport {
394            ToolServerTransport::Stdio { command, args, .. } => {
395                assert_eq!(command, "npx");
396                assert_eq!(args.len(), 3);
397            }
398            _ => panic!("expected Stdio transport"),
399        }
400    }
401
402    #[test]
403    fn deserialize_flat_http() {
404        let json = r#"{
405            "name": "remote",
406            "url": "http://localhost:3000/mcp",
407            "headers": {"Authorization": "Bearer tok"}
408        }"#;
409        let config: ToolServerConfig = serde_json::from_str(json).expect("deserialize flat http");
410        assert_eq!(config.name, "remote");
411        match &config.transport {
412            ToolServerTransport::Http { url, headers } => {
413                assert_eq!(url, "http://localhost:3000/mcp");
414                assert_eq!(
415                    headers.get("Authorization").map(String::as_str),
416                    Some("Bearer tok")
417                );
418            }
419            _ => panic!("expected Http transport"),
420        }
421    }
422
423    #[test]
424    fn deserialize_nested_still_works() {
425        let json = r#"{
426            "name": "test",
427            "transport": {
428                "type": "stdio",
429                "command": "echo",
430                "args": ["hello"]
431            }
432        }"#;
433        let config: ToolServerConfig =
434            serde_json::from_str(json).expect("deserialize nested transport");
435        assert_eq!(config.name, "test");
436        match &config.transport {
437            ToolServerTransport::Stdio { command, args, .. } => {
438                assert_eq!(command, "echo");
439                assert_eq!(args, &["hello"]);
440            }
441            _ => panic!("expected Stdio transport"),
442        }
443    }
444
445    #[test]
446    fn deserialize_rejects_missing_transport() {
447        let json = r#"{"name": "bad"}"#;
448        let result = serde_json::from_str::<ToolServerConfig>(json);
449        assert!(result.is_err());
450    }
451
452    #[test]
453    fn deserialize_bridge_flag() {
454        let json = r#"{
455            "name": "my-tools",
456            "command": "my-mcp-server",
457            "bridge": true
458        }"#;
459        let config: ToolServerConfig = serde_json::from_str(json).expect("deserialize bridge flag");
460        assert!(config.bridge);
461    }
462
463    #[test]
464    fn deserialize_bridge_defaults_to_false() {
465        let json = r#"{
466            "name": "my-tools",
467            "command": "my-mcp-server"
468        }"#;
469        let config: ToolServerConfig =
470            serde_json::from_str(json).expect("deserialize without bridge flag");
471        assert!(!config.bridge);
472    }
473
474    #[test]
475    fn validate_rejects_reserved_name_sse() {
476        assert!(test_stdio_config("sse", "echo").validate().is_err());
477    }
478
479    // ── ToolServerAccessGroups tests ────────────────────────────────
480
481    #[test]
482    fn access_groups_expand_patterns() {
483        let groups = ToolServerAccessGroups {
484            groups: HashMap::from([
485                ("dev_tools".into(), vec!["github".into(), "jira".into()]),
486                ("comms".into(), vec!["slack".into(), "email".into()]),
487            ]),
488        };
489        let mut expanded = groups.expand_patterns(&["dev_tools/*".into()]);
490        expanded.sort();
491        assert_eq!(expanded, vec!["github/*", "jira/*"]);
492    }
493
494    #[test]
495    fn access_groups_bare_name_expands_to_wildcard() {
496        let groups = ToolServerAccessGroups {
497            groups: HashMap::from([("dev_tools".into(), vec!["github".into(), "jira".into()])]),
498        };
499        let mut expanded = groups.expand_patterns(&["dev_tools".into()]);
500        expanded.sort();
501        assert_eq!(expanded, vec!["github/*", "jira/*"]);
502    }
503
504    #[test]
505    fn access_groups_non_group_passthrough() {
506        let groups = ToolServerAccessGroups::default();
507        let expanded = groups.expand_patterns(&["direct_server/tool".into()]);
508        assert_eq!(expanded, vec!["direct_server/tool"]);
509    }
510
511    #[test]
512    fn access_groups_serde_roundtrip() {
513        let json = r#"{
514            "dev_tools": ["github", "jira"],
515            "comms": ["slack"]
516        }"#;
517        let groups: ToolServerAccessGroups = serde_json::from_str(json).unwrap_or_default();
518        assert!(groups.contains("dev_tools"));
519        assert_eq!(
520            groups.servers("dev_tools").map(|s: &[String]| s.len()),
521            Some(2)
522        );
523    }
524
525    // ── AgentConfig tests ───────────────────────────────────────────
526
527    #[test]
528    fn agent_validate_rejects_empty_name() {
529        let config = AgentConfig {
530            name: String::new(),
531            url: "http://localhost".to_string(),
532            headers: HashMap::new(),
533            card_path: None,
534        };
535        assert!(config.validate().is_err());
536    }
537
538    #[test]
539    fn agent_validate_rejects_slash_in_name() {
540        let config = AgentConfig {
541            name: "my/agent".to_string(),
542            url: "http://localhost".to_string(),
543            headers: HashMap::new(),
544            card_path: None,
545        };
546        assert!(config.validate().is_err());
547    }
548
549    #[test]
550    fn agent_validate_rejects_empty_url() {
551        let config = AgentConfig {
552            name: "agent".to_string(),
553            url: String::new(),
554            headers: HashMap::new(),
555            card_path: None,
556        };
557        assert!(config.validate().is_err());
558    }
559
560    #[test]
561    fn agent_validate_accepts_valid() {
562        let config = AgentConfig {
563            name: "test-agent".to_string(),
564            url: "http://localhost:9000".to_string(),
565            headers: HashMap::new(),
566            card_path: None,
567        };
568        assert!(config.validate().is_ok());
569    }
570
571    #[test]
572    fn agent_discovery_url_default_path() {
573        let config = AgentConfig {
574            name: "agent".to_string(),
575            url: "https://agent.example.com".to_string(),
576            headers: HashMap::new(),
577            card_path: None,
578        };
579        assert_eq!(
580            config.discovery_url(),
581            "https://agent.example.com/.well-known/agent-card.json"
582        );
583    }
584
585    #[test]
586    fn agent_discovery_url_custom_path() {
587        let config = AgentConfig {
588            name: "agent".to_string(),
589            url: "https://agent.example.com/".to_string(),
590            headers: HashMap::new(),
591            card_path: Some("/custom/card.json".to_string()),
592        };
593        assert_eq!(
594            config.discovery_url(),
595            "https://agent.example.com/custom/card.json"
596        );
597    }
598
599    #[test]
600    fn agent_serde_round_trip() {
601        let cfg = AgentConfig {
602            name: "my-agent".to_string(),
603            url: "https://agent.example.com".to_string(),
604            headers: HashMap::from([("Authorization".into(), "Bearer tok".into())]),
605            card_path: Some("/custom/card.json".to_string()),
606        };
607        let json = serde_json::to_string(&cfg).expect("serialize");
608        let parsed: AgentConfig = serde_json::from_str(&json).expect("deserialize");
609        assert_eq!(parsed.name, "my-agent");
610        assert_eq!(parsed.url, "https://agent.example.com");
611        assert_eq!(
612            parsed.headers.get("Authorization").map(String::as_str),
613            Some("Bearer tok")
614        );
615        assert_eq!(parsed.card_path.as_deref(), Some("/custom/card.json"));
616    }
617}