Skip to main content

rmcp_server_kit/
config.rs

1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5/// Server listener configuration (reusable across MCP projects).
6#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9    /// Listen address (IP or hostname). Default: `127.0.0.1`.
10    #[serde(default = "default_listen_addr")]
11    pub listen_addr: String,
12    /// Listen TCP port. Default: `8443`.
13    #[serde(default = "default_listen_port")]
14    pub listen_port: u16,
15    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
16    pub tls_cert_path: Option<PathBuf>,
17    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
18    pub tls_key_path: Option<PathBuf>,
19    /// Graceful shutdown timeout, parsed via `humantime`.
20    #[serde(default = "default_shutdown_timeout")]
21    pub shutdown_timeout: String,
22    /// Per-request timeout, parsed via `humantime`.
23    #[serde(default = "default_request_timeout")]
24    pub request_timeout: String,
25    /// Allowed Origin header values for DNS rebinding protection (MCP spec).
26    /// Requests with an Origin not in this list are rejected with 403.
27    /// Requests without an Origin header are always allowed (non-browser).
28    #[serde(default)]
29    pub allowed_origins: Vec<String>,
30    /// Allow the stdio transport subcommand. Disabled by default because
31    /// stdio mode bypasses auth, RBAC, TLS, and Origin validation.
32    #[serde(default)]
33    pub stdio_enabled: bool,
34    /// Maximum tool invocations per source IP per minute.
35    /// When set, enforced by the RBAC middleware on `tools/call` requests.
36    /// Protects against both abuse and runaway LLM loops.
37    pub tool_rate_limit: Option<u32>,
38    /// Idle timeout for MCP sessions. Sessions with no activity for this
39    /// duration are closed automatically. Default: 20 minutes.
40    #[serde(default = "default_session_idle_timeout")]
41    pub session_idle_timeout: String,
42    /// Interval for SSE keep-alive pings sent to the client. Prevents
43    /// proxies and load balancers from killing idle connections.
44    /// Default: 15 seconds.
45    #[serde(default = "default_sse_keep_alive")]
46    pub sse_keep_alive: String,
47    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
48    /// When set, OAuth metadata endpoints advertise this URL instead of
49    /// the listen address. Required when the server binds to `0.0.0.0`
50    /// behind a reverse proxy or inside a container.
51    pub public_url: Option<String>,
52    /// Enable gzip/br response compression for MCP responses.
53    #[serde(default)]
54    pub compression_enabled: bool,
55    /// Minimum response size (bytes) before compression kicks in.
56    /// Only used when `compression_enabled` is true. Default: 1024.
57    #[serde(default = "default_compression_min_size")]
58    pub compression_min_size: u16,
59    /// Global cap on in-flight HTTP requests. When reached, excess
60    /// requests receive 503 Service Unavailable (via load shedding).
61    pub max_concurrent_requests: Option<usize>,
62    /// Enable `/admin/*` diagnostic endpoints.
63    #[serde(default)]
64    pub admin_enabled: bool,
65    /// RBAC role required to access admin endpoints.
66    #[serde(default = "default_admin_role")]
67    pub admin_role: String,
68    /// Authentication configuration (API keys, mTLS, OAuth).
69    pub auth: Option<crate::auth::AuthConfig>,
70}
71
72impl Default for ServerConfig {
73    fn default() -> Self {
74        Self {
75            listen_addr: default_listen_addr(),
76            listen_port: default_listen_port(),
77            tls_cert_path: None,
78            tls_key_path: None,
79            shutdown_timeout: default_shutdown_timeout(),
80            request_timeout: default_request_timeout(),
81            allowed_origins: Vec::new(),
82            stdio_enabled: false,
83            tool_rate_limit: None,
84            session_idle_timeout: default_session_idle_timeout(),
85            sse_keep_alive: default_sse_keep_alive(),
86            public_url: None,
87            compression_enabled: false,
88            compression_min_size: default_compression_min_size(),
89            max_concurrent_requests: None,
90            admin_enabled: false,
91            admin_role: default_admin_role(),
92            auth: None,
93        }
94    }
95}
96
97/// Observability settings (reusable across MCP projects).
98#[derive(Debug, Deserialize)]
99#[non_exhaustive]
100pub struct ObservabilityConfig {
101    /// `tracing` log level / env filter string (e.g. `info,rmcp_server_kit=debug`).
102    #[serde(default = "default_log_level")]
103    pub log_level: String,
104    /// Log output format: `json` or `text`.
105    #[serde(default = "default_log_format")]
106    pub log_format: String,
107    /// Optional path to an append-only audit log file.
108    pub audit_log_path: Option<PathBuf>,
109    /// Emit inbound HTTP request headers at DEBUG level in transport logs.
110    /// Sensitive headers remain redacted when enabled.
111    #[serde(default)]
112    pub log_request_headers: bool,
113    /// Enable the Prometheus metrics endpoint.
114    #[serde(default)]
115    pub metrics_enabled: bool,
116    /// Bind address for the Prometheus metrics listener.
117    #[serde(default = "default_metrics_bind")]
118    pub metrics_bind: String,
119}
120
121impl Default for ObservabilityConfig {
122    fn default() -> Self {
123        Self {
124            log_level: default_log_level(),
125            log_format: default_log_format(),
126            audit_log_path: None,
127            log_request_headers: false,
128            metrics_enabled: false,
129            metrics_bind: default_metrics_bind(),
130        }
131    }
132}
133
134/// Validate the generic server config fields.
135///
136/// # Errors
137///
138/// Returns `McpxError::Config` on invalid values.
139pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
140    use crate::error::McpxError;
141
142    if server.listen_port == 0 {
143        return Err(McpxError::Config("listen_port must be nonzero".into()));
144    }
145
146    match (&server.tls_cert_path, &server.tls_key_path) {
147        (Some(_), None) | (None, Some(_)) => {
148            return Err(McpxError::Config(
149                "tls_cert_path and tls_key_path must both be set or both omitted".into(),
150            ));
151        }
152        _ => {}
153    }
154
155    if let Some(0) = server.max_concurrent_requests {
156        return Err(McpxError::Config(
157            "max_concurrent_requests must be nonzero when set".into(),
158        ));
159    }
160
161    if server.admin_enabled {
162        let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
163        if !auth_enabled {
164            return Err(McpxError::Config(
165                "admin_enabled=true requires auth to be configured and enabled".into(),
166            ));
167        }
168        if server.admin_role.trim().is_empty() {
169            return Err(McpxError::Config("admin_role must not be empty".into()));
170        }
171    }
172
173    for (field, value) in [
174        ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
175        ("server.request_timeout", server.request_timeout.as_str()),
176        (
177            "server.session_idle_timeout",
178            server.session_idle_timeout.as_str(),
179        ),
180        ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
181    ] {
182        if humantime::parse_duration(value).is_err() {
183            return Err(McpxError::Config(format!(
184                "invalid duration for {field}: {value:?}"
185            )));
186        }
187    }
188
189    Ok(())
190}
191
192/// Validate observability config fields.
193///
194/// # Errors
195///
196/// Returns `McpxError::Config` on invalid values.
197pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
198    use tracing_subscriber::EnvFilter;
199
200    use crate::error::McpxError;
201
202    if EnvFilter::try_new(&obs.log_level).is_err() {
203        return Err(McpxError::Config(format!(
204            "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
205            obs.log_level
206        )));
207    }
208    let valid_formats = ["json", "pretty", "text"];
209    if !valid_formats.contains(&obs.log_format.as_str()) {
210        return Err(McpxError::Config(format!(
211            "invalid log_format: {:?} (expected one of: {valid_formats:?})",
212            obs.log_format
213        )));
214    }
215
216    Ok(())
217}
218
219// - Default value functions -
220
221fn default_listen_addr() -> String {
222    "127.0.0.1".into()
223}
224fn default_listen_port() -> u16 {
225    8443
226}
227fn default_shutdown_timeout() -> String {
228    "30s".into()
229}
230fn default_request_timeout() -> String {
231    "120s".into()
232}
233fn default_log_level() -> String {
234    "info,rmcp=warn".into()
235}
236fn default_log_format() -> String {
237    "pretty".into()
238}
239fn default_metrics_bind() -> String {
240    "127.0.0.1:9090".into()
241}
242fn default_session_idle_timeout() -> String {
243    "20m".into()
244}
245fn default_admin_role() -> String {
246    "admin".into()
247}
248fn default_compression_min_size() -> u16 {
249    1024
250}
251fn default_sse_keep_alive() -> String {
252    "15s".into()
253}
254
255#[cfg(test)]
256mod tests {
257    #![allow(
258        clippy::unwrap_used,
259        clippy::expect_used,
260        clippy::panic,
261        clippy::indexing_slicing,
262        clippy::unwrap_in_result,
263        clippy::print_stdout,
264        clippy::print_stderr
265    )]
266    use super::*;
267
268    // -- ServerConfig defaults --
269
270    #[test]
271    fn server_config_defaults() {
272        let cfg = ServerConfig::default();
273        assert_eq!(cfg.listen_addr, "127.0.0.1");
274        assert_eq!(cfg.listen_port, 8443);
275        assert!(cfg.tls_cert_path.is_none());
276        assert!(cfg.tls_key_path.is_none());
277        assert_eq!(cfg.shutdown_timeout, "30s");
278        assert_eq!(cfg.request_timeout, "120s");
279        assert!(cfg.allowed_origins.is_empty());
280        assert!(!cfg.stdio_enabled);
281        assert!(cfg.tool_rate_limit.is_none());
282        assert_eq!(cfg.session_idle_timeout, "20m");
283        assert_eq!(cfg.sse_keep_alive, "15s");
284        assert!(cfg.public_url.is_none());
285    }
286
287    #[test]
288    fn observability_config_defaults() {
289        let cfg = ObservabilityConfig::default();
290        assert_eq!(cfg.log_level, "info,rmcp=warn");
291        assert_eq!(cfg.log_format, "pretty");
292        assert!(cfg.audit_log_path.is_none());
293        assert!(!cfg.log_request_headers);
294        assert!(!cfg.metrics_enabled);
295        assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
296    }
297
298    // -- validate_server_config --
299
300    #[test]
301    fn valid_server_config_passes() {
302        let cfg = ServerConfig::default();
303        assert!(validate_server_config(&cfg).is_ok());
304    }
305
306    #[test]
307    fn zero_port_rejected() {
308        let cfg = ServerConfig {
309            listen_port: 0,
310            ..ServerConfig::default()
311        };
312        let err = validate_server_config(&cfg).unwrap_err();
313        assert!(err.to_string().contains("listen_port"));
314    }
315
316    #[test]
317    fn tls_cert_without_key_rejected() {
318        let cfg = ServerConfig {
319            tls_cert_path: Some("/tmp/cert.pem".into()),
320            ..ServerConfig::default()
321        };
322        let err = validate_server_config(&cfg).unwrap_err();
323        assert!(err.to_string().contains("tls_cert_path"));
324    }
325
326    #[test]
327    fn tls_key_without_cert_rejected() {
328        let cfg = ServerConfig {
329            tls_key_path: Some("/tmp/key.pem".into()),
330            ..ServerConfig::default()
331        };
332        let err = validate_server_config(&cfg).unwrap_err();
333        assert!(err.to_string().contains("tls_cert_path"));
334    }
335
336    #[test]
337    fn tls_both_set_passes() {
338        let cfg = ServerConfig {
339            tls_cert_path: Some("/tmp/cert.pem".into()),
340            tls_key_path: Some("/tmp/key.pem".into()),
341            ..ServerConfig::default()
342        };
343        assert!(validate_server_config(&cfg).is_ok());
344    }
345
346    #[test]
347    fn invalid_shutdown_timeout_rejected() {
348        let cfg = ServerConfig {
349            shutdown_timeout: "not-a-duration".into(),
350            ..ServerConfig::default()
351        };
352        let err = validate_server_config(&cfg).unwrap_err();
353        assert!(err.to_string().contains("shutdown_timeout"));
354    }
355
356    #[test]
357    fn invalid_request_timeout_rejected() {
358        let cfg = ServerConfig {
359            request_timeout: "xyz".into(),
360            ..ServerConfig::default()
361        };
362        let err = validate_server_config(&cfg).unwrap_err();
363        assert!(err.to_string().contains("request_timeout"));
364    }
365
366    // -- validate_observability_config --
367
368    #[test]
369    fn valid_observability_config_passes() {
370        let cfg = ObservabilityConfig::default();
371        assert!(validate_observability_config(&cfg).is_ok());
372    }
373
374    #[test]
375    fn invalid_log_level_rejected() {
376        let cfg = ObservabilityConfig {
377            log_level: "[invalid".into(),
378            ..ObservabilityConfig::default()
379        };
380        let err = validate_observability_config(&cfg).unwrap_err();
381        assert!(err.to_string().contains("log_level"));
382    }
383
384    #[test]
385    fn invalid_log_format_rejected() {
386        let cfg = ObservabilityConfig {
387            log_format: "yaml".into(),
388            ..ObservabilityConfig::default()
389        };
390        let err = validate_observability_config(&cfg).unwrap_err();
391        assert!(err.to_string().contains("log_format"));
392    }
393
394    #[test]
395    fn all_valid_log_levels_accepted() {
396        for level in &[
397            "trace",
398            "debug",
399            "info",
400            "warn",
401            "error",
402            "info,rmcp=warn",
403            "debug,hyper=error",
404        ] {
405            let cfg = ObservabilityConfig {
406                log_level: (*level).into(),
407                ..ObservabilityConfig::default()
408            };
409            assert!(
410                validate_observability_config(&cfg).is_ok(),
411                "level {level} should be valid"
412            );
413        }
414    }
415
416    #[test]
417    fn both_log_formats_accepted() {
418        for fmt in &["json", "pretty"] {
419            let cfg = ObservabilityConfig {
420                log_format: (*fmt).into(),
421                ..ObservabilityConfig::default()
422            };
423            assert!(
424                validate_observability_config(&cfg).is_ok(),
425                "format {fmt} should be valid"
426            );
427        }
428    }
429
430    // -- serde deserialization --
431
432    #[test]
433    fn server_config_deserialize_defaults() {
434        let cfg: ServerConfig = toml::from_str("").unwrap();
435        assert_eq!(cfg.listen_port, 8443);
436        assert_eq!(cfg.listen_addr, "127.0.0.1");
437    }
438
439    #[test]
440    fn observability_config_deserialize_defaults() {
441        let cfg: ObservabilityConfig = toml::from_str("").unwrap();
442        assert_eq!(cfg.log_level, "info,rmcp=warn");
443        assert_eq!(cfg.log_format, "pretty");
444        assert!(!cfg.log_request_headers);
445        assert!(!cfg.metrics_enabled);
446    }
447}