1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9 #[serde(default = "default_listen_addr")]
11 pub listen_addr: String,
12 #[serde(default = "default_listen_port")]
14 pub listen_port: u16,
15 pub tls_cert_path: Option<PathBuf>,
17 pub tls_key_path: Option<PathBuf>,
19 #[serde(default = "default_tls_handshake_timeout")]
24 pub tls_handshake_timeout: String,
25 #[serde(default = "default_max_concurrent_tls_handshakes")]
30 pub max_concurrent_tls_handshakes: usize,
31 #[serde(default = "default_shutdown_timeout")]
33 pub shutdown_timeout: String,
34 #[serde(default = "default_request_timeout")]
36 pub request_timeout: String,
37 #[serde(default)]
41 pub allowed_origins: Vec<String>,
42 #[serde(default)]
45 pub stdio_enabled: bool,
46 pub tool_rate_limit: Option<u32>,
50 pub extra_route_rate_limit: Option<u32>,
56 #[serde(default = "default_session_idle_timeout")]
59 pub session_idle_timeout: String,
60 #[serde(default = "default_sse_keep_alive")]
64 pub sse_keep_alive: String,
65 pub public_url: Option<String>,
70 #[serde(default)]
72 pub compression_enabled: bool,
73 #[serde(default = "default_compression_min_size")]
76 pub compression_min_size: u16,
77 pub max_concurrent_requests: Option<usize>,
80 #[serde(default)]
82 pub admin_enabled: bool,
83 #[serde(default = "default_admin_role")]
85 pub admin_role: String,
86 pub auth: Option<crate::auth::AuthConfig>,
88}
89
90impl Default for ServerConfig {
91 fn default() -> Self {
92 Self {
93 listen_addr: default_listen_addr(),
94 listen_port: default_listen_port(),
95 tls_cert_path: None,
96 tls_key_path: None,
97 tls_handshake_timeout: default_tls_handshake_timeout(),
98 max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
99 shutdown_timeout: default_shutdown_timeout(),
100 request_timeout: default_request_timeout(),
101 allowed_origins: Vec::new(),
102 stdio_enabled: false,
103 tool_rate_limit: None,
104 extra_route_rate_limit: None,
105 session_idle_timeout: default_session_idle_timeout(),
106 sse_keep_alive: default_sse_keep_alive(),
107 public_url: None,
108 compression_enabled: false,
109 compression_min_size: default_compression_min_size(),
110 max_concurrent_requests: None,
111 admin_enabled: false,
112 admin_role: default_admin_role(),
113 auth: None,
114 }
115 }
116}
117
118#[derive(Debug, Deserialize)]
120#[non_exhaustive]
121pub struct ObservabilityConfig {
122 #[serde(default = "default_log_level")]
124 pub log_level: String,
125 #[serde(default = "default_log_format")]
127 pub log_format: String,
128 pub audit_log_path: Option<PathBuf>,
130 #[serde(default)]
133 pub log_request_headers: bool,
134 #[serde(default)]
136 pub metrics_enabled: bool,
137 #[serde(default = "default_metrics_bind")]
139 pub metrics_bind: String,
140}
141
142impl Default for ObservabilityConfig {
143 fn default() -> Self {
144 Self {
145 log_level: default_log_level(),
146 log_format: default_log_format(),
147 audit_log_path: None,
148 log_request_headers: false,
149 metrics_enabled: false,
150 metrics_bind: default_metrics_bind(),
151 }
152 }
153}
154
155pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
161 use crate::error::McpxError;
162
163 if server.listen_port == 0 {
164 return Err(McpxError::Config("listen_port must be nonzero".into()));
165 }
166
167 match (&server.tls_cert_path, &server.tls_key_path) {
168 (Some(_), None) | (None, Some(_)) => {
169 return Err(McpxError::Config(
170 "tls_cert_path and tls_key_path must both be set or both omitted".into(),
171 ));
172 }
173 _ => {}
174 }
175
176 if let Some(0) = server.max_concurrent_requests {
177 return Err(McpxError::Config(
178 "max_concurrent_requests must be nonzero when set".into(),
179 ));
180 }
181
182 if let Some(0) = server.extra_route_rate_limit {
183 return Err(McpxError::Config(
184 "server.extra_route_rate_limit must be greater than zero".into(),
185 ));
186 }
187
188 if server.admin_enabled {
189 let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
190 if !auth_enabled {
191 return Err(McpxError::Config(
192 "admin_enabled=true requires auth to be configured and enabled".into(),
193 ));
194 }
195 if server.admin_role.trim().is_empty() {
196 return Err(McpxError::Config("admin_role must not be empty".into()));
197 }
198 }
199
200 for (field, value) in [
201 ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
202 ("server.request_timeout", server.request_timeout.as_str()),
203 (
204 "server.session_idle_timeout",
205 server.session_idle_timeout.as_str(),
206 ),
207 ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
208 (
209 "server.tls_handshake_timeout",
210 server.tls_handshake_timeout.as_str(),
211 ),
212 ] {
213 if humantime::parse_duration(value).is_err() {
214 return Err(McpxError::Config(format!(
215 "invalid duration for {field}: {value:?}"
216 )));
217 }
218 }
219
220 if humantime::parse_duration(&server.tls_handshake_timeout)
224 .is_ok_and(|d| d == std::time::Duration::ZERO)
225 {
226 return Err(McpxError::Config(
227 "server.tls_handshake_timeout must be greater than zero".into(),
228 ));
229 }
230
231 if server.max_concurrent_tls_handshakes == 0 {
235 return Err(McpxError::Config(
236 "server.max_concurrent_tls_handshakes must be greater than zero".into(),
237 ));
238 }
239
240 Ok(())
241}
242
243pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
249 use tracing_subscriber::EnvFilter;
250
251 use crate::error::McpxError;
252
253 if EnvFilter::try_new(&obs.log_level).is_err() {
254 return Err(McpxError::Config(format!(
255 "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
256 obs.log_level
257 )));
258 }
259 let valid_formats = ["json", "pretty", "text"];
260 if !valid_formats.contains(&obs.log_format.as_str()) {
261 return Err(McpxError::Config(format!(
262 "invalid log_format: {:?} (expected one of: {valid_formats:?})",
263 obs.log_format
264 )));
265 }
266
267 Ok(())
268}
269
270fn default_listen_addr() -> String {
273 "127.0.0.1".into()
274}
275fn default_listen_port() -> u16 {
276 8443
277}
278fn default_shutdown_timeout() -> String {
279 "30s".into()
280}
281fn default_request_timeout() -> String {
282 "120s".into()
283}
284fn default_log_level() -> String {
285 "info,rmcp=warn".into()
286}
287fn default_log_format() -> String {
288 "pretty".into()
289}
290fn default_metrics_bind() -> String {
291 "127.0.0.1:9090".into()
292}
293fn default_session_idle_timeout() -> String {
294 "20m".into()
295}
296fn default_tls_handshake_timeout() -> String {
297 "10s".into()
298}
299const fn default_max_concurrent_tls_handshakes() -> usize {
300 256
301}
302fn default_admin_role() -> String {
303 "admin".into()
304}
305fn default_compression_min_size() -> u16 {
306 1024
307}
308fn default_sse_keep_alive() -> String {
309 "15s".into()
310}
311
312#[cfg(test)]
313mod tests {
314 #![allow(
315 clippy::unwrap_used,
316 clippy::expect_used,
317 clippy::panic,
318 clippy::indexing_slicing,
319 clippy::unwrap_in_result,
320 clippy::print_stdout,
321 clippy::print_stderr,
322 reason = "test-only relaxations; production code uses ? and tracing"
323 )]
324 use super::*;
325
326 #[test]
329 fn server_config_defaults() {
330 let cfg = ServerConfig::default();
331 assert_eq!(cfg.listen_addr, "127.0.0.1");
332 assert_eq!(cfg.listen_port, 8443);
333 assert!(cfg.tls_cert_path.is_none());
334 assert!(cfg.tls_key_path.is_none());
335 assert_eq!(cfg.shutdown_timeout, "30s");
336 assert_eq!(cfg.request_timeout, "120s");
337 assert!(cfg.allowed_origins.is_empty());
338 assert!(!cfg.stdio_enabled);
339 assert!(cfg.tool_rate_limit.is_none());
340 assert_eq!(cfg.session_idle_timeout, "20m");
341 assert_eq!(cfg.sse_keep_alive, "15s");
342 assert!(cfg.public_url.is_none());
343 }
344
345 #[test]
346 fn observability_config_defaults() {
347 let cfg = ObservabilityConfig::default();
348 assert_eq!(cfg.log_level, "info,rmcp=warn");
349 assert_eq!(cfg.log_format, "pretty");
350 assert!(cfg.audit_log_path.is_none());
351 assert!(!cfg.log_request_headers);
352 assert!(!cfg.metrics_enabled);
353 assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
354 }
355
356 #[test]
359 fn valid_server_config_passes() {
360 let cfg = ServerConfig::default();
361 assert!(validate_server_config(&cfg).is_ok());
362 }
363
364 #[test]
365 fn zero_port_rejected() {
366 let cfg = ServerConfig {
367 listen_port: 0,
368 ..ServerConfig::default()
369 };
370 let err = validate_server_config(&cfg).unwrap_err();
371 assert!(err.to_string().contains("listen_port"));
372 }
373
374 #[test]
375 fn zero_extra_route_rate_limit_rejected() {
376 let cfg = ServerConfig {
377 extra_route_rate_limit: Some(0),
378 ..ServerConfig::default()
379 };
380 let err = validate_server_config(&cfg).unwrap_err();
381 assert!(err.to_string().contains("extra_route_rate_limit"));
382 }
383
384 #[test]
385 fn tls_cert_without_key_rejected() {
386 let cfg = ServerConfig {
387 tls_cert_path: Some("/tmp/cert.pem".into()),
388 ..ServerConfig::default()
389 };
390 let err = validate_server_config(&cfg).unwrap_err();
391 assert!(err.to_string().contains("tls_cert_path"));
392 }
393
394 #[test]
395 fn tls_key_without_cert_rejected() {
396 let cfg = ServerConfig {
397 tls_key_path: Some("/tmp/key.pem".into()),
398 ..ServerConfig::default()
399 };
400 let err = validate_server_config(&cfg).unwrap_err();
401 assert!(err.to_string().contains("tls_cert_path"));
402 }
403
404 #[test]
405 fn tls_both_set_passes() {
406 let cfg = ServerConfig {
407 tls_cert_path: Some("/tmp/cert.pem".into()),
408 tls_key_path: Some("/tmp/key.pem".into()),
409 ..ServerConfig::default()
410 };
411 assert!(validate_server_config(&cfg).is_ok());
412 }
413
414 #[test]
415 fn invalid_tls_handshake_timeout_rejected() {
416 let cfg = ServerConfig {
417 tls_handshake_timeout: "not-a-duration".into(),
418 ..ServerConfig::default()
419 };
420 let err = validate_server_config(&cfg).unwrap_err();
421 assert!(err.to_string().contains("tls_handshake_timeout"));
422 }
423
424 #[test]
425 fn zero_tls_handshake_timeout_rejected() {
426 let cfg = ServerConfig {
427 tls_handshake_timeout: "0s".into(),
428 ..ServerConfig::default()
429 };
430 let err = validate_server_config(&cfg).unwrap_err();
431 assert!(err.to_string().contains("tls_handshake_timeout"));
432 }
433
434 #[test]
435 fn zero_max_concurrent_tls_handshakes_rejected() {
436 let cfg = ServerConfig {
437 max_concurrent_tls_handshakes: 0,
438 ..ServerConfig::default()
439 };
440 let err = validate_server_config(&cfg).unwrap_err();
441 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
442 }
443
444 #[test]
445 fn invalid_shutdown_timeout_rejected() {
446 let cfg = ServerConfig {
447 shutdown_timeout: "not-a-duration".into(),
448 ..ServerConfig::default()
449 };
450 let err = validate_server_config(&cfg).unwrap_err();
451 assert!(err.to_string().contains("shutdown_timeout"));
452 }
453
454 #[test]
455 fn invalid_request_timeout_rejected() {
456 let cfg = ServerConfig {
457 request_timeout: "xyz".into(),
458 ..ServerConfig::default()
459 };
460 let err = validate_server_config(&cfg).unwrap_err();
461 assert!(err.to_string().contains("request_timeout"));
462 }
463
464 #[test]
467 fn valid_observability_config_passes() {
468 let cfg = ObservabilityConfig::default();
469 assert!(validate_observability_config(&cfg).is_ok());
470 }
471
472 #[test]
473 fn invalid_log_level_rejected() {
474 let cfg = ObservabilityConfig {
475 log_level: "[invalid".into(),
476 ..ObservabilityConfig::default()
477 };
478 let err = validate_observability_config(&cfg).unwrap_err();
479 assert!(err.to_string().contains("log_level"));
480 }
481
482 #[test]
483 fn invalid_log_format_rejected() {
484 let cfg = ObservabilityConfig {
485 log_format: "yaml".into(),
486 ..ObservabilityConfig::default()
487 };
488 let err = validate_observability_config(&cfg).unwrap_err();
489 assert!(err.to_string().contains("log_format"));
490 }
491
492 #[test]
493 fn all_valid_log_levels_accepted() {
494 for level in &[
495 "trace",
496 "debug",
497 "info",
498 "warn",
499 "error",
500 "info,rmcp=warn",
501 "debug,hyper=error",
502 ] {
503 let cfg = ObservabilityConfig {
504 log_level: (*level).into(),
505 ..ObservabilityConfig::default()
506 };
507 assert!(
508 validate_observability_config(&cfg).is_ok(),
509 "level {level} should be valid"
510 );
511 }
512 }
513
514 #[test]
515 fn all_log_formats_accepted() {
516 for fmt in &["json", "pretty", "text"] {
517 let cfg = ObservabilityConfig {
518 log_format: (*fmt).into(),
519 ..ObservabilityConfig::default()
520 };
521 assert!(
522 validate_observability_config(&cfg).is_ok(),
523 "format {fmt} should be valid"
524 );
525 }
526 }
527
528 #[test]
531 fn server_config_deserialize_defaults() {
532 let cfg: ServerConfig = toml::from_str("").unwrap();
533 assert_eq!(cfg.listen_port, 8443);
534 assert_eq!(cfg.listen_addr, "127.0.0.1");
535 assert_eq!(cfg.tls_handshake_timeout, "10s");
536 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
537 }
538
539 #[test]
540 fn observability_config_deserialize_defaults() {
541 let cfg: ObservabilityConfig = toml::from_str("").unwrap();
542 assert_eq!(cfg.log_level, "info,rmcp=warn");
543 assert_eq!(cfg.log_format, "pretty");
544 assert!(!cfg.log_request_headers);
545 assert!(!cfg.metrics_enabled);
546 }
547}