1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9 #[serde(default = "default_listen_addr")]
11 pub listen_addr: String,
12 #[serde(default = "default_listen_port")]
14 pub listen_port: u16,
15 pub tls_cert_path: Option<PathBuf>,
17 pub tls_key_path: Option<PathBuf>,
19 #[serde(default = "default_tls_handshake_timeout")]
24 pub tls_handshake_timeout: String,
25 #[serde(default = "default_max_concurrent_tls_handshakes")]
30 pub max_concurrent_tls_handshakes: usize,
31 #[serde(default = "default_shutdown_timeout")]
33 pub shutdown_timeout: String,
34 #[serde(default = "default_request_timeout")]
36 pub request_timeout: String,
37 #[serde(default)]
41 pub allowed_origins: Vec<String>,
42 #[serde(default)]
45 pub stdio_enabled: bool,
46 pub tool_rate_limit: Option<u32>,
50 pub tool_rate_limit_burst: Option<u32>,
54 pub extra_route_rate_limit: Option<u32>,
60 pub extra_route_rate_limit_burst: Option<u32>,
64 #[serde(default)]
70 pub extra_route_rate_limit_exempt_paths: Vec<String>,
71 #[serde(default)]
77 pub trusted_proxies: Vec<String>,
78 pub forwarded_header: Option<crate::transport::ForwardedHeaderMode>,
82 #[serde(default = "default_session_idle_timeout")]
85 pub session_idle_timeout: String,
86 #[serde(default = "default_sse_keep_alive")]
90 pub sse_keep_alive: String,
91 pub public_url: Option<String>,
96 #[serde(default)]
98 pub compression_enabled: bool,
99 #[serde(default = "default_compression_min_size")]
102 pub compression_min_size: u16,
103 pub max_concurrent_requests: Option<usize>,
106 #[serde(default)]
108 pub admin_enabled: bool,
109 #[serde(default = "default_admin_role")]
111 pub admin_role: String,
112 pub auth: Option<crate::auth::AuthConfig>,
114}
115
116impl Default for ServerConfig {
117 fn default() -> Self {
118 Self {
119 listen_addr: default_listen_addr(),
120 listen_port: default_listen_port(),
121 tls_cert_path: None,
122 tls_key_path: None,
123 tls_handshake_timeout: default_tls_handshake_timeout(),
124 max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
125 shutdown_timeout: default_shutdown_timeout(),
126 request_timeout: default_request_timeout(),
127 allowed_origins: Vec::new(),
128 stdio_enabled: false,
129 tool_rate_limit: None,
130 tool_rate_limit_burst: None,
131 extra_route_rate_limit: None,
132 extra_route_rate_limit_burst: None,
133 extra_route_rate_limit_exempt_paths: Vec::new(),
134 trusted_proxies: Vec::new(),
135 forwarded_header: None,
136 session_idle_timeout: default_session_idle_timeout(),
137 sse_keep_alive: default_sse_keep_alive(),
138 public_url: None,
139 compression_enabled: false,
140 compression_min_size: default_compression_min_size(),
141 max_concurrent_requests: None,
142 admin_enabled: false,
143 admin_role: default_admin_role(),
144 auth: None,
145 }
146 }
147}
148
149#[derive(Debug, Deserialize)]
151#[non_exhaustive]
152pub struct ObservabilityConfig {
153 #[serde(default = "default_log_level")]
155 pub log_level: String,
156 #[serde(default = "default_log_format")]
158 pub log_format: String,
159 pub audit_log_path: Option<PathBuf>,
161 #[serde(default)]
164 pub log_request_headers: bool,
165 #[serde(default)]
167 pub metrics_enabled: bool,
168 #[serde(default = "default_metrics_bind")]
170 pub metrics_bind: String,
171}
172
173impl Default for ObservabilityConfig {
174 fn default() -> Self {
175 Self {
176 log_level: default_log_level(),
177 log_format: default_log_format(),
178 audit_log_path: None,
179 log_request_headers: false,
180 metrics_enabled: false,
181 metrics_bind: default_metrics_bind(),
182 }
183 }
184}
185
186pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
192 use crate::error::McpxError;
193
194 if server.listen_port == 0 {
195 return Err(McpxError::Config("listen_port must be nonzero".into()));
196 }
197
198 match (&server.tls_cert_path, &server.tls_key_path) {
199 (Some(_), None) | (None, Some(_)) => {
200 return Err(McpxError::Config(
201 "tls_cert_path and tls_key_path must both be set or both omitted".into(),
202 ));
203 }
204 _ => {}
205 }
206
207 if server.max_concurrent_requests == Some(0) {
208 return Err(McpxError::Config(
209 "max_concurrent_requests must be nonzero when set".into(),
210 ));
211 }
212
213 if server.extra_route_rate_limit == Some(0) {
214 return Err(McpxError::Config(
215 "server.extra_route_rate_limit must be greater than zero".into(),
216 ));
217 }
218
219 validate_rate_limit_knobs(server)?;
220 validate_trusted_forwarder_config(server)?;
221
222 if server.admin_enabled {
223 let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
224 if !auth_enabled {
225 return Err(McpxError::Config(
226 "admin_enabled=true requires auth to be configured and enabled".into(),
227 ));
228 }
229 if server.admin_role.trim().is_empty() {
230 return Err(McpxError::Config("admin_role must not be empty".into()));
231 }
232 }
233
234 for (field, value) in [
235 ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
236 ("server.request_timeout", server.request_timeout.as_str()),
237 (
238 "server.session_idle_timeout",
239 server.session_idle_timeout.as_str(),
240 ),
241 ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
242 (
243 "server.tls_handshake_timeout",
244 server.tls_handshake_timeout.as_str(),
245 ),
246 ] {
247 if humantime::parse_duration(value).is_err() {
248 return Err(McpxError::Config(format!(
249 "invalid duration for {field}: {value:?}"
250 )));
251 }
252 }
253
254 if humantime::parse_duration(&server.tls_handshake_timeout)
258 .is_ok_and(|d| d == std::time::Duration::ZERO)
259 {
260 return Err(McpxError::Config(
261 "server.tls_handshake_timeout must be greater than zero".into(),
262 ));
263 }
264
265 if server.max_concurrent_tls_handshakes == 0 {
269 return Err(McpxError::Config(
270 "server.max_concurrent_tls_handshakes must be greater than zero".into(),
271 ));
272 }
273
274 Ok(())
275}
276
277fn validate_rate_limit_knobs(server: &ServerConfig) -> crate::error::Result<()> {
281 use crate::error::McpxError;
282
283 if server.tool_rate_limit_burst == Some(0) {
284 return Err(McpxError::Config(
285 "server.tool_rate_limit_burst must be greater than zero".into(),
286 ));
287 }
288 if server.extra_route_rate_limit_burst == Some(0) {
289 return Err(McpxError::Config(
290 "server.extra_route_rate_limit_burst must be greater than zero".into(),
291 ));
292 }
293 if server.tool_rate_limit_burst.is_some() && server.tool_rate_limit.is_none() {
294 return Err(McpxError::Config(
295 "server.tool_rate_limit_burst requires server.tool_rate_limit".into(),
296 ));
297 }
298 if server.extra_route_rate_limit_burst.is_some() && server.extra_route_rate_limit.is_none() {
299 return Err(McpxError::Config(
300 "server.extra_route_rate_limit_burst requires server.extra_route_rate_limit".into(),
301 ));
302 }
303 if !server.extra_route_rate_limit_exempt_paths.is_empty()
304 && server.extra_route_rate_limit.is_none()
305 {
306 return Err(McpxError::Config(
307 "server.extra_route_rate_limit_exempt_paths requires server.extra_route_rate_limit"
308 .into(),
309 ));
310 }
311 for path in &server.extra_route_rate_limit_exempt_paths {
312 if path.is_empty() || !path.starts_with('/') {
313 return Err(McpxError::Config(format!(
314 "server.extra_route_rate_limit_exempt_paths entries must be non-empty and start with '/': {path:?}"
315 )));
316 }
317 }
318 if let Some(rl) = server.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
319 if rl.burst == Some(0) {
320 return Err(McpxError::Config(
321 "auth.rate_limit.burst must be greater than zero".into(),
322 ));
323 }
324 if rl.pre_auth_burst == Some(0) {
325 return Err(McpxError::Config(
326 "auth.rate_limit.pre_auth_burst must be greater than zero".into(),
327 ));
328 }
329 }
330 Ok(())
331}
332
333fn validate_trusted_forwarder_config(server: &ServerConfig) -> crate::error::Result<()> {
336 use crate::error::McpxError;
337
338 for entry in &server.trusted_proxies {
339 let is_net = entry.parse::<ipnet::IpNet>().is_ok();
340 let is_ip = entry.parse::<std::net::IpAddr>().is_ok();
341 if !(is_net || is_ip) {
342 return Err(McpxError::Config(format!(
343 "server.trusted_proxies entry {entry:?} is neither a CIDR nor an IP address"
344 )));
345 }
346 }
347 if server.forwarded_header.is_some() && server.trusted_proxies.is_empty() {
348 return Err(McpxError::Config(
349 "server.forwarded_header requires server.trusted_proxies to be nonempty".into(),
350 ));
351 }
352 Ok(())
353}
354
355pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
361 use tracing_subscriber::EnvFilter;
362
363 use crate::error::McpxError;
364
365 if EnvFilter::try_new(&obs.log_level).is_err() {
366 return Err(McpxError::Config(format!(
367 "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
368 obs.log_level
369 )));
370 }
371 let valid_formats = ["json", "pretty", "text"];
372 if !valid_formats.contains(&obs.log_format.as_str()) {
373 return Err(McpxError::Config(format!(
374 "invalid log_format: {:?} (expected one of: {valid_formats:?})",
375 obs.log_format
376 )));
377 }
378
379 Ok(())
380}
381
382fn default_listen_addr() -> String {
385 "127.0.0.1".into()
386}
387fn default_listen_port() -> u16 {
388 8443
389}
390fn default_shutdown_timeout() -> String {
391 "30s".into()
392}
393fn default_request_timeout() -> String {
394 "120s".into()
395}
396fn default_log_level() -> String {
397 "info,rmcp=warn".into()
398}
399fn default_log_format() -> String {
400 "pretty".into()
401}
402fn default_metrics_bind() -> String {
403 "127.0.0.1:9090".into()
404}
405fn default_session_idle_timeout() -> String {
406 "20m".into()
407}
408fn default_tls_handshake_timeout() -> String {
409 "10s".into()
410}
411const fn default_max_concurrent_tls_handshakes() -> usize {
412 256
413}
414fn default_admin_role() -> String {
415 "admin".into()
416}
417fn default_compression_min_size() -> u16 {
418 1024
419}
420fn default_sse_keep_alive() -> String {
421 "15s".into()
422}
423
424#[cfg(test)]
425mod tests {
426 #![allow(
427 clippy::unwrap_used,
428 clippy::expect_used,
429 clippy::panic,
430 clippy::indexing_slicing,
431 clippy::unwrap_in_result,
432 clippy::print_stdout,
433 clippy::print_stderr,
434 reason = "test-only relaxations; production code uses ? and tracing"
435 )]
436 use super::*;
437
438 #[test]
441 fn server_config_defaults() {
442 let cfg = ServerConfig::default();
443 assert_eq!(cfg.listen_addr, "127.0.0.1");
444 assert_eq!(cfg.listen_port, 8443);
445 assert!(cfg.tls_cert_path.is_none());
446 assert!(cfg.tls_key_path.is_none());
447 assert_eq!(cfg.shutdown_timeout, "30s");
448 assert_eq!(cfg.request_timeout, "120s");
449 assert!(cfg.allowed_origins.is_empty());
450 assert!(!cfg.stdio_enabled);
451 assert!(cfg.tool_rate_limit.is_none());
452 assert_eq!(cfg.session_idle_timeout, "20m");
453 assert_eq!(cfg.sse_keep_alive, "15s");
454 assert!(cfg.public_url.is_none());
455 }
456
457 #[test]
458 fn observability_config_defaults() {
459 let cfg = ObservabilityConfig::default();
460 assert_eq!(cfg.log_level, "info,rmcp=warn");
461 assert_eq!(cfg.log_format, "pretty");
462 assert!(cfg.audit_log_path.is_none());
463 assert!(!cfg.log_request_headers);
464 assert!(!cfg.metrics_enabled);
465 assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
466 }
467
468 #[test]
471 fn valid_server_config_passes() {
472 let cfg = ServerConfig::default();
473 assert!(validate_server_config(&cfg).is_ok());
474 }
475
476 #[test]
477 fn zero_port_rejected() {
478 let cfg = ServerConfig {
479 listen_port: 0,
480 ..ServerConfig::default()
481 };
482 let err = validate_server_config(&cfg).unwrap_err();
483 assert!(err.to_string().contains("listen_port"));
484 }
485
486 #[test]
487 fn zero_extra_route_rate_limit_rejected() {
488 let cfg = ServerConfig {
489 extra_route_rate_limit: Some(0),
490 ..ServerConfig::default()
491 };
492 let err = validate_server_config(&cfg).unwrap_err();
493 assert!(err.to_string().contains("extra_route_rate_limit"));
494 }
495
496 #[test]
497 fn zero_burst_knobs_rejected() {
498 let cfg = ServerConfig {
499 tool_rate_limit: Some(10),
500 tool_rate_limit_burst: Some(0),
501 ..ServerConfig::default()
502 };
503 let err = validate_server_config(&cfg).unwrap_err();
504 assert!(err.to_string().contains("tool_rate_limit_burst"));
505
506 let cfg = ServerConfig {
507 extra_route_rate_limit: Some(10),
508 extra_route_rate_limit_burst: Some(0),
509 ..ServerConfig::default()
510 };
511 let err = validate_server_config(&cfg).unwrap_err();
512 assert!(err.to_string().contains("extra_route_rate_limit_burst"));
513 }
514
515 #[test]
516 fn orphan_burst_knobs_rejected() {
517 let cfg = ServerConfig {
518 tool_rate_limit_burst: Some(5),
519 ..ServerConfig::default()
520 };
521 let err = validate_server_config(&cfg).unwrap_err();
522 assert!(err.to_string().contains("requires server.tool_rate_limit"));
523
524 let cfg = ServerConfig {
525 extra_route_rate_limit_burst: Some(5),
526 ..ServerConfig::default()
527 };
528 let err = validate_server_config(&cfg).unwrap_err();
529 assert!(
530 err.to_string()
531 .contains("requires server.extra_route_rate_limit")
532 );
533 }
534
535 #[test]
536 fn exempt_paths_toml_roundtrip_and_validation() {
537 let cfg: ServerConfig = toml::from_str(
538 r#"
539 extra_route_rate_limit = 60
540 extra_route_rate_limit_exempt_paths = ["/.well-known/oauth-authorization-server"]
541 "#,
542 )
543 .unwrap();
544 assert_eq!(
545 cfg.extra_route_rate_limit_exempt_paths,
546 vec!["/.well-known/oauth-authorization-server".to_owned()]
547 );
548 assert!(validate_server_config(&cfg).is_ok());
549 }
550
551 #[test]
552 fn orphan_exempt_paths_rejected() {
553 let cfg = ServerConfig {
554 extra_route_rate_limit_exempt_paths: vec!["/ok".into()],
555 ..ServerConfig::default()
556 };
557 let err = validate_server_config(&cfg).unwrap_err();
558 assert!(
559 err.to_string()
560 .contains("requires server.extra_route_rate_limit")
561 );
562 }
563
564 #[test]
565 fn malformed_exempt_paths_rejected() {
566 for bad in ["", "no-slash"] {
567 let cfg = ServerConfig {
568 extra_route_rate_limit: Some(10),
569 extra_route_rate_limit_exempt_paths: vec![bad.into()],
570 ..ServerConfig::default()
571 };
572 let err = validate_server_config(&cfg).unwrap_err();
573 assert!(
574 err.to_string()
575 .contains("must be non-empty and start with '/'"),
576 "entry {bad:?}: {err}"
577 );
578 }
579 }
580
581 #[test]
582 fn bad_trusted_proxy_entry_rejected() {
583 let cfg = ServerConfig {
584 trusted_proxies: vec!["not-a-cidr".into()],
585 ..ServerConfig::default()
586 };
587 let err = validate_server_config(&cfg).unwrap_err();
588 assert!(err.to_string().contains("trusted_proxies"));
589 }
590
591 #[test]
592 fn cidr_and_bare_ip_proxy_entries_accepted() {
593 let cfg = ServerConfig {
594 trusted_proxies: vec!["10.0.0.0/8".into(), "192.0.2.1".into()],
595 ..ServerConfig::default()
596 };
597 assert!(validate_server_config(&cfg).is_ok());
598 }
599
600 #[test]
601 fn forwarded_header_without_proxies_rejected() {
602 let cfg = ServerConfig {
603 forwarded_header: Some(crate::transport::ForwardedHeaderMode::Forwarded),
604 ..ServerConfig::default()
605 };
606 let err = validate_server_config(&cfg).unwrap_err();
607 assert!(err.to_string().contains("requires server.trusted_proxies"));
608 }
609
610 #[test]
611 fn zero_auth_bursts_rejected() {
612 let auth = crate::auth::AuthConfig::with_keys(vec![])
613 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
614 let cfg = ServerConfig {
615 auth: Some(auth),
616 ..ServerConfig::default()
617 };
618 let err = validate_server_config(&cfg).unwrap_err();
619 assert!(err.to_string().contains("rate_limit.burst"));
620
621 let auth = crate::auth::AuthConfig::with_keys(vec![])
622 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
623 let cfg = ServerConfig {
624 auth: Some(auth),
625 ..ServerConfig::default()
626 };
627 let err = validate_server_config(&cfg).unwrap_err();
628 assert!(err.to_string().contains("pre_auth_burst"));
629 }
630
631 #[test]
632 fn tls_cert_without_key_rejected() {
633 let cfg = ServerConfig {
634 tls_cert_path: Some("/tmp/cert.pem".into()),
635 ..ServerConfig::default()
636 };
637 let err = validate_server_config(&cfg).unwrap_err();
638 assert!(err.to_string().contains("tls_cert_path"));
639 }
640
641 #[test]
642 fn tls_key_without_cert_rejected() {
643 let cfg = ServerConfig {
644 tls_key_path: Some("/tmp/key.pem".into()),
645 ..ServerConfig::default()
646 };
647 let err = validate_server_config(&cfg).unwrap_err();
648 assert!(err.to_string().contains("tls_cert_path"));
649 }
650
651 #[test]
652 fn tls_both_set_passes() {
653 let cfg = ServerConfig {
654 tls_cert_path: Some("/tmp/cert.pem".into()),
655 tls_key_path: Some("/tmp/key.pem".into()),
656 ..ServerConfig::default()
657 };
658 assert!(validate_server_config(&cfg).is_ok());
659 }
660
661 #[test]
662 fn invalid_tls_handshake_timeout_rejected() {
663 let cfg = ServerConfig {
664 tls_handshake_timeout: "not-a-duration".into(),
665 ..ServerConfig::default()
666 };
667 let err = validate_server_config(&cfg).unwrap_err();
668 assert!(err.to_string().contains("tls_handshake_timeout"));
669 }
670
671 #[test]
672 fn zero_tls_handshake_timeout_rejected() {
673 let cfg = ServerConfig {
674 tls_handshake_timeout: "0s".into(),
675 ..ServerConfig::default()
676 };
677 let err = validate_server_config(&cfg).unwrap_err();
678 assert!(err.to_string().contains("tls_handshake_timeout"));
679 }
680
681 #[test]
682 fn zero_max_concurrent_tls_handshakes_rejected() {
683 let cfg = ServerConfig {
684 max_concurrent_tls_handshakes: 0,
685 ..ServerConfig::default()
686 };
687 let err = validate_server_config(&cfg).unwrap_err();
688 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
689 }
690
691 #[test]
692 fn invalid_shutdown_timeout_rejected() {
693 let cfg = ServerConfig {
694 shutdown_timeout: "not-a-duration".into(),
695 ..ServerConfig::default()
696 };
697 let err = validate_server_config(&cfg).unwrap_err();
698 assert!(err.to_string().contains("shutdown_timeout"));
699 }
700
701 #[test]
702 fn invalid_request_timeout_rejected() {
703 let cfg = ServerConfig {
704 request_timeout: "xyz".into(),
705 ..ServerConfig::default()
706 };
707 let err = validate_server_config(&cfg).unwrap_err();
708 assert!(err.to_string().contains("request_timeout"));
709 }
710
711 #[test]
714 fn valid_observability_config_passes() {
715 let cfg = ObservabilityConfig::default();
716 assert!(validate_observability_config(&cfg).is_ok());
717 }
718
719 #[test]
720 fn invalid_log_level_rejected() {
721 let cfg = ObservabilityConfig {
722 log_level: "[invalid".into(),
723 ..ObservabilityConfig::default()
724 };
725 let err = validate_observability_config(&cfg).unwrap_err();
726 assert!(err.to_string().contains("log_level"));
727 }
728
729 #[test]
730 fn invalid_log_format_rejected() {
731 let cfg = ObservabilityConfig {
732 log_format: "yaml".into(),
733 ..ObservabilityConfig::default()
734 };
735 let err = validate_observability_config(&cfg).unwrap_err();
736 assert!(err.to_string().contains("log_format"));
737 }
738
739 #[test]
740 fn all_valid_log_levels_accepted() {
741 for level in &[
742 "trace",
743 "debug",
744 "info",
745 "warn",
746 "error",
747 "info,rmcp=warn",
748 "debug,hyper=error",
749 ] {
750 let cfg = ObservabilityConfig {
751 log_level: (*level).into(),
752 ..ObservabilityConfig::default()
753 };
754 assert!(
755 validate_observability_config(&cfg).is_ok(),
756 "level {level} should be valid"
757 );
758 }
759 }
760
761 #[test]
762 fn all_log_formats_accepted() {
763 for fmt in &["json", "pretty", "text"] {
764 let cfg = ObservabilityConfig {
765 log_format: (*fmt).into(),
766 ..ObservabilityConfig::default()
767 };
768 assert!(
769 validate_observability_config(&cfg).is_ok(),
770 "format {fmt} should be valid"
771 );
772 }
773 }
774
775 #[test]
778 fn server_config_deserialize_defaults() {
779 let cfg: ServerConfig = toml::from_str("").unwrap();
780 assert_eq!(cfg.listen_port, 8443);
781 assert_eq!(cfg.listen_addr, "127.0.0.1");
782 assert_eq!(cfg.tls_handshake_timeout, "10s");
783 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
784 }
785
786 #[test]
787 fn observability_config_deserialize_defaults() {
788 let cfg: ObservabilityConfig = toml::from_str("").unwrap();
789 assert_eq!(cfg.log_level, "info,rmcp=warn");
790 assert_eq!(cfg.log_format, "pretty");
791 assert!(!cfg.log_request_headers);
792 assert!(!cfg.metrics_enabled);
793 }
794}