just_mcp_lib/
environment.rs

1use dotenvy;
2use snafu::prelude::*;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6/// MCP-specific environment variables that may be relevant for server operation
7pub const MCP_ENVIRONMENT_VARIABLES: &[&str] = &[
8    "MCP_SERVER_NAME",
9    "MCP_SERVER_VERSION",
10    "MCP_LOG_LEVEL",
11    "MCP_CONFIG_PATH",
12    "MCP_DATA_DIR",
13    "MCP_TEMP_DIR",
14    "MCP_MAX_MESSAGE_SIZE",
15    "MCP_TIMEOUT_SECONDS",
16];
17
18#[derive(Debug, Clone)]
19pub struct McpEnvironment {
20    pub variables: HashMap<String, String>,
21    pub sources: Vec<EnvironmentSource>,
22    pub snapshot: Option<HashMap<String, String>>,
23}
24
25#[derive(Debug, Clone)]
26pub enum EnvironmentSource {
27    EnvFile(PathBuf),
28    ProcessEnv,
29    ServerConfig(String),
30    Custom(HashMap<String, String>),
31}
32
33#[derive(Debug, Snafu)]
34pub enum EnvironmentError {
35    #[snafu(display("Failed to load .env file {}: {}", path.display(), source))]
36    EnvFileLoad {
37        path: PathBuf,
38        source: dotenvy::Error,
39    },
40
41    #[snafu(display("Missing required MCP environment variable: {}", var_name))]
42    MissingMcpVariable { var_name: String },
43
44    #[snafu(display("Invalid MCP environment configuration: {}", message))]
45    InvalidMcpConfig { message: String },
46
47    #[snafu(display("MCP environment validation failed: {}", message))]
48    McpValidationFailed { message: String },
49
50    #[snafu(display("Environment snapshot error: {}", message))]
51    SnapshotError { message: String },
52}
53
54pub type Result<T> = std::result::Result<T, EnvironmentError>;
55
56impl McpEnvironment {
57    pub fn new() -> Self {
58        McpEnvironment {
59            variables: HashMap::new(),
60            sources: Vec::new(),
61            snapshot: None,
62        }
63    }
64
65    pub fn with_process_env() -> Self {
66        let mut env = McpEnvironment::new();
67        env.load_process_env();
68        env
69    }
70
71    pub fn load_process_env(&mut self) {
72        for (key, value) in std::env::vars() {
73            self.variables.insert(key, value);
74        }
75        self.sources.push(EnvironmentSource::ProcessEnv);
76    }
77
78    pub fn load_env_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
79        let path = path.as_ref();
80
81        // Use dotenvy to load the .env file
82        match dotenvy::from_path(path) {
83            Ok(_) => {
84                // Reload process environment to pick up the new variables
85                self.load_process_env();
86                self.sources
87                    .push(EnvironmentSource::EnvFile(path.to_path_buf()));
88                Ok(())
89            }
90            Err(e) => Err(EnvironmentError::EnvFileLoad {
91                path: path.to_path_buf(),
92                source: e,
93            }),
94        }
95    }
96
97    pub fn set_server_config(&mut self, config_name: String, vars: HashMap<String, String>) {
98        for (key, value) in &vars {
99            self.variables.insert(key.clone(), value.clone());
100        }
101        self.sources
102            .push(EnvironmentSource::ServerConfig(config_name));
103    }
104
105    pub fn set_custom(&mut self, vars: HashMap<String, String>) {
106        for (key, value) in &vars {
107            self.variables.insert(key.clone(), value.clone());
108        }
109        self.sources.push(EnvironmentSource::Custom(vars));
110    }
111
112    pub fn get(&self, key: &str) -> Option<&String> {
113        self.variables.get(key)
114    }
115
116    pub fn set(&mut self, key: String, value: String) {
117        self.variables.insert(key, value);
118    }
119
120    /// Create a snapshot of the current environment state
121    pub fn create_snapshot(&mut self) {
122        self.snapshot = Some(self.variables.clone());
123    }
124
125    /// Restore environment from snapshot
126    pub fn restore_from_snapshot(&mut self) -> Result<()> {
127        match &self.snapshot {
128            Some(snapshot) => {
129                self.variables = snapshot.clone();
130                Ok(())
131            }
132            None => Err(EnvironmentError::SnapshotError {
133                message: "No snapshot available to restore from".to_string(),
134            }),
135        }
136    }
137
138    /// Clear the snapshot
139    pub fn clear_snapshot(&mut self) {
140        self.snapshot = None;
141    }
142
143    /// Get environment info for MCP introspection
144    pub fn get_environment_info(&self) -> HashMap<String, String> {
145        let mut info = HashMap::new();
146
147        // Add source information
148        info.insert("source_count".to_string(), self.sources.len().to_string());
149        info.insert(
150            "variable_count".to_string(),
151            self.variables.len().to_string(),
152        );
153        info.insert(
154            "has_snapshot".to_string(),
155            self.snapshot.is_some().to_string(),
156        );
157
158        // Add source types
159        let source_types: Vec<String> = self
160            .sources
161            .iter()
162            .map(|s| match s {
163                EnvironmentSource::ProcessEnv => "ProcessEnv".to_string(),
164                EnvironmentSource::EnvFile(path) => format!("EnvFile({})", path.display()),
165                EnvironmentSource::ServerConfig(name) => format!("ServerConfig({name})"),
166                EnvironmentSource::Custom(_) => "Custom".to_string(),
167            })
168            .collect();
169        info.insert("sources".to_string(), source_types.join(", "));
170
171        // Add MCP-specific variables if present
172        for mcp_var in MCP_ENVIRONMENT_VARIABLES {
173            if let Some(value) = self.variables.get(*mcp_var) {
174                info.insert(format!("mcp_{}", mcp_var.to_lowercase()), value.clone());
175            }
176        }
177
178        info
179    }
180
181    pub fn expand_variables(&self, text: &str) -> Result<String> {
182        let mut result = text.to_string();
183
184        // Handle ${VAR} and $VAR syntax
185        let mut changed = true;
186        let mut iterations = 0;
187        const MAX_ITERATIONS: usize = 10; // Prevent infinite loops
188
189        while changed && iterations < MAX_ITERATIONS {
190            changed = false;
191            iterations += 1;
192
193            // Handle ${VAR} syntax
194            while let Some(start) = result.find("${") {
195                if let Some(end) = result[start..].find('}') {
196                    let var_name = &result[start + 2..start + end];
197                    let replacement = self.variables.get(var_name).cloned().unwrap_or_else(|| {
198                        // Check system environment as fallback
199                        std::env::var(var_name).unwrap_or_default()
200                    });
201
202                    result.replace_range(start..start + end + 1, &replacement);
203                    changed = true;
204                } else {
205                    break;
206                }
207            }
208
209            // Handle $VAR syntax (simple variable names)
210            let mut pos = 0;
211            while let Some(dollar_pos) = result[pos..].find('$') {
212                let abs_pos = pos + dollar_pos;
213
214                // Skip if it's ${VAR} syntax (already handled above)
215                if result.chars().nth(abs_pos + 1) == Some('{') {
216                    pos = abs_pos + 1;
217                    continue;
218                }
219
220                // Extract variable name (alphanumeric + underscore)
221                let var_start = abs_pos + 1;
222                let var_end = result[var_start..]
223                    .chars()
224                    .take_while(|c| c.is_alphanumeric() || *c == '_')
225                    .count()
226                    + var_start;
227
228                if var_end > var_start {
229                    let var_name = &result[var_start..var_end];
230                    let replacement = self.variables.get(var_name).cloned().unwrap_or_else(|| {
231                        // Check system environment as fallback
232                        std::env::var(var_name).unwrap_or_default()
233                    });
234
235                    result.replace_range(abs_pos..var_end, &replacement);
236                    changed = true;
237                    pos = abs_pos + replacement.len();
238                } else {
239                    pos = abs_pos + 1;
240                }
241            }
242        }
243
244        if iterations >= MAX_ITERATIONS {
245            return Err(EnvironmentError::InvalidMcpConfig {
246                message: "Too many variable expansion iterations - possible circular reference"
247                    .to_string(),
248            });
249        }
250
251        Ok(result)
252    }
253}
254
255impl Default for McpEnvironment {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261/// Load MCP environment from multiple sources
262pub fn load_mcp_environment(sources: &[EnvironmentSource]) -> Result<McpEnvironment> {
263    let mut env = McpEnvironment::new();
264
265    for source in sources {
266        match source {
267            EnvironmentSource::ProcessEnv => {
268                env.load_process_env();
269            }
270            EnvironmentSource::EnvFile(path) => {
271                env.load_env_file(path)?;
272            }
273            EnvironmentSource::Custom(vars) => {
274                env.set_custom(vars.clone());
275            }
276            EnvironmentSource::ServerConfig(config_name) => {
277                // Load server-specific configuration
278                // This could be expanded to load from config files
279                let mut config_vars = HashMap::new();
280                config_vars.insert("MCP_SERVER_CONFIG".to_string(), config_name.clone());
281                env.set_server_config(config_name.clone(), config_vars);
282            }
283        }
284    }
285
286    Ok(env)
287}
288
289/// Validate MCP environment has required variables
290pub fn validate_mcp_environment(environment: &McpEnvironment, requirements: &[&str]) -> Result<()> {
291    let missing_vars: Vec<&str> = requirements
292        .iter()
293        .filter(|&var| environment.get(var).is_none())
294        .copied()
295        .collect();
296
297    if !missing_vars.is_empty() {
298        return Err(EnvironmentError::McpValidationFailed {
299            message: format!("Missing required variables: {}", missing_vars.join(", ")),
300        });
301    }
302
303    Ok(())
304}
305
306/// Get environment info for MCP introspection
307pub fn get_environment_info() -> HashMap<String, String> {
308    let env = McpEnvironment::with_process_env();
309    env.get_environment_info()
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_mcp_environment_new() {
318        let env = McpEnvironment::new();
319        assert!(env.variables.is_empty());
320        assert!(env.sources.is_empty());
321        assert!(env.snapshot.is_none());
322    }
323
324    #[test]
325    fn test_mcp_environment_set_get() {
326        let mut env = McpEnvironment::new();
327        env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
328
329        assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"just-mcp".to_string()));
330        assert_eq!(env.get("NONEXISTENT"), None);
331    }
332
333    #[test]
334    fn test_mcp_environment_snapshot() {
335        let mut env = McpEnvironment::new();
336        env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
337
338        // Create snapshot
339        env.create_snapshot();
340        assert!(env.snapshot.is_some());
341
342        // Modify environment
343        env.set("MCP_SERVER_NAME".to_string(), "modified".to_string());
344        assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"modified".to_string()));
345
346        // Restore from snapshot
347        env.restore_from_snapshot().unwrap();
348        assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"just-mcp".to_string()));
349
350        // Clear snapshot
351        env.clear_snapshot();
352        assert!(env.snapshot.is_none());
353
354        // Try to restore without snapshot should fail
355        assert!(env.restore_from_snapshot().is_err());
356    }
357
358    #[test]
359    fn test_mcp_environment_server_config() {
360        let mut env = McpEnvironment::new();
361
362        let mut config_vars = HashMap::new();
363        config_vars.insert("MCP_LOG_LEVEL".to_string(), "debug".to_string());
364        config_vars.insert("MCP_TIMEOUT_SECONDS".to_string(), "30".to_string());
365
366        env.set_server_config("production".to_string(), config_vars);
367
368        assert_eq!(env.get("MCP_LOG_LEVEL"), Some(&"debug".to_string()));
369        assert_eq!(env.get("MCP_TIMEOUT_SECONDS"), Some(&"30".to_string()));
370        assert_eq!(env.sources.len(), 1);
371    }
372
373    #[test]
374    fn test_mcp_environment_info() {
375        let mut env = McpEnvironment::new();
376        env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
377        env.set("MCP_LOG_LEVEL".to_string(), "info".to_string());
378
379        let mut custom_vars = HashMap::new();
380        custom_vars.insert("CUSTOM_VAR".to_string(), "custom_value".to_string());
381        env.set_custom(custom_vars);
382
383        let info = env.get_environment_info();
384
385        assert_eq!(info.get("variable_count"), Some(&"3".to_string()));
386        assert_eq!(info.get("source_count"), Some(&"1".to_string()));
387        assert_eq!(info.get("has_snapshot"), Some(&"false".to_string()));
388        assert_eq!(
389            info.get("mcp_mcp_server_name"),
390            Some(&"just-mcp".to_string())
391        );
392        assert_eq!(info.get("mcp_mcp_log_level"), Some(&"info".to_string()));
393    }
394
395    #[test]
396    fn test_mcp_variable_expansion() {
397        let mut env = McpEnvironment::new();
398        env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
399        env.set("MCP_LOG_LEVEL".to_string(), "debug".to_string());
400
401        let result = env
402            .expand_variables("Server: ${MCP_SERVER_NAME} (${MCP_LOG_LEVEL})")
403            .unwrap();
404        assert_eq!(result, "Server: just-mcp (debug)");
405
406        let result = env.expand_variables("$MCP_SERVER_NAME running").unwrap();
407        assert_eq!(result, "just-mcp running");
408    }
409
410    #[test]
411    fn test_validate_mcp_environment() {
412        let mut env = McpEnvironment::new();
413        env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
414        env.set("MCP_LOG_LEVEL".to_string(), "info".to_string());
415
416        // Should pass with required variables present
417        let result = validate_mcp_environment(&env, &["MCP_SERVER_NAME", "MCP_LOG_LEVEL"]);
418        assert!(result.is_ok());
419
420        // Should fail with missing required variable
421        let result = validate_mcp_environment(&env, &["MCP_SERVER_NAME", "MCP_MISSING_VAR"]);
422        assert!(result.is_err());
423        assert!(result.unwrap_err().to_string().contains("MCP_MISSING_VAR"));
424    }
425
426    #[test]
427    fn test_load_mcp_environment_multiple_sources() {
428        let mut custom_vars = HashMap::new();
429        custom_vars.insert("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
430        custom_vars.insert("CUSTOM_VAR".to_string(), "custom_value".to_string());
431
432        let sources = vec![
433            EnvironmentSource::ProcessEnv,
434            EnvironmentSource::Custom(custom_vars),
435            EnvironmentSource::ServerConfig("production".to_string()),
436        ];
437
438        let env = load_mcp_environment(&sources).unwrap();
439
440        assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"just-mcp".to_string()));
441        assert_eq!(env.get("CUSTOM_VAR"), Some(&"custom_value".to_string()));
442        assert_eq!(
443            env.get("MCP_SERVER_CONFIG"),
444            Some(&"production".to_string())
445        );
446        assert_eq!(env.sources.len(), 3);
447    }
448
449    #[test]
450    fn test_get_environment_info_function() {
451        // This test will depend on the actual process environment
452        let info = get_environment_info();
453
454        assert!(info.contains_key("source_count"));
455        assert!(info.contains_key("variable_count"));
456        assert!(info.contains_key("has_snapshot"));
457        assert!(info.contains_key("sources"));
458    }
459}