Skip to main content

bamboo_agent/agent/mcp/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Root MCP configuration
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct McpConfig {
7    #[serde(default = "default_version")]
8    pub version: u32,
9    #[serde(default)]
10    pub servers: Vec<McpServerConfig>,
11}
12
13fn default_version() -> u32 {
14    1
15}
16
17impl Default for McpConfig {
18    fn default() -> Self {
19        Self {
20            version: 1,
21            servers: Vec::new(),
22        }
23    }
24}
25
26/// Single MCP server configuration
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct McpServerConfig {
29    /// Unique identifier for this server
30    pub id: String,
31    /// Human-readable name
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub name: Option<String>,
34    /// Whether this server is enabled
35    #[serde(default = "default_true")]
36    pub enabled: bool,
37    /// Transport configuration
38    pub transport: TransportConfig,
39    /// Request timeout in milliseconds
40    #[serde(default = "default_request_timeout")]
41    pub request_timeout_ms: u64,
42    /// Health check interval in milliseconds
43    #[serde(default = "default_healthcheck_interval")]
44    pub healthcheck_interval_ms: u64,
45    /// Reconnection configuration
46    #[serde(default)]
47    pub reconnect: ReconnectConfig,
48    /// List of allowed tools (empty = all allowed)
49    #[serde(default)]
50    pub allowed_tools: Vec<String>,
51    /// List of denied tools
52    #[serde(default)]
53    pub denied_tools: Vec<String>,
54}
55
56fn default_true() -> bool {
57    true
58}
59
60fn default_request_timeout() -> u64 {
61    60000 // 60 seconds
62}
63
64fn default_healthcheck_interval() -> u64 {
65    30000 // 30 seconds
66}
67
68/// Transport configuration variants
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type", rename_all = "lowercase")]
71pub enum TransportConfig {
72    Stdio(StdioConfig),
73    Sse(SseConfig),
74}
75
76/// Stdio transport configuration
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct StdioConfig {
79    /// Command to execute
80    pub command: String,
81    /// Arguments for the command
82    #[serde(default)]
83    pub args: Vec<String>,
84    /// Working directory
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub cwd: Option<String>,
87    /// Environment variables (plaintext, in-memory only).
88    ///
89    /// Persisted to disk as `env_encrypted` and hydrated on load.
90    #[serde(default, skip_serializing)]
91    pub env: HashMap<String, String>,
92    /// Encrypted environment variables values (nonce:ciphertext), keyed by env var name.
93    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
94    pub env_encrypted: HashMap<String, String>,
95    /// Startup timeout in milliseconds
96    #[serde(default = "default_startup_timeout")]
97    pub startup_timeout_ms: u64,
98}
99
100fn default_startup_timeout() -> u64 {
101    20000 // 20 seconds
102}
103
104/// SSE transport configuration
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SseConfig {
107    /// SSE endpoint URL
108    pub url: String,
109    /// Additional headers
110    #[serde(default)]
111    pub headers: Vec<HeaderConfig>,
112    /// Connection timeout in milliseconds
113    #[serde(default = "default_connect_timeout")]
114    pub connect_timeout_ms: u64,
115}
116
117fn default_connect_timeout() -> u64 {
118    10000 // 10 seconds
119}
120
121/// HTTP header configuration
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct HeaderConfig {
124    pub name: String,
125    /// Header value (plaintext, in-memory only).
126    ///
127    /// Persisted to disk as `value_encrypted` and hydrated on load.
128    #[serde(default, skip_serializing)]
129    pub value: String,
130    /// Encrypted header value (nonce:ciphertext).
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    pub value_encrypted: Option<String>,
133}
134
135/// Reconnection configuration
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ReconnectConfig {
138    #[serde(default = "default_true")]
139    pub enabled: bool,
140    /// Initial backoff in milliseconds
141    #[serde(default = "default_initial_backoff")]
142    pub initial_backoff_ms: u64,
143    /// Maximum backoff in milliseconds
144    #[serde(default = "default_max_backoff")]
145    pub max_backoff_ms: u64,
146    /// Maximum reconnection attempts (0 = unlimited)
147    #[serde(default)]
148    pub max_attempts: u32,
149}
150
151impl Default for ReconnectConfig {
152    fn default() -> Self {
153        Self {
154            enabled: true,
155            initial_backoff_ms: 1000,
156            max_backoff_ms: 30000,
157            max_attempts: 0,
158        }
159    }
160}
161
162fn default_initial_backoff() -> u64 {
163    1000
164}
165
166fn default_max_backoff() -> u64 {
167    30000
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_mcp_config_default() {
176        let config = McpConfig::default();
177        assert_eq!(config.version, 1);
178        assert!(config.servers.is_empty());
179    }
180
181    #[test]
182    fn test_mcp_config_deserialization() {
183        let json = r#"{"version": 2, "servers": []}"#;
184        let config: McpConfig = serde_json::from_str(json).unwrap();
185        assert_eq!(config.version, 2);
186        assert!(config.servers.is_empty());
187    }
188
189    #[test]
190    fn test_mcp_config_default_version() {
191        let json = r#"{"servers": []}"#;
192        let config: McpConfig = serde_json::from_str(json).unwrap();
193        assert_eq!(config.version, 1);
194    }
195
196    #[test]
197    fn test_mcp_server_config_minimal() {
198        let json = r#"{
199            "id": "test-server",
200            "transport": {
201                "type": "stdio",
202                "command": "node"
203            }
204        }"#;
205        let config: McpServerConfig = serde_json::from_str(json).unwrap();
206        assert_eq!(config.id, "test-server");
207        assert!(config.enabled); // default
208        assert_eq!(config.request_timeout_ms, 60000); // default
209        assert_eq!(config.healthcheck_interval_ms, 30000); // default
210        assert!(config.reconnect.enabled); // default
211        assert!(config.allowed_tools.is_empty());
212        assert!(config.denied_tools.is_empty());
213    }
214
215    #[test]
216    fn test_mcp_server_config_full() {
217        let json = r#"{
218            "id": "test-server",
219            "name": "Test Server",
220            "enabled": false,
221            "transport": {
222                "type": "stdio",
223                "command": "node",
224                "args": ["server.js"],
225                "cwd": "/app",
226                "env": {"NODE_ENV": "production"},
227                "startup_timeout_ms": 30000
228            },
229            "request_timeout_ms": 120000,
230            "healthcheck_interval_ms": 60000,
231            "reconnect": {
232                "enabled": true,
233                "initial_backoff_ms": 2000,
234                "max_backoff_ms": 60000,
235                "max_attempts": 5
236            },
237            "allowed_tools": ["tool1", "tool2"],
238            "denied_tools": ["tool3"]
239        }"#;
240        let config: McpServerConfig = serde_json::from_str(json).unwrap();
241        assert_eq!(config.id, "test-server");
242        assert_eq!(config.name, Some("Test Server".to_string()));
243        assert!(!config.enabled);
244        assert_eq!(config.request_timeout_ms, 120000);
245        assert_eq!(config.healthcheck_interval_ms, 60000);
246        assert!(config.reconnect.enabled);
247        assert_eq!(config.reconnect.initial_backoff_ms, 2000);
248        assert_eq!(config.reconnect.max_backoff_ms, 60000);
249        assert_eq!(config.reconnect.max_attempts, 5);
250        assert_eq!(config.allowed_tools, vec!["tool1", "tool2"]);
251        assert_eq!(config.denied_tools, vec!["tool3"]);
252    }
253
254    #[test]
255    fn test_stdio_config() {
256        let json = r#"{
257            "type": "stdio",
258            "command": "python",
259            "args": ["-m", "server"],
260            "cwd": "/home/user",
261            "env": {"DEBUG": "1"},
262            "startup_timeout_ms": 15000
263        }"#;
264        let config: TransportConfig = serde_json::from_str(json).unwrap();
265        match config {
266            TransportConfig::Stdio(stdio) => {
267                assert_eq!(stdio.command, "python");
268                assert_eq!(stdio.args, vec!["-m", "server"]);
269                assert_eq!(stdio.cwd, Some("/home/user".to_string()));
270                assert_eq!(stdio.env.get("DEBUG"), Some(&"1".to_string()));
271                assert_eq!(stdio.startup_timeout_ms, 15000);
272            }
273            _ => panic!("Expected Stdio transport"),
274        }
275    }
276
277    #[test]
278    fn test_stdio_config_minimal() {
279        let json = r#"{
280            "type": "stdio",
281            "command": "node"
282        }"#;
283        let config: TransportConfig = serde_json::from_str(json).unwrap();
284        match config {
285            TransportConfig::Stdio(stdio) => {
286                assert_eq!(stdio.command, "node");
287                assert!(stdio.args.is_empty());
288                assert!(stdio.cwd.is_none());
289                assert!(stdio.env.is_empty());
290                assert_eq!(stdio.startup_timeout_ms, 20000); // default
291            }
292            _ => panic!("Expected Stdio transport"),
293        }
294    }
295
296    #[test]
297    fn test_sse_config() {
298        let json = r#"{
299            "type": "sse",
300            "url": "http://localhost:8080/sse",
301            "headers": [
302                {"name": "Authorization", "value": "Bearer token123"}
303            ],
304            "connect_timeout_ms": 5000
305        }"#;
306        let config: TransportConfig = serde_json::from_str(json).unwrap();
307        match config {
308            TransportConfig::Sse(sse) => {
309                assert_eq!(sse.url, "http://localhost:8080/sse");
310                assert_eq!(sse.headers.len(), 1);
311                assert_eq!(sse.headers[0].name, "Authorization");
312                assert_eq!(sse.headers[0].value, "Bearer token123");
313                assert_eq!(sse.connect_timeout_ms, 5000);
314            }
315            _ => panic!("Expected SSE transport"),
316        }
317    }
318
319    #[test]
320    fn test_sse_config_minimal() {
321        let json = r#"{
322            "type": "sse",
323            "url": "http://localhost:8080/sse"
324        }"#;
325        let config: TransportConfig = serde_json::from_str(json).unwrap();
326        match config {
327            TransportConfig::Sse(sse) => {
328                assert_eq!(sse.url, "http://localhost:8080/sse");
329                assert!(sse.headers.is_empty());
330                assert_eq!(sse.connect_timeout_ms, 10000); // default
331            }
332            _ => panic!("Expected SSE transport"),
333        }
334    }
335
336    #[test]
337    fn test_reconnect_config_default() {
338        let config = ReconnectConfig::default();
339        assert!(config.enabled);
340        assert_eq!(config.initial_backoff_ms, 1000);
341        assert_eq!(config.max_backoff_ms, 30000);
342        assert_eq!(config.max_attempts, 0); // unlimited
343    }
344
345    #[test]
346    fn test_reconnect_config_unlimited_attempts() {
347        let json = r#"{
348            "enabled": true,
349            "initial_backoff_ms": 500,
350            "max_backoff_ms": 10000
351        }"#;
352        let config: ReconnectConfig = serde_json::from_str(json).unwrap();
353        assert!(config.enabled);
354        assert_eq!(config.initial_backoff_ms, 500);
355        assert_eq!(config.max_backoff_ms, 10000);
356        assert_eq!(config.max_attempts, 0);
357    }
358
359    #[test]
360    fn test_reconnect_config_disabled() {
361        let json = r#"{"enabled": false}"#;
362        let config: ReconnectConfig = serde_json::from_str(json).unwrap();
363        assert!(!config.enabled);
364    }
365
366    #[test]
367    fn test_header_config() {
368        let header = HeaderConfig {
369            name: "Content-Type".to_string(),
370            value: "application/json".to_string(),
371            value_encrypted: None,
372        };
373        assert_eq!(header.name, "Content-Type");
374        assert_eq!(header.value, "application/json");
375    }
376
377    #[test]
378    fn test_full_mcp_config() {
379        let json = r#"{
380            "version": 1,
381            "servers": [
382                {
383                    "id": "fs-server",
384                    "transport": {
385                        "type": "stdio",
386                        "command": "mcp-server-filesystem"
387                    }
388                },
389                {
390                    "id": "web-server",
391                    "transport": {
392                        "type": "sse",
393                        "url": "http://localhost:3000/sse"
394                    }
395                }
396            ]
397        }"#;
398        let config: McpConfig = serde_json::from_str(json).unwrap();
399        assert_eq!(config.servers.len(), 2);
400        assert_eq!(config.servers[0].id, "fs-server");
401        assert_eq!(config.servers[1].id, "web-server");
402    }
403
404    #[test]
405    fn test_server_config_disabled() {
406        let json = r#"{
407            "id": "disabled-server",
408            "enabled": false,
409            "transport": {"type": "stdio", "command": "node"}
410        }"#;
411        let config: McpServerConfig = serde_json::from_str(json).unwrap();
412        assert!(!config.enabled);
413    }
414}