matrixcode_core/matrixrpc/
config.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use thiserror::Error;
9
10use super::service::{ExtensionService, TransportConfig, TransportType};
11
12pub const CONFIG_FILE_NAME: &str = "matrixrpc.toml";
14
15#[derive(Debug, Error)]
17pub enum ConfigError {
18 #[error("IO error: {0}")]
20 Io(#[from] std::io::Error),
21
22 #[error("TOML parse error: {0}")]
24 Toml(#[from] toml::de::Error),
25
26 #[error("Configuration validation error: {0}")]
28 Validation(String),
29
30 #[error("Service '{0}' not found in configuration")]
32 ServiceNotFound(String),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct MatrixRpcConfig {
38 #[serde(default)]
40 pub global: GlobalConfig,
41
42 #[serde(default)]
44 pub services: HashMap<String, ServiceDefinition>,
45}
46
47impl MatrixRpcConfig {
48 pub fn new() -> Self {
50 Self {
51 global: GlobalConfig::default(),
52 services: HashMap::new(),
53 }
54 }
55
56 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
58 let content = std::fs::read_to_string(path.as_ref())?;
59 let config: Self = toml::from_str(&content)?;
60 config.validate()?;
61 Ok(config)
62 }
63
64 pub fn load_default() -> Result<Self, ConfigError> {
66 let candidates = vec![
71 PathBuf::from(CONFIG_FILE_NAME),
72 PathBuf::from(".matrix").join(CONFIG_FILE_NAME),
73 dirs::config_dir()
74 .map(|p| p.join("matrix").join(CONFIG_FILE_NAME))
75 .unwrap_or_default(),
76 ];
77
78 for path in candidates {
79 if path.exists() {
80 return Self::load(&path);
81 }
82 }
83
84 Ok(Self::new())
86 }
87
88 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), ConfigError> {
90 let content = toml::to_string_pretty(self)
91 .map_err(|e| ConfigError::Validation(e.to_string()))?;
92 std::fs::write(path.as_ref(), content)?;
93 Ok(())
94 }
95
96 pub fn validate(&self) -> Result<(), ConfigError> {
98 for (name, service) in &self.services {
99 if service.command.is_none() && service.address.is_none() {
100 return Err(ConfigError::Validation(format!(
101 "Service '{}' must have either 'command' or 'address' configured",
102 name
103 )));
104 }
105 }
106 Ok(())
107 }
108
109 pub fn get_service(&self, name: &str) -> Option<&ServiceDefinition> {
111 self.services.get(name)
112 }
113
114 pub fn add_service(&mut self, name: impl Into<String>, service: ServiceDefinition) {
116 self.services.insert(name.into(), service);
117 }
118
119 pub fn create_service(&self, name: &str) -> Result<ExtensionService, ConfigError> {
121 let def = self
122 .get_service(name)
123 .ok_or_else(|| ConfigError::ServiceNotFound(name.to_string()))?;
124
125 let transport = def.to_transport_config();
126
127 let mut service = ExtensionService::new(name, &def.version);
128 service = service.description(&def.description);
129 service = service.transport(transport);
130
131 for cap in &def.capabilities {
132 service = service.capability(super::service::Capability::new(cap));
133 }
134
135 Ok(service)
136 }
137
138 pub fn service_names(&self) -> Vec<&String> {
140 self.services.keys().collect()
141 }
142}
143
144impl Default for MatrixRpcConfig {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct GlobalConfig {
153 #[serde(default = "default_timeout_secs")]
155 pub timeout_secs: u64,
156
157 #[serde(default = "default_heartbeat_interval")]
159 pub heartbeat_interval_secs: u64,
160
161 #[serde(default = "default_max_retries")]
163 pub max_retries: u32,
164
165 #[serde(default)]
167 pub debug: bool,
168
169 #[serde(default = "default_log_level")]
171 pub log_level: String,
172}
173
174fn default_timeout_secs() -> u64 {
175 30
176}
177
178fn default_heartbeat_interval() -> u64 {
179 30
180}
181
182fn default_max_retries() -> u32 {
183 3
184}
185
186fn default_log_level() -> String {
187 "info".to_string()
188}
189
190fn default_true() -> bool {
191 true
192}
193
194impl Default for GlobalConfig {
195 fn default() -> Self {
196 Self {
197 timeout_secs: default_timeout_secs(),
198 heartbeat_interval_secs: default_heartbeat_interval(),
199 max_retries: default_max_retries(),
200 debug: false,
201 log_level: default_log_level(),
202 }
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize, Default)]
208pub struct ServiceDefinition {
209 #[serde(default)]
211 pub version: String,
212
213 #[serde(default)]
215 pub description: String,
216
217 #[serde(rename = "type", default)]
219 pub transport_type: ServiceTransportType,
220
221 #[serde(default)]
223 pub command: Option<String>,
224
225 #[serde(default)]
227 pub args: Vec<String>,
228
229 #[serde(default)]
231 pub env: HashMap<String, String>,
232
233 #[serde(default)]
235 pub cwd: Option<String>,
236
237 #[serde(default)]
239 pub address: Option<String>,
240
241 #[serde(default)]
243 pub port: Option<u16>,
244
245 #[serde(default)]
247 pub timeout_secs: Option<u64>,
248
249 #[serde(default = "default_true")]
251 pub auto_reconnect: bool,
252
253 #[serde(default)]
255 pub max_retries: Option<u32>,
256
257 #[serde(default)]
259 pub heartbeat_interval_secs: Option<u64>,
260
261 #[serde(default)]
263 pub capabilities: Vec<String>,
264
265 #[serde(default)]
267 pub config: HashMap<String, serde_json::Value>,
268}
269
270impl ServiceDefinition {
271 pub fn stdio(command: impl Into<String>) -> Self {
273 Self {
274 version: String::new(),
275 description: String::new(),
276 transport_type: ServiceTransportType::Stdio,
277 command: Some(command.into()),
278 args: Vec::new(),
279 env: HashMap::new(),
280 cwd: None,
281 address: None,
282 port: None,
283 timeout_secs: None,
284 auto_reconnect: true,
285 max_retries: None,
286 heartbeat_interval_secs: None,
287 capabilities: Vec::new(),
288 config: HashMap::new(),
289 }
290 }
291
292 pub fn tcp(address: impl Into<String>, port: u16) -> Self {
294 Self {
295 version: String::new(),
296 description: String::new(),
297 transport_type: ServiceTransportType::Tcp,
298 command: None,
299 args: Vec::new(),
300 env: HashMap::new(),
301 cwd: None,
302 address: Some(address.into()),
303 port: Some(port),
304 timeout_secs: None,
305 auto_reconnect: true,
306 max_retries: None,
307 heartbeat_interval_secs: None,
308 capabilities: Vec::new(),
309 config: HashMap::new(),
310 }
311 }
312
313 pub fn arg(mut self, arg: impl Into<String>) -> Self {
315 self.args.push(arg.into());
316 self
317 }
318
319 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
321 self.env.insert(key.into(), value.into());
322 self
323 }
324
325 pub fn version(mut self, version: impl Into<String>) -> Self {
327 self.version = version.into();
328 self
329 }
330
331 pub fn description(mut self, desc: impl Into<String>) -> Self {
333 self.description = desc.into();
334 self
335 }
336
337 pub fn capability(mut self, cap: impl Into<String>) -> Self {
339 self.capabilities.push(cap.into());
340 self
341 }
342
343 pub fn to_transport_config(&self) -> TransportConfig {
345 TransportConfig {
346 transport_type: self.transport_type.into(),
347 address: self.address.clone(),
348 port: self.port,
349 command: self.command.clone(),
350 args: self.args.clone(),
351 env: self.env.clone(),
352 cwd: self.cwd.clone(),
353 timeout_secs: self.timeout_secs.unwrap_or(default_timeout_secs()),
354 auto_reconnect: self.auto_reconnect,
355 max_retries: self.max_retries.unwrap_or(default_max_retries()),
356 heartbeat_interval_secs: self.heartbeat_interval_secs
357 .unwrap_or(default_heartbeat_interval()),
358 }
359 }
360}
361
362#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
364#[serde(rename_all = "lowercase")]
365pub enum ServiceTransportType {
366 Stdio,
368 Tcp,
370 WebSocket,
372 #[cfg(unix)]
374 Unix,
375}
376
377impl Default for ServiceTransportType {
378 fn default() -> Self {
379 Self::Stdio
380 }
381}
382
383impl From<ServiceTransportType> for TransportType {
384 fn from(value: ServiceTransportType) -> Self {
385 match value {
386 ServiceTransportType::Stdio => TransportType::Stdio,
387 ServiceTransportType::Tcp => TransportType::Tcp,
388 ServiceTransportType::WebSocket => TransportType::WebSocket,
389 #[cfg(unix)]
390 ServiceTransportType::Unix => TransportType::Unix,
391 }
392 }
393}
394
395impl From<TransportType> for ServiceTransportType {
396 fn from(value: TransportType) -> Self {
397 match value {
398 TransportType::Stdio => ServiceTransportType::Stdio,
399 TransportType::Tcp => ServiceTransportType::Tcp,
400 TransportType::WebSocket => ServiceTransportType::WebSocket,
401 #[cfg(unix)]
402 TransportType::Unix => ServiceTransportType::Unix,
403 }
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_default_config() {
413 let config = MatrixRpcConfig::new();
414 assert!(config.services.is_empty());
415 assert_eq!(config.global.timeout_secs, 30);
416 }
417
418 #[test]
419 fn test_service_definition_stdio() {
420 let def = ServiceDefinition::stdio("my-server")
421 .arg("--port")
422 .arg("8080")
423 .env("DEBUG", "1")
424 .version("1.0.0");
425
426 assert_eq!(def.command, Some("my-server".to_string()));
427 assert_eq!(def.args, vec!["--port", "8080"]);
428 assert_eq!(def.env.get("DEBUG"), Some(&"1".to_string()));
429 assert_eq!(def.version, "1.0.0");
430 }
431
432 #[test]
433 fn test_service_definition_tcp() {
434 let def = ServiceDefinition::tcp("localhost", 8080)
435 .version("2.0.0")
436 .capability("tools");
437
438 assert_eq!(def.address, Some("localhost".to_string()));
439 assert_eq!(def.port, Some(8080));
440 assert_eq!(def.version, "2.0.0");
441 assert!(def.capabilities.contains(&"tools".to_string()));
442 }
443
444 #[test]
445 fn test_parse_toml() {
446 let toml = r#"
447[global]
448timeout_secs = 60
449debug = true
450
451[services.my-server]
452version = "1.0.0"
453description = "My test server"
454type = "stdio"
455command = "my-server"
456args = ["--verbose"]
457
458[services.my-server.env]
459DEBUG = "1"
460
461[services.tcp-server]
462type = "tcp"
463address = "127.0.0.1"
464port = 9000
465"#;
466
467 let config: MatrixRpcConfig = toml::from_str(toml).unwrap();
468 assert_eq!(config.global.timeout_secs, 60);
469 assert!(config.global.debug);
470 assert!(config.services.contains_key("my-server"));
471 assert!(config.services.contains_key("tcp-server"));
472
473 let my_server = &config.services["my-server"];
474 assert_eq!(my_server.version, "1.0.0");
475 assert_eq!(my_server.command, Some("my-server".to_string()));
476 assert_eq!(my_server.env.get("DEBUG"), Some(&"1".to_string()));
477
478 let tcp_server = &config.services["tcp-server"];
479 assert_eq!(tcp_server.address, Some("127.0.0.1".to_string()));
480 assert_eq!(tcp_server.port, Some(9000));
481 }
482
483 #[test]
484 fn test_config_validation() {
485 let mut config = MatrixRpcConfig::new();
486
487 config.add_service("valid", ServiceDefinition::stdio("server"));
489 assert!(config.validate().is_ok());
490
491 config.services.insert(
493 "invalid".to_string(),
494 ServiceDefinition {
495 command: None,
496 address: None,
497 ..Default::default()
498 },
499 );
500 assert!(config.validate().is_err());
501 }
502
503 #[test]
504 fn test_create_service() {
505 let mut config = MatrixRpcConfig::new();
506 config.add_service(
507 "test",
508 ServiceDefinition::stdio("test-server")
509 .version("1.0.0")
510 .capability("tools"),
511 );
512
513 let service = config.create_service("test").unwrap();
514 assert_eq!(service.name, "test");
515 assert_eq!(service.version, "1.0.0");
516 assert!(service.has_capability("tools"));
517 }
518}