Skip to main content

rmcp_server_kit/
config.rs

1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5/// Server listener configuration (reusable across MCP projects).
6#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9    /// Listen address (IP or hostname). Default: `127.0.0.1`.
10    #[serde(default = "default_listen_addr")]
11    pub listen_addr: String,
12    /// Listen TCP port. Default: `8443`.
13    #[serde(default = "default_listen_port")]
14    pub listen_port: u16,
15    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
16    pub tls_cert_path: Option<PathBuf>,
17    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
18    pub tls_key_path: Option<PathBuf>,
19    /// Per-handshake deadline on the TLS accept path, parsed via
20    /// `humantime`. Idle or slow-loris connections are dropped once it
21    /// elapses. Startup-only (not hot-reloadable); ignored unless TLS is
22    /// configured. Default: `10s`.
23    #[serde(default = "default_tls_handshake_timeout")]
24    pub tls_handshake_timeout: String,
25    /// Cap on concurrently in-flight TLS handshakes. At saturation the
26    /// acceptor stops pulling new connections from the kernel backlog
27    /// (backpressure). Startup-only (not hot-reloadable); ignored unless
28    /// TLS is configured. Default: `256`.
29    #[serde(default = "default_max_concurrent_tls_handshakes")]
30    pub max_concurrent_tls_handshakes: usize,
31    /// Graceful shutdown timeout, parsed via `humantime`.
32    #[serde(default = "default_shutdown_timeout")]
33    pub shutdown_timeout: String,
34    /// Per-request timeout, parsed via `humantime`.
35    #[serde(default = "default_request_timeout")]
36    pub request_timeout: String,
37    /// Allowed Origin header values for DNS rebinding protection (MCP spec).
38    /// Requests with an Origin not in this list are rejected with 403.
39    /// Requests without an Origin header are always allowed (non-browser).
40    #[serde(default)]
41    pub allowed_origins: Vec<String>,
42    /// Allow the stdio transport subcommand. Disabled by default because
43    /// stdio mode bypasses auth, RBAC, TLS, and Origin validation.
44    #[serde(default)]
45    pub stdio_enabled: bool,
46    /// Maximum tool invocations per source IP per minute.
47    /// When set, enforced by the RBAC middleware on `tools/call` requests.
48    /// Protects against both abuse and runaway LLM loops.
49    pub tool_rate_limit: Option<u32>,
50    /// Burst capacity for the tool rate limiter (bucket size; sustained
51    /// rate stays `tool_rate_limit`). Requires `tool_rate_limit`; must
52    /// be greater than zero.
53    pub tool_rate_limit_burst: Option<u32>,
54    /// Maximum requests per source IP per minute on application routes
55    /// merged via `McpServerConfig::with_extra_router` (which bypass
56    /// auth/RBAC). Opt-in; must be greater than zero when set.
57    /// Keyed by the direct socket peer — no `X-Forwarded-For`
58    /// interpretation. Startup-only.
59    pub extra_route_rate_limit: Option<u32>,
60    /// Burst capacity for the extra-route rate limiter (bucket size;
61    /// sustained rate stays `extra_route_rate_limit`). Requires
62    /// `extra_route_rate_limit`; must be greater than zero.
63    pub extra_route_rate_limit_burst: Option<u32>,
64    /// Idle timeout for MCP sessions. Sessions with no activity for this
65    /// duration are closed automatically. Default: 20 minutes.
66    #[serde(default = "default_session_idle_timeout")]
67    pub session_idle_timeout: String,
68    /// Interval for SSE keep-alive pings sent to the client. Prevents
69    /// proxies and load balancers from killing idle connections.
70    /// Default: 15 seconds.
71    #[serde(default = "default_sse_keep_alive")]
72    pub sse_keep_alive: String,
73    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
74    /// When set, OAuth metadata endpoints advertise this URL instead of
75    /// the listen address. Required when the server binds to `0.0.0.0`
76    /// behind a reverse proxy or inside a container.
77    pub public_url: Option<String>,
78    /// Enable gzip/br response compression for MCP responses.
79    #[serde(default)]
80    pub compression_enabled: bool,
81    /// Minimum response size (bytes) before compression kicks in.
82    /// Only used when `compression_enabled` is true. Default: 1024.
83    #[serde(default = "default_compression_min_size")]
84    pub compression_min_size: u16,
85    /// Global cap on in-flight HTTP requests. When reached, excess
86    /// requests receive 503 Service Unavailable (via load shedding).
87    pub max_concurrent_requests: Option<usize>,
88    /// Enable `/admin/*` diagnostic endpoints.
89    #[serde(default)]
90    pub admin_enabled: bool,
91    /// RBAC role required to access admin endpoints.
92    #[serde(default = "default_admin_role")]
93    pub admin_role: String,
94    /// Authentication configuration (API keys, mTLS, OAuth).
95    pub auth: Option<crate::auth::AuthConfig>,
96}
97
98impl Default for ServerConfig {
99    fn default() -> Self {
100        Self {
101            listen_addr: default_listen_addr(),
102            listen_port: default_listen_port(),
103            tls_cert_path: None,
104            tls_key_path: None,
105            tls_handshake_timeout: default_tls_handshake_timeout(),
106            max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
107            shutdown_timeout: default_shutdown_timeout(),
108            request_timeout: default_request_timeout(),
109            allowed_origins: Vec::new(),
110            stdio_enabled: false,
111            tool_rate_limit: None,
112            tool_rate_limit_burst: None,
113            extra_route_rate_limit: None,
114            extra_route_rate_limit_burst: None,
115            session_idle_timeout: default_session_idle_timeout(),
116            sse_keep_alive: default_sse_keep_alive(),
117            public_url: None,
118            compression_enabled: false,
119            compression_min_size: default_compression_min_size(),
120            max_concurrent_requests: None,
121            admin_enabled: false,
122            admin_role: default_admin_role(),
123            auth: None,
124        }
125    }
126}
127
128/// Observability settings (reusable across MCP projects).
129#[derive(Debug, Deserialize)]
130#[non_exhaustive]
131pub struct ObservabilityConfig {
132    /// `tracing` log level / env filter string (e.g. `info,rmcp_server_kit=debug`).
133    #[serde(default = "default_log_level")]
134    pub log_level: String,
135    /// Log output format: `json`, `pretty`, or `text` (default: `pretty`).
136    #[serde(default = "default_log_format")]
137    pub log_format: String,
138    /// Optional path to an append-only audit log file.
139    pub audit_log_path: Option<PathBuf>,
140    /// Emit inbound HTTP request headers at DEBUG level in transport logs.
141    /// Sensitive headers remain redacted when enabled.
142    #[serde(default)]
143    pub log_request_headers: bool,
144    /// Enable the Prometheus metrics endpoint.
145    #[serde(default)]
146    pub metrics_enabled: bool,
147    /// Bind address for the Prometheus metrics listener.
148    #[serde(default = "default_metrics_bind")]
149    pub metrics_bind: String,
150}
151
152impl Default for ObservabilityConfig {
153    fn default() -> Self {
154        Self {
155            log_level: default_log_level(),
156            log_format: default_log_format(),
157            audit_log_path: None,
158            log_request_headers: false,
159            metrics_enabled: false,
160            metrics_bind: default_metrics_bind(),
161        }
162    }
163}
164
165/// Validate the generic server config fields.
166///
167/// # Errors
168///
169/// Returns `McpxError::Config` on invalid values.
170pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
171    use crate::error::McpxError;
172
173    if server.listen_port == 0 {
174        return Err(McpxError::Config("listen_port must be nonzero".into()));
175    }
176
177    match (&server.tls_cert_path, &server.tls_key_path) {
178        (Some(_), None) | (None, Some(_)) => {
179            return Err(McpxError::Config(
180                "tls_cert_path and tls_key_path must both be set or both omitted".into(),
181            ));
182        }
183        _ => {}
184    }
185
186    if let Some(0) = server.max_concurrent_requests {
187        return Err(McpxError::Config(
188            "max_concurrent_requests must be nonzero when set".into(),
189        ));
190    }
191
192    if let Some(0) = server.extra_route_rate_limit {
193        return Err(McpxError::Config(
194            "server.extra_route_rate_limit must be greater than zero".into(),
195        ));
196    }
197
198    if let Some(0) = server.tool_rate_limit_burst {
199        return Err(McpxError::Config(
200            "server.tool_rate_limit_burst must be greater than zero".into(),
201        ));
202    }
203    if let Some(0) = server.extra_route_rate_limit_burst {
204        return Err(McpxError::Config(
205            "server.extra_route_rate_limit_burst must be greater than zero".into(),
206        ));
207    }
208    if server.tool_rate_limit_burst.is_some() && server.tool_rate_limit.is_none() {
209        return Err(McpxError::Config(
210            "server.tool_rate_limit_burst requires server.tool_rate_limit".into(),
211        ));
212    }
213    if server.extra_route_rate_limit_burst.is_some() && server.extra_route_rate_limit.is_none() {
214        return Err(McpxError::Config(
215            "server.extra_route_rate_limit_burst requires server.extra_route_rate_limit".into(),
216        ));
217    }
218    if let Some(rl) = server.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
219        if rl.burst == Some(0) {
220            return Err(McpxError::Config(
221                "auth.rate_limit.burst must be greater than zero".into(),
222            ));
223        }
224        if rl.pre_auth_burst == Some(0) {
225            return Err(McpxError::Config(
226                "auth.rate_limit.pre_auth_burst must be greater than zero".into(),
227            ));
228        }
229    }
230
231    if server.admin_enabled {
232        let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
233        if !auth_enabled {
234            return Err(McpxError::Config(
235                "admin_enabled=true requires auth to be configured and enabled".into(),
236            ));
237        }
238        if server.admin_role.trim().is_empty() {
239            return Err(McpxError::Config("admin_role must not be empty".into()));
240        }
241    }
242
243    for (field, value) in [
244        ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
245        ("server.request_timeout", server.request_timeout.as_str()),
246        (
247            "server.session_idle_timeout",
248            server.session_idle_timeout.as_str(),
249        ),
250        ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
251        (
252            "server.tls_handshake_timeout",
253            server.tls_handshake_timeout.as_str(),
254        ),
255    ] {
256        if humantime::parse_duration(value).is_err() {
257            return Err(McpxError::Config(format!(
258                "invalid duration for {field}: {value:?}"
259            )));
260        }
261    }
262
263    // The handshake deadline must be a positive duration: a zero value
264    // would reap every TLS handshake before it could complete. Mirrors
265    // check #11 in `McpServerConfig::check`.
266    if humantime::parse_duration(&server.tls_handshake_timeout)
267        .is_ok_and(|d| d == std::time::Duration::ZERO)
268    {
269        return Err(McpxError::Config(
270            "server.tls_handshake_timeout must be greater than zero".into(),
271        ));
272    }
273
274    // A zero-permit handshake semaphore would never admit a handshake,
275    // deadlocking the TLS accept path. Mirrors check #12 in
276    // `McpServerConfig::check`.
277    if server.max_concurrent_tls_handshakes == 0 {
278        return Err(McpxError::Config(
279            "server.max_concurrent_tls_handshakes must be greater than zero".into(),
280        ));
281    }
282
283    Ok(())
284}
285
286/// Validate observability config fields.
287///
288/// # Errors
289///
290/// Returns `McpxError::Config` on invalid values.
291pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
292    use tracing_subscriber::EnvFilter;
293
294    use crate::error::McpxError;
295
296    if EnvFilter::try_new(&obs.log_level).is_err() {
297        return Err(McpxError::Config(format!(
298            "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
299            obs.log_level
300        )));
301    }
302    let valid_formats = ["json", "pretty", "text"];
303    if !valid_formats.contains(&obs.log_format.as_str()) {
304        return Err(McpxError::Config(format!(
305            "invalid log_format: {:?} (expected one of: {valid_formats:?})",
306            obs.log_format
307        )));
308    }
309
310    Ok(())
311}
312
313// - Default value functions -
314
315fn default_listen_addr() -> String {
316    "127.0.0.1".into()
317}
318fn default_listen_port() -> u16 {
319    8443
320}
321fn default_shutdown_timeout() -> String {
322    "30s".into()
323}
324fn default_request_timeout() -> String {
325    "120s".into()
326}
327fn default_log_level() -> String {
328    "info,rmcp=warn".into()
329}
330fn default_log_format() -> String {
331    "pretty".into()
332}
333fn default_metrics_bind() -> String {
334    "127.0.0.1:9090".into()
335}
336fn default_session_idle_timeout() -> String {
337    "20m".into()
338}
339fn default_tls_handshake_timeout() -> String {
340    "10s".into()
341}
342const fn default_max_concurrent_tls_handshakes() -> usize {
343    256
344}
345fn default_admin_role() -> String {
346    "admin".into()
347}
348fn default_compression_min_size() -> u16 {
349    1024
350}
351fn default_sse_keep_alive() -> String {
352    "15s".into()
353}
354
355#[cfg(test)]
356mod tests {
357    #![allow(
358        clippy::unwrap_used,
359        clippy::expect_used,
360        clippy::panic,
361        clippy::indexing_slicing,
362        clippy::unwrap_in_result,
363        clippy::print_stdout,
364        clippy::print_stderr,
365        reason = "test-only relaxations; production code uses ? and tracing"
366    )]
367    use super::*;
368
369    // -- ServerConfig defaults --
370
371    #[test]
372    fn server_config_defaults() {
373        let cfg = ServerConfig::default();
374        assert_eq!(cfg.listen_addr, "127.0.0.1");
375        assert_eq!(cfg.listen_port, 8443);
376        assert!(cfg.tls_cert_path.is_none());
377        assert!(cfg.tls_key_path.is_none());
378        assert_eq!(cfg.shutdown_timeout, "30s");
379        assert_eq!(cfg.request_timeout, "120s");
380        assert!(cfg.allowed_origins.is_empty());
381        assert!(!cfg.stdio_enabled);
382        assert!(cfg.tool_rate_limit.is_none());
383        assert_eq!(cfg.session_idle_timeout, "20m");
384        assert_eq!(cfg.sse_keep_alive, "15s");
385        assert!(cfg.public_url.is_none());
386    }
387
388    #[test]
389    fn observability_config_defaults() {
390        let cfg = ObservabilityConfig::default();
391        assert_eq!(cfg.log_level, "info,rmcp=warn");
392        assert_eq!(cfg.log_format, "pretty");
393        assert!(cfg.audit_log_path.is_none());
394        assert!(!cfg.log_request_headers);
395        assert!(!cfg.metrics_enabled);
396        assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
397    }
398
399    // -- validate_server_config --
400
401    #[test]
402    fn valid_server_config_passes() {
403        let cfg = ServerConfig::default();
404        assert!(validate_server_config(&cfg).is_ok());
405    }
406
407    #[test]
408    fn zero_port_rejected() {
409        let cfg = ServerConfig {
410            listen_port: 0,
411            ..ServerConfig::default()
412        };
413        let err = validate_server_config(&cfg).unwrap_err();
414        assert!(err.to_string().contains("listen_port"));
415    }
416
417    #[test]
418    fn zero_extra_route_rate_limit_rejected() {
419        let cfg = ServerConfig {
420            extra_route_rate_limit: Some(0),
421            ..ServerConfig::default()
422        };
423        let err = validate_server_config(&cfg).unwrap_err();
424        assert!(err.to_string().contains("extra_route_rate_limit"));
425    }
426
427    #[test]
428    fn zero_burst_knobs_rejected() {
429        let cfg = ServerConfig {
430            tool_rate_limit: Some(10),
431            tool_rate_limit_burst: Some(0),
432            ..ServerConfig::default()
433        };
434        let err = validate_server_config(&cfg).unwrap_err();
435        assert!(err.to_string().contains("tool_rate_limit_burst"));
436
437        let cfg = ServerConfig {
438            extra_route_rate_limit: Some(10),
439            extra_route_rate_limit_burst: Some(0),
440            ..ServerConfig::default()
441        };
442        let err = validate_server_config(&cfg).unwrap_err();
443        assert!(err.to_string().contains("extra_route_rate_limit_burst"));
444    }
445
446    #[test]
447    fn orphan_burst_knobs_rejected() {
448        let cfg = ServerConfig {
449            tool_rate_limit_burst: Some(5),
450            ..ServerConfig::default()
451        };
452        let err = validate_server_config(&cfg).unwrap_err();
453        assert!(err.to_string().contains("requires server.tool_rate_limit"));
454
455        let cfg = ServerConfig {
456            extra_route_rate_limit_burst: Some(5),
457            ..ServerConfig::default()
458        };
459        let err = validate_server_config(&cfg).unwrap_err();
460        assert!(
461            err.to_string()
462                .contains("requires server.extra_route_rate_limit")
463        );
464    }
465
466    #[test]
467    fn zero_auth_bursts_rejected() {
468        let auth = crate::auth::AuthConfig::with_keys(vec![])
469            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
470        let cfg = ServerConfig {
471            auth: Some(auth),
472            ..ServerConfig::default()
473        };
474        let err = validate_server_config(&cfg).unwrap_err();
475        assert!(err.to_string().contains("rate_limit.burst"));
476
477        let auth = crate::auth::AuthConfig::with_keys(vec![])
478            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
479        let cfg = ServerConfig {
480            auth: Some(auth),
481            ..ServerConfig::default()
482        };
483        let err = validate_server_config(&cfg).unwrap_err();
484        assert!(err.to_string().contains("pre_auth_burst"));
485    }
486
487    #[test]
488    fn tls_cert_without_key_rejected() {
489        let cfg = ServerConfig {
490            tls_cert_path: Some("/tmp/cert.pem".into()),
491            ..ServerConfig::default()
492        };
493        let err = validate_server_config(&cfg).unwrap_err();
494        assert!(err.to_string().contains("tls_cert_path"));
495    }
496
497    #[test]
498    fn tls_key_without_cert_rejected() {
499        let cfg = ServerConfig {
500            tls_key_path: Some("/tmp/key.pem".into()),
501            ..ServerConfig::default()
502        };
503        let err = validate_server_config(&cfg).unwrap_err();
504        assert!(err.to_string().contains("tls_cert_path"));
505    }
506
507    #[test]
508    fn tls_both_set_passes() {
509        let cfg = ServerConfig {
510            tls_cert_path: Some("/tmp/cert.pem".into()),
511            tls_key_path: Some("/tmp/key.pem".into()),
512            ..ServerConfig::default()
513        };
514        assert!(validate_server_config(&cfg).is_ok());
515    }
516
517    #[test]
518    fn invalid_tls_handshake_timeout_rejected() {
519        let cfg = ServerConfig {
520            tls_handshake_timeout: "not-a-duration".into(),
521            ..ServerConfig::default()
522        };
523        let err = validate_server_config(&cfg).unwrap_err();
524        assert!(err.to_string().contains("tls_handshake_timeout"));
525    }
526
527    #[test]
528    fn zero_tls_handshake_timeout_rejected() {
529        let cfg = ServerConfig {
530            tls_handshake_timeout: "0s".into(),
531            ..ServerConfig::default()
532        };
533        let err = validate_server_config(&cfg).unwrap_err();
534        assert!(err.to_string().contains("tls_handshake_timeout"));
535    }
536
537    #[test]
538    fn zero_max_concurrent_tls_handshakes_rejected() {
539        let cfg = ServerConfig {
540            max_concurrent_tls_handshakes: 0,
541            ..ServerConfig::default()
542        };
543        let err = validate_server_config(&cfg).unwrap_err();
544        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
545    }
546
547    #[test]
548    fn invalid_shutdown_timeout_rejected() {
549        let cfg = ServerConfig {
550            shutdown_timeout: "not-a-duration".into(),
551            ..ServerConfig::default()
552        };
553        let err = validate_server_config(&cfg).unwrap_err();
554        assert!(err.to_string().contains("shutdown_timeout"));
555    }
556
557    #[test]
558    fn invalid_request_timeout_rejected() {
559        let cfg = ServerConfig {
560            request_timeout: "xyz".into(),
561            ..ServerConfig::default()
562        };
563        let err = validate_server_config(&cfg).unwrap_err();
564        assert!(err.to_string().contains("request_timeout"));
565    }
566
567    // -- validate_observability_config --
568
569    #[test]
570    fn valid_observability_config_passes() {
571        let cfg = ObservabilityConfig::default();
572        assert!(validate_observability_config(&cfg).is_ok());
573    }
574
575    #[test]
576    fn invalid_log_level_rejected() {
577        let cfg = ObservabilityConfig {
578            log_level: "[invalid".into(),
579            ..ObservabilityConfig::default()
580        };
581        let err = validate_observability_config(&cfg).unwrap_err();
582        assert!(err.to_string().contains("log_level"));
583    }
584
585    #[test]
586    fn invalid_log_format_rejected() {
587        let cfg = ObservabilityConfig {
588            log_format: "yaml".into(),
589            ..ObservabilityConfig::default()
590        };
591        let err = validate_observability_config(&cfg).unwrap_err();
592        assert!(err.to_string().contains("log_format"));
593    }
594
595    #[test]
596    fn all_valid_log_levels_accepted() {
597        for level in &[
598            "trace",
599            "debug",
600            "info",
601            "warn",
602            "error",
603            "info,rmcp=warn",
604            "debug,hyper=error",
605        ] {
606            let cfg = ObservabilityConfig {
607                log_level: (*level).into(),
608                ..ObservabilityConfig::default()
609            };
610            assert!(
611                validate_observability_config(&cfg).is_ok(),
612                "level {level} should be valid"
613            );
614        }
615    }
616
617    #[test]
618    fn all_log_formats_accepted() {
619        for fmt in &["json", "pretty", "text"] {
620            let cfg = ObservabilityConfig {
621                log_format: (*fmt).into(),
622                ..ObservabilityConfig::default()
623            };
624            assert!(
625                validate_observability_config(&cfg).is_ok(),
626                "format {fmt} should be valid"
627            );
628        }
629    }
630
631    // -- serde deserialization --
632
633    #[test]
634    fn server_config_deserialize_defaults() {
635        let cfg: ServerConfig = toml::from_str("").unwrap();
636        assert_eq!(cfg.listen_port, 8443);
637        assert_eq!(cfg.listen_addr, "127.0.0.1");
638        assert_eq!(cfg.tls_handshake_timeout, "10s");
639        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
640    }
641
642    #[test]
643    fn observability_config_deserialize_defaults() {
644        let cfg: ObservabilityConfig = toml::from_str("").unwrap();
645        assert_eq!(cfg.log_level, "info,rmcp=warn");
646        assert_eq!(cfg.log_format, "pretty");
647        assert!(!cfg.log_request_headers);
648        assert!(!cfg.metrics_enabled);
649    }
650}