1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9 #[serde(default = "default_listen_addr")]
11 pub listen_addr: String,
12 #[serde(default = "default_listen_port")]
14 pub listen_port: u16,
15 pub tls_cert_path: Option<PathBuf>,
17 pub tls_key_path: Option<PathBuf>,
19 #[serde(default = "default_tls_handshake_timeout")]
24 pub tls_handshake_timeout: String,
25 #[serde(default = "default_max_concurrent_tls_handshakes")]
30 pub max_concurrent_tls_handshakes: usize,
31 #[serde(default = "default_shutdown_timeout")]
33 pub shutdown_timeout: String,
34 #[serde(default = "default_request_timeout")]
36 pub request_timeout: String,
37 #[serde(default)]
41 pub allowed_origins: Vec<String>,
42 #[serde(default)]
45 pub stdio_enabled: bool,
46 pub tool_rate_limit: Option<u32>,
50 #[serde(default = "default_session_idle_timeout")]
53 pub session_idle_timeout: String,
54 #[serde(default = "default_sse_keep_alive")]
58 pub sse_keep_alive: String,
59 pub public_url: Option<String>,
64 #[serde(default)]
66 pub compression_enabled: bool,
67 #[serde(default = "default_compression_min_size")]
70 pub compression_min_size: u16,
71 pub max_concurrent_requests: Option<usize>,
74 #[serde(default)]
76 pub admin_enabled: bool,
77 #[serde(default = "default_admin_role")]
79 pub admin_role: String,
80 pub auth: Option<crate::auth::AuthConfig>,
82}
83
84impl Default for ServerConfig {
85 fn default() -> Self {
86 Self {
87 listen_addr: default_listen_addr(),
88 listen_port: default_listen_port(),
89 tls_cert_path: None,
90 tls_key_path: None,
91 tls_handshake_timeout: default_tls_handshake_timeout(),
92 max_concurrent_tls_handshakes: default_max_concurrent_tls_handshakes(),
93 shutdown_timeout: default_shutdown_timeout(),
94 request_timeout: default_request_timeout(),
95 allowed_origins: Vec::new(),
96 stdio_enabled: false,
97 tool_rate_limit: None,
98 session_idle_timeout: default_session_idle_timeout(),
99 sse_keep_alive: default_sse_keep_alive(),
100 public_url: None,
101 compression_enabled: false,
102 compression_min_size: default_compression_min_size(),
103 max_concurrent_requests: None,
104 admin_enabled: false,
105 admin_role: default_admin_role(),
106 auth: None,
107 }
108 }
109}
110
111#[derive(Debug, Deserialize)]
113#[non_exhaustive]
114pub struct ObservabilityConfig {
115 #[serde(default = "default_log_level")]
117 pub log_level: String,
118 #[serde(default = "default_log_format")]
120 pub log_format: String,
121 pub audit_log_path: Option<PathBuf>,
123 #[serde(default)]
126 pub log_request_headers: bool,
127 #[serde(default)]
129 pub metrics_enabled: bool,
130 #[serde(default = "default_metrics_bind")]
132 pub metrics_bind: String,
133}
134
135impl Default for ObservabilityConfig {
136 fn default() -> Self {
137 Self {
138 log_level: default_log_level(),
139 log_format: default_log_format(),
140 audit_log_path: None,
141 log_request_headers: false,
142 metrics_enabled: false,
143 metrics_bind: default_metrics_bind(),
144 }
145 }
146}
147
148pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
154 use crate::error::McpxError;
155
156 if server.listen_port == 0 {
157 return Err(McpxError::Config("listen_port must be nonzero".into()));
158 }
159
160 match (&server.tls_cert_path, &server.tls_key_path) {
161 (Some(_), None) | (None, Some(_)) => {
162 return Err(McpxError::Config(
163 "tls_cert_path and tls_key_path must both be set or both omitted".into(),
164 ));
165 }
166 _ => {}
167 }
168
169 if let Some(0) = server.max_concurrent_requests {
170 return Err(McpxError::Config(
171 "max_concurrent_requests must be nonzero when set".into(),
172 ));
173 }
174
175 if server.admin_enabled {
176 let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
177 if !auth_enabled {
178 return Err(McpxError::Config(
179 "admin_enabled=true requires auth to be configured and enabled".into(),
180 ));
181 }
182 if server.admin_role.trim().is_empty() {
183 return Err(McpxError::Config("admin_role must not be empty".into()));
184 }
185 }
186
187 for (field, value) in [
188 ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
189 ("server.request_timeout", server.request_timeout.as_str()),
190 (
191 "server.session_idle_timeout",
192 server.session_idle_timeout.as_str(),
193 ),
194 ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
195 (
196 "server.tls_handshake_timeout",
197 server.tls_handshake_timeout.as_str(),
198 ),
199 ] {
200 if humantime::parse_duration(value).is_err() {
201 return Err(McpxError::Config(format!(
202 "invalid duration for {field}: {value:?}"
203 )));
204 }
205 }
206
207 if humantime::parse_duration(&server.tls_handshake_timeout)
211 .is_ok_and(|d| d == std::time::Duration::ZERO)
212 {
213 return Err(McpxError::Config(
214 "server.tls_handshake_timeout must be greater than zero".into(),
215 ));
216 }
217
218 if server.max_concurrent_tls_handshakes == 0 {
222 return Err(McpxError::Config(
223 "server.max_concurrent_tls_handshakes must be greater than zero".into(),
224 ));
225 }
226
227 Ok(())
228}
229
230pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
236 use tracing_subscriber::EnvFilter;
237
238 use crate::error::McpxError;
239
240 if EnvFilter::try_new(&obs.log_level).is_err() {
241 return Err(McpxError::Config(format!(
242 "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
243 obs.log_level
244 )));
245 }
246 let valid_formats = ["json", "pretty", "text"];
247 if !valid_formats.contains(&obs.log_format.as_str()) {
248 return Err(McpxError::Config(format!(
249 "invalid log_format: {:?} (expected one of: {valid_formats:?})",
250 obs.log_format
251 )));
252 }
253
254 Ok(())
255}
256
257fn default_listen_addr() -> String {
260 "127.0.0.1".into()
261}
262fn default_listen_port() -> u16 {
263 8443
264}
265fn default_shutdown_timeout() -> String {
266 "30s".into()
267}
268fn default_request_timeout() -> String {
269 "120s".into()
270}
271fn default_log_level() -> String {
272 "info,rmcp=warn".into()
273}
274fn default_log_format() -> String {
275 "pretty".into()
276}
277fn default_metrics_bind() -> String {
278 "127.0.0.1:9090".into()
279}
280fn default_session_idle_timeout() -> String {
281 "20m".into()
282}
283fn default_tls_handshake_timeout() -> String {
284 "10s".into()
285}
286const fn default_max_concurrent_tls_handshakes() -> usize {
287 256
288}
289fn default_admin_role() -> String {
290 "admin".into()
291}
292fn default_compression_min_size() -> u16 {
293 1024
294}
295fn default_sse_keep_alive() -> String {
296 "15s".into()
297}
298
299#[cfg(test)]
300mod tests {
301 #![allow(
302 clippy::unwrap_used,
303 clippy::expect_used,
304 clippy::panic,
305 clippy::indexing_slicing,
306 clippy::unwrap_in_result,
307 clippy::print_stdout,
308 clippy::print_stderr,
309 reason = "test-only relaxations; production code uses ? and tracing"
310 )]
311 use super::*;
312
313 #[test]
316 fn server_config_defaults() {
317 let cfg = ServerConfig::default();
318 assert_eq!(cfg.listen_addr, "127.0.0.1");
319 assert_eq!(cfg.listen_port, 8443);
320 assert!(cfg.tls_cert_path.is_none());
321 assert!(cfg.tls_key_path.is_none());
322 assert_eq!(cfg.shutdown_timeout, "30s");
323 assert_eq!(cfg.request_timeout, "120s");
324 assert!(cfg.allowed_origins.is_empty());
325 assert!(!cfg.stdio_enabled);
326 assert!(cfg.tool_rate_limit.is_none());
327 assert_eq!(cfg.session_idle_timeout, "20m");
328 assert_eq!(cfg.sse_keep_alive, "15s");
329 assert!(cfg.public_url.is_none());
330 }
331
332 #[test]
333 fn observability_config_defaults() {
334 let cfg = ObservabilityConfig::default();
335 assert_eq!(cfg.log_level, "info,rmcp=warn");
336 assert_eq!(cfg.log_format, "pretty");
337 assert!(cfg.audit_log_path.is_none());
338 assert!(!cfg.log_request_headers);
339 assert!(!cfg.metrics_enabled);
340 assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
341 }
342
343 #[test]
346 fn valid_server_config_passes() {
347 let cfg = ServerConfig::default();
348 assert!(validate_server_config(&cfg).is_ok());
349 }
350
351 #[test]
352 fn zero_port_rejected() {
353 let cfg = ServerConfig {
354 listen_port: 0,
355 ..ServerConfig::default()
356 };
357 let err = validate_server_config(&cfg).unwrap_err();
358 assert!(err.to_string().contains("listen_port"));
359 }
360
361 #[test]
362 fn tls_cert_without_key_rejected() {
363 let cfg = ServerConfig {
364 tls_cert_path: Some("/tmp/cert.pem".into()),
365 ..ServerConfig::default()
366 };
367 let err = validate_server_config(&cfg).unwrap_err();
368 assert!(err.to_string().contains("tls_cert_path"));
369 }
370
371 #[test]
372 fn tls_key_without_cert_rejected() {
373 let cfg = ServerConfig {
374 tls_key_path: Some("/tmp/key.pem".into()),
375 ..ServerConfig::default()
376 };
377 let err = validate_server_config(&cfg).unwrap_err();
378 assert!(err.to_string().contains("tls_cert_path"));
379 }
380
381 #[test]
382 fn tls_both_set_passes() {
383 let cfg = ServerConfig {
384 tls_cert_path: Some("/tmp/cert.pem".into()),
385 tls_key_path: Some("/tmp/key.pem".into()),
386 ..ServerConfig::default()
387 };
388 assert!(validate_server_config(&cfg).is_ok());
389 }
390
391 #[test]
392 fn invalid_tls_handshake_timeout_rejected() {
393 let cfg = ServerConfig {
394 tls_handshake_timeout: "not-a-duration".into(),
395 ..ServerConfig::default()
396 };
397 let err = validate_server_config(&cfg).unwrap_err();
398 assert!(err.to_string().contains("tls_handshake_timeout"));
399 }
400
401 #[test]
402 fn zero_tls_handshake_timeout_rejected() {
403 let cfg = ServerConfig {
404 tls_handshake_timeout: "0s".into(),
405 ..ServerConfig::default()
406 };
407 let err = validate_server_config(&cfg).unwrap_err();
408 assert!(err.to_string().contains("tls_handshake_timeout"));
409 }
410
411 #[test]
412 fn zero_max_concurrent_tls_handshakes_rejected() {
413 let cfg = ServerConfig {
414 max_concurrent_tls_handshakes: 0,
415 ..ServerConfig::default()
416 };
417 let err = validate_server_config(&cfg).unwrap_err();
418 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
419 }
420
421 #[test]
422 fn invalid_shutdown_timeout_rejected() {
423 let cfg = ServerConfig {
424 shutdown_timeout: "not-a-duration".into(),
425 ..ServerConfig::default()
426 };
427 let err = validate_server_config(&cfg).unwrap_err();
428 assert!(err.to_string().contains("shutdown_timeout"));
429 }
430
431 #[test]
432 fn invalid_request_timeout_rejected() {
433 let cfg = ServerConfig {
434 request_timeout: "xyz".into(),
435 ..ServerConfig::default()
436 };
437 let err = validate_server_config(&cfg).unwrap_err();
438 assert!(err.to_string().contains("request_timeout"));
439 }
440
441 #[test]
444 fn valid_observability_config_passes() {
445 let cfg = ObservabilityConfig::default();
446 assert!(validate_observability_config(&cfg).is_ok());
447 }
448
449 #[test]
450 fn invalid_log_level_rejected() {
451 let cfg = ObservabilityConfig {
452 log_level: "[invalid".into(),
453 ..ObservabilityConfig::default()
454 };
455 let err = validate_observability_config(&cfg).unwrap_err();
456 assert!(err.to_string().contains("log_level"));
457 }
458
459 #[test]
460 fn invalid_log_format_rejected() {
461 let cfg = ObservabilityConfig {
462 log_format: "yaml".into(),
463 ..ObservabilityConfig::default()
464 };
465 let err = validate_observability_config(&cfg).unwrap_err();
466 assert!(err.to_string().contains("log_format"));
467 }
468
469 #[test]
470 fn all_valid_log_levels_accepted() {
471 for level in &[
472 "trace",
473 "debug",
474 "info",
475 "warn",
476 "error",
477 "info,rmcp=warn",
478 "debug,hyper=error",
479 ] {
480 let cfg = ObservabilityConfig {
481 log_level: (*level).into(),
482 ..ObservabilityConfig::default()
483 };
484 assert!(
485 validate_observability_config(&cfg).is_ok(),
486 "level {level} should be valid"
487 );
488 }
489 }
490
491 #[test]
492 fn all_log_formats_accepted() {
493 for fmt in &["json", "pretty", "text"] {
494 let cfg = ObservabilityConfig {
495 log_format: (*fmt).into(),
496 ..ObservabilityConfig::default()
497 };
498 assert!(
499 validate_observability_config(&cfg).is_ok(),
500 "format {fmt} should be valid"
501 );
502 }
503 }
504
505 #[test]
508 fn server_config_deserialize_defaults() {
509 let cfg: ServerConfig = toml::from_str("").unwrap();
510 assert_eq!(cfg.listen_port, 8443);
511 assert_eq!(cfg.listen_addr, "127.0.0.1");
512 assert_eq!(cfg.tls_handshake_timeout, "10s");
513 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
514 }
515
516 #[test]
517 fn observability_config_deserialize_defaults() {
518 let cfg: ObservabilityConfig = toml::from_str("").unwrap();
519 assert_eq!(cfg.log_level, "info,rmcp=warn");
520 assert_eq!(cfg.log_format, "pretty");
521 assert!(!cfg.log_request_headers);
522 assert!(!cfg.metrics_enabled);
523 }
524}