Skip to main content

matrixcode_core/matrixrpc/
config.rs

1//! MatrixRPC Configuration
2//!
3//! Handles loading and parsing of matrixrpc.toml configuration files.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use thiserror::Error;
9
10use super::service::{ExtensionService, TransportConfig, TransportType};
11
12/// Configuration file name
13pub const CONFIG_FILE_NAME: &str = "matrixrpc.toml";
14
15/// Configuration error
16#[derive(Debug, Error)]
17pub enum ConfigError {
18    /// IO error
19    #[error("IO error: {0}")]
20    Io(#[from] std::io::Error),
21
22    /// TOML parse error
23    #[error("TOML parse error: {0}")]
24    Toml(#[from] toml::de::Error),
25
26    /// Validation error
27    #[error("Configuration validation error: {0}")]
28    Validation(String),
29
30    /// Service not found
31    #[error("Service '{0}' not found in configuration")]
32    ServiceNotFound(String),
33}
34
35/// Root configuration structure for matrixrpc.toml
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct MatrixRpcConfig {
38    /// Global settings
39    #[serde(default)]
40    pub global: GlobalConfig,
41
42    /// Service definitions
43    #[serde(default)]
44    pub services: HashMap<String, ServiceDefinition>,
45}
46
47impl MatrixRpcConfig {
48    /// Create an empty configuration
49    pub fn new() -> Self {
50        Self {
51            global: GlobalConfig::default(),
52            services: HashMap::new(),
53        }
54    }
55
56    /// Load configuration from a file
57    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
58        let content = std::fs::read_to_string(path.as_ref())?;
59        let config: Self = toml::from_str(&content)?;
60        config.validate()?;
61        Ok(config)
62    }
63
64    /// Load configuration from the default location
65    pub fn load_default() -> Result<Self, ConfigError> {
66        // Try in order:
67        // 1. ./matrixrpc.toml
68        // 2. ./.matrix/matrixrpc.toml
69        // 3. ~/.matrix/matrixrpc.toml
70        let candidates = vec![
71            PathBuf::from(CONFIG_FILE_NAME),
72            PathBuf::from(".matrix").join(CONFIG_FILE_NAME),
73            dirs::config_dir()
74                .map(|p| p.join("matrix").join(CONFIG_FILE_NAME))
75                .unwrap_or_default(),
76        ];
77
78        for path in candidates {
79            if path.exists() {
80                return Self::load(&path);
81            }
82        }
83
84        // Return empty config if no file found
85        Ok(Self::new())
86    }
87
88    /// Save configuration to a file
89    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), ConfigError> {
90        let content = toml::to_string_pretty(self)
91            .map_err(|e| ConfigError::Validation(e.to_string()))?;
92        std::fs::write(path.as_ref(), content)?;
93        Ok(())
94    }
95
96    /// Validate the configuration
97    pub fn validate(&self) -> Result<(), ConfigError> {
98        for (name, service) in &self.services {
99            if service.command.is_none() && service.address.is_none() {
100                return Err(ConfigError::Validation(format!(
101                    "Service '{}' must have either 'command' or 'address' configured",
102                    name
103                )));
104            }
105        }
106        Ok(())
107    }
108
109    /// Get a service definition by name
110    pub fn get_service(&self, name: &str) -> Option<&ServiceDefinition> {
111        self.services.get(name)
112    }
113
114    /// Add a service definition
115    pub fn add_service(&mut self, name: impl Into<String>, service: ServiceDefinition) {
116        self.services.insert(name.into(), service);
117    }
118
119    /// Convert service definition to ExtensionService
120    pub fn create_service(&self, name: &str) -> Result<ExtensionService, ConfigError> {
121        let def = self
122            .get_service(name)
123            .ok_or_else(|| ConfigError::ServiceNotFound(name.to_string()))?;
124
125        let transport = def.to_transport_config();
126
127        let mut service = ExtensionService::new(name, &def.version);
128        service = service.description(&def.description);
129        service = service.transport(transport);
130
131        for cap in &def.capabilities {
132            service = service.capability(super::service::Capability::new(cap));
133        }
134
135        Ok(service)
136    }
137
138    /// List all service names
139    pub fn service_names(&self) -> Vec<&String> {
140        self.services.keys().collect()
141    }
142}
143
144impl Default for MatrixRpcConfig {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150/// Global configuration settings
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct GlobalConfig {
153    /// Default connection timeout in seconds
154    #[serde(default = "default_timeout_secs")]
155    pub timeout_secs: u64,
156
157    /// Default heartbeat interval in seconds
158    #[serde(default = "default_heartbeat_interval")]
159    pub heartbeat_interval_secs: u64,
160
161    /// Maximum retry attempts for connections
162    #[serde(default = "default_max_retries")]
163    pub max_retries: u32,
164
165    /// Enable debug logging
166    #[serde(default)]
167    pub debug: bool,
168
169    /// Log level (trace, debug, info, warn, error)
170    #[serde(default = "default_log_level")]
171    pub log_level: String,
172}
173
174fn default_timeout_secs() -> u64 {
175    30
176}
177
178fn default_heartbeat_interval() -> u64 {
179    30
180}
181
182fn default_max_retries() -> u32 {
183    3
184}
185
186fn default_log_level() -> String {
187    "info".to_string()
188}
189
190fn default_true() -> bool {
191    true
192}
193
194impl Default for GlobalConfig {
195    fn default() -> Self {
196        Self {
197            timeout_secs: default_timeout_secs(),
198            heartbeat_interval_secs: default_heartbeat_interval(),
199            max_retries: default_max_retries(),
200            debug: false,
201            log_level: default_log_level(),
202        }
203    }
204}
205
206/// Service definition in configuration
207#[derive(Debug, Clone, Serialize, Deserialize, Default)]
208pub struct ServiceDefinition {
209    /// Service version
210    #[serde(default)]
211    pub version: String,
212
213    /// Service description
214    #[serde(default)]
215    pub description: String,
216
217    /// Transport type
218    #[serde(rename = "type", default)]
219    pub transport_type: ServiceTransportType,
220
221    /// Command to execute (for stdio transport)
222    #[serde(default)]
223    pub command: Option<String>,
224
225    /// Command arguments
226    #[serde(default)]
227    pub args: Vec<String>,
228
229    /// Environment variables
230    #[serde(default)]
231    pub env: HashMap<String, String>,
232
233    /// Working directory
234    #[serde(default)]
235    pub cwd: Option<String>,
236
237    /// Network address (for TCP/WebSocket transport)
238    #[serde(default)]
239    pub address: Option<String>,
240
241    /// Port number (for TCP transport)
242    #[serde(default)]
243    pub port: Option<u16>,
244
245    /// Connection timeout in seconds
246    #[serde(default)]
247    pub timeout_secs: Option<u64>,
248
249    /// Enable auto-reconnect
250    #[serde(default = "default_true")]
251    pub auto_reconnect: bool,
252
253    /// Maximum retry attempts
254    #[serde(default)]
255    pub max_retries: Option<u32>,
256
257    /// Heartbeat interval in seconds
258    #[serde(default)]
259    pub heartbeat_interval_secs: Option<u64>,
260
261    /// Capabilities provided by this service
262    #[serde(default)]
263    pub capabilities: Vec<String>,
264
265    /// Service-specific configuration
266    #[serde(default)]
267    pub config: HashMap<String, serde_json::Value>,
268}
269
270impl ServiceDefinition {
271    /// Create a new stdio service definition
272    pub fn stdio(command: impl Into<String>) -> Self {
273        Self {
274            version: String::new(),
275            description: String::new(),
276            transport_type: ServiceTransportType::Stdio,
277            command: Some(command.into()),
278            args: Vec::new(),
279            env: HashMap::new(),
280            cwd: None,
281            address: None,
282            port: None,
283            timeout_secs: None,
284            auto_reconnect: true,
285            max_retries: None,
286            heartbeat_interval_secs: None,
287            capabilities: Vec::new(),
288            config: HashMap::new(),
289        }
290    }
291
292    /// Create a new TCP service definition
293    pub fn tcp(address: impl Into<String>, port: u16) -> Self {
294        Self {
295            version: String::new(),
296            description: String::new(),
297            transport_type: ServiceTransportType::Tcp,
298            command: None,
299            args: Vec::new(),
300            env: HashMap::new(),
301            cwd: None,
302            address: Some(address.into()),
303            port: Some(port),
304            timeout_secs: None,
305            auto_reconnect: true,
306            max_retries: None,
307            heartbeat_interval_secs: None,
308            capabilities: Vec::new(),
309            config: HashMap::new(),
310        }
311    }
312
313    /// Add a command argument
314    pub fn arg(mut self, arg: impl Into<String>) -> Self {
315        self.args.push(arg.into());
316        self
317    }
318
319    /// Add an environment variable
320    pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
321        self.env.insert(key.into(), value.into());
322        self
323    }
324
325    /// Set version
326    pub fn version(mut self, version: impl Into<String>) -> Self {
327        self.version = version.into();
328        self
329    }
330
331    /// Set description
332    pub fn description(mut self, desc: impl Into<String>) -> Self {
333        self.description = desc.into();
334        self
335    }
336
337    /// Add a capability
338    pub fn capability(mut self, cap: impl Into<String>) -> Self {
339        self.capabilities.push(cap.into());
340        self
341    }
342
343    /// Convert to TransportConfig
344    pub fn to_transport_config(&self) -> TransportConfig {
345        TransportConfig {
346            transport_type: self.transport_type.into(),
347            address: self.address.clone(),
348            port: self.port,
349            command: self.command.clone(),
350            args: self.args.clone(),
351            env: self.env.clone(),
352            cwd: self.cwd.clone(),
353            timeout_secs: self.timeout_secs.unwrap_or(default_timeout_secs()),
354            auto_reconnect: self.auto_reconnect,
355            max_retries: self.max_retries.unwrap_or(default_max_retries()),
356            heartbeat_interval_secs: self.heartbeat_interval_secs
357                .unwrap_or(default_heartbeat_interval()),
358        }
359    }
360}
361
362/// Service transport type (for TOML serialization)
363#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
364#[serde(rename_all = "lowercase")]
365pub enum ServiceTransportType {
366    /// Stdio transport
367    Stdio,
368    /// TCP transport
369    Tcp,
370    /// WebSocket transport
371    WebSocket,
372    /// Unix socket transport (Unix only)
373    #[cfg(unix)]
374    Unix,
375}
376
377impl Default for ServiceTransportType {
378    fn default() -> Self {
379        Self::Stdio
380    }
381}
382
383impl From<ServiceTransportType> for TransportType {
384    fn from(value: ServiceTransportType) -> Self {
385        match value {
386            ServiceTransportType::Stdio => TransportType::Stdio,
387            ServiceTransportType::Tcp => TransportType::Tcp,
388            ServiceTransportType::WebSocket => TransportType::WebSocket,
389            #[cfg(unix)]
390            ServiceTransportType::Unix => TransportType::Unix,
391        }
392    }
393}
394
395impl From<TransportType> for ServiceTransportType {
396    fn from(value: TransportType) -> Self {
397        match value {
398            TransportType::Stdio => ServiceTransportType::Stdio,
399            TransportType::Tcp => ServiceTransportType::Tcp,
400            TransportType::WebSocket => ServiceTransportType::WebSocket,
401            #[cfg(unix)]
402            TransportType::Unix => ServiceTransportType::Unix,
403        }
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_default_config() {
413        let config = MatrixRpcConfig::new();
414        assert!(config.services.is_empty());
415        assert_eq!(config.global.timeout_secs, 30);
416    }
417
418    #[test]
419    fn test_service_definition_stdio() {
420        let def = ServiceDefinition::stdio("my-server")
421            .arg("--port")
422            .arg("8080")
423            .env("DEBUG", "1")
424            .version("1.0.0");
425
426        assert_eq!(def.command, Some("my-server".to_string()));
427        assert_eq!(def.args, vec!["--port", "8080"]);
428        assert_eq!(def.env.get("DEBUG"), Some(&"1".to_string()));
429        assert_eq!(def.version, "1.0.0");
430    }
431
432    #[test]
433    fn test_service_definition_tcp() {
434        let def = ServiceDefinition::tcp("localhost", 8080)
435            .version("2.0.0")
436            .capability("tools");
437
438        assert_eq!(def.address, Some("localhost".to_string()));
439        assert_eq!(def.port, Some(8080));
440        assert_eq!(def.version, "2.0.0");
441        assert!(def.capabilities.contains(&"tools".to_string()));
442    }
443
444    #[test]
445    fn test_parse_toml() {
446        let toml = r#"
447[global]
448timeout_secs = 60
449debug = true
450
451[services.my-server]
452version = "1.0.0"
453description = "My test server"
454type = "stdio"
455command = "my-server"
456args = ["--verbose"]
457
458[services.my-server.env]
459DEBUG = "1"
460
461[services.tcp-server]
462type = "tcp"
463address = "127.0.0.1"
464port = 9000
465"#;
466
467        let config: MatrixRpcConfig = toml::from_str(toml).unwrap();
468        assert_eq!(config.global.timeout_secs, 60);
469        assert!(config.global.debug);
470        assert!(config.services.contains_key("my-server"));
471        assert!(config.services.contains_key("tcp-server"));
472
473        let my_server = &config.services["my-server"];
474        assert_eq!(my_server.version, "1.0.0");
475        assert_eq!(my_server.command, Some("my-server".to_string()));
476        assert_eq!(my_server.env.get("DEBUG"), Some(&"1".to_string()));
477
478        let tcp_server = &config.services["tcp-server"];
479        assert_eq!(tcp_server.address, Some("127.0.0.1".to_string()));
480        assert_eq!(tcp_server.port, Some(9000));
481    }
482
483    #[test]
484    fn test_config_validation() {
485        let mut config = MatrixRpcConfig::new();
486
487        // Valid stdio service
488        config.add_service("valid", ServiceDefinition::stdio("server"));
489        assert!(config.validate().is_ok());
490
491        // Invalid service (no command or address)
492        config.services.insert(
493            "invalid".to_string(),
494            ServiceDefinition {
495                command: None,
496                address: None,
497                ..Default::default()
498            },
499        );
500        assert!(config.validate().is_err());
501    }
502
503    #[test]
504    fn test_create_service() {
505        let mut config = MatrixRpcConfig::new();
506        config.add_service(
507            "test",
508            ServiceDefinition::stdio("test-server")
509                .version("1.0.0")
510                .capability("tools"),
511        );
512
513        let service = config.create_service("test").unwrap();
514        assert_eq!(service.name, "test");
515        assert_eq!(service.version, "1.0.0");
516        assert!(service.has_capability("tools"));
517    }
518}