Skip to main content

mcp_proxy/
builder.rs

1//! Programmatic proxy builder for library users.
2//!
3//! Constructs a [`ProxyConfig`] via a fluent API, avoiding the need for
4//! TOML files. The resulting config is passed to [`Proxy::from_config()`]
5//! as usual.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use mcp_proxy::builder::ProxyBuilder;
11//!
12//! # async fn example() -> anyhow::Result<()> {
13//! let proxy = ProxyBuilder::new("my-proxy")
14//!     .version("1.0.0")
15//!     .listen("0.0.0.0", 9090)
16//!     .stdio_backend("files", "npx", &["-y", "@modelcontextprotocol/server-filesystem"])
17//!     .http_backend("api", "http://api:8080")
18//!     .build()
19//!     .await?;
20//!
21//! // Embed in an existing axum app
22//! let (router, _session_handle) = proxy.into_router();
23//! # Ok(())
24//! # }
25//! ```
26
27use std::collections::HashMap;
28use std::time::Duration;
29
30use anyhow::Result;
31
32use crate::Proxy;
33use crate::config::*;
34
35/// Fluent builder for constructing an MCP proxy without TOML config files.
36///
37/// Call [`build()`](Self::build) to connect backends and produce a
38/// ready-to-serve [`Proxy`].
39pub struct ProxyBuilder {
40    config: ProxyConfig,
41}
42
43impl ProxyBuilder {
44    /// Create a new proxy builder with the given name.
45    ///
46    /// Defaults: version "0.1.0", separator "/", listen 127.0.0.1:8080.
47    pub fn new(name: impl Into<String>) -> Self {
48        Self {
49            config: ProxyConfig {
50                proxy: ProxySettings {
51                    name: name.into(),
52                    version: "0.1.0".to_string(),
53                    separator: "/".to_string(),
54                    listen: ListenConfig {
55                        host: "127.0.0.1".to_string(),
56                        port: 8080,
57                    },
58                    instructions: None,
59                    shutdown_timeout_seconds: 30,
60                    hot_reload: false,
61                    import_backends: None,
62                    rate_limit: None,
63                    tool_discovery: false,
64                    tool_exposure: crate::config::ToolExposure::default(),
65                },
66                backends: Vec::new(),
67                auth: None,
68                performance: PerformanceConfig::default(),
69                security: SecurityConfig::default(),
70                cache: CacheBackendConfig::default(),
71                observability: ObservabilityConfig::default(),
72                composite_tools: Vec::new(),
73            },
74        }
75    }
76
77    /// Set the proxy version (default: "0.1.0").
78    pub fn version(mut self, version: impl Into<String>) -> Self {
79        self.config.proxy.version = version.into();
80        self
81    }
82
83    /// Set the namespace separator (default: "/").
84    pub fn separator(mut self, separator: impl Into<String>) -> Self {
85        self.config.proxy.separator = separator.into();
86        self
87    }
88
89    /// Set the listen address and port (default: 127.0.0.1:8080).
90    pub fn listen(mut self, host: impl Into<String>, port: u16) -> Self {
91        self.config.proxy.listen = ListenConfig {
92            host: host.into(),
93            port,
94        };
95        self
96    }
97
98    /// Set instructions text sent to MCP clients.
99    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
100        self.config.proxy.instructions = Some(instructions.into());
101        self
102    }
103
104    /// Set the graceful shutdown timeout (default: 30s).
105    pub fn shutdown_timeout(mut self, timeout: Duration) -> Self {
106        self.config.proxy.shutdown_timeout_seconds = timeout.as_secs();
107        self
108    }
109
110    /// Enable hot reload for watching config file changes.
111    pub fn hot_reload(mut self, enabled: bool) -> Self {
112        self.config.proxy.hot_reload = enabled;
113        self
114    }
115
116    /// Set a global rate limit across all backends.
117    pub fn global_rate_limit(mut self, requests: usize, period: Duration) -> Self {
118        self.config.proxy.rate_limit = Some(GlobalRateLimitConfig {
119            requests,
120            period_seconds: period.as_secs(),
121        });
122        self
123    }
124
125    /// Add a stdio backend (subprocess).
126    pub fn stdio_backend(
127        mut self,
128        name: impl Into<String>,
129        command: impl Into<String>,
130        args: &[&str],
131    ) -> Self {
132        self.config.backends.push(BackendConfig {
133            name: name.into(),
134            transport: TransportType::Stdio,
135            command: Some(command.into()),
136            args: args.iter().map(|s| s.to_string()).collect(),
137            url: None,
138            ..default_backend()
139        });
140        self
141    }
142
143    /// Add a stdio backend with environment variables.
144    pub fn stdio_backend_with_env(
145        mut self,
146        name: impl Into<String>,
147        command: impl Into<String>,
148        args: &[&str],
149        env: HashMap<String, String>,
150    ) -> Self {
151        self.config.backends.push(BackendConfig {
152            name: name.into(),
153            transport: TransportType::Stdio,
154            command: Some(command.into()),
155            args: args.iter().map(|s| s.to_string()).collect(),
156            url: None,
157            env,
158            ..default_backend()
159        });
160        self
161    }
162
163    /// Add an HTTP backend.
164    pub fn http_backend(mut self, name: impl Into<String>, url: impl Into<String>) -> Self {
165        self.config.backends.push(BackendConfig {
166            name: name.into(),
167            transport: TransportType::Http,
168            command: None,
169            url: Some(url.into()),
170            ..default_backend()
171        });
172        self
173    }
174
175    /// Add an HTTP backend with a bearer token.
176    pub fn http_backend_with_token(
177        mut self,
178        name: impl Into<String>,
179        url: impl Into<String>,
180        token: impl Into<String>,
181    ) -> Self {
182        self.config.backends.push(BackendConfig {
183            name: name.into(),
184            transport: TransportType::Http,
185            command: None,
186            url: Some(url.into()),
187            bearer_token: Some(token.into()),
188            ..default_backend()
189        });
190        self
191    }
192
193    /// Configure the last added backend with a per-backend modifier.
194    ///
195    /// # Panics
196    ///
197    /// Panics if no backends have been added.
198    pub fn configure_backend(mut self, f: impl FnOnce(&mut BackendConfig)) -> Self {
199        let backend = self
200            .config
201            .backends
202            .last_mut()
203            .expect("configure_backend called with no backends");
204        f(backend);
205        self
206    }
207
208    /// Enable bearer token authentication.
209    /// Enable bearer token authentication.
210    ///
211    /// All tokens in this list have unrestricted access to all tools.
212    /// For per-token tool scoping, use [`scoped_bearer_auth`](Self::scoped_bearer_auth).
213    pub fn bearer_auth(mut self, tokens: Vec<String>) -> Self {
214        self.config.auth = Some(AuthConfig::Bearer {
215            tokens,
216            scoped_tokens: vec![],
217        });
218        self
219    }
220
221    /// Enable bearer token authentication with per-token tool scoping.
222    ///
223    /// Each [`BearerTokenConfig`] can specify
224    /// `allow_tools` or `deny_tools` to restrict which tools that token can access.
225    pub fn scoped_bearer_auth(mut self, scoped_tokens: Vec<BearerTokenConfig>) -> Self {
226        self.config.auth = Some(AuthConfig::Bearer {
227            tokens: vec![],
228            scoped_tokens,
229        });
230        self
231    }
232
233    /// Enable request coalescing.
234    pub fn coalesce_requests(mut self, enabled: bool) -> Self {
235        self.config.performance.coalesce_requests = enabled;
236        self
237    }
238
239    /// Set the maximum argument size for validation.
240    pub fn max_argument_size(mut self, max_bytes: usize) -> Self {
241        self.config.security.max_argument_size = Some(max_bytes);
242        self
243    }
244
245    /// Enable audit logging.
246    pub fn audit_logging(mut self, enabled: bool) -> Self {
247        self.config.observability.audit = enabled;
248        self
249    }
250
251    /// Enable access logging.
252    pub fn access_logging(mut self, enabled: bool) -> Self {
253        self.config.observability.access_log.enabled = enabled;
254        self
255    }
256
257    /// Set the log level (default: "info").
258    pub fn log_level(mut self, level: impl Into<String>) -> Self {
259        self.config.observability.log_level = level.into();
260        self
261    }
262
263    /// Enable structured JSON logging.
264    pub fn json_logs(mut self, enabled: bool) -> Self {
265        self.config.observability.json_logs = enabled;
266        self
267    }
268
269    /// Enable Prometheus metrics.
270    pub fn metrics(mut self, enabled: bool) -> Self {
271        self.config.observability.metrics.enabled = enabled;
272        self
273    }
274
275    /// Set the timeout for the last added backend.
276    ///
277    /// # Panics
278    ///
279    /// Panics if no backends have been added.
280    ///
281    /// # Example
282    ///
283    /// ```rust
284    /// use mcp_proxy::builder::ProxyBuilder;
285    ///
286    /// let config = ProxyBuilder::new("my-proxy")
287    ///     .http_backend("api", "http://api:8080")
288    ///     .timeout(30)
289    ///     .into_config();
290    ///
291    /// assert_eq!(config.backends[0].timeout.as_ref().unwrap().seconds, 30);
292    /// ```
293    pub fn timeout(mut self, seconds: u64) -> Self {
294        let backend = self
295            .config
296            .backends
297            .last_mut()
298            .expect("timeout called with no backends");
299        backend.timeout = Some(TimeoutConfig { seconds });
300        self
301    }
302
303    /// Set the rate limit for the last added backend.
304    ///
305    /// # Panics
306    ///
307    /// Panics if no backends have been added.
308    ///
309    /// # Example
310    ///
311    /// ```rust
312    /// use mcp_proxy::builder::ProxyBuilder;
313    ///
314    /// let config = ProxyBuilder::new("my-proxy")
315    ///     .http_backend("api", "http://api:8080")
316    ///     .rate_limit(100, 1)
317    ///     .into_config();
318    ///
319    /// let rl = config.backends[0].rate_limit.as_ref().unwrap();
320    /// assert_eq!(rl.requests, 100);
321    /// assert_eq!(rl.period_seconds, 1);
322    /// ```
323    pub fn rate_limit(mut self, requests: usize, period_seconds: u64) -> Self {
324        let backend = self
325            .config
326            .backends
327            .last_mut()
328            .expect("rate_limit called with no backends");
329        backend.rate_limit = Some(RateLimitConfig {
330            requests,
331            period_seconds,
332        });
333        self
334    }
335
336    /// Set the circuit breaker for the last added backend.
337    ///
338    /// Uses sensible defaults for other fields: minimum 5 calls,
339    /// 30-second wait duration, and 3 half-open calls.
340    ///
341    /// # Panics
342    ///
343    /// Panics if no backends have been added.
344    ///
345    /// # Example
346    ///
347    /// ```rust
348    /// use mcp_proxy::builder::ProxyBuilder;
349    ///
350    /// let config = ProxyBuilder::new("my-proxy")
351    ///     .http_backend("api", "http://api:8080")
352    ///     .circuit_breaker(0.5)
353    ///     .into_config();
354    ///
355    /// let cb = config.backends[0].circuit_breaker.as_ref().unwrap();
356    /// assert!((cb.failure_rate_threshold - 0.5).abs() < f64::EPSILON);
357    /// ```
358    pub fn circuit_breaker(mut self, failure_rate: f64) -> Self {
359        let backend = self
360            .config
361            .backends
362            .last_mut()
363            .expect("circuit_breaker called with no backends");
364        backend.circuit_breaker = Some(CircuitBreakerConfig {
365            failure_rate_threshold: failure_rate,
366            minimum_calls: 5,
367            wait_duration_seconds: 30,
368            permitted_calls_in_half_open: 3,
369        });
370        self
371    }
372
373    /// Set the tool allowlist for the last added backend.
374    ///
375    /// Only the listed tools will be exposed through the proxy.
376    ///
377    /// # Panics
378    ///
379    /// Panics if no backends have been added.
380    ///
381    /// # Example
382    ///
383    /// ```rust
384    /// use mcp_proxy::builder::ProxyBuilder;
385    ///
386    /// let config = ProxyBuilder::new("my-proxy")
387    ///     .http_backend("api", "http://api:8080")
388    ///     .expose_tools(&["read_file", "list_dir"])
389    ///     .into_config();
390    ///
391    /// assert_eq!(config.backends[0].expose_tools, vec!["read_file", "list_dir"]);
392    /// ```
393    pub fn expose_tools(mut self, tools: &[&str]) -> Self {
394        let backend = self
395            .config
396            .backends
397            .last_mut()
398            .expect("expose_tools called with no backends");
399        backend.expose_tools = tools.iter().map(|s| s.to_string()).collect();
400        self
401    }
402
403    /// Set the tool denylist for the last added backend.
404    ///
405    /// The listed tools will be hidden from clients.
406    ///
407    /// # Panics
408    ///
409    /// Panics if no backends have been added.
410    ///
411    /// # Example
412    ///
413    /// ```rust
414    /// use mcp_proxy::builder::ProxyBuilder;
415    ///
416    /// let config = ProxyBuilder::new("my-proxy")
417    ///     .http_backend("api", "http://api:8080")
418    ///     .hide_tools(&["dangerous_op"])
419    ///     .into_config();
420    ///
421    /// assert_eq!(config.backends[0].hide_tools, vec!["dangerous_op"]);
422    /// ```
423    pub fn hide_tools(mut self, tools: &[&str]) -> Self {
424        let backend = self
425            .config
426            .backends
427            .last_mut()
428            .expect("hide_tools called with no backends");
429        backend.hide_tools = tools.iter().map(|s| s.to_string()).collect();
430        self
431    }
432
433    /// Set the retry policy for the last added backend.
434    ///
435    /// Uses sensible defaults: 100ms initial backoff, 5000ms max backoff,
436    /// no budget limit.
437    ///
438    /// # Panics
439    ///
440    /// Panics if no backends have been added.
441    ///
442    /// # Example
443    ///
444    /// ```rust
445    /// use mcp_proxy::builder::ProxyBuilder;
446    ///
447    /// let config = ProxyBuilder::new("my-proxy")
448    ///     .http_backend("api", "http://api:8080")
449    ///     .retry(3)
450    ///     .into_config();
451    ///
452    /// let retry = config.backends[0].retry.as_ref().unwrap();
453    /// assert_eq!(retry.max_retries, 3);
454    /// ```
455    pub fn retry(mut self, max_retries: u32) -> Self {
456        let backend = self
457            .config
458            .backends
459            .last_mut()
460            .expect("retry called with no backends");
461        backend.retry = Some(RetryConfig {
462            max_retries,
463            initial_backoff_ms: 100,
464            max_backoff_ms: 5000,
465            budget_percent: None,
466            min_retries_per_sec: 10,
467        });
468        self
469    }
470
471    /// Extract the built [`ProxyConfig`] without connecting to backends.
472    ///
473    /// Useful for inspection, serialization, or passing to
474    /// [`Proxy::from_config()`] manually.
475    pub fn into_config(self) -> ProxyConfig {
476        self.config
477    }
478
479    /// Build the proxy: validate config, connect to all backends, and
480    /// construct the middleware stack.
481    pub async fn build(self) -> Result<Proxy> {
482        Proxy::from_config(self.config).await
483    }
484}
485
486/// Default backend config with all optional fields set to `None`/empty.
487fn default_backend() -> BackendConfig {
488    BackendConfig {
489        name: String::new(),
490        transport: TransportType::Stdio,
491        command: None,
492        args: Vec::new(),
493        url: None,
494        env: HashMap::new(),
495        bearer_token: None,
496        forward_auth: false,
497        timeout: None,
498        circuit_breaker: None,
499        rate_limit: None,
500        concurrency: None,
501        retry: None,
502        outlier_detection: None,
503        hedging: None,
504        cache: None,
505        default_args: serde_json::Map::new(),
506        inject_args: Vec::new(),
507        param_overrides: Vec::new(),
508        expose_tools: Vec::new(),
509        hide_tools: Vec::new(),
510        expose_resources: Vec::new(),
511        hide_resources: Vec::new(),
512        expose_prompts: Vec::new(),
513        hide_prompts: Vec::new(),
514        hide_destructive: false,
515        read_only_only: false,
516        failover_for: None,
517        priority: 0,
518        canary_of: None,
519        weight: 100,
520        aliases: Vec::new(),
521        mirror_of: None,
522        mirror_percent: 100,
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    #[test]
531    fn test_builder_minimal() {
532        let config = ProxyBuilder::new("test-proxy").into_config();
533        assert_eq!(config.proxy.name, "test-proxy");
534        assert_eq!(config.proxy.version, "0.1.0");
535        assert_eq!(config.proxy.separator, "/");
536        assert_eq!(config.proxy.listen.host, "127.0.0.1");
537        assert_eq!(config.proxy.listen.port, 8080);
538        assert!(config.backends.is_empty());
539    }
540
541    #[test]
542    fn test_builder_with_backends() {
543        let config = ProxyBuilder::new("test")
544            .stdio_backend("files", "npx", &["-y", "@mcp/server-files"])
545            .http_backend("api", "http://localhost:8080")
546            .into_config();
547
548        assert_eq!(config.backends.len(), 2);
549        assert_eq!(config.backends[0].name, "files");
550        assert!(matches!(config.backends[0].transport, TransportType::Stdio));
551        assert_eq!(config.backends[0].command.as_deref(), Some("npx"));
552        assert_eq!(config.backends[1].name, "api");
553        assert!(matches!(config.backends[1].transport, TransportType::Http));
554        assert_eq!(
555            config.backends[1].url.as_deref(),
556            Some("http://localhost:8080")
557        );
558    }
559
560    #[test]
561    fn test_builder_configure_backend() {
562        let config = ProxyBuilder::new("test")
563            .http_backend("api", "http://localhost:8080")
564            .configure_backend(|b| {
565                b.timeout = Some(TimeoutConfig { seconds: 30 });
566                b.rate_limit = Some(RateLimitConfig {
567                    requests: 100,
568                    period_seconds: 1,
569                });
570                b.hide_tools = vec!["dangerous_op".to_string()];
571            })
572            .into_config();
573
574        assert!(config.backends[0].timeout.is_some());
575        assert!(config.backends[0].rate_limit.is_some());
576        assert_eq!(config.backends[0].hide_tools, vec!["dangerous_op"]);
577    }
578
579    #[test]
580    fn test_builder_auth_and_observability() {
581        let config = ProxyBuilder::new("test")
582            .bearer_auth(vec!["token1".into(), "token2".into()])
583            .audit_logging(true)
584            .access_logging(true)
585            .metrics(true)
586            .json_logs(true)
587            .log_level("debug")
588            .into_config();
589
590        assert!(config.auth.is_some());
591        assert!(config.observability.audit);
592        assert!(config.observability.access_log.enabled);
593        assert!(config.observability.metrics.enabled);
594        assert!(config.observability.json_logs);
595        assert_eq!(config.observability.log_level, "debug");
596    }
597
598    #[test]
599    fn test_builder_global_rate_limit() {
600        let config = ProxyBuilder::new("test")
601            .global_rate_limit(500, Duration::from_secs(1))
602            .into_config();
603
604        let rl = config.proxy.rate_limit.unwrap();
605        assert_eq!(rl.requests, 500);
606        assert_eq!(rl.period_seconds, 1);
607    }
608
609    #[test]
610    fn test_builder_all_settings() {
611        let config = ProxyBuilder::new("enterprise")
612            .version("2.0.0")
613            .separator("::")
614            .listen("0.0.0.0", 9090)
615            .instructions("Enterprise MCP gateway")
616            .shutdown_timeout(Duration::from_secs(60))
617            .coalesce_requests(true)
618            .max_argument_size(1_048_576)
619            .into_config();
620
621        assert_eq!(config.proxy.name, "enterprise");
622        assert_eq!(config.proxy.version, "2.0.0");
623        assert_eq!(config.proxy.separator, "::");
624        assert_eq!(config.proxy.listen.host, "0.0.0.0");
625        assert_eq!(config.proxy.listen.port, 9090);
626        assert_eq!(
627            config.proxy.instructions.as_deref(),
628            Some("Enterprise MCP gateway")
629        );
630        assert_eq!(config.proxy.shutdown_timeout_seconds, 60);
631        assert!(config.performance.coalesce_requests);
632        assert_eq!(config.security.max_argument_size, Some(1_048_576));
633    }
634
635    #[test]
636    fn test_builder_http_backend_with_token() {
637        let config = ProxyBuilder::new("test")
638            .http_backend_with_token("api", "http://api:8080", "secret")
639            .into_config();
640
641        assert_eq!(config.backends[0].bearer_token.as_deref(), Some("secret"));
642    }
643
644    #[test]
645    fn test_builder_ergonomic_backend_methods() {
646        let config = ProxyBuilder::new("test")
647            .http_backend("api", "http://api:8080")
648            .timeout(30)
649            .rate_limit(100, 1)
650            .circuit_breaker(0.7)
651            .expose_tools(&["read_file", "list_dir"])
652            .retry(5)
653            .stdio_backend("files", "npx", &["-y", "@mcp/server-files"])
654            .hide_tools(&["dangerous_op"])
655            .timeout(60)
656            .into_config();
657
658        // First backend: api
659        let api = &config.backends[0];
660        assert_eq!(api.timeout.as_ref().unwrap().seconds, 30);
661        let rl = api.rate_limit.as_ref().unwrap();
662        assert_eq!(rl.requests, 100);
663        assert_eq!(rl.period_seconds, 1);
664        let cb = api.circuit_breaker.as_ref().unwrap();
665        assert!((cb.failure_rate_threshold - 0.7).abs() < f64::EPSILON);
666        assert_eq!(cb.minimum_calls, 5);
667        assert_eq!(cb.wait_duration_seconds, 30);
668        assert_eq!(cb.permitted_calls_in_half_open, 3);
669        assert_eq!(api.expose_tools, vec!["read_file", "list_dir"]);
670        let retry = api.retry.as_ref().unwrap();
671        assert_eq!(retry.max_retries, 5);
672        assert_eq!(retry.initial_backoff_ms, 100);
673        assert_eq!(retry.max_backoff_ms, 5000);
674        assert!(retry.budget_percent.is_none());
675
676        // Second backend: files
677        let files = &config.backends[1];
678        assert_eq!(files.hide_tools, vec!["dangerous_op"]);
679        assert_eq!(files.timeout.as_ref().unwrap().seconds, 60);
680        assert!(files.circuit_breaker.is_none());
681        assert!(files.rate_limit.is_none());
682    }
683
684    #[test]
685    fn test_builder_stdio_backend_with_env() {
686        let mut env = HashMap::new();
687        env.insert("GITHUB_TOKEN".to_string(), "ghp_xxx".to_string());
688
689        let config = ProxyBuilder::new("test")
690            .stdio_backend_with_env("github", "npx", &["-y", "@mcp/github"], env)
691            .into_config();
692
693        assert_eq!(
694            config.backends[0].env.get("GITHUB_TOKEN").unwrap(),
695            "ghp_xxx"
696        );
697    }
698}