Skip to main content

rmcp_server_kit/
config.rs

1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5/// Server listener configuration (reusable across MCP projects).
6#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9    /// Listen address (IP or hostname). Default: `127.0.0.1`.
10    #[serde(default = "default_listen_addr")]
11    pub listen_addr: String,
12    /// Listen TCP port. Default: `8443`.
13    #[serde(default = "default_listen_port")]
14    pub listen_port: u16,
15    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
16    pub tls_cert_path: Option<PathBuf>,
17    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
18    pub tls_key_path: Option<PathBuf>,
19    /// Per-handshake deadline on the TLS accept path, parsed via
20    /// `humantime`. Idle or slow-loris connections are dropped once it
21    /// elapses. Startup-only (not hot-reloadable); ignored unless TLS is
22    /// configured. Default: `10s`.
23    #[serde(default = "default_tls_handshake_timeout")]
24    pub tls_handshake_timeout: String,
25    /// Cap on concurrently in-flight TLS handshakes. At saturation the
26    /// acceptor stops pulling new connections from the kernel backlog
27    /// (backpressure). Startup-only (not hot-reloadable); ignored unless
28    /// TLS is configured. Default: `256`.
29    #[serde(default = "default_max_concurrent_tls_handshakes")]
30    pub max_concurrent_tls_handshakes: usize,
31    /// Graceful shutdown timeout, parsed via `humantime`.
32    #[serde(default = "default_shutdown_timeout")]
33    pub shutdown_timeout: String,
34    /// Per-request timeout, parsed via `humantime`.
35    #[serde(default = "default_request_timeout")]
36    pub request_timeout: String,
37    /// Allowed Origin header values for DNS rebinding protection (MCP spec).
38    /// Requests with an Origin not in this list are rejected with 403.
39    /// Requests without an Origin header are always allowed (non-browser).
40    #[serde(default)]
41    pub allowed_origins: Vec<String>,
42    /// Allow the stdio transport subcommand. Disabled by default because
43    /// stdio mode bypasses auth, RBAC, TLS, and Origin validation.
44    #[serde(default)]
45    pub stdio_enabled: bool,
46    /// Maximum tool invocations per source IP per minute.
47    /// When set, enforced by the RBAC middleware on `tools/call` requests.
48    /// Protects against both abuse and runaway LLM loops.
49    pub tool_rate_limit: Option<u32>,
50    /// Maximum requests per source IP per minute on application routes
51    /// merged via `McpServerConfig::with_extra_router` (which bypass
52    /// auth/RBAC). Opt-in; must be greater than zero when set.
53    /// Keyed by the direct socket peer — no `X-Forwarded-For`
54    /// interpretation. Startup-only.
55    pub extra_route_rate_limit: Option<u32>,
56    /// Idle timeout for MCP sessions. Sessions with no activity for this
57    /// duration are closed automatically. Default: 20 minutes.
58    #[serde(default = "default_session_idle_timeout")]
59    pub session_idle_timeout: String,
60    /// Interval for SSE keep-alive pings sent to the client. Prevents
61    /// proxies and load balancers from killing idle connections.
62    /// Default: 15 seconds.
63    #[serde(default = "default_sse_keep_alive")]
64    pub sse_keep_alive: String,
65    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
66    /// When set, OAuth metadata endpoints advertise this URL instead of
67    /// the listen address. Required when the server binds to `0.0.0.0`
68    /// behind a reverse proxy or inside a container.
69    pub public_url: Option<String>,
70    /// Enable gzip/br response compression for MCP responses.
71    #[serde(default)]
72    pub compression_enabled: bool,
73    /// Minimum response size (bytes) before compression kicks in.
74    /// Only used when `compression_enabled` is true. Default: 1024.
75    #[serde(default = "default_compression_min_size")]
76    pub compression_min_size: u16,
77    /// Global cap on in-flight HTTP requests. When reached, excess
78    /// requests receive 503 Service Unavailable (via load shedding).
79    pub max_concurrent_requests: Option<usize>,
80    /// Enable `/admin/*` diagnostic endpoints.
81    #[serde(default)]
82    pub admin_enabled: bool,
83    /// RBAC role required to access admin endpoints.
84    #[serde(default = "default_admin_role")]
85    pub admin_role: String,
86    /// Authentication configuration (API keys, mTLS, OAuth).
87    pub auth: Option<crate::auth::AuthConfig>,
88}
89
90impl Default for ServerConfig {
91    fn default() -> Self {
92        Self {
93            listen_addr: default_listen_addr(),
94            listen_port: default_listen_port(),
95            tls_cert_path: None,
96            tls_key_path: None,
97            tls_handshake_timeout: default_tls_handshake_timeout(),
98            max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
99            shutdown_timeout: default_shutdown_timeout(),
100            request_timeout: default_request_timeout(),
101            allowed_origins: Vec::new(),
102            stdio_enabled: false,
103            tool_rate_limit: None,
104            extra_route_rate_limit: None,
105            session_idle_timeout: default_session_idle_timeout(),
106            sse_keep_alive: default_sse_keep_alive(),
107            public_url: None,
108            compression_enabled: false,
109            compression_min_size: default_compression_min_size(),
110            max_concurrent_requests: None,
111            admin_enabled: false,
112            admin_role: default_admin_role(),
113            auth: None,
114        }
115    }
116}
117
118/// Observability settings (reusable across MCP projects).
119#[derive(Debug, Deserialize)]
120#[non_exhaustive]
121pub struct ObservabilityConfig {
122    /// `tracing` log level / env filter string (e.g. `info,rmcp_server_kit=debug`).
123    #[serde(default = "default_log_level")]
124    pub log_level: String,
125    /// Log output format: `json`, `pretty`, or `text` (default: `pretty`).
126    #[serde(default = "default_log_format")]
127    pub log_format: String,
128    /// Optional path to an append-only audit log file.
129    pub audit_log_path: Option<PathBuf>,
130    /// Emit inbound HTTP request headers at DEBUG level in transport logs.
131    /// Sensitive headers remain redacted when enabled.
132    #[serde(default)]
133    pub log_request_headers: bool,
134    /// Enable the Prometheus metrics endpoint.
135    #[serde(default)]
136    pub metrics_enabled: bool,
137    /// Bind address for the Prometheus metrics listener.
138    #[serde(default = "default_metrics_bind")]
139    pub metrics_bind: String,
140}
141
142impl Default for ObservabilityConfig {
143    fn default() -> Self {
144        Self {
145            log_level: default_log_level(),
146            log_format: default_log_format(),
147            audit_log_path: None,
148            log_request_headers: false,
149            metrics_enabled: false,
150            metrics_bind: default_metrics_bind(),
151        }
152    }
153}
154
155/// Validate the generic server config fields.
156///
157/// # Errors
158///
159/// Returns `McpxError::Config` on invalid values.
160pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
161    use crate::error::McpxError;
162
163    if server.listen_port == 0 {
164        return Err(McpxError::Config("listen_port must be nonzero".into()));
165    }
166
167    match (&server.tls_cert_path, &server.tls_key_path) {
168        (Some(_), None) | (None, Some(_)) => {
169            return Err(McpxError::Config(
170                "tls_cert_path and tls_key_path must both be set or both omitted".into(),
171            ));
172        }
173        _ => {}
174    }
175
176    if let Some(0) = server.max_concurrent_requests {
177        return Err(McpxError::Config(
178            "max_concurrent_requests must be nonzero when set".into(),
179        ));
180    }
181
182    if let Some(0) = server.extra_route_rate_limit {
183        return Err(McpxError::Config(
184            "server.extra_route_rate_limit must be greater than zero".into(),
185        ));
186    }
187
188    if server.admin_enabled {
189        let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
190        if !auth_enabled {
191            return Err(McpxError::Config(
192                "admin_enabled=true requires auth to be configured and enabled".into(),
193            ));
194        }
195        if server.admin_role.trim().is_empty() {
196            return Err(McpxError::Config("admin_role must not be empty".into()));
197        }
198    }
199
200    for (field, value) in [
201        ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
202        ("server.request_timeout", server.request_timeout.as_str()),
203        (
204            "server.session_idle_timeout",
205            server.session_idle_timeout.as_str(),
206        ),
207        ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
208        (
209            "server.tls_handshake_timeout",
210            server.tls_handshake_timeout.as_str(),
211        ),
212    ] {
213        if humantime::parse_duration(value).is_err() {
214            return Err(McpxError::Config(format!(
215                "invalid duration for {field}: {value:?}"
216            )));
217        }
218    }
219
220    // The handshake deadline must be a positive duration: a zero value
221    // would reap every TLS handshake before it could complete. Mirrors
222    // check #11 in `McpServerConfig::check`.
223    if humantime::parse_duration(&server.tls_handshake_timeout)
224        .is_ok_and(|d| d == std::time::Duration::ZERO)
225    {
226        return Err(McpxError::Config(
227            "server.tls_handshake_timeout must be greater than zero".into(),
228        ));
229    }
230
231    // A zero-permit handshake semaphore would never admit a handshake,
232    // deadlocking the TLS accept path. Mirrors check #12 in
233    // `McpServerConfig::check`.
234    if server.max_concurrent_tls_handshakes == 0 {
235        return Err(McpxError::Config(
236            "server.max_concurrent_tls_handshakes must be greater than zero".into(),
237        ));
238    }
239
240    Ok(())
241}
242
243/// Validate observability config fields.
244///
245/// # Errors
246///
247/// Returns `McpxError::Config` on invalid values.
248pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
249    use tracing_subscriber::EnvFilter;
250
251    use crate::error::McpxError;
252
253    if EnvFilter::try_new(&obs.log_level).is_err() {
254        return Err(McpxError::Config(format!(
255            "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
256            obs.log_level
257        )));
258    }
259    let valid_formats = ["json", "pretty", "text"];
260    if !valid_formats.contains(&obs.log_format.as_str()) {
261        return Err(McpxError::Config(format!(
262            "invalid log_format: {:?} (expected one of: {valid_formats:?})",
263            obs.log_format
264        )));
265    }
266
267    Ok(())
268}
269
270// - Default value functions -
271
272fn default_listen_addr() -> String {
273    "127.0.0.1".into()
274}
275fn default_listen_port() -> u16 {
276    8443
277}
278fn default_shutdown_timeout() -> String {
279    "30s".into()
280}
281fn default_request_timeout() -> String {
282    "120s".into()
283}
284fn default_log_level() -> String {
285    "info,rmcp=warn".into()
286}
287fn default_log_format() -> String {
288    "pretty".into()
289}
290fn default_metrics_bind() -> String {
291    "127.0.0.1:9090".into()
292}
293fn default_session_idle_timeout() -> String {
294    "20m".into()
295}
296fn default_tls_handshake_timeout() -> String {
297    "10s".into()
298}
299const fn default_max_concurrent_tls_handshakes() -> usize {
300    256
301}
302fn default_admin_role() -> String {
303    "admin".into()
304}
305fn default_compression_min_size() -> u16 {
306    1024
307}
308fn default_sse_keep_alive() -> String {
309    "15s".into()
310}
311
312#[cfg(test)]
313mod tests {
314    #![allow(
315        clippy::unwrap_used,
316        clippy::expect_used,
317        clippy::panic,
318        clippy::indexing_slicing,
319        clippy::unwrap_in_result,
320        clippy::print_stdout,
321        clippy::print_stderr,
322        reason = "test-only relaxations; production code uses ? and tracing"
323    )]
324    use super::*;
325
326    // -- ServerConfig defaults --
327
328    #[test]
329    fn server_config_defaults() {
330        let cfg = ServerConfig::default();
331        assert_eq!(cfg.listen_addr, "127.0.0.1");
332        assert_eq!(cfg.listen_port, 8443);
333        assert!(cfg.tls_cert_path.is_none());
334        assert!(cfg.tls_key_path.is_none());
335        assert_eq!(cfg.shutdown_timeout, "30s");
336        assert_eq!(cfg.request_timeout, "120s");
337        assert!(cfg.allowed_origins.is_empty());
338        assert!(!cfg.stdio_enabled);
339        assert!(cfg.tool_rate_limit.is_none());
340        assert_eq!(cfg.session_idle_timeout, "20m");
341        assert_eq!(cfg.sse_keep_alive, "15s");
342        assert!(cfg.public_url.is_none());
343    }
344
345    #[test]
346    fn observability_config_defaults() {
347        let cfg = ObservabilityConfig::default();
348        assert_eq!(cfg.log_level, "info,rmcp=warn");
349        assert_eq!(cfg.log_format, "pretty");
350        assert!(cfg.audit_log_path.is_none());
351        assert!(!cfg.log_request_headers);
352        assert!(!cfg.metrics_enabled);
353        assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
354    }
355
356    // -- validate_server_config --
357
358    #[test]
359    fn valid_server_config_passes() {
360        let cfg = ServerConfig::default();
361        assert!(validate_server_config(&cfg).is_ok());
362    }
363
364    #[test]
365    fn zero_port_rejected() {
366        let cfg = ServerConfig {
367            listen_port: 0,
368            ..ServerConfig::default()
369        };
370        let err = validate_server_config(&cfg).unwrap_err();
371        assert!(err.to_string().contains("listen_port"));
372    }
373
374    #[test]
375    fn zero_extra_route_rate_limit_rejected() {
376        let cfg = ServerConfig {
377            extra_route_rate_limit: Some(0),
378            ..ServerConfig::default()
379        };
380        let err = validate_server_config(&cfg).unwrap_err();
381        assert!(err.to_string().contains("extra_route_rate_limit"));
382    }
383
384    #[test]
385    fn tls_cert_without_key_rejected() {
386        let cfg = ServerConfig {
387            tls_cert_path: Some("/tmp/cert.pem".into()),
388            ..ServerConfig::default()
389        };
390        let err = validate_server_config(&cfg).unwrap_err();
391        assert!(err.to_string().contains("tls_cert_path"));
392    }
393
394    #[test]
395    fn tls_key_without_cert_rejected() {
396        let cfg = ServerConfig {
397            tls_key_path: Some("/tmp/key.pem".into()),
398            ..ServerConfig::default()
399        };
400        let err = validate_server_config(&cfg).unwrap_err();
401        assert!(err.to_string().contains("tls_cert_path"));
402    }
403
404    #[test]
405    fn tls_both_set_passes() {
406        let cfg = ServerConfig {
407            tls_cert_path: Some("/tmp/cert.pem".into()),
408            tls_key_path: Some("/tmp/key.pem".into()),
409            ..ServerConfig::default()
410        };
411        assert!(validate_server_config(&cfg).is_ok());
412    }
413
414    #[test]
415    fn invalid_tls_handshake_timeout_rejected() {
416        let cfg = ServerConfig {
417            tls_handshake_timeout: "not-a-duration".into(),
418            ..ServerConfig::default()
419        };
420        let err = validate_server_config(&cfg).unwrap_err();
421        assert!(err.to_string().contains("tls_handshake_timeout"));
422    }
423
424    #[test]
425    fn zero_tls_handshake_timeout_rejected() {
426        let cfg = ServerConfig {
427            tls_handshake_timeout: "0s".into(),
428            ..ServerConfig::default()
429        };
430        let err = validate_server_config(&cfg).unwrap_err();
431        assert!(err.to_string().contains("tls_handshake_timeout"));
432    }
433
434    #[test]
435    fn zero_max_concurrent_tls_handshakes_rejected() {
436        let cfg = ServerConfig {
437            max_concurrent_tls_handshakes: 0,
438            ..ServerConfig::default()
439        };
440        let err = validate_server_config(&cfg).unwrap_err();
441        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
442    }
443
444    #[test]
445    fn invalid_shutdown_timeout_rejected() {
446        let cfg = ServerConfig {
447            shutdown_timeout: "not-a-duration".into(),
448            ..ServerConfig::default()
449        };
450        let err = validate_server_config(&cfg).unwrap_err();
451        assert!(err.to_string().contains("shutdown_timeout"));
452    }
453
454    #[test]
455    fn invalid_request_timeout_rejected() {
456        let cfg = ServerConfig {
457            request_timeout: "xyz".into(),
458            ..ServerConfig::default()
459        };
460        let err = validate_server_config(&cfg).unwrap_err();
461        assert!(err.to_string().contains("request_timeout"));
462    }
463
464    // -- validate_observability_config --
465
466    #[test]
467    fn valid_observability_config_passes() {
468        let cfg = ObservabilityConfig::default();
469        assert!(validate_observability_config(&cfg).is_ok());
470    }
471
472    #[test]
473    fn invalid_log_level_rejected() {
474        let cfg = ObservabilityConfig {
475            log_level: "[invalid".into(),
476            ..ObservabilityConfig::default()
477        };
478        let err = validate_observability_config(&cfg).unwrap_err();
479        assert!(err.to_string().contains("log_level"));
480    }
481
482    #[test]
483    fn invalid_log_format_rejected() {
484        let cfg = ObservabilityConfig {
485            log_format: "yaml".into(),
486            ..ObservabilityConfig::default()
487        };
488        let err = validate_observability_config(&cfg).unwrap_err();
489        assert!(err.to_string().contains("log_format"));
490    }
491
492    #[test]
493    fn all_valid_log_levels_accepted() {
494        for level in &[
495            "trace",
496            "debug",
497            "info",
498            "warn",
499            "error",
500            "info,rmcp=warn",
501            "debug,hyper=error",
502        ] {
503            let cfg = ObservabilityConfig {
504                log_level: (*level).into(),
505                ..ObservabilityConfig::default()
506            };
507            assert!(
508                validate_observability_config(&cfg).is_ok(),
509                "level {level} should be valid"
510            );
511        }
512    }
513
514    #[test]
515    fn all_log_formats_accepted() {
516        for fmt in &["json", "pretty", "text"] {
517            let cfg = ObservabilityConfig {
518                log_format: (*fmt).into(),
519                ..ObservabilityConfig::default()
520            };
521            assert!(
522                validate_observability_config(&cfg).is_ok(),
523                "format {fmt} should be valid"
524            );
525        }
526    }
527
528    // -- serde deserialization --
529
530    #[test]
531    fn server_config_deserialize_defaults() {
532        let cfg: ServerConfig = toml::from_str("").unwrap();
533        assert_eq!(cfg.listen_port, 8443);
534        assert_eq!(cfg.listen_addr, "127.0.0.1");
535        assert_eq!(cfg.tls_handshake_timeout, "10s");
536        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
537    }
538
539    #[test]
540    fn observability_config_deserialize_defaults() {
541        let cfg: ObservabilityConfig = toml::from_str("").unwrap();
542        assert_eq!(cfg.log_level, "info,rmcp=warn");
543        assert_eq!(cfg.log_format, "pretty");
544        assert!(!cfg.log_request_headers);
545        assert!(!cfg.metrics_enabled);
546    }
547}