Skip to main content

rmcp_server_kit/
config.rs

1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5/// Server listener configuration (reusable across MCP projects).
6#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9    /// Listen address (IP or hostname). Default: `127.0.0.1`.
10    #[serde(default = "default_listen_addr")]
11    pub listen_addr: String,
12    /// Listen TCP port. Default: `8443`.
13    #[serde(default = "default_listen_port")]
14    pub listen_port: u16,
15    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
16    pub tls_cert_path: Option<PathBuf>,
17    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
18    pub tls_key_path: Option<PathBuf>,
19    /// Per-handshake deadline on the TLS accept path, parsed via
20    /// `humantime`. Idle or slow-loris connections are dropped once it
21    /// elapses. Startup-only (not hot-reloadable); ignored unless TLS is
22    /// configured. Default: `10s`.
23    #[serde(default = "default_tls_handshake_timeout")]
24    pub tls_handshake_timeout: String,
25    /// Cap on concurrently in-flight TLS handshakes. At saturation the
26    /// acceptor stops pulling new connections from the kernel backlog
27    /// (backpressure). Startup-only (not hot-reloadable); ignored unless
28    /// TLS is configured. Default: `256`.
29    #[serde(default = "default_max_concurrent_tls_handshakes")]
30    pub max_concurrent_tls_handshakes: usize,
31    /// Graceful shutdown timeout, parsed via `humantime`.
32    #[serde(default = "default_shutdown_timeout")]
33    pub shutdown_timeout: String,
34    /// Per-request timeout, parsed via `humantime`.
35    #[serde(default = "default_request_timeout")]
36    pub request_timeout: String,
37    /// Allowed Origin header values for DNS rebinding protection (MCP spec).
38    /// Requests with an Origin not in this list are rejected with 403.
39    /// Requests without an Origin header are always allowed (non-browser).
40    #[serde(default)]
41    pub allowed_origins: Vec<String>,
42    /// Allow the stdio transport subcommand. Disabled by default because
43    /// stdio mode bypasses auth, RBAC, TLS, and Origin validation.
44    #[serde(default)]
45    pub stdio_enabled: bool,
46    /// Maximum tool invocations per source IP per minute.
47    /// When set, enforced by the RBAC middleware on `tools/call` requests.
48    /// Protects against both abuse and runaway LLM loops.
49    pub tool_rate_limit: Option<u32>,
50    /// Idle timeout for MCP sessions. Sessions with no activity for this
51    /// duration are closed automatically. Default: 20 minutes.
52    #[serde(default = "default_session_idle_timeout")]
53    pub session_idle_timeout: String,
54    /// Interval for SSE keep-alive pings sent to the client. Prevents
55    /// proxies and load balancers from killing idle connections.
56    /// Default: 15 seconds.
57    #[serde(default = "default_sse_keep_alive")]
58    pub sse_keep_alive: String,
59    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
60    /// When set, OAuth metadata endpoints advertise this URL instead of
61    /// the listen address. Required when the server binds to `0.0.0.0`
62    /// behind a reverse proxy or inside a container.
63    pub public_url: Option<String>,
64    /// Enable gzip/br response compression for MCP responses.
65    #[serde(default)]
66    pub compression_enabled: bool,
67    /// Minimum response size (bytes) before compression kicks in.
68    /// Only used when `compression_enabled` is true. Default: 1024.
69    #[serde(default = "default_compression_min_size")]
70    pub compression_min_size: u16,
71    /// Global cap on in-flight HTTP requests. When reached, excess
72    /// requests receive 503 Service Unavailable (via load shedding).
73    pub max_concurrent_requests: Option<usize>,
74    /// Enable `/admin/*` diagnostic endpoints.
75    #[serde(default)]
76    pub admin_enabled: bool,
77    /// RBAC role required to access admin endpoints.
78    #[serde(default = "default_admin_role")]
79    pub admin_role: String,
80    /// Authentication configuration (API keys, mTLS, OAuth).
81    pub auth: Option<crate::auth::AuthConfig>,
82}
83
84impl Default for ServerConfig {
85    fn default() -> Self {
86        Self {
87            listen_addr: default_listen_addr(),
88            listen_port: default_listen_port(),
89            tls_cert_path: None,
90            tls_key_path: None,
91            tls_handshake_timeout: default_tls_handshake_timeout(),
92            max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
93            shutdown_timeout: default_shutdown_timeout(),
94            request_timeout: default_request_timeout(),
95            allowed_origins: Vec::new(),
96            stdio_enabled: false,
97            tool_rate_limit: None,
98            session_idle_timeout: default_session_idle_timeout(),
99            sse_keep_alive: default_sse_keep_alive(),
100            public_url: None,
101            compression_enabled: false,
102            compression_min_size: default_compression_min_size(),
103            max_concurrent_requests: None,
104            admin_enabled: false,
105            admin_role: default_admin_role(),
106            auth: None,
107        }
108    }
109}
110
111/// Observability settings (reusable across MCP projects).
112#[derive(Debug, Deserialize)]
113#[non_exhaustive]
114pub struct ObservabilityConfig {
115    /// `tracing` log level / env filter string (e.g. `info,rmcp_server_kit=debug`).
116    #[serde(default = "default_log_level")]
117    pub log_level: String,
118    /// Log output format: `json`, `pretty`, or `text` (default: `pretty`).
119    #[serde(default = "default_log_format")]
120    pub log_format: String,
121    /// Optional path to an append-only audit log file.
122    pub audit_log_path: Option<PathBuf>,
123    /// Emit inbound HTTP request headers at DEBUG level in transport logs.
124    /// Sensitive headers remain redacted when enabled.
125    #[serde(default)]
126    pub log_request_headers: bool,
127    /// Enable the Prometheus metrics endpoint.
128    #[serde(default)]
129    pub metrics_enabled: bool,
130    /// Bind address for the Prometheus metrics listener.
131    #[serde(default = "default_metrics_bind")]
132    pub metrics_bind: String,
133}
134
135impl Default for ObservabilityConfig {
136    fn default() -> Self {
137        Self {
138            log_level: default_log_level(),
139            log_format: default_log_format(),
140            audit_log_path: None,
141            log_request_headers: false,
142            metrics_enabled: false,
143            metrics_bind: default_metrics_bind(),
144        }
145    }
146}
147
148/// Validate the generic server config fields.
149///
150/// # Errors
151///
152/// Returns `McpxError::Config` on invalid values.
153pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
154    use crate::error::McpxError;
155
156    if server.listen_port == 0 {
157        return Err(McpxError::Config("listen_port must be nonzero".into()));
158    }
159
160    match (&server.tls_cert_path, &server.tls_key_path) {
161        (Some(_), None) | (None, Some(_)) => {
162            return Err(McpxError::Config(
163                "tls_cert_path and tls_key_path must both be set or both omitted".into(),
164            ));
165        }
166        _ => {}
167    }
168
169    if let Some(0) = server.max_concurrent_requests {
170        return Err(McpxError::Config(
171            "max_concurrent_requests must be nonzero when set".into(),
172        ));
173    }
174
175    if server.admin_enabled {
176        let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
177        if !auth_enabled {
178            return Err(McpxError::Config(
179                "admin_enabled=true requires auth to be configured and enabled".into(),
180            ));
181        }
182        if server.admin_role.trim().is_empty() {
183            return Err(McpxError::Config("admin_role must not be empty".into()));
184        }
185    }
186
187    for (field, value) in [
188        ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
189        ("server.request_timeout", server.request_timeout.as_str()),
190        (
191            "server.session_idle_timeout",
192            server.session_idle_timeout.as_str(),
193        ),
194        ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
195        (
196            "server.tls_handshake_timeout",
197            server.tls_handshake_timeout.as_str(),
198        ),
199    ] {
200        if humantime::parse_duration(value).is_err() {
201            return Err(McpxError::Config(format!(
202                "invalid duration for {field}: {value:?}"
203            )));
204        }
205    }
206
207    // The handshake deadline must be a positive duration: a zero value
208    // would reap every TLS handshake before it could complete. Mirrors
209    // check #11 in `McpServerConfig::check`.
210    if humantime::parse_duration(&server.tls_handshake_timeout)
211        .is_ok_and(|d| d == std::time::Duration::ZERO)
212    {
213        return Err(McpxError::Config(
214            "server.tls_handshake_timeout must be greater than zero".into(),
215        ));
216    }
217
218    // A zero-permit handshake semaphore would never admit a handshake,
219    // deadlocking the TLS accept path. Mirrors check #12 in
220    // `McpServerConfig::check`.
221    if server.max_concurrent_tls_handshakes == 0 {
222        return Err(McpxError::Config(
223            "server.max_concurrent_tls_handshakes must be greater than zero".into(),
224        ));
225    }
226
227    Ok(())
228}
229
230/// Validate observability config fields.
231///
232/// # Errors
233///
234/// Returns `McpxError::Config` on invalid values.
235pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
236    use tracing_subscriber::EnvFilter;
237
238    use crate::error::McpxError;
239
240    if EnvFilter::try_new(&obs.log_level).is_err() {
241        return Err(McpxError::Config(format!(
242            "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
243            obs.log_level
244        )));
245    }
246    let valid_formats = ["json", "pretty", "text"];
247    if !valid_formats.contains(&obs.log_format.as_str()) {
248        return Err(McpxError::Config(format!(
249            "invalid log_format: {:?} (expected one of: {valid_formats:?})",
250            obs.log_format
251        )));
252    }
253
254    Ok(())
255}
256
257// - Default value functions -
258
259fn default_listen_addr() -> String {
260    "127.0.0.1".into()
261}
262fn default_listen_port() -> u16 {
263    8443
264}
265fn default_shutdown_timeout() -> String {
266    "30s".into()
267}
268fn default_request_timeout() -> String {
269    "120s".into()
270}
271fn default_log_level() -> String {
272    "info,rmcp=warn".into()
273}
274fn default_log_format() -> String {
275    "pretty".into()
276}
277fn default_metrics_bind() -> String {
278    "127.0.0.1:9090".into()
279}
280fn default_session_idle_timeout() -> String {
281    "20m".into()
282}
283fn default_tls_handshake_timeout() -> String {
284    "10s".into()
285}
286const fn default_max_concurrent_tls_handshakes() -> usize {
287    256
288}
289fn default_admin_role() -> String {
290    "admin".into()
291}
292fn default_compression_min_size() -> u16 {
293    1024
294}
295fn default_sse_keep_alive() -> String {
296    "15s".into()
297}
298
299#[cfg(test)]
300mod tests {
301    #![allow(
302        clippy::unwrap_used,
303        clippy::expect_used,
304        clippy::panic,
305        clippy::indexing_slicing,
306        clippy::unwrap_in_result,
307        clippy::print_stdout,
308        clippy::print_stderr,
309        reason = "test-only relaxations; production code uses ? and tracing"
310    )]
311    use super::*;
312
313    // -- ServerConfig defaults --
314
315    #[test]
316    fn server_config_defaults() {
317        let cfg = ServerConfig::default();
318        assert_eq!(cfg.listen_addr, "127.0.0.1");
319        assert_eq!(cfg.listen_port, 8443);
320        assert!(cfg.tls_cert_path.is_none());
321        assert!(cfg.tls_key_path.is_none());
322        assert_eq!(cfg.shutdown_timeout, "30s");
323        assert_eq!(cfg.request_timeout, "120s");
324        assert!(cfg.allowed_origins.is_empty());
325        assert!(!cfg.stdio_enabled);
326        assert!(cfg.tool_rate_limit.is_none());
327        assert_eq!(cfg.session_idle_timeout, "20m");
328        assert_eq!(cfg.sse_keep_alive, "15s");
329        assert!(cfg.public_url.is_none());
330    }
331
332    #[test]
333    fn observability_config_defaults() {
334        let cfg = ObservabilityConfig::default();
335        assert_eq!(cfg.log_level, "info,rmcp=warn");
336        assert_eq!(cfg.log_format, "pretty");
337        assert!(cfg.audit_log_path.is_none());
338        assert!(!cfg.log_request_headers);
339        assert!(!cfg.metrics_enabled);
340        assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
341    }
342
343    // -- validate_server_config --
344
345    #[test]
346    fn valid_server_config_passes() {
347        let cfg = ServerConfig::default();
348        assert!(validate_server_config(&cfg).is_ok());
349    }
350
351    #[test]
352    fn zero_port_rejected() {
353        let cfg = ServerConfig {
354            listen_port: 0,
355            ..ServerConfig::default()
356        };
357        let err = validate_server_config(&cfg).unwrap_err();
358        assert!(err.to_string().contains("listen_port"));
359    }
360
361    #[test]
362    fn tls_cert_without_key_rejected() {
363        let cfg = ServerConfig {
364            tls_cert_path: Some("/tmp/cert.pem".into()),
365            ..ServerConfig::default()
366        };
367        let err = validate_server_config(&cfg).unwrap_err();
368        assert!(err.to_string().contains("tls_cert_path"));
369    }
370
371    #[test]
372    fn tls_key_without_cert_rejected() {
373        let cfg = ServerConfig {
374            tls_key_path: Some("/tmp/key.pem".into()),
375            ..ServerConfig::default()
376        };
377        let err = validate_server_config(&cfg).unwrap_err();
378        assert!(err.to_string().contains("tls_cert_path"));
379    }
380
381    #[test]
382    fn tls_both_set_passes() {
383        let cfg = ServerConfig {
384            tls_cert_path: Some("/tmp/cert.pem".into()),
385            tls_key_path: Some("/tmp/key.pem".into()),
386            ..ServerConfig::default()
387        };
388        assert!(validate_server_config(&cfg).is_ok());
389    }
390
391    #[test]
392    fn invalid_tls_handshake_timeout_rejected() {
393        let cfg = ServerConfig {
394            tls_handshake_timeout: "not-a-duration".into(),
395            ..ServerConfig::default()
396        };
397        let err = validate_server_config(&cfg).unwrap_err();
398        assert!(err.to_string().contains("tls_handshake_timeout"));
399    }
400
401    #[test]
402    fn zero_tls_handshake_timeout_rejected() {
403        let cfg = ServerConfig {
404            tls_handshake_timeout: "0s".into(),
405            ..ServerConfig::default()
406        };
407        let err = validate_server_config(&cfg).unwrap_err();
408        assert!(err.to_string().contains("tls_handshake_timeout"));
409    }
410
411    #[test]
412    fn zero_max_concurrent_tls_handshakes_rejected() {
413        let cfg = ServerConfig {
414            max_concurrent_tls_handshakes: 0,
415            ..ServerConfig::default()
416        };
417        let err = validate_server_config(&cfg).unwrap_err();
418        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
419    }
420
421    #[test]
422    fn invalid_shutdown_timeout_rejected() {
423        let cfg = ServerConfig {
424            shutdown_timeout: "not-a-duration".into(),
425            ..ServerConfig::default()
426        };
427        let err = validate_server_config(&cfg).unwrap_err();
428        assert!(err.to_string().contains("shutdown_timeout"));
429    }
430
431    #[test]
432    fn invalid_request_timeout_rejected() {
433        let cfg = ServerConfig {
434            request_timeout: "xyz".into(),
435            ..ServerConfig::default()
436        };
437        let err = validate_server_config(&cfg).unwrap_err();
438        assert!(err.to_string().contains("request_timeout"));
439    }
440
441    // -- validate_observability_config --
442
443    #[test]
444    fn valid_observability_config_passes() {
445        let cfg = ObservabilityConfig::default();
446        assert!(validate_observability_config(&cfg).is_ok());
447    }
448
449    #[test]
450    fn invalid_log_level_rejected() {
451        let cfg = ObservabilityConfig {
452            log_level: "[invalid".into(),
453            ..ObservabilityConfig::default()
454        };
455        let err = validate_observability_config(&cfg).unwrap_err();
456        assert!(err.to_string().contains("log_level"));
457    }
458
459    #[test]
460    fn invalid_log_format_rejected() {
461        let cfg = ObservabilityConfig {
462            log_format: "yaml".into(),
463            ..ObservabilityConfig::default()
464        };
465        let err = validate_observability_config(&cfg).unwrap_err();
466        assert!(err.to_string().contains("log_format"));
467    }
468
469    #[test]
470    fn all_valid_log_levels_accepted() {
471        for level in &[
472            "trace",
473            "debug",
474            "info",
475            "warn",
476            "error",
477            "info,rmcp=warn",
478            "debug,hyper=error",
479        ] {
480            let cfg = ObservabilityConfig {
481                log_level: (*level).into(),
482                ..ObservabilityConfig::default()
483            };
484            assert!(
485                validate_observability_config(&cfg).is_ok(),
486                "level {level} should be valid"
487            );
488        }
489    }
490
491    #[test]
492    fn all_log_formats_accepted() {
493        for fmt in &["json", "pretty", "text"] {
494            let cfg = ObservabilityConfig {
495                log_format: (*fmt).into(),
496                ..ObservabilityConfig::default()
497            };
498            assert!(
499                validate_observability_config(&cfg).is_ok(),
500                "format {fmt} should be valid"
501            );
502        }
503    }
504
505    // -- serde deserialization --
506
507    #[test]
508    fn server_config_deserialize_defaults() {
509        let cfg: ServerConfig = toml::from_str("").unwrap();
510        assert_eq!(cfg.listen_port, 8443);
511        assert_eq!(cfg.listen_addr, "127.0.0.1");
512        assert_eq!(cfg.tls_handshake_timeout, "10s");
513        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
514    }
515
516    #[test]
517    fn observability_config_deserialize_defaults() {
518        let cfg: ObservabilityConfig = toml::from_str("").unwrap();
519        assert_eq!(cfg.log_level, "info,rmcp=warn");
520        assert_eq!(cfg.log_format, "pretty");
521        assert!(!cfg.log_request_headers);
522        assert!(!cfg.metrics_enabled);
523    }
524}