1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9 #[serde(default = "default_listen_addr")]
11 pub listen_addr: String,
12 #[serde(default = "default_listen_port")]
14 pub listen_port: u16,
15 pub tls_cert_path: Option<PathBuf>,
17 pub tls_key_path: Option<PathBuf>,
19 #[serde(default = "default_tls_handshake_timeout")]
24 pub tls_handshake_timeout: String,
25 #[serde(default = "default_max_concurrent_tls_handshakes")]
30 pub max_concurrent_tls_handshakes: usize,
31 #[serde(default = "default_shutdown_timeout")]
33 pub shutdown_timeout: String,
34 #[serde(default = "default_request_timeout")]
36 pub request_timeout: String,
37 #[serde(default)]
41 pub allowed_origins: Vec<String>,
42 #[serde(default)]
45 pub stdio_enabled: bool,
46 pub tool_rate_limit: Option<u32>,
50 pub tool_rate_limit_burst: Option<u32>,
54 pub extra_route_rate_limit: Option<u32>,
60 pub extra_route_rate_limit_burst: Option<u32>,
64 #[serde(default)]
70 pub trusted_proxies: Vec<String>,
71 pub forwarded_header: Option<crate::transport::ForwardedHeaderMode>,
75 #[serde(default = "default_session_idle_timeout")]
78 pub session_idle_timeout: String,
79 #[serde(default = "default_sse_keep_alive")]
83 pub sse_keep_alive: String,
84 pub public_url: Option<String>,
89 #[serde(default)]
91 pub compression_enabled: bool,
92 #[serde(default = "default_compression_min_size")]
95 pub compression_min_size: u16,
96 pub max_concurrent_requests: Option<usize>,
99 #[serde(default)]
101 pub admin_enabled: bool,
102 #[serde(default = "default_admin_role")]
104 pub admin_role: String,
105 pub auth: Option<crate::auth::AuthConfig>,
107}
108
109impl Default for ServerConfig {
110 fn default() -> Self {
111 Self {
112 listen_addr: default_listen_addr(),
113 listen_port: default_listen_port(),
114 tls_cert_path: None,
115 tls_key_path: None,
116 tls_handshake_timeout: default_tls_handshake_timeout(),
117 max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
118 shutdown_timeout: default_shutdown_timeout(),
119 request_timeout: default_request_timeout(),
120 allowed_origins: Vec::new(),
121 stdio_enabled: false,
122 tool_rate_limit: None,
123 tool_rate_limit_burst: None,
124 extra_route_rate_limit: None,
125 extra_route_rate_limit_burst: None,
126 trusted_proxies: Vec::new(),
127 forwarded_header: None,
128 session_idle_timeout: default_session_idle_timeout(),
129 sse_keep_alive: default_sse_keep_alive(),
130 public_url: None,
131 compression_enabled: false,
132 compression_min_size: default_compression_min_size(),
133 max_concurrent_requests: None,
134 admin_enabled: false,
135 admin_role: default_admin_role(),
136 auth: None,
137 }
138 }
139}
140
141#[derive(Debug, Deserialize)]
143#[non_exhaustive]
144pub struct ObservabilityConfig {
145 #[serde(default = "default_log_level")]
147 pub log_level: String,
148 #[serde(default = "default_log_format")]
150 pub log_format: String,
151 pub audit_log_path: Option<PathBuf>,
153 #[serde(default)]
156 pub log_request_headers: bool,
157 #[serde(default)]
159 pub metrics_enabled: bool,
160 #[serde(default = "default_metrics_bind")]
162 pub metrics_bind: String,
163}
164
165impl Default for ObservabilityConfig {
166 fn default() -> Self {
167 Self {
168 log_level: default_log_level(),
169 log_format: default_log_format(),
170 audit_log_path: None,
171 log_request_headers: false,
172 metrics_enabled: false,
173 metrics_bind: default_metrics_bind(),
174 }
175 }
176}
177
178pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
184 use crate::error::McpxError;
185
186 if server.listen_port == 0 {
187 return Err(McpxError::Config("listen_port must be nonzero".into()));
188 }
189
190 match (&server.tls_cert_path, &server.tls_key_path) {
191 (Some(_), None) | (None, Some(_)) => {
192 return Err(McpxError::Config(
193 "tls_cert_path and tls_key_path must both be set or both omitted".into(),
194 ));
195 }
196 _ => {}
197 }
198
199 if let Some(0) = server.max_concurrent_requests {
200 return Err(McpxError::Config(
201 "max_concurrent_requests must be nonzero when set".into(),
202 ));
203 }
204
205 if let Some(0) = server.extra_route_rate_limit {
206 return Err(McpxError::Config(
207 "server.extra_route_rate_limit must be greater than zero".into(),
208 ));
209 }
210
211 validate_rate_limit_knobs(server)?;
212 validate_trusted_forwarder_config(server)?;
213
214 if server.admin_enabled {
215 let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
216 if !auth_enabled {
217 return Err(McpxError::Config(
218 "admin_enabled=true requires auth to be configured and enabled".into(),
219 ));
220 }
221 if server.admin_role.trim().is_empty() {
222 return Err(McpxError::Config("admin_role must not be empty".into()));
223 }
224 }
225
226 for (field, value) in [
227 ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
228 ("server.request_timeout", server.request_timeout.as_str()),
229 (
230 "server.session_idle_timeout",
231 server.session_idle_timeout.as_str(),
232 ),
233 ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
234 (
235 "server.tls_handshake_timeout",
236 server.tls_handshake_timeout.as_str(),
237 ),
238 ] {
239 if humantime::parse_duration(value).is_err() {
240 return Err(McpxError::Config(format!(
241 "invalid duration for {field}: {value:?}"
242 )));
243 }
244 }
245
246 if humantime::parse_duration(&server.tls_handshake_timeout)
250 .is_ok_and(|d| d == std::time::Duration::ZERO)
251 {
252 return Err(McpxError::Config(
253 "server.tls_handshake_timeout must be greater than zero".into(),
254 ));
255 }
256
257 if server.max_concurrent_tls_handshakes == 0 {
261 return Err(McpxError::Config(
262 "server.max_concurrent_tls_handshakes must be greater than zero".into(),
263 ));
264 }
265
266 Ok(())
267}
268
269fn validate_rate_limit_knobs(server: &ServerConfig) -> crate::error::Result<()> {
273 use crate::error::McpxError;
274
275 if let Some(0) = server.tool_rate_limit_burst {
276 return Err(McpxError::Config(
277 "server.tool_rate_limit_burst must be greater than zero".into(),
278 ));
279 }
280 if let Some(0) = server.extra_route_rate_limit_burst {
281 return Err(McpxError::Config(
282 "server.extra_route_rate_limit_burst must be greater than zero".into(),
283 ));
284 }
285 if server.tool_rate_limit_burst.is_some() && server.tool_rate_limit.is_none() {
286 return Err(McpxError::Config(
287 "server.tool_rate_limit_burst requires server.tool_rate_limit".into(),
288 ));
289 }
290 if server.extra_route_rate_limit_burst.is_some() && server.extra_route_rate_limit.is_none() {
291 return Err(McpxError::Config(
292 "server.extra_route_rate_limit_burst requires server.extra_route_rate_limit".into(),
293 ));
294 }
295 if let Some(rl) = server.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
296 if rl.burst == Some(0) {
297 return Err(McpxError::Config(
298 "auth.rate_limit.burst must be greater than zero".into(),
299 ));
300 }
301 if rl.pre_auth_burst == Some(0) {
302 return Err(McpxError::Config(
303 "auth.rate_limit.pre_auth_burst must be greater than zero".into(),
304 ));
305 }
306 }
307 Ok(())
308}
309
310fn validate_trusted_forwarder_config(server: &ServerConfig) -> crate::error::Result<()> {
313 use crate::error::McpxError;
314
315 for entry in &server.trusted_proxies {
316 let is_net = entry.parse::<ipnet::IpNet>().is_ok();
317 let is_ip = entry.parse::<std::net::IpAddr>().is_ok();
318 if !(is_net || is_ip) {
319 return Err(McpxError::Config(format!(
320 "server.trusted_proxies entry {entry:?} is neither a CIDR nor an IP address"
321 )));
322 }
323 }
324 if server.forwarded_header.is_some() && server.trusted_proxies.is_empty() {
325 return Err(McpxError::Config(
326 "server.forwarded_header requires server.trusted_proxies to be nonempty".into(),
327 ));
328 }
329 Ok(())
330}
331
332pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
338 use tracing_subscriber::EnvFilter;
339
340 use crate::error::McpxError;
341
342 if EnvFilter::try_new(&obs.log_level).is_err() {
343 return Err(McpxError::Config(format!(
344 "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
345 obs.log_level
346 )));
347 }
348 let valid_formats = ["json", "pretty", "text"];
349 if !valid_formats.contains(&obs.log_format.as_str()) {
350 return Err(McpxError::Config(format!(
351 "invalid log_format: {:?} (expected one of: {valid_formats:?})",
352 obs.log_format
353 )));
354 }
355
356 Ok(())
357}
358
359fn default_listen_addr() -> String {
362 "127.0.0.1".into()
363}
364fn default_listen_port() -> u16 {
365 8443
366}
367fn default_shutdown_timeout() -> String {
368 "30s".into()
369}
370fn default_request_timeout() -> String {
371 "120s".into()
372}
373fn default_log_level() -> String {
374 "info,rmcp=warn".into()
375}
376fn default_log_format() -> String {
377 "pretty".into()
378}
379fn default_metrics_bind() -> String {
380 "127.0.0.1:9090".into()
381}
382fn default_session_idle_timeout() -> String {
383 "20m".into()
384}
385fn default_tls_handshake_timeout() -> String {
386 "10s".into()
387}
388const fn default_max_concurrent_tls_handshakes() -> usize {
389 256
390}
391fn default_admin_role() -> String {
392 "admin".into()
393}
394fn default_compression_min_size() -> u16 {
395 1024
396}
397fn default_sse_keep_alive() -> String {
398 "15s".into()
399}
400
401#[cfg(test)]
402mod tests {
403 #![allow(
404 clippy::unwrap_used,
405 clippy::expect_used,
406 clippy::panic,
407 clippy::indexing_slicing,
408 clippy::unwrap_in_result,
409 clippy::print_stdout,
410 clippy::print_stderr,
411 reason = "test-only relaxations; production code uses ? and tracing"
412 )]
413 use super::*;
414
415 #[test]
418 fn server_config_defaults() {
419 let cfg = ServerConfig::default();
420 assert_eq!(cfg.listen_addr, "127.0.0.1");
421 assert_eq!(cfg.listen_port, 8443);
422 assert!(cfg.tls_cert_path.is_none());
423 assert!(cfg.tls_key_path.is_none());
424 assert_eq!(cfg.shutdown_timeout, "30s");
425 assert_eq!(cfg.request_timeout, "120s");
426 assert!(cfg.allowed_origins.is_empty());
427 assert!(!cfg.stdio_enabled);
428 assert!(cfg.tool_rate_limit.is_none());
429 assert_eq!(cfg.session_idle_timeout, "20m");
430 assert_eq!(cfg.sse_keep_alive, "15s");
431 assert!(cfg.public_url.is_none());
432 }
433
434 #[test]
435 fn observability_config_defaults() {
436 let cfg = ObservabilityConfig::default();
437 assert_eq!(cfg.log_level, "info,rmcp=warn");
438 assert_eq!(cfg.log_format, "pretty");
439 assert!(cfg.audit_log_path.is_none());
440 assert!(!cfg.log_request_headers);
441 assert!(!cfg.metrics_enabled);
442 assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
443 }
444
445 #[test]
448 fn valid_server_config_passes() {
449 let cfg = ServerConfig::default();
450 assert!(validate_server_config(&cfg).is_ok());
451 }
452
453 #[test]
454 fn zero_port_rejected() {
455 let cfg = ServerConfig {
456 listen_port: 0,
457 ..ServerConfig::default()
458 };
459 let err = validate_server_config(&cfg).unwrap_err();
460 assert!(err.to_string().contains("listen_port"));
461 }
462
463 #[test]
464 fn zero_extra_route_rate_limit_rejected() {
465 let cfg = ServerConfig {
466 extra_route_rate_limit: Some(0),
467 ..ServerConfig::default()
468 };
469 let err = validate_server_config(&cfg).unwrap_err();
470 assert!(err.to_string().contains("extra_route_rate_limit"));
471 }
472
473 #[test]
474 fn zero_burst_knobs_rejected() {
475 let cfg = ServerConfig {
476 tool_rate_limit: Some(10),
477 tool_rate_limit_burst: Some(0),
478 ..ServerConfig::default()
479 };
480 let err = validate_server_config(&cfg).unwrap_err();
481 assert!(err.to_string().contains("tool_rate_limit_burst"));
482
483 let cfg = ServerConfig {
484 extra_route_rate_limit: Some(10),
485 extra_route_rate_limit_burst: Some(0),
486 ..ServerConfig::default()
487 };
488 let err = validate_server_config(&cfg).unwrap_err();
489 assert!(err.to_string().contains("extra_route_rate_limit_burst"));
490 }
491
492 #[test]
493 fn orphan_burst_knobs_rejected() {
494 let cfg = ServerConfig {
495 tool_rate_limit_burst: Some(5),
496 ..ServerConfig::default()
497 };
498 let err = validate_server_config(&cfg).unwrap_err();
499 assert!(err.to_string().contains("requires server.tool_rate_limit"));
500
501 let cfg = ServerConfig {
502 extra_route_rate_limit_burst: Some(5),
503 ..ServerConfig::default()
504 };
505 let err = validate_server_config(&cfg).unwrap_err();
506 assert!(
507 err.to_string()
508 .contains("requires server.extra_route_rate_limit")
509 );
510 }
511
512 #[test]
513 fn bad_trusted_proxy_entry_rejected() {
514 let cfg = ServerConfig {
515 trusted_proxies: vec!["not-a-cidr".into()],
516 ..ServerConfig::default()
517 };
518 let err = validate_server_config(&cfg).unwrap_err();
519 assert!(err.to_string().contains("trusted_proxies"));
520 }
521
522 #[test]
523 fn cidr_and_bare_ip_proxy_entries_accepted() {
524 let cfg = ServerConfig {
525 trusted_proxies: vec!["10.0.0.0/8".into(), "192.0.2.1".into()],
526 ..ServerConfig::default()
527 };
528 assert!(validate_server_config(&cfg).is_ok());
529 }
530
531 #[test]
532 fn forwarded_header_without_proxies_rejected() {
533 let cfg = ServerConfig {
534 forwarded_header: Some(crate::transport::ForwardedHeaderMode::Forwarded),
535 ..ServerConfig::default()
536 };
537 let err = validate_server_config(&cfg).unwrap_err();
538 assert!(err.to_string().contains("requires server.trusted_proxies"));
539 }
540
541 #[test]
542 fn zero_auth_bursts_rejected() {
543 let auth = crate::auth::AuthConfig::with_keys(vec![])
544 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
545 let cfg = ServerConfig {
546 auth: Some(auth),
547 ..ServerConfig::default()
548 };
549 let err = validate_server_config(&cfg).unwrap_err();
550 assert!(err.to_string().contains("rate_limit.burst"));
551
552 let auth = crate::auth::AuthConfig::with_keys(vec![])
553 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
554 let cfg = ServerConfig {
555 auth: Some(auth),
556 ..ServerConfig::default()
557 };
558 let err = validate_server_config(&cfg).unwrap_err();
559 assert!(err.to_string().contains("pre_auth_burst"));
560 }
561
562 #[test]
563 fn tls_cert_without_key_rejected() {
564 let cfg = ServerConfig {
565 tls_cert_path: Some("/tmp/cert.pem".into()),
566 ..ServerConfig::default()
567 };
568 let err = validate_server_config(&cfg).unwrap_err();
569 assert!(err.to_string().contains("tls_cert_path"));
570 }
571
572 #[test]
573 fn tls_key_without_cert_rejected() {
574 let cfg = ServerConfig {
575 tls_key_path: Some("/tmp/key.pem".into()),
576 ..ServerConfig::default()
577 };
578 let err = validate_server_config(&cfg).unwrap_err();
579 assert!(err.to_string().contains("tls_cert_path"));
580 }
581
582 #[test]
583 fn tls_both_set_passes() {
584 let cfg = ServerConfig {
585 tls_cert_path: Some("/tmp/cert.pem".into()),
586 tls_key_path: Some("/tmp/key.pem".into()),
587 ..ServerConfig::default()
588 };
589 assert!(validate_server_config(&cfg).is_ok());
590 }
591
592 #[test]
593 fn invalid_tls_handshake_timeout_rejected() {
594 let cfg = ServerConfig {
595 tls_handshake_timeout: "not-a-duration".into(),
596 ..ServerConfig::default()
597 };
598 let err = validate_server_config(&cfg).unwrap_err();
599 assert!(err.to_string().contains("tls_handshake_timeout"));
600 }
601
602 #[test]
603 fn zero_tls_handshake_timeout_rejected() {
604 let cfg = ServerConfig {
605 tls_handshake_timeout: "0s".into(),
606 ..ServerConfig::default()
607 };
608 let err = validate_server_config(&cfg).unwrap_err();
609 assert!(err.to_string().contains("tls_handshake_timeout"));
610 }
611
612 #[test]
613 fn zero_max_concurrent_tls_handshakes_rejected() {
614 let cfg = ServerConfig {
615 max_concurrent_tls_handshakes: 0,
616 ..ServerConfig::default()
617 };
618 let err = validate_server_config(&cfg).unwrap_err();
619 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
620 }
621
622 #[test]
623 fn invalid_shutdown_timeout_rejected() {
624 let cfg = ServerConfig {
625 shutdown_timeout: "not-a-duration".into(),
626 ..ServerConfig::default()
627 };
628 let err = validate_server_config(&cfg).unwrap_err();
629 assert!(err.to_string().contains("shutdown_timeout"));
630 }
631
632 #[test]
633 fn invalid_request_timeout_rejected() {
634 let cfg = ServerConfig {
635 request_timeout: "xyz".into(),
636 ..ServerConfig::default()
637 };
638 let err = validate_server_config(&cfg).unwrap_err();
639 assert!(err.to_string().contains("request_timeout"));
640 }
641
642 #[test]
645 fn valid_observability_config_passes() {
646 let cfg = ObservabilityConfig::default();
647 assert!(validate_observability_config(&cfg).is_ok());
648 }
649
650 #[test]
651 fn invalid_log_level_rejected() {
652 let cfg = ObservabilityConfig {
653 log_level: "[invalid".into(),
654 ..ObservabilityConfig::default()
655 };
656 let err = validate_observability_config(&cfg).unwrap_err();
657 assert!(err.to_string().contains("log_level"));
658 }
659
660 #[test]
661 fn invalid_log_format_rejected() {
662 let cfg = ObservabilityConfig {
663 log_format: "yaml".into(),
664 ..ObservabilityConfig::default()
665 };
666 let err = validate_observability_config(&cfg).unwrap_err();
667 assert!(err.to_string().contains("log_format"));
668 }
669
670 #[test]
671 fn all_valid_log_levels_accepted() {
672 for level in &[
673 "trace",
674 "debug",
675 "info",
676 "warn",
677 "error",
678 "info,rmcp=warn",
679 "debug,hyper=error",
680 ] {
681 let cfg = ObservabilityConfig {
682 log_level: (*level).into(),
683 ..ObservabilityConfig::default()
684 };
685 assert!(
686 validate_observability_config(&cfg).is_ok(),
687 "level {level} should be valid"
688 );
689 }
690 }
691
692 #[test]
693 fn all_log_formats_accepted() {
694 for fmt in &["json", "pretty", "text"] {
695 let cfg = ObservabilityConfig {
696 log_format: (*fmt).into(),
697 ..ObservabilityConfig::default()
698 };
699 assert!(
700 validate_observability_config(&cfg).is_ok(),
701 "format {fmt} should be valid"
702 );
703 }
704 }
705
706 #[test]
709 fn server_config_deserialize_defaults() {
710 let cfg: ServerConfig = toml::from_str("").unwrap();
711 assert_eq!(cfg.listen_port, 8443);
712 assert_eq!(cfg.listen_addr, "127.0.0.1");
713 assert_eq!(cfg.tls_handshake_timeout, "10s");
714 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
715 }
716
717 #[test]
718 fn observability_config_deserialize_defaults() {
719 let cfg: ObservabilityConfig = toml::from_str("").unwrap();
720 assert_eq!(cfg.log_level, "info,rmcp=warn");
721 assert_eq!(cfg.log_format, "pretty");
722 assert!(!cfg.log_request_headers);
723 assert!(!cfg.metrics_enabled);
724 }
725}