Skip to main content

org_mcp_server/
config.rs

1use config::{Config as ConfigRs, ConfigError};
2use org_core::{
3    LoggingConfig, OrgConfig, OrgModeError,
4    config::{build_config_with_file_and_env, load_logging_config, load_org_config},
5};
6use serde::{Deserialize, Serialize};
7use std::path::PathBuf;
8
9/// MCP server-specific configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ServerConfig {
12    #[serde(default = "default_max_connections")]
13    pub max_connections: usize,
14}
15
16impl Default for ServerConfig {
17    fn default() -> Self {
18        Self {
19            max_connections: default_max_connections(),
20        }
21    }
22}
23
24/// Complete MCP server application configuration
25#[derive(Debug, Clone)]
26pub struct ServerAppConfig {
27    pub org: OrgConfig,
28    pub server: ServerConfig,
29    pub logging: LoggingConfig,
30}
31
32impl ServerAppConfig {
33    /// Load server configuration from file and environment with CLI argument overrides
34    pub fn load(
35        config_file: Option<String>,
36        root_directory: Option<String>,
37        log_level: Option<String>,
38    ) -> Result<Self, OrgModeError> {
39        let org = load_org_config(config_file.as_deref(), root_directory.as_deref())?;
40        let server = Self::load_server_config(config_file.as_deref())?;
41        let logging = load_logging_config(config_file.as_deref(), log_level.as_deref())?;
42
43        Ok(Self {
44            org,
45            server,
46            logging,
47        })
48    }
49
50    fn load_server_config(config_file: Option<&str>) -> Result<ServerConfig, OrgModeError> {
51        let builder = ConfigRs::builder().set_default(
52            "server.max_connections",
53            default_max_connections().to_string(),
54        )?;
55
56        let config = build_config_with_file_and_env(config_file, builder)?;
57
58        let server_config: ServerConfig = config.get("server").map_err(|e: ConfigError| {
59            OrgModeError::ConfigError(format!("Failed to deserialize server config: {e}"))
60        })?;
61
62        Ok(server_config)
63    }
64
65    /// Save the configuration to a file
66    pub fn save_to_file(&self, path: &PathBuf) -> Result<(), OrgModeError> {
67        #[derive(Serialize)]
68        struct SavedConfig<'a> {
69            org: &'a OrgConfig,
70            server: &'a ServerConfig,
71            logging: &'a LoggingConfig,
72        }
73
74        if let Some(parent) = path.parent() {
75            std::fs::create_dir_all(parent).map_err(|e| {
76                OrgModeError::ConfigError(format!("Failed to create config directory: {e}"))
77            })?;
78        }
79
80        let saved = SavedConfig {
81            org: &self.org,
82            server: &self.server,
83            logging: &self.logging,
84        };
85
86        let content = toml::to_string_pretty(&saved)
87            .map_err(|e| OrgModeError::ConfigError(format!("Failed to serialize config: {e}")))?;
88
89        std::fs::write(path, content)
90            .map_err(|e| OrgModeError::ConfigError(format!("Failed to write config file: {e}")))?;
91
92        Ok(())
93    }
94}
95
96fn default_max_connections() -> usize {
97    10
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use tempfile::tempdir;
104
105    #[test]
106    fn test_default_server_config() {
107        let config = ServerConfig::default();
108        assert_eq!(config.max_connections, 10);
109    }
110
111    #[test]
112    #[serial_test::serial]
113    fn test_load_from_file() {
114        let temp_dir = tempdir().unwrap();
115        let config_path = temp_dir.path().join("config.toml");
116
117        let path_str = temp_dir.path().to_str().unwrap().replace('\\', "/");
118        let test_config = format!(
119            r#"
120[org]
121org_directory = "{}"
122
123[server]
124max_connections = 20
125
126[logging]
127level = "debug"
128"#,
129            path_str
130        );
131
132        std::fs::write(&config_path, test_config).unwrap();
133
134        let config =
135            ServerAppConfig::load(Some(config_path.to_str().unwrap().to_string()), None, None)
136                .unwrap();
137
138        assert_eq!(config.org.org_directory, path_str);
139        assert_eq!(config.server.max_connections, 20);
140        assert_eq!(config.logging.level, "debug");
141    }
142
143    #[test]
144    #[serial_test::serial]
145    fn test_cli_overrides() {
146        let temp_dir = tempdir().unwrap();
147        let config_path = temp_dir.path().join("config.toml");
148
149        let test_config = format!(
150            r#"
151[org]
152org_directory = "{}"
153
154[server]
155max_connections = 5
156"#,
157            temp_dir.path().to_str().unwrap().replace('\\', "/")
158        );
159
160        std::fs::write(&config_path, test_config).unwrap();
161
162        let override_dir = tempdir().unwrap();
163        let config = ServerAppConfig::load(
164            Some(config_path.to_str().unwrap().to_string()),
165            Some(override_dir.path().to_str().unwrap().to_string()),
166            Some("trace".to_string()),
167        )
168        .unwrap();
169
170        assert_eq!(
171            config.org.org_directory,
172            override_dir.path().to_str().unwrap()
173        );
174        assert_eq!(config.logging.level, "trace");
175    }
176
177    #[test]
178    fn test_save_to_file() {
179        let temp_dir = tempdir().unwrap();
180        let save_path = temp_dir.path().join("saved_config.toml");
181
182        let config_dir = tempdir().unwrap();
183        let config = ServerAppConfig {
184            org: OrgConfig {
185                org_directory: config_dir.path().to_str().unwrap().to_string(),
186                org_default_notes_file: "test.org".to_string(),
187                org_agenda_files: vec!["agenda.org".to_string()],
188                org_agenda_text_search_extra_files: vec![],
189                org_todo_keywords: vec!["TODO".to_string(), "|".to_string(), "DONE".to_string()],
190            },
191            server: ServerConfig {
192                max_connections: 25,
193            },
194            logging: LoggingConfig {
195                level: "warn".to_string(),
196                file: "/tmp/server.log".to_string(),
197            },
198        };
199
200        let result = config.save_to_file(&save_path);
201        assert!(result.is_ok());
202
203        assert!(save_path.exists());
204
205        let content = std::fs::read_to_string(&save_path).unwrap();
206        assert!(content.contains("max_connections = 25"));
207        assert!(content.contains("level = \"warn\""));
208        assert!(content.contains("org_default_notes_file = \"test.org\""));
209    }
210
211    #[test]
212    fn test_save_to_file_creates_parent_directory() {
213        let temp_dir = tempdir().unwrap();
214        let nested_path = temp_dir
215            .path()
216            .join("nested")
217            .join("dirs")
218            .join("config.toml");
219
220        let config_dir = tempdir().unwrap();
221        let config = ServerAppConfig {
222            org: OrgConfig {
223                org_directory: config_dir.path().to_str().unwrap().to_string(),
224                ..OrgConfig::default()
225            },
226            server: ServerConfig::default(),
227            logging: LoggingConfig::default(),
228        };
229
230        let result = config.save_to_file(&nested_path);
231        assert!(result.is_ok());
232        assert!(nested_path.exists());
233        assert!(nested_path.parent().unwrap().exists());
234    }
235
236    #[test]
237    #[serial_test::serial]
238    #[cfg_attr(
239        target_os = "windows",
240        ignore = "Environment variable handling unreliable in Windows tests"
241    )]
242    fn test_env_var_server_override() {
243        use temp_env::with_vars;
244
245        let temp_dir = tempdir().unwrap();
246        let temp_dir_path = temp_dir.path().to_str().unwrap();
247
248        with_vars(
249            [
250                ("ORG_ORG__ORG_DIRECTORY", Some(temp_dir_path)),
251                ("ORG_SERVER__MAX_CONNECTIONS", Some("50")),
252            ],
253            || {
254                let config = ServerAppConfig::load(None, None, None).unwrap();
255                assert_eq!(config.org.org_directory, temp_dir_path);
256                assert_eq!(config.server.max_connections, 50);
257            },
258        );
259    }
260
261    #[test]
262    #[serial_test::serial]
263    fn test_load_server_config_extension_fallback() {
264        let temp_dir = tempdir().unwrap();
265        let config_dir = temp_dir.path().join(".config");
266        std::fs::create_dir_all(&config_dir).unwrap();
267
268        let yaml_config = r#"
269server:
270  max_connections: 15
271org:
272  org_directory: "/tmp"
273logging:
274  level: "info"
275"#;
276
277        let yaml_path = config_dir.join("config.yaml");
278        std::fs::write(&yaml_path, yaml_config).unwrap();
279
280        let org_dir = tempdir().unwrap();
281        let config = ServerAppConfig::load(
282            Some(config_dir.join("config").to_str().unwrap().to_string()),
283            Some(org_dir.path().to_str().unwrap().to_string()),
284            None,
285        );
286
287        assert!(config.is_ok());
288        let config = config.unwrap();
289        assert_eq!(config.server.max_connections, 15);
290    }
291}