Skip to main content

forge_config/
lib.rs

1#![warn(missing_docs)]
2
3//! # forge-config
4//!
5//! Configuration loading for the Forge Code Mode MCP Gateway.
6//!
7//! Supports TOML configuration files with environment variable expansion.
8//!
9//! ## Example
10//!
11//! ```toml
12//! [servers.narsil]
13//! command = "narsil-mcp"
14//! args = ["--repos", "."]
15//! transport = "stdio"
16//!
17//! [servers.github]
18//! url = "https://mcp.github.com/mcp"
19//! transport = "sse"
20//! headers = { Authorization = "Bearer ${GITHUB_TOKEN}" }
21//!
22//! [sandbox]
23//! timeout_secs = 5
24//! max_heap_mb = 64
25//! max_concurrent = 8
26//! max_tool_calls = 50
27//! ```
28
29use std::collections::HashMap;
30use std::path::Path;
31
32use serde::Deserialize;
33use thiserror::Error;
34
35/// Errors from config parsing.
36#[derive(Debug, Error)]
37pub enum ConfigError {
38    /// Failed to read config file.
39    #[error("failed to read config file: {0}")]
40    Io(#[from] std::io::Error),
41
42    /// Failed to parse TOML.
43    #[error("failed to parse config: {0}")]
44    Parse(#[from] toml::de::Error),
45
46    /// Invalid configuration value.
47    #[error("invalid config: {0}")]
48    Invalid(String),
49}
50
51/// Top-level Forge configuration.
52#[derive(Debug, Clone, Deserialize)]
53pub struct ForgeConfig {
54    /// Downstream MCP server configurations, keyed by server name.
55    #[serde(default)]
56    pub servers: HashMap<String, ServerConfig>,
57
58    /// Sandbox execution settings.
59    #[serde(default)]
60    pub sandbox: SandboxOverrides,
61
62    /// Server group definitions for cross-server data flow policies.
63    #[serde(default)]
64    pub groups: HashMap<String, GroupConfig>,
65}
66
67/// Configuration for a server group.
68#[derive(Debug, Clone, Deserialize)]
69pub struct GroupConfig {
70    /// Server names belonging to this group.
71    pub servers: Vec<String>,
72
73    /// Isolation mode: "strict" (no cross-group data flow) or "open" (unrestricted).
74    #[serde(default = "default_isolation")]
75    pub isolation: String,
76}
77
78fn default_isolation() -> String {
79    "open".to_string()
80}
81
82/// Configuration for a single downstream MCP server.
83#[derive(Debug, Clone, Deserialize)]
84pub struct ServerConfig {
85    /// Transport type: "stdio" or "sse".
86    pub transport: String,
87
88    /// Command to execute (stdio transport).
89    #[serde(default)]
90    pub command: Option<String>,
91
92    /// Command arguments (stdio transport).
93    #[serde(default)]
94    pub args: Vec<String>,
95
96    /// Server URL (sse transport).
97    #[serde(default)]
98    pub url: Option<String>,
99
100    /// HTTP headers (sse transport).
101    #[serde(default)]
102    pub headers: HashMap<String, String>,
103
104    /// Server description (optional, for manifest).
105    #[serde(default)]
106    pub description: Option<String>,
107
108    /// Per-server timeout in seconds for individual tool calls.
109    #[serde(default)]
110    pub timeout_secs: Option<u64>,
111
112    /// Enable circuit breaker for this server.
113    #[serde(default)]
114    pub circuit_breaker: Option<bool>,
115
116    /// Number of consecutive failures before opening the circuit (default: 3).
117    #[serde(default)]
118    pub failure_threshold: Option<u32>,
119
120    /// Seconds to wait before probing a tripped circuit (default: 30).
121    #[serde(default)]
122    pub recovery_timeout_secs: Option<u64>,
123}
124
125/// Sandbox configuration overrides.
126#[derive(Debug, Clone, Default, Deserialize)]
127pub struct SandboxOverrides {
128    /// Execution timeout in seconds.
129    #[serde(default)]
130    pub timeout_secs: Option<u64>,
131
132    /// Maximum V8 heap size in megabytes.
133    #[serde(default)]
134    pub max_heap_mb: Option<usize>,
135
136    /// Maximum concurrent sandbox executions.
137    #[serde(default)]
138    pub max_concurrent: Option<usize>,
139
140    /// Maximum tool calls per execution.
141    #[serde(default)]
142    pub max_tool_calls: Option<usize>,
143
144    /// Execution mode: "in_process" (default) or "child_process".
145    #[serde(default)]
146    pub execution_mode: Option<String>,
147}
148
149impl ForgeConfig {
150    /// Parse a config from a TOML string.
151    pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
152        let config: ForgeConfig = toml::from_str(toml_str)?;
153        config.validate()?;
154        Ok(config)
155    }
156
157    /// Load config from a file path.
158    pub fn from_file(path: &Path) -> Result<Self, ConfigError> {
159        let content = std::fs::read_to_string(path)?;
160        Self::from_toml(&content)
161    }
162
163    /// Parse a config from a TOML string, expanding `${ENV_VAR}` references.
164    pub fn from_toml_with_env(toml_str: &str) -> Result<Self, ConfigError> {
165        let expanded = expand_env_vars(toml_str);
166        Self::from_toml(&expanded)
167    }
168
169    /// Load config from a file path, expanding environment variables.
170    pub fn from_file_with_env(path: &Path) -> Result<Self, ConfigError> {
171        let content = std::fs::read_to_string(path)?;
172        Self::from_toml_with_env(&content)
173    }
174
175    fn validate(&self) -> Result<(), ConfigError> {
176        for (name, server) in &self.servers {
177            match server.transport.as_str() {
178                "stdio" => {
179                    if server.command.is_none() {
180                        return Err(ConfigError::Invalid(format!(
181                            "server '{}': stdio transport requires 'command'",
182                            name
183                        )));
184                    }
185                }
186                "sse" => {
187                    if server.url.is_none() {
188                        return Err(ConfigError::Invalid(format!(
189                            "server '{}': sse transport requires 'url'",
190                            name
191                        )));
192                    }
193                }
194                other => {
195                    return Err(ConfigError::Invalid(format!(
196                        "server '{}': unsupported transport '{}', supported: stdio, sse",
197                        name, other
198                    )));
199                }
200            }
201        }
202
203        // Validate groups
204        let mut seen_servers: HashMap<&str, &str> = HashMap::new();
205        for (group_name, group_config) in &self.groups {
206            // Validate isolation mode
207            match group_config.isolation.as_str() {
208                "strict" | "open" => {}
209                other => {
210                    return Err(ConfigError::Invalid(format!(
211                        "group '{}': unsupported isolation '{}', supported: strict, open",
212                        group_name, other
213                    )));
214                }
215            }
216
217            for server_ref in &group_config.servers {
218                // Check server exists
219                if !self.servers.contains_key(server_ref) {
220                    return Err(ConfigError::Invalid(format!(
221                        "group '{}': references unknown server '{}'",
222                        group_name, server_ref
223                    )));
224                }
225                // Check no server in multiple groups
226                if let Some(existing_group) = seen_servers.get(server_ref.as_str()) {
227                    return Err(ConfigError::Invalid(format!(
228                        "server '{}' is in multiple groups: '{}' and '{}'",
229                        server_ref, existing_group, group_name
230                    )));
231                }
232                seen_servers.insert(server_ref, group_name);
233            }
234        }
235
236        Ok(())
237    }
238}
239
240/// Expand `${ENV_VAR}` patterns in a string using environment variables.
241fn expand_env_vars(input: &str) -> String {
242    let mut result = String::with_capacity(input.len());
243    let mut chars = input.chars().peekable();
244
245    while let Some(ch) = chars.next() {
246        if ch == '$' && chars.peek() == Some(&'{') {
247            chars.next(); // consume '{'
248            let mut var_name = String::new();
249            for c in chars.by_ref() {
250                if c == '}' {
251                    break;
252                }
253                var_name.push(c);
254            }
255            match std::env::var(&var_name) {
256                Ok(value) => result.push_str(&value),
257                Err(_) => {
258                    // Leave the placeholder if env var not found
259                    result.push_str(&format!("${{{}}}", var_name));
260                }
261            }
262        } else {
263            result.push(ch);
264        }
265    }
266
267    result
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn config_parses_minimal_toml() {
276        let toml = r#"
277            [servers.narsil]
278            command = "narsil-mcp"
279            transport = "stdio"
280        "#;
281
282        let config = ForgeConfig::from_toml(toml).unwrap();
283        assert_eq!(config.servers.len(), 1);
284        let narsil = &config.servers["narsil"];
285        assert_eq!(narsil.transport, "stdio");
286        assert_eq!(narsil.command.as_deref(), Some("narsil-mcp"));
287    }
288
289    #[test]
290    fn config_parses_sse_server() {
291        let toml = r#"
292            [servers.github]
293            url = "https://mcp.github.com/sse"
294            transport = "sse"
295        "#;
296
297        let config = ForgeConfig::from_toml(toml).unwrap();
298        let github = &config.servers["github"];
299        assert_eq!(github.transport, "sse");
300        assert_eq!(github.url.as_deref(), Some("https://mcp.github.com/sse"));
301    }
302
303    #[test]
304    fn config_parses_sandbox_overrides() {
305        let toml = r#"
306            [sandbox]
307            timeout_secs = 10
308            max_heap_mb = 128
309            max_concurrent = 4
310            max_tool_calls = 100
311        "#;
312
313        let config = ForgeConfig::from_toml(toml).unwrap();
314        assert_eq!(config.sandbox.timeout_secs, Some(10));
315        assert_eq!(config.sandbox.max_heap_mb, Some(128));
316        assert_eq!(config.sandbox.max_concurrent, Some(4));
317        assert_eq!(config.sandbox.max_tool_calls, Some(100));
318    }
319
320    #[test]
321    fn config_expands_environment_variables() {
322        std::env::set_var("FORGE_TEST_TOKEN", "secret123");
323        let toml = r#"
324            [servers.github]
325            url = "https://mcp.github.com/sse"
326            transport = "sse"
327            headers = { Authorization = "Bearer ${FORGE_TEST_TOKEN}" }
328        "#;
329
330        let config = ForgeConfig::from_toml_with_env(toml).unwrap();
331        let github = &config.servers["github"];
332        assert_eq!(
333            github.headers.get("Authorization").unwrap(),
334            "Bearer secret123"
335        );
336        std::env::remove_var("FORGE_TEST_TOKEN");
337    }
338
339    #[test]
340    fn config_rejects_invalid_transport() {
341        let toml = r#"
342            [servers.test]
343            command = "test"
344            transport = "grpc"
345        "#;
346
347        let err = ForgeConfig::from_toml(toml).unwrap_err();
348        let msg = err.to_string();
349        assert!(
350            msg.contains("grpc"),
351            "error should mention the transport: {msg}"
352        );
353        assert!(
354            msg.contains("stdio"),
355            "error should mention supported transports: {msg}"
356        );
357    }
358
359    #[test]
360    fn config_rejects_stdio_without_command() {
361        let toml = r#"
362            [servers.test]
363            transport = "stdio"
364        "#;
365
366        let err = ForgeConfig::from_toml(toml).unwrap_err();
367        assert!(err.to_string().contains("command"));
368    }
369
370    #[test]
371    fn config_rejects_sse_without_url() {
372        let toml = r#"
373            [servers.test]
374            transport = "sse"
375        "#;
376
377        let err = ForgeConfig::from_toml(toml).unwrap_err();
378        assert!(err.to_string().contains("url"));
379    }
380
381    #[test]
382    fn config_loads_from_file() {
383        let dir = std::env::temp_dir().join("forge-config-test");
384        std::fs::create_dir_all(&dir).unwrap();
385        let path = dir.join("forge.toml");
386        std::fs::write(
387            &path,
388            r#"
389            [servers.test]
390            command = "test-server"
391            transport = "stdio"
392        "#,
393        )
394        .unwrap();
395
396        let config = ForgeConfig::from_file(&path).unwrap();
397        assert_eq!(config.servers.len(), 1);
398        assert_eq!(
399            config.servers["test"].command.as_deref(),
400            Some("test-server")
401        );
402
403        std::fs::remove_dir_all(&dir).ok();
404    }
405
406    #[test]
407    fn config_uses_defaults_when_absent() {
408        let toml = r#"
409            [servers.test]
410            command = "test"
411            transport = "stdio"
412        "#;
413
414        let config = ForgeConfig::from_toml(toml).unwrap();
415        assert!(config.sandbox.timeout_secs.is_none());
416        assert!(config.sandbox.max_heap_mb.is_none());
417        assert!(config.sandbox.max_concurrent.is_none());
418        assert!(config.sandbox.max_tool_calls.is_none());
419    }
420
421    #[test]
422    fn config_parses_full_example() {
423        let toml = r#"
424            [servers.narsil]
425            command = "narsil-mcp"
426            args = ["--repos", ".", "--streaming"]
427            transport = "stdio"
428            description = "Code intelligence"
429
430            [servers.github]
431            url = "https://mcp.github.com/sse"
432            transport = "sse"
433            headers = { Authorization = "Bearer token123" }
434
435            [sandbox]
436            timeout_secs = 5
437            max_heap_mb = 64
438            max_concurrent = 8
439            max_tool_calls = 50
440        "#;
441
442        let config = ForgeConfig::from_toml(toml).unwrap();
443        assert_eq!(config.servers.len(), 2);
444
445        let narsil = &config.servers["narsil"];
446        assert_eq!(narsil.command.as_deref(), Some("narsil-mcp"));
447        assert_eq!(narsil.args, vec!["--repos", ".", "--streaming"]);
448        assert_eq!(narsil.description.as_deref(), Some("Code intelligence"));
449
450        let github = &config.servers["github"];
451        assert_eq!(github.url.as_deref(), Some("https://mcp.github.com/sse"));
452        assert_eq!(
453            github.headers.get("Authorization").unwrap(),
454            "Bearer token123"
455        );
456
457        assert_eq!(config.sandbox.timeout_secs, Some(5));
458    }
459
460    #[test]
461    fn config_empty_servers_is_valid() {
462        let toml = "";
463        let config = ForgeConfig::from_toml(toml).unwrap();
464        assert!(config.servers.is_empty());
465    }
466
467    #[test]
468    fn env_var_expansion_preserves_unresolved() {
469        let result = expand_env_vars("prefix ${DEFINITELY_NOT_SET_12345} suffix");
470        assert_eq!(result, "prefix ${DEFINITELY_NOT_SET_12345} suffix");
471    }
472
473    #[test]
474    fn env_var_expansion_handles_no_vars() {
475        let result = expand_env_vars("no variables here");
476        assert_eq!(result, "no variables here");
477    }
478
479    #[test]
480    fn config_parses_execution_mode_child_process() {
481        let toml = r#"
482            [sandbox]
483            execution_mode = "child_process"
484        "#;
485
486        let config = ForgeConfig::from_toml(toml).unwrap();
487        assert_eq!(
488            config.sandbox.execution_mode.as_deref(),
489            Some("child_process")
490        );
491    }
492
493    #[test]
494    fn config_parses_groups() {
495        let toml = r#"
496            [servers.vault]
497            command = "vault-mcp"
498            transport = "stdio"
499
500            [servers.slack]
501            command = "slack-mcp"
502            transport = "stdio"
503
504            [groups.internal]
505            servers = ["vault"]
506            isolation = "strict"
507
508            [groups.external]
509            servers = ["slack"]
510            isolation = "open"
511        "#;
512
513        let config = ForgeConfig::from_toml(toml).unwrap();
514        assert_eq!(config.groups.len(), 2);
515        assert_eq!(config.groups["internal"].isolation, "strict");
516        assert_eq!(config.groups["external"].servers, vec!["slack"]);
517    }
518
519    #[test]
520    fn config_groups_default_to_empty() {
521        let toml = r#"
522            [servers.test]
523            command = "test"
524            transport = "stdio"
525        "#;
526        let config = ForgeConfig::from_toml(toml).unwrap();
527        assert!(config.groups.is_empty());
528    }
529
530    #[test]
531    fn config_rejects_group_with_unknown_server() {
532        let toml = r#"
533            [servers.real]
534            command = "real"
535            transport = "stdio"
536
537            [groups.bad]
538            servers = ["nonexistent"]
539        "#;
540        let err = ForgeConfig::from_toml(toml).unwrap_err();
541        let msg = err.to_string();
542        assert!(msg.contains("nonexistent"), "should mention server: {msg}");
543        assert!(msg.contains("unknown"), "should say unknown: {msg}");
544    }
545
546    #[test]
547    fn config_rejects_server_in_multiple_groups() {
548        let toml = r#"
549            [servers.shared]
550            command = "shared"
551            transport = "stdio"
552
553            [groups.a]
554            servers = ["shared"]
555
556            [groups.b]
557            servers = ["shared"]
558        "#;
559        let err = ForgeConfig::from_toml(toml).unwrap_err();
560        let msg = err.to_string();
561        assert!(msg.contains("shared"), "should mention server: {msg}");
562        assert!(
563            msg.contains("multiple groups"),
564            "should say multiple groups: {msg}"
565        );
566    }
567
568    #[test]
569    fn config_rejects_invalid_isolation_mode() {
570        let toml = r#"
571            [servers.test]
572            command = "test"
573            transport = "stdio"
574
575            [groups.bad]
576            servers = ["test"]
577            isolation = "paranoid"
578        "#;
579        let err = ForgeConfig::from_toml(toml).unwrap_err();
580        let msg = err.to_string();
581        assert!(msg.contains("paranoid"), "should mention mode: {msg}");
582    }
583
584    #[test]
585    fn config_parses_server_timeout() {
586        let toml = r#"
587            [servers.slow]
588            command = "slow-mcp"
589            transport = "stdio"
590            timeout_secs = 30
591        "#;
592
593        let config = ForgeConfig::from_toml(toml).unwrap();
594        assert_eq!(config.servers["slow"].timeout_secs, Some(30));
595    }
596
597    #[test]
598    fn config_server_timeout_defaults_to_none() {
599        let toml = r#"
600            [servers.fast]
601            command = "fast-mcp"
602            transport = "stdio"
603        "#;
604
605        let config = ForgeConfig::from_toml(toml).unwrap();
606        assert!(config.servers["fast"].timeout_secs.is_none());
607    }
608
609    #[test]
610    fn config_parses_circuit_breaker() {
611        let toml = r#"
612            [servers.flaky]
613            command = "flaky-mcp"
614            transport = "stdio"
615            circuit_breaker = true
616            failure_threshold = 5
617            recovery_timeout_secs = 60
618        "#;
619
620        let config = ForgeConfig::from_toml(toml).unwrap();
621        let flaky = &config.servers["flaky"];
622        assert_eq!(flaky.circuit_breaker, Some(true));
623        assert_eq!(flaky.failure_threshold, Some(5));
624        assert_eq!(flaky.recovery_timeout_secs, Some(60));
625    }
626
627    #[test]
628    fn config_circuit_breaker_defaults_to_none() {
629        let toml = r#"
630            [servers.stable]
631            command = "stable-mcp"
632            transport = "stdio"
633        "#;
634
635        let config = ForgeConfig::from_toml(toml).unwrap();
636        let stable = &config.servers["stable"];
637        assert!(stable.circuit_breaker.is_none());
638        assert!(stable.failure_threshold.is_none());
639        assert!(stable.recovery_timeout_secs.is_none());
640    }
641
642    #[test]
643    fn config_execution_mode_defaults_to_none() {
644        let toml = r#"
645            [sandbox]
646            timeout_secs = 5
647        "#;
648
649        let config = ForgeConfig::from_toml(toml).unwrap();
650        assert!(config.sandbox.execution_mode.is_none());
651    }
652}