1use std::{
4 collections::HashMap,
5 path::{Path, PathBuf},
6};
7
8use mcpkit_rs_policy::Policy as PolicyConfig;
9use serde::{Deserialize, Serialize};
10
11pub mod defaults;
12pub mod error;
13pub mod loader;
14pub mod validation;
15
16pub use error::{ConfigError, Result};
17pub use loader::ConfigLoader;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub struct Config {
23 pub version: String,
25
26 pub metadata: Option<Metadata>,
28
29 pub server: ServerConfig,
31
32 pub transport: TransportConfig,
34
35 pub policy: Option<PolicyConfig>,
37
38 pub runtime: RuntimeConfig,
40
41 pub mcp: McpConfig,
43
44 pub distribution: Option<DistributionConfig>,
46
47 #[serde(default)]
49 pub extensions: HashMap<String, serde_json::Value>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Metadata {
55 pub name: Option<String>,
56 pub description: Option<String>,
57 pub author: Option<String>,
58 pub created_at: Option<String>,
59 pub modified_at: Option<String>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ServerConfig {
65 pub name: String,
67
68 pub version: String,
70
71 pub description: Option<String>,
73
74 pub bind: String,
76
77 pub port: u16,
79
80 pub max_connections: Option<usize>,
82
83 pub request_timeout: Option<u64>,
85
86 #[serde(default)]
88 pub debug: bool,
89
90 pub log_level: Option<String>,
92}
93
94#[derive(Debug, Clone, Serialize)]
96#[serde(rename_all = "snake_case")]
97pub struct TransportConfig {
98 #[serde(rename = "type")]
100 pub transport_type: TransportType,
101
102 pub settings: TransportSettings,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108#[serde(rename_all = "snake_case")]
109pub enum TransportType {
110 Stdio,
111 Http,
112 WebSocket,
113 Grpc,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118#[serde(untagged)]
119pub enum TransportSettings {
120 Stdio(StdioSettings),
121 Http(HttpSettings),
122 WebSocket(WebSocketSettings),
123 Grpc(GrpcSettings),
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct StdioSettings {
129 pub buffer_size: Option<usize>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct HttpSettings {
135 pub cors_enabled: Option<bool>,
136 pub cors_origins: Option<Vec<String>>,
137 pub max_body_size: Option<usize>,
138 pub compression: Option<bool>,
139 pub tls: Option<TlsConfig>,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct WebSocketSettings {
145 pub ping_interval: Option<u64>,
146 pub max_frame_size: Option<usize>,
147 pub compression: Option<bool>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct GrpcSettings {
153 pub reflection: Option<bool>,
154 pub max_message_size: Option<usize>,
155 pub tls: Option<TlsConfig>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct TlsConfig {
161 pub cert_file: PathBuf,
162 pub key_file: PathBuf,
163 pub ca_file: Option<PathBuf>,
164 pub verify_client: Option<bool>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct RuntimeConfig {
170 #[serde(rename = "type")]
172 pub runtime_type: RuntimeType,
173
174 pub wasm: Option<WasmConfig>,
176
177 pub limits: Option<ResourceLimits>,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
183#[serde(rename_all = "snake_case")]
184pub enum RuntimeType {
185 Native,
186 Wasmtime,
187 #[serde(rename = "wasmedge")]
188 WasmEdge,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct WasmConfig {
194 pub module_path: Option<PathBuf>,
196
197 pub fuel: Option<u64>,
199
200 pub memory_pages: Option<u32>,
202
203 pub cache: Option<bool>,
205
206 pub cache_dir: Option<PathBuf>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct ResourceLimits {
213 pub cpu: Option<String>,
214 pub memory: Option<String>,
215 pub execution_time: Option<String>,
216 pub max_requests_per_minute: Option<u32>,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct McpConfig {
222 pub protocol_version: String,
224
225 pub tools: Option<Vec<ToolConfig>>,
227
228 pub prompts: Option<Vec<PromptConfig>>,
230
231 pub resources: Option<Vec<ResourceConfig>>,
233
234 pub capabilities: Option<McpCapabilities>,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct ToolConfig {
241 pub name: String,
242 pub description: String,
243 pub input_schema: serde_json::Value,
244 pub handler: Option<String>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct PromptConfig {
250 pub name: String,
251 pub description: String,
252 pub arguments: Option<Vec<PromptArgument>>,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct PromptArgument {
258 pub name: String,
259 pub description: Option<String>,
260 pub required: bool,
261 pub default: Option<serde_json::Value>,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct ResourceConfig {
267 pub name: String,
268 pub uri: String,
269 pub description: Option<String>,
270 pub mime_type: Option<String>,
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275#[serde(untagged)]
276pub enum McpCapabilities {
277 List(Vec<String>),
278 Struct {
279 tools: Option<bool>,
280 prompts: Option<bool>,
281 resources: Option<bool>,
282 logging: Option<bool>,
283 experimental: Option<HashMap<String, bool>>,
284 },
285}
286
287impl McpCapabilities {
288 pub fn has_tools(&self) -> bool {
289 match self {
290 McpCapabilities::List(caps) => caps.contains(&"tools".to_string()),
291 McpCapabilities::Struct { tools, .. } => tools.unwrap_or(false),
292 }
293 }
294
295 pub fn has_prompts(&self) -> bool {
296 match self {
297 McpCapabilities::List(caps) => caps.contains(&"prompts".to_string()),
298 McpCapabilities::Struct { prompts, .. } => prompts.unwrap_or(false),
299 }
300 }
301
302 pub fn has_resources(&self) -> bool {
303 match self {
304 McpCapabilities::List(caps) => caps.contains(&"resources".to_string()),
305 McpCapabilities::Struct { resources, .. } => resources.unwrap_or(false),
306 }
307 }
308
309 pub fn has_logging(&self) -> bool {
310 match self {
311 McpCapabilities::List(caps) => caps.contains(&"logging".to_string()),
312 McpCapabilities::Struct { logging, .. } => logging.unwrap_or(false),
313 }
314 }
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct DistributionConfig {
320 pub registry: String,
322
323 pub version: Option<String>,
325
326 #[serde(default)]
328 pub tags: Vec<String>,
329
330 pub metadata: Option<BundleMetadata>,
332
333 #[serde(default)]
335 pub include: Vec<String>,
336
337 pub auth: Option<RegistryAuth>,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct BundleMetadata {
344 #[serde(default)]
346 pub authors: Vec<String>,
347
348 pub license: Option<String>,
350
351 pub repository: Option<String>,
353
354 #[serde(default)]
356 pub keywords: Vec<String>,
357
358 pub homepage: Option<String>,
360
361 pub documentation: Option<String>,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct RegistryAuth {
368 pub username: Option<String>,
370
371 pub password: Option<String>,
373
374 pub auth_file: Option<PathBuf>,
376
377 #[serde(default)]
379 pub use_keychain: bool,
380}
381
382#[cfg(test)]
383mod tests;
384
385impl<'de> Deserialize<'de> for TransportConfig {
387 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
388 where
389 D: serde::Deserializer<'de>,
390 {
391 use std::fmt;
392
393 use serde::de::{self, MapAccess, Visitor};
394
395 #[derive(Deserialize)]
396 #[serde(field_identifier, rename_all = "snake_case")]
397 enum Field {
398 #[serde(rename = "type")]
399 Type,
400 Settings,
401 }
402
403 struct TransportConfigVisitor;
404
405 impl<'de> Visitor<'de> for TransportConfigVisitor {
406 type Value = TransportConfig;
407
408 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
409 formatter.write_str("struct TransportConfig")
410 }
411
412 fn visit_map<A>(self, mut map: A) -> std::result::Result<TransportConfig, A::Error>
413 where
414 A: MapAccess<'de>,
415 {
416 let mut transport_type: Option<TransportType> = None;
417 let mut settings_value: Option<serde_json::Value> = None;
418
419 while let Some(key) = map.next_key()? {
420 match key {
421 Field::Type => {
422 if transport_type.is_some() {
423 return Err(de::Error::duplicate_field("type"));
424 }
425 transport_type = Some(map.next_value()?);
426 }
427 Field::Settings => {
428 if settings_value.is_some() {
429 return Err(de::Error::duplicate_field("settings"));
430 }
431 settings_value = Some(map.next_value()?);
432 }
433 }
434 }
435
436 let transport_type =
437 transport_type.ok_or_else(|| de::Error::missing_field("type"))?;
438 let settings_value =
439 settings_value.ok_or_else(|| de::Error::missing_field("settings"))?;
440
441 let settings = match transport_type {
443 TransportType::Stdio => {
444 let stdio_settings: StdioSettings = serde_json::from_value(settings_value)
445 .map_err(|e| {
446 de::Error::custom(format!("Invalid stdio settings: {}", e))
447 })?;
448 TransportSettings::Stdio(stdio_settings)
449 }
450 TransportType::Http => {
451 let http_settings: HttpSettings = serde_json::from_value(settings_value)
452 .map_err(|e| {
453 de::Error::custom(format!("Invalid HTTP settings: {}", e))
454 })?;
455 TransportSettings::Http(http_settings)
456 }
457 TransportType::WebSocket => {
458 let ws_settings: WebSocketSettings = serde_json::from_value(settings_value)
459 .map_err(|e| {
460 de::Error::custom(format!("Invalid WebSocket settings: {}", e))
461 })?;
462 TransportSettings::WebSocket(ws_settings)
463 }
464 TransportType::Grpc => {
465 let grpc_settings: GrpcSettings = serde_json::from_value(settings_value)
466 .map_err(|e| {
467 de::Error::custom(format!("Invalid gRPC settings: {}", e))
468 })?;
469 TransportSettings::Grpc(grpc_settings)
470 }
471 };
472
473 Ok(TransportConfig {
474 transport_type,
475 settings,
476 })
477 }
478 }
479
480 const FIELDS: &[&str] = &["type", "settings"];
481 deserializer.deserialize_struct("TransportConfig", FIELDS, TransportConfigVisitor)
482 }
483}
484
485impl Config {
486 pub fn from_yaml_file<P: AsRef<Path>>(path: P) -> Result<Self> {
488 let contents = std::fs::read_to_string(path)?;
489 Self::from_yaml(&contents)
490 }
491
492 pub fn from_yaml(yaml: &str) -> Result<Self> {
494 let config: Config = serde_yaml::from_str(yaml)?;
495 config.validate()?;
496 Ok(config)
497 }
498
499 pub fn from_json_file<P: AsRef<Path>>(path: P) -> Result<Self> {
501 let contents = std::fs::read_to_string(path)?;
502 Self::from_json(&contents)
503 }
504
505 pub fn from_json(json: &str) -> Result<Self> {
507 let config: Config = serde_json::from_str(json)?;
508 config.validate()?;
509 Ok(config)
510 }
511
512 pub fn to_yaml_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
514 let yaml = serde_yaml::to_string(self)?;
515 std::fs::write(path, yaml)?;
516 Ok(())
517 }
518
519 pub fn to_json_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
521 let json = serde_json::to_string_pretty(self)?;
522 std::fs::write(path, json)?;
523 Ok(())
524 }
525
526 pub fn validate(&self) -> Result<()> {
528 validation::validate_config(self)
529 }
530
531 pub fn merge(&mut self, other: Config) -> Result<()> {
533 if other.version != self.version {
534 return Err(ConfigError::VersionMismatch {
535 expected: self.version.clone(),
536 found: other.version,
537 });
538 }
539
540 if let Some(metadata) = other.metadata {
541 self.metadata = Some(metadata);
542 }
543
544 self.server = other.server;
545 self.transport = other.transport;
546
547 if let Some(policy) = other.policy {
548 self.policy = Some(policy);
549 }
550
551 self.runtime = other.runtime;
552 self.mcp = other.mcp;
553
554 if let Some(distribution) = other.distribution {
555 self.distribution = Some(distribution);
556 }
557
558 self.extensions.extend(other.extensions);
559
560 Ok(())
561 }
562}
563
564impl Default for Config {
565 fn default() -> Self {
566 defaults::default_config()
567 }
568}