1use std::path::PathBuf;
2
3use serde::Deserialize;
4
5#[derive(Debug, Deserialize)]
7#[non_exhaustive]
8pub struct ServerConfig {
9 #[serde(default = "default_listen_addr")]
11 pub listen_addr: String,
12 #[serde(default = "default_listen_port")]
14 pub listen_port: u16,
15 pub tls_cert_path: Option<PathBuf>,
17 pub tls_key_path: Option<PathBuf>,
19 #[serde(default = "default_shutdown_timeout")]
21 pub shutdown_timeout: String,
22 #[serde(default = "default_request_timeout")]
24 pub request_timeout: String,
25 #[serde(default)]
29 pub allowed_origins: Vec<String>,
30 #[serde(default)]
33 pub stdio_enabled: bool,
34 pub tool_rate_limit: Option<u32>,
38 #[serde(default = "default_session_idle_timeout")]
41 pub session_idle_timeout: String,
42 #[serde(default = "default_sse_keep_alive")]
46 pub sse_keep_alive: String,
47 pub public_url: Option<String>,
52 #[serde(default)]
54 pub compression_enabled: bool,
55 #[serde(default = "default_compression_min_size")]
58 pub compression_min_size: u16,
59 pub max_concurrent_requests: Option<usize>,
62 #[serde(default)]
64 pub admin_enabled: bool,
65 #[serde(default = "default_admin_role")]
67 pub admin_role: String,
68 pub auth: Option<crate::auth::AuthConfig>,
70}
71
72impl Default for ServerConfig {
73 fn default() -> Self {
74 Self {
75 listen_addr: default_listen_addr(),
76 listen_port: default_listen_port(),
77 tls_cert_path: None,
78 tls_key_path: None,
79 shutdown_timeout: default_shutdown_timeout(),
80 request_timeout: default_request_timeout(),
81 allowed_origins: Vec::new(),
82 stdio_enabled: false,
83 tool_rate_limit: None,
84 session_idle_timeout: default_session_idle_timeout(),
85 sse_keep_alive: default_sse_keep_alive(),
86 public_url: None,
87 compression_enabled: false,
88 compression_min_size: default_compression_min_size(),
89 max_concurrent_requests: None,
90 admin_enabled: false,
91 admin_role: default_admin_role(),
92 auth: None,
93 }
94 }
95}
96
97#[derive(Debug, Deserialize)]
99#[non_exhaustive]
100pub struct ObservabilityConfig {
101 #[serde(default = "default_log_level")]
103 pub log_level: String,
104 #[serde(default = "default_log_format")]
106 pub log_format: String,
107 pub audit_log_path: Option<PathBuf>,
109 #[serde(default)]
112 pub log_request_headers: bool,
113 #[serde(default)]
115 pub metrics_enabled: bool,
116 #[serde(default = "default_metrics_bind")]
118 pub metrics_bind: String,
119}
120
121impl Default for ObservabilityConfig {
122 fn default() -> Self {
123 Self {
124 log_level: default_log_level(),
125 log_format: default_log_format(),
126 audit_log_path: None,
127 log_request_headers: false,
128 metrics_enabled: false,
129 metrics_bind: default_metrics_bind(),
130 }
131 }
132}
133
134pub fn validate_server_config(server: &ServerConfig) -> crate::error::Result<()> {
140 use crate::error::McpxError;
141
142 if server.listen_port == 0 {
143 return Err(McpxError::Config("listen_port must be nonzero".into()));
144 }
145
146 match (&server.tls_cert_path, &server.tls_key_path) {
147 (Some(_), None) | (None, Some(_)) => {
148 return Err(McpxError::Config(
149 "tls_cert_path and tls_key_path must both be set or both omitted".into(),
150 ));
151 }
152 _ => {}
153 }
154
155 if let Some(0) = server.max_concurrent_requests {
156 return Err(McpxError::Config(
157 "max_concurrent_requests must be nonzero when set".into(),
158 ));
159 }
160
161 if server.admin_enabled {
162 let auth_enabled = server.auth.as_ref().is_some_and(|a| a.enabled);
163 if !auth_enabled {
164 return Err(McpxError::Config(
165 "admin_enabled=true requires auth to be configured and enabled".into(),
166 ));
167 }
168 if server.admin_role.trim().is_empty() {
169 return Err(McpxError::Config("admin_role must not be empty".into()));
170 }
171 }
172
173 for (field, value) in [
174 ("server.shutdown_timeout", server.shutdown_timeout.as_str()),
175 ("server.request_timeout", server.request_timeout.as_str()),
176 (
177 "server.session_idle_timeout",
178 server.session_idle_timeout.as_str(),
179 ),
180 ("server.sse_keep_alive", server.sse_keep_alive.as_str()),
181 ] {
182 if humantime::parse_duration(value).is_err() {
183 return Err(McpxError::Config(format!(
184 "invalid duration for {field}: {value:?}"
185 )));
186 }
187 }
188
189 Ok(())
190}
191
192pub fn validate_observability_config(obs: &ObservabilityConfig) -> crate::error::Result<()> {
198 use tracing_subscriber::EnvFilter;
199
200 use crate::error::McpxError;
201
202 if EnvFilter::try_new(&obs.log_level).is_err() {
203 return Err(McpxError::Config(format!(
204 "invalid log_level: {:?} (expected a valid tracing filter directive, e.g. \"info\", \"debug,hyper=warn\")",
205 obs.log_level
206 )));
207 }
208 let valid_formats = ["json", "pretty", "text"];
209 if !valid_formats.contains(&obs.log_format.as_str()) {
210 return Err(McpxError::Config(format!(
211 "invalid log_format: {:?} (expected one of: {valid_formats:?})",
212 obs.log_format
213 )));
214 }
215
216 Ok(())
217}
218
219fn default_listen_addr() -> String {
222 "127.0.0.1".into()
223}
224fn default_listen_port() -> u16 {
225 8443
226}
227fn default_shutdown_timeout() -> String {
228 "30s".into()
229}
230fn default_request_timeout() -> String {
231 "120s".into()
232}
233fn default_log_level() -> String {
234 "info,rmcp=warn".into()
235}
236fn default_log_format() -> String {
237 "pretty".into()
238}
239fn default_metrics_bind() -> String {
240 "127.0.0.1:9090".into()
241}
242fn default_session_idle_timeout() -> String {
243 "20m".into()
244}
245fn default_admin_role() -> String {
246 "admin".into()
247}
248fn default_compression_min_size() -> u16 {
249 1024
250}
251fn default_sse_keep_alive() -> String {
252 "15s".into()
253}
254
255#[cfg(test)]
256mod tests {
257 #![allow(
258 clippy::unwrap_used,
259 clippy::expect_used,
260 clippy::panic,
261 clippy::indexing_slicing,
262 clippy::unwrap_in_result,
263 clippy::print_stdout,
264 clippy::print_stderr
265 )]
266 use super::*;
267
268 #[test]
271 fn server_config_defaults() {
272 let cfg = ServerConfig::default();
273 assert_eq!(cfg.listen_addr, "127.0.0.1");
274 assert_eq!(cfg.listen_port, 8443);
275 assert!(cfg.tls_cert_path.is_none());
276 assert!(cfg.tls_key_path.is_none());
277 assert_eq!(cfg.shutdown_timeout, "30s");
278 assert_eq!(cfg.request_timeout, "120s");
279 assert!(cfg.allowed_origins.is_empty());
280 assert!(!cfg.stdio_enabled);
281 assert!(cfg.tool_rate_limit.is_none());
282 assert_eq!(cfg.session_idle_timeout, "20m");
283 assert_eq!(cfg.sse_keep_alive, "15s");
284 assert!(cfg.public_url.is_none());
285 }
286
287 #[test]
288 fn observability_config_defaults() {
289 let cfg = ObservabilityConfig::default();
290 assert_eq!(cfg.log_level, "info,rmcp=warn");
291 assert_eq!(cfg.log_format, "pretty");
292 assert!(cfg.audit_log_path.is_none());
293 assert!(!cfg.log_request_headers);
294 assert!(!cfg.metrics_enabled);
295 assert_eq!(cfg.metrics_bind, "127.0.0.1:9090");
296 }
297
298 #[test]
301 fn valid_server_config_passes() {
302 let cfg = ServerConfig::default();
303 assert!(validate_server_config(&cfg).is_ok());
304 }
305
306 #[test]
307 fn zero_port_rejected() {
308 let cfg = ServerConfig {
309 listen_port: 0,
310 ..ServerConfig::default()
311 };
312 let err = validate_server_config(&cfg).unwrap_err();
313 assert!(err.to_string().contains("listen_port"));
314 }
315
316 #[test]
317 fn tls_cert_without_key_rejected() {
318 let cfg = ServerConfig {
319 tls_cert_path: Some("/tmp/cert.pem".into()),
320 ..ServerConfig::default()
321 };
322 let err = validate_server_config(&cfg).unwrap_err();
323 assert!(err.to_string().contains("tls_cert_path"));
324 }
325
326 #[test]
327 fn tls_key_without_cert_rejected() {
328 let cfg = ServerConfig {
329 tls_key_path: Some("/tmp/key.pem".into()),
330 ..ServerConfig::default()
331 };
332 let err = validate_server_config(&cfg).unwrap_err();
333 assert!(err.to_string().contains("tls_cert_path"));
334 }
335
336 #[test]
337 fn tls_both_set_passes() {
338 let cfg = ServerConfig {
339 tls_cert_path: Some("/tmp/cert.pem".into()),
340 tls_key_path: Some("/tmp/key.pem".into()),
341 ..ServerConfig::default()
342 };
343 assert!(validate_server_config(&cfg).is_ok());
344 }
345
346 #[test]
347 fn invalid_shutdown_timeout_rejected() {
348 let cfg = ServerConfig {
349 shutdown_timeout: "not-a-duration".into(),
350 ..ServerConfig::default()
351 };
352 let err = validate_server_config(&cfg).unwrap_err();
353 assert!(err.to_string().contains("shutdown_timeout"));
354 }
355
356 #[test]
357 fn invalid_request_timeout_rejected() {
358 let cfg = ServerConfig {
359 request_timeout: "xyz".into(),
360 ..ServerConfig::default()
361 };
362 let err = validate_server_config(&cfg).unwrap_err();
363 assert!(err.to_string().contains("request_timeout"));
364 }
365
366 #[test]
369 fn valid_observability_config_passes() {
370 let cfg = ObservabilityConfig::default();
371 assert!(validate_observability_config(&cfg).is_ok());
372 }
373
374 #[test]
375 fn invalid_log_level_rejected() {
376 let cfg = ObservabilityConfig {
377 log_level: "[invalid".into(),
378 ..ObservabilityConfig::default()
379 };
380 let err = validate_observability_config(&cfg).unwrap_err();
381 assert!(err.to_string().contains("log_level"));
382 }
383
384 #[test]
385 fn invalid_log_format_rejected() {
386 let cfg = ObservabilityConfig {
387 log_format: "yaml".into(),
388 ..ObservabilityConfig::default()
389 };
390 let err = validate_observability_config(&cfg).unwrap_err();
391 assert!(err.to_string().contains("log_format"));
392 }
393
394 #[test]
395 fn all_valid_log_levels_accepted() {
396 for level in &[
397 "trace",
398 "debug",
399 "info",
400 "warn",
401 "error",
402 "info,rmcp=warn",
403 "debug,hyper=error",
404 ] {
405 let cfg = ObservabilityConfig {
406 log_level: (*level).into(),
407 ..ObservabilityConfig::default()
408 };
409 assert!(
410 validate_observability_config(&cfg).is_ok(),
411 "level {level} should be valid"
412 );
413 }
414 }
415
416 #[test]
417 fn both_log_formats_accepted() {
418 for fmt in &["json", "pretty"] {
419 let cfg = ObservabilityConfig {
420 log_format: (*fmt).into(),
421 ..ObservabilityConfig::default()
422 };
423 assert!(
424 validate_observability_config(&cfg).is_ok(),
425 "format {fmt} should be valid"
426 );
427 }
428 }
429
430 #[test]
433 fn server_config_deserialize_defaults() {
434 let cfg: ServerConfig = toml::from_str("").unwrap();
435 assert_eq!(cfg.listen_port, 8443);
436 assert_eq!(cfg.listen_addr, "127.0.0.1");
437 }
438
439 #[test]
440 fn observability_config_deserialize_defaults() {
441 let cfg: ObservabilityConfig = toml::from_str("").unwrap();
442 assert_eq!(cfg.log_level, "info,rmcp=warn");
443 assert_eq!(cfg.log_format, "pretty");
444 assert!(!cfg.log_request_headers);
445 assert!(!cfg.metrics_enabled);
446 }
447}