1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9 #[serde(default = "default_listen_addr")]
11 pub listen_addr: String,
12 #[serde(default = "default_listen_port")]
14 pub listen_port: u16,
15 pub tls_cert_path: Option<PathBuf>,
17 pub tls_key_path: Option<PathBuf>,
19 #[serde(default = "default_tls_handshake_timeout")]
24 pub tls_handshake_timeout: String,
25 #[serde(default = "default_max_concurrent_tls_handshakes")]
30 pub max_concurrent_tls_handshakes: usize,
31 #[serde(default = "default_shutdown_timeout")]
33 pub shutdown_timeout: String,
34 #[serde(default = "default_request_timeout")]
36 pub request_timeout: String,
37 #[serde(default)]
41 pub allowed_origins: Vec<String>,
42 #[serde(default)]
45 pub stdio_enabled: bool,
46 pub tool_rate_limit: Option<u32>,
50 pub tool_rate_limit_burst: Option<u32>,
54 pub extra_route_rate_limit: Option<u32>,
60 pub extra_route_rate_limit_burst: Option<u32>,
64 #[serde(default = "default_session_idle_timeout")]
67 pub session_idle_timeout: String,
68 #[serde(default = "default_sse_keep_alive")]
72 pub sse_keep_alive: String,
73 pub public_url: Option<String>,
78 #[serde(default)]
80 pub compression_enabled: bool,
81 #[serde(default = "default_compression_min_size")]
84 pub compression_min_size: u16,
85 pub max_concurrent_requests: Option<usize>,
88 #[serde(default)]
90 pub admin_enabled: bool,
91 #[serde(default = "default_admin_role")]
93 pub admin_role: String,
94 pub auth: Option<crate::auth::AuthConfig>,
96}
97
98impl Default for ServerConfig {
99 fn default() -> Self {
100 Self {
101 listen_addr: default_listen_addr(),
102 listen_port: default_listen_port(),
103 tls_cert_path: None,
104 tls_key_path: None,
105 tls_handshake_timeout: default_tls_handshake_timeout(),
106 max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
107 shutdown_timeout: default_shutdown_timeout(),
108 request_timeout: default_request_timeout(),
109 allowed_origins: Vec::new(),
110 stdio_enabled: false,
111 tool_rate_limit: None,
112 tool_rate_limit_burst: None,
113 extra_route_rate_limit: None,
114 extra_route_rate_limit_burst: None,
115 session_idle_timeout: default_session_idle_timeout(),
116 sse_keep_alive: default_sse_keep_alive(),
117 public_url: None,
118 compression_enabled: false,
119 compression_min_size: default_compression_min_size(),
120 max_concurrent_requests: None,
121 admin_enabled: false,
122 admin_role: default_admin_role(),
123 auth: None,
124 }
125 }
126}
127
128#[derive(Debug, Deserialize)]
130#[non_exhaustive]
131pub struct ObservabilityConfig {
132 #[serde(default = "default_log_level")]
134 pub log_level: String,
135 #[serde(default = "default_log_format")]
137 pub log_format: String,
138 pub audit_log_path: Option<PathBuf>,
140 #[serde(default)]
143 pub log_request_headers: bool,
144 #[serde(default)]
146 pub metrics_enabled: bool,
147 #[serde(default = "default_metrics_bind")]
149 pub metrics_bind: String,
150}
151
152impl Default for ObservabilityConfig {
153 fn default() -> Self {
154 Self {
155 log_level: default_log_level(),
156 log_format: default_log_format(),
157 audit_log_path: None,
158 log_request_headers: false,
159 metrics_enabled: false,
160 metrics_bind: default_metrics_bind(),
161 }
162 }
163}
164
165pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
171 use crate::error::McpxError;
172
173 if server.listen_port == 0 {
174 return Err(McpxError::Config("listen_port must be nonzero".into()));
175 }
176
177 match (&server.tls_cert_path, &server.tls_key_path) {
178 (Some(_), None) | (None, Some(_)) => {
179 return Err(McpxError::Config(
180 "tls_cert_path and tls_key_path must both be set or both omitted".into(),
181 ));
182 }
183 _ => {}
184 }
185
186 if let Some(0) = server.max_concurrent_requests {
187 return Err(McpxError::Config(
188 "max_concurrent_requests must be nonzero when set".into(),
189 ));
190 }
191
192 if let Some(0) = server.extra_route_rate_limit {
193 return Err(McpxError::Config(
194 "server.extra_route_rate_limit must be greater than zero".into(),
195 ));
196 }
197
198 if let Some(0) = server.tool_rate_limit_burst {
199 return Err(McpxError::Config(
200 "server.tool_rate_limit_burst must be greater than zero".into(),
201 ));
202 }
203 if let Some(0) = server.extra_route_rate_limit_burst {
204 return Err(McpxError::Config(
205 "server.extra_route_rate_limit_burst must be greater than zero".into(),
206 ));
207 }
208 if server.tool_rate_limit_burst.is_some() && server.tool_rate_limit.is_none() {
209 return Err(McpxError::Config(
210 "server.tool_rate_limit_burst requires server.tool_rate_limit".into(),
211 ));
212 }
213 if server.extra_route_rate_limit_burst.is_some() && server.extra_route_rate_limit.is_none() {
214 return Err(McpxError::Config(
215 "server.extra_route_rate_limit_burst requires server.extra_route_rate_limit".into(),
216 ));
217 }
218 if let Some(rl) = server.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
219 if rl.burst == Some(0) {
220 return Err(McpxError::Config(
221 "auth.rate_limit.burst must be greater than zero".into(),
222 ));
223 }
224 if rl.pre_auth_burst == Some(0) {
225 return Err(McpxError::Config(
226 "auth.rate_limit.pre_auth_burst must be greater than zero".into(),
227 ));
228 }
229 }
230
231 if server.admin_enabled {
232 let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
233 if !auth_enabled {
234 return Err(McpxError::Config(
235 "admin_enabled=true requires auth to be configured and enabled".into(),
236 ));
237 }
238 if server.admin_role.trim().is_empty() {
239 return Err(McpxError::Config("admin_role must not be empty".into()));
240 }
241 }
242
243 for (field, value) in [
244 ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
245 ("server.request_timeout", server.request_timeout.as_str()),
246 (
247 "server.session_idle_timeout",
248 server.session_idle_timeout.as_str(),
249 ),
250 ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
251 (
252 "server.tls_handshake_timeout",
253 server.tls_handshake_timeout.as_str(),
254 ),
255 ] {
256 if humantime::parse_duration(value).is_err() {
257 return Err(McpxError::Config(format!(
258 "invalid duration for {field}: {value:?}"
259 )));
260 }
261 }
262
263 if humantime::parse_duration(&server.tls_handshake_timeout)
267 .is_ok_and(|d| d == std::time::Duration::ZERO)
268 {
269 return Err(McpxError::Config(
270 "server.tls_handshake_timeout must be greater than zero".into(),
271 ));
272 }
273
274 if server.max_concurrent_tls_handshakes == 0 {
278 return Err(McpxError::Config(
279 "server.max_concurrent_tls_handshakes must be greater than zero".into(),
280 ));
281 }
282
283 Ok(())
284}
285
286pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
292 use tracing_subscriber::EnvFilter;
293
294 use crate::error::McpxError;
295
296 if EnvFilter::try_new(&obs.log_level).is_err() {
297 return Err(McpxError::Config(format!(
298 "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
299 obs.log_level
300 )));
301 }
302 let valid_formats = ["json", "pretty", "text"];
303 if !valid_formats.contains(&obs.log_format.as_str()) {
304 return Err(McpxError::Config(format!(
305 "invalid log_format: {:?} (expected one of: {valid_formats:?})",
306 obs.log_format
307 )));
308 }
309
310 Ok(())
311}
312
313fn default_listen_addr() -> String {
316 "127.0.0.1".into()
317}
318fn default_listen_port() -> u16 {
319 8443
320}
321fn default_shutdown_timeout() -> String {
322 "30s".into()
323}
324fn default_request_timeout() -> String {
325 "120s".into()
326}
327fn default_log_level() -> String {
328 "info,rmcp=warn".into()
329}
330fn default_log_format() -> String {
331 "pretty".into()
332}
333fn default_metrics_bind() -> String {
334 "127.0.0.1:9090".into()
335}
336fn default_session_idle_timeout() -> String {
337 "20m".into()
338}
339fn default_tls_handshake_timeout() -> String {
340 "10s".into()
341}
342const fn default_max_concurrent_tls_handshakes() -> usize {
343 256
344}
345fn default_admin_role() -> String {
346 "admin".into()
347}
348fn default_compression_min_size() -> u16 {
349 1024
350}
351fn default_sse_keep_alive() -> String {
352 "15s".into()
353}
354
355#[cfg(test)]
356mod tests {
357 #![allow(
358 clippy::unwrap_used,
359 clippy::expect_used,
360 clippy::panic,
361 clippy::indexing_slicing,
362 clippy::unwrap_in_result,
363 clippy::print_stdout,
364 clippy::print_stderr,
365 reason = "test-only relaxations; production code uses ? and tracing"
366 )]
367 use super::*;
368
369 #[test]
372 fn server_config_defaults() {
373 let cfg = ServerConfig::default();
374 assert_eq!(cfg.listen_addr, "127.0.0.1");
375 assert_eq!(cfg.listen_port, 8443);
376 assert!(cfg.tls_cert_path.is_none());
377 assert!(cfg.tls_key_path.is_none());
378 assert_eq!(cfg.shutdown_timeout, "30s");
379 assert_eq!(cfg.request_timeout, "120s");
380 assert!(cfg.allowed_origins.is_empty());
381 assert!(!cfg.stdio_enabled);
382 assert!(cfg.tool_rate_limit.is_none());
383 assert_eq!(cfg.session_idle_timeout, "20m");
384 assert_eq!(cfg.sse_keep_alive, "15s");
385 assert!(cfg.public_url.is_none());
386 }
387
388 #[test]
389 fn observability_config_defaults() {
390 let cfg = ObservabilityConfig::default();
391 assert_eq!(cfg.log_level, "info,rmcp=warn");
392 assert_eq!(cfg.log_format, "pretty");
393 assert!(cfg.audit_log_path.is_none());
394 assert!(!cfg.log_request_headers);
395 assert!(!cfg.metrics_enabled);
396 assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
397 }
398
399 #[test]
402 fn valid_server_config_passes() {
403 let cfg = ServerConfig::default();
404 assert!(validate_server_config(&cfg).is_ok());
405 }
406
407 #[test]
408 fn zero_port_rejected() {
409 let cfg = ServerConfig {
410 listen_port: 0,
411 ..ServerConfig::default()
412 };
413 let err = validate_server_config(&cfg).unwrap_err();
414 assert!(err.to_string().contains("listen_port"));
415 }
416
417 #[test]
418 fn zero_extra_route_rate_limit_rejected() {
419 let cfg = ServerConfig {
420 extra_route_rate_limit: Some(0),
421 ..ServerConfig::default()
422 };
423 let err = validate_server_config(&cfg).unwrap_err();
424 assert!(err.to_string().contains("extra_route_rate_limit"));
425 }
426
427 #[test]
428 fn zero_burst_knobs_rejected() {
429 let cfg = ServerConfig {
430 tool_rate_limit: Some(10),
431 tool_rate_limit_burst: Some(0),
432 ..ServerConfig::default()
433 };
434 let err = validate_server_config(&cfg).unwrap_err();
435 assert!(err.to_string().contains("tool_rate_limit_burst"));
436
437 let cfg = ServerConfig {
438 extra_route_rate_limit: Some(10),
439 extra_route_rate_limit_burst: Some(0),
440 ..ServerConfig::default()
441 };
442 let err = validate_server_config(&cfg).unwrap_err();
443 assert!(err.to_string().contains("extra_route_rate_limit_burst"));
444 }
445
446 #[test]
447 fn orphan_burst_knobs_rejected() {
448 let cfg = ServerConfig {
449 tool_rate_limit_burst: Some(5),
450 ..ServerConfig::default()
451 };
452 let err = validate_server_config(&cfg).unwrap_err();
453 assert!(err.to_string().contains("requires server.tool_rate_limit"));
454
455 let cfg = ServerConfig {
456 extra_route_rate_limit_burst: Some(5),
457 ..ServerConfig::default()
458 };
459 let err = validate_server_config(&cfg).unwrap_err();
460 assert!(
461 err.to_string()
462 .contains("requires server.extra_route_rate_limit")
463 );
464 }
465
466 #[test]
467 fn zero_auth_bursts_rejected() {
468 let auth = crate::auth::AuthConfig::with_keys(vec![])
469 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
470 let cfg = ServerConfig {
471 auth: Some(auth),
472 ..ServerConfig::default()
473 };
474 let err = validate_server_config(&cfg).unwrap_err();
475 assert!(err.to_string().contains("rate_limit.burst"));
476
477 let auth = crate::auth::AuthConfig::with_keys(vec![])
478 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
479 let cfg = ServerConfig {
480 auth: Some(auth),
481 ..ServerConfig::default()
482 };
483 let err = validate_server_config(&cfg).unwrap_err();
484 assert!(err.to_string().contains("pre_auth_burst"));
485 }
486
487 #[test]
488 fn tls_cert_without_key_rejected() {
489 let cfg = ServerConfig {
490 tls_cert_path: Some("/tmp/cert.pem".into()),
491 ..ServerConfig::default()
492 };
493 let err = validate_server_config(&cfg).unwrap_err();
494 assert!(err.to_string().contains("tls_cert_path"));
495 }
496
497 #[test]
498 fn tls_key_without_cert_rejected() {
499 let cfg = ServerConfig {
500 tls_key_path: Some("/tmp/key.pem".into()),
501 ..ServerConfig::default()
502 };
503 let err = validate_server_config(&cfg).unwrap_err();
504 assert!(err.to_string().contains("tls_cert_path"));
505 }
506
507 #[test]
508 fn tls_both_set_passes() {
509 let cfg = ServerConfig {
510 tls_cert_path: Some("/tmp/cert.pem".into()),
511 tls_key_path: Some("/tmp/key.pem".into()),
512 ..ServerConfig::default()
513 };
514 assert!(validate_server_config(&cfg).is_ok());
515 }
516
517 #[test]
518 fn invalid_tls_handshake_timeout_rejected() {
519 let cfg = ServerConfig {
520 tls_handshake_timeout: "not-a-duration".into(),
521 ..ServerConfig::default()
522 };
523 let err = validate_server_config(&cfg).unwrap_err();
524 assert!(err.to_string().contains("tls_handshake_timeout"));
525 }
526
527 #[test]
528 fn zero_tls_handshake_timeout_rejected() {
529 let cfg = ServerConfig {
530 tls_handshake_timeout: "0s".into(),
531 ..ServerConfig::default()
532 };
533 let err = validate_server_config(&cfg).unwrap_err();
534 assert!(err.to_string().contains("tls_handshake_timeout"));
535 }
536
537 #[test]
538 fn zero_max_concurrent_tls_handshakes_rejected() {
539 let cfg = ServerConfig {
540 max_concurrent_tls_handshakes: 0,
541 ..ServerConfig::default()
542 };
543 let err = validate_server_config(&cfg).unwrap_err();
544 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
545 }
546
547 #[test]
548 fn invalid_shutdown_timeout_rejected() {
549 let cfg = ServerConfig {
550 shutdown_timeout: "not-a-duration".into(),
551 ..ServerConfig::default()
552 };
553 let err = validate_server_config(&cfg).unwrap_err();
554 assert!(err.to_string().contains("shutdown_timeout"));
555 }
556
557 #[test]
558 fn invalid_request_timeout_rejected() {
559 let cfg = ServerConfig {
560 request_timeout: "xyz".into(),
561 ..ServerConfig::default()
562 };
563 let err = validate_server_config(&cfg).unwrap_err();
564 assert!(err.to_string().contains("request_timeout"));
565 }
566
567 #[test]
570 fn valid_observability_config_passes() {
571 let cfg = ObservabilityConfig::default();
572 assert!(validate_observability_config(&cfg).is_ok());
573 }
574
575 #[test]
576 fn invalid_log_level_rejected() {
577 let cfg = ObservabilityConfig {
578 log_level: "[invalid".into(),
579 ..ObservabilityConfig::default()
580 };
581 let err = validate_observability_config(&cfg).unwrap_err();
582 assert!(err.to_string().contains("log_level"));
583 }
584
585 #[test]
586 fn invalid_log_format_rejected() {
587 let cfg = ObservabilityConfig {
588 log_format: "yaml".into(),
589 ..ObservabilityConfig::default()
590 };
591 let err = validate_observability_config(&cfg).unwrap_err();
592 assert!(err.to_string().contains("log_format"));
593 }
594
595 #[test]
596 fn all_valid_log_levels_accepted() {
597 for level in &[
598 "trace",
599 "debug",
600 "info",
601 "warn",
602 "error",
603 "info,rmcp=warn",
604 "debug,hyper=error",
605 ] {
606 let cfg = ObservabilityConfig {
607 log_level: (*level).into(),
608 ..ObservabilityConfig::default()
609 };
610 assert!(
611 validate_observability_config(&cfg).is_ok(),
612 "level {level} should be valid"
613 );
614 }
615 }
616
617 #[test]
618 fn all_log_formats_accepted() {
619 for fmt in &["json", "pretty", "text"] {
620 let cfg = ObservabilityConfig {
621 log_format: (*fmt).into(),
622 ..ObservabilityConfig::default()
623 };
624 assert!(
625 validate_observability_config(&cfg).is_ok(),
626 "format {fmt} should be valid"
627 );
628 }
629 }
630
631 #[test]
634 fn server_config_deserialize_defaults() {
635 let cfg: ServerConfig = toml::from_str("").unwrap();
636 assert_eq!(cfg.listen_port, 8443);
637 assert_eq!(cfg.listen_addr, "127.0.0.1");
638 assert_eq!(cfg.tls_handshake_timeout, "10s");
639 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
640 }
641
642 #[test]
643 fn observability_config_deserialize_defaults() {
644 let cfg: ObservabilityConfig = toml::from_str("").unwrap();
645 assert_eq!(cfg.log_level, "info,rmcp=warn");
646 assert_eq!(cfg.log_format, "pretty");
647 assert!(!cfg.log_request_headers);
648 assert!(!cfg.metrics_enabled);
649 }
650}