Skip to main content

rmcp_server_kit/
config.rs

1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5/// Server listener configuration (reusable across MCP projects).
6#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9    /// Listen address (IP or hostname). Default: `127.0.0.1`.
10    #[serde(default = "default_listen_addr")]
11    pub listen_addr: String,
12    /// Listen TCP port. Default: `8443`.
13    #[serde(default = "default_listen_port")]
14    pub listen_port: u16,
15    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
16    pub tls_cert_path: Option<PathBuf>,
17    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
18    pub tls_key_path: Option<PathBuf>,
19    /// Graceful shutdown timeout, parsed via `humantime`.
20    #[serde(default = "default_shutdown_timeout")]
21    pub shutdown_timeout: String,
22    /// Per-request timeout, parsed via `humantime`.
23    #[serde(default = "default_request_timeout")]
24    pub request_timeout: String,
25    /// Allowed Origin header values for DNS rebinding protection (MCP spec).
26    /// Requests with an Origin not in this list are rejected with 403.
27    /// Requests without an Origin header are always allowed (non-browser).
28    #[serde(default)]
29    pub allowed_origins: Vec<String>,
30    /// Allow the stdio transport subcommand. Disabled by default because
31    /// stdio mode bypasses auth, RBAC, TLS, and Origin validation.
32    #[serde(default)]
33    pub stdio_enabled: bool,
34    /// Maximum tool invocations per source IP per minute.
35    /// When set, enforced by the RBAC middleware on `tools/call` requests.
36    /// Protects against both abuse and runaway LLM loops.
37    pub tool_rate_limit: Option<u32>,
38    /// Idle timeout for MCP sessions. Sessions with no activity for this
39    /// duration are closed automatically. Default: 20 minutes.
40    #[serde(default = "default_session_idle_timeout")]
41    pub session_idle_timeout: String,
42    /// Interval for SSE keep-alive pings sent to the client. Prevents
43    /// proxies and load balancers from killing idle connections.
44    /// Default: 15 seconds.
45    #[serde(default = "default_sse_keep_alive")]
46    pub sse_keep_alive: String,
47    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
48    /// When set, OAuth metadata endpoints advertise this URL instead of
49    /// the listen address. Required when the server binds to `0.0.0.0`
50    /// behind a reverse proxy or inside a container.
51    pub public_url: Option<String>,
52    /// Enable gzip/br response compression for MCP responses.
53    #[serde(default)]
54    pub compression_enabled: bool,
55    /// Minimum response size (bytes) before compression kicks in.
56    /// Only used when `compression_enabled` is true. Default: 1024.
57    #[serde(default = "default_compression_min_size")]
58    pub compression_min_size: u16,
59    /// Global cap on in-flight HTTP requests. When reached, excess
60    /// requests receive 503 Service Unavailable (via load shedding).
61    pub max_concurrent_requests: Option<usize>,
62    /// Enable `/admin/*` diagnostic endpoints.
63    #[serde(default)]
64    pub admin_enabled: bool,
65    /// RBAC role required to access admin endpoints.
66    #[serde(default = "default_admin_role")]
67    pub admin_role: String,
68    /// Authentication configuration (API keys, mTLS, OAuth).
69    pub auth: Option<crate::auth::AuthConfig>,
70}
71
72impl Default for ServerConfig {
73    fn default() -> Self {
74        Self {
75            listen_addr: default_listen_addr(),
76            listen_port: default_listen_port(),
77            tls_cert_path: None,
78            tls_key_path: None,
79            shutdown_timeout: default_shutdown_timeout(),
80            request_timeout: default_request_timeout(),
81            allowed_origins: Vec::new(),
82            stdio_enabled: false,
83            tool_rate_limit: None,
84            session_idle_timeout: default_session_idle_timeout(),
85            sse_keep_alive: default_sse_keep_alive(),
86            public_url: None,
87            compression_enabled: false,
88            compression_min_size: default_compression_min_size(),
89            max_concurrent_requests: None,
90            admin_enabled: false,
91            admin_role: default_admin_role(),
92            auth: None,
93        }
94    }
95}
96
97/// Observability settings (reusable across MCP projects).
98#[derive(Debug, Deserialize)]
99#[non_exhaustive]
100pub struct ObservabilityConfig {
101    /// `tracing` log level / env filter string (e.g. `info,rmcp_server_kit=debug`).
102    #[serde(default = "default_log_level")]
103    pub log_level: String,
104    /// Log output format: `json` or `text`.
105    #[serde(default = "default_log_format")]
106    pub log_format: String,
107    /// Optional path to an append-only audit log file.
108    pub audit_log_path: Option<PathBuf>,
109    /// Emit inbound HTTP request headers at DEBUG level in transport logs.
110    /// Sensitive headers remain redacted when enabled.
111    #[serde(default)]
112    pub log_request_headers: bool,
113    /// Enable the Prometheus metrics endpoint.
114    #[serde(default)]
115    pub metrics_enabled: bool,
116    /// Bind address for the Prometheus metrics listener.
117    #[serde(default = "default_metrics_bind")]
118    pub metrics_bind: String,
119}
120
121impl Default for ObservabilityConfig {
122    fn default() -> Self {
123        Self {
124            log_level: default_log_level(),
125            log_format: default_log_format(),
126            audit_log_path: None,
127            log_request_headers: false,
128            metrics_enabled: false,
129            metrics_bind: default_metrics_bind(),
130        }
131    }
132}
133
134/// Validate the generic server config fields.
135///
136/// # Errors
137///
138/// Returns `McpxError::Config` on invalid values.
139pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
140    use crate::error::McpxError;
141
142    if server.listen_port == 0 {
143        return Err(McpxError::Config("listen_port must be nonzero".into()));
144    }
145
146    match (&server.tls_cert_path, &server.tls_key_path) {
147        (Some(_), None) | (None, Some(_)) => {
148            return Err(McpxError::Config(
149                "tls_cert_path and tls_key_path must both be set or both omitted".into(),
150            ));
151        }
152        _ => {}
153    }
154
155    if let Some(0) = server.max_concurrent_requests {
156        return Err(McpxError::Config(
157            "max_concurrent_requests must be nonzero when set".into(),
158        ));
159    }
160
161    if server.admin_enabled {
162        let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
163        if !auth_enabled {
164            return Err(McpxError::Config(
165                "admin_enabled=true requires auth to be configured and enabled".into(),
166            ));
167        }
168        if server.admin_role.trim().is_empty() {
169            return Err(McpxError::Config("admin_role must not be empty".into()));
170        }
171    }
172
173    for (field, value) in [
174        ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
175        ("server.request_timeout", server.request_timeout.as_str()),
176        (
177            "server.session_idle_timeout",
178            server.session_idle_timeout.as_str(),
179        ),
180        ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
181    ] {
182        if humantime::parse_duration(value).is_err() {
183            return Err(McpxError::Config(format!(
184                "invalid duration for {field}: {value:?}"
185            )));
186        }
187    }
188
189    Ok(())
190}
191
192/// Validate observability config fields.
193///
194/// # Errors
195///
196/// Returns `McpxError::Config` on invalid values.
197pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
198    use tracing_subscriber::EnvFilter;
199
200    use crate::error::McpxError;
201
202    if EnvFilter::try_new(&obs.log_level).is_err() {
203        return Err(McpxError::Config(format!(
204            "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
205            obs.log_level
206        )));
207    }
208    let valid_formats = ["json", "pretty", "text"];
209    if !valid_formats.contains(&obs.log_format.as_str()) {
210        return Err(McpxError::Config(format!(
211            "invalid log_format: {:?} (expected one of: {valid_formats:?})",
212            obs.log_format
213        )));
214    }
215
216    Ok(())
217}
218
219// - Default value functions -
220
221fn default_listen_addr() -> String {
222    "127.0.0.1".into()
223}
224fn default_listen_port() -> u16 {
225    8443
226}
227fn default_shutdown_timeout() -> String {
228    "30s".into()
229}
230fn default_request_timeout() -> String {
231    "120s".into()
232}
233fn default_log_level() -> String {
234    "info,rmcp=warn".into()
235}
236fn default_log_format() -> String {
237    "pretty".into()
238}
239fn default_metrics_bind() -> String {
240    "127.0.0.1:9090".into()
241}
242fn default_session_idle_timeout() -> String {
243    "20m".into()
244}
245fn default_admin_role() -> String {
246    "admin".into()
247}
248fn default_compression_min_size() -> u16 {
249    1024
250}
251fn default_sse_keep_alive() -> String {
252    "15s".into()
253}
254
255#[cfg(test)]
256mod tests {
257    #![allow(
258        clippy::unwrap_used,
259        clippy::expect_used,
260        clippy::panic,
261        clippy::indexing_slicing,
262        clippy::unwrap_in_result,
263        clippy::print_stdout,
264        clippy::print_stderr,
265        reason = "test-only relaxations; production code uses ? and tracing"
266    )]
267    use super::*;
268
269    // -- ServerConfig defaults --
270
271    #[test]
272    fn server_config_defaults() {
273        let cfg = ServerConfig::default();
274        assert_eq!(cfg.listen_addr, "127.0.0.1");
275        assert_eq!(cfg.listen_port, 8443);
276        assert!(cfg.tls_cert_path.is_none());
277        assert!(cfg.tls_key_path.is_none());
278        assert_eq!(cfg.shutdown_timeout, "30s");
279        assert_eq!(cfg.request_timeout, "120s");
280        assert!(cfg.allowed_origins.is_empty());
281        assert!(!cfg.stdio_enabled);
282        assert!(cfg.tool_rate_limit.is_none());
283        assert_eq!(cfg.session_idle_timeout, "20m");
284        assert_eq!(cfg.sse_keep_alive, "15s");
285        assert!(cfg.public_url.is_none());
286    }
287
288    #[test]
289    fn observability_config_defaults() {
290        let cfg = ObservabilityConfig::default();
291        assert_eq!(cfg.log_level, "info,rmcp=warn");
292        assert_eq!(cfg.log_format, "pretty");
293        assert!(cfg.audit_log_path.is_none());
294        assert!(!cfg.log_request_headers);
295        assert!(!cfg.metrics_enabled);
296        assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
297    }
298
299    // -- validate_server_config --
300
301    #[test]
302    fn valid_server_config_passes() {
303        let cfg = ServerConfig::default();
304        assert!(validate_server_config(&cfg).is_ok());
305    }
306
307    #[test]
308    fn zero_port_rejected() {
309        let cfg = ServerConfig {
310            listen_port: 0,
311            ..ServerConfig::default()
312        };
313        let err = validate_server_config(&cfg).unwrap_err();
314        assert!(err.to_string().contains("listen_port"));
315    }
316
317    #[test]
318    fn tls_cert_without_key_rejected() {
319        let cfg = ServerConfig {
320            tls_cert_path: Some("/tmp/cert.pem".into()),
321            ..ServerConfig::default()
322        };
323        let err = validate_server_config(&cfg).unwrap_err();
324        assert!(err.to_string().contains("tls_cert_path"));
325    }
326
327    #[test]
328    fn tls_key_without_cert_rejected() {
329        let cfg = ServerConfig {
330            tls_key_path: Some("/tmp/key.pem".into()),
331            ..ServerConfig::default()
332        };
333        let err = validate_server_config(&cfg).unwrap_err();
334        assert!(err.to_string().contains("tls_cert_path"));
335    }
336
337    #[test]
338    fn tls_both_set_passes() {
339        let cfg = ServerConfig {
340            tls_cert_path: Some("/tmp/cert.pem".into()),
341            tls_key_path: Some("/tmp/key.pem".into()),
342            ..ServerConfig::default()
343        };
344        assert!(validate_server_config(&cfg).is_ok());
345    }
346
347    #[test]
348    fn invalid_shutdown_timeout_rejected() {
349        let cfg = ServerConfig {
350            shutdown_timeout: "not-a-duration".into(),
351            ..ServerConfig::default()
352        };
353        let err = validate_server_config(&cfg).unwrap_err();
354        assert!(err.to_string().contains("shutdown_timeout"));
355    }
356
357    #[test]
358    fn invalid_request_timeout_rejected() {
359        let cfg = ServerConfig {
360            request_timeout: "xyz".into(),
361            ..ServerConfig::default()
362        };
363        let err = validate_server_config(&cfg).unwrap_err();
364        assert!(err.to_string().contains("request_timeout"));
365    }
366
367    // -- validate_observability_config --
368
369    #[test]
370    fn valid_observability_config_passes() {
371        let cfg = ObservabilityConfig::default();
372        assert!(validate_observability_config(&cfg).is_ok());
373    }
374
375    #[test]
376    fn invalid_log_level_rejected() {
377        let cfg = ObservabilityConfig {
378            log_level: "[invalid".into(),
379            ..ObservabilityConfig::default()
380        };
381        let err = validate_observability_config(&cfg).unwrap_err();
382        assert!(err.to_string().contains("log_level"));
383    }
384
385    #[test]
386    fn invalid_log_format_rejected() {
387        let cfg = ObservabilityConfig {
388            log_format: "yaml".into(),
389            ..ObservabilityConfig::default()
390        };
391        let err = validate_observability_config(&cfg).unwrap_err();
392        assert!(err.to_string().contains("log_format"));
393    }
394
395    #[test]
396    fn all_valid_log_levels_accepted() {
397        for level in &[
398            "trace",
399            "debug",
400            "info",
401            "warn",
402            "error",
403            "info,rmcp=warn",
404            "debug,hyper=error",
405        ] {
406            let cfg = ObservabilityConfig {
407                log_level: (*level).into(),
408                ..ObservabilityConfig::default()
409            };
410            assert!(
411                validate_observability_config(&cfg).is_ok(),
412                "level {level} should be valid"
413            );
414        }
415    }
416
417    #[test]
418    fn both_log_formats_accepted() {
419        for fmt in &["json", "pretty"] {
420            let cfg = ObservabilityConfig {
421                log_format: (*fmt).into(),
422                ..ObservabilityConfig::default()
423            };
424            assert!(
425                validate_observability_config(&cfg).is_ok(),
426                "format {fmt} should be valid"
427            );
428        }
429    }
430
431    // -- serde deserialization --
432
433    #[test]
434    fn server_config_deserialize_defaults() {
435        let cfg: ServerConfig = toml::from_str("").unwrap();
436        assert_eq!(cfg.listen_port, 8443);
437        assert_eq!(cfg.listen_addr, "127.0.0.1");
438    }
439
440    #[test]
441    fn observability_config_deserialize_defaults() {
442        let cfg: ObservabilityConfig = toml::from_str("").unwrap();
443        assert_eq!(cfg.log_level, "info,rmcp=warn");
444        assert_eq!(cfg.log_format, "pretty");
445        assert!(!cfg.log_request_headers);
446        assert!(!cfg.metrics_enabled);
447    }
448}