1use dotenvy;
2use snafu::prelude::*;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6pub const MCP_ENVIRONMENT_VARIABLES: &[&str] = &[
8 "MCP_SERVER_NAME",
9 "MCP_SERVER_VERSION",
10 "MCP_LOG_LEVEL",
11 "MCP_CONFIG_PATH",
12 "MCP_DATA_DIR",
13 "MCP_TEMP_DIR",
14 "MCP_MAX_MESSAGE_SIZE",
15 "MCP_TIMEOUT_SECONDS",
16];
17
18#[derive(Debug, Clone)]
19pub struct McpEnvironment {
20 pub variables: HashMap<String, String>,
21 pub sources: Vec<EnvironmentSource>,
22 pub snapshot: Option<HashMap<String, String>>,
23}
24
25#[derive(Debug, Clone)]
26pub enum EnvironmentSource {
27 EnvFile(PathBuf),
28 ProcessEnv,
29 ServerConfig(String),
30 Custom(HashMap<String, String>),
31}
32
33#[derive(Debug, Snafu)]
34pub enum EnvironmentError {
35 #[snafu(display("Failed to load .env file {}: {}", path.display(), source))]
36 EnvFileLoad {
37 path: PathBuf,
38 source: dotenvy::Error,
39 },
40
41 #[snafu(display("Missing required MCP environment variable: {}", var_name))]
42 MissingMcpVariable { var_name: String },
43
44 #[snafu(display("Invalid MCP environment configuration: {}", message))]
45 InvalidMcpConfig { message: String },
46
47 #[snafu(display("MCP environment validation failed: {}", message))]
48 McpValidationFailed { message: String },
49
50 #[snafu(display("Environment snapshot error: {}", message))]
51 SnapshotError { message: String },
52}
53
54pub type Result<T> = std::result::Result<T, EnvironmentError>;
55
56impl McpEnvironment {
57 pub fn new() -> Self {
58 McpEnvironment {
59 variables: HashMap::new(),
60 sources: Vec::new(),
61 snapshot: None,
62 }
63 }
64
65 pub fn with_process_env() -> Self {
66 let mut env = McpEnvironment::new();
67 env.load_process_env();
68 env
69 }
70
71 pub fn load_process_env(&mut self) {
72 for (key, value) in std::env::vars() {
73 self.variables.insert(key, value);
74 }
75 self.sources.push(EnvironmentSource::ProcessEnv);
76 }
77
78 pub fn load_env_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
79 let path = path.as_ref();
80
81 match dotenvy::from_path(path) {
83 Ok(_) => {
84 self.load_process_env();
86 self.sources
87 .push(EnvironmentSource::EnvFile(path.to_path_buf()));
88 Ok(())
89 }
90 Err(e) => Err(EnvironmentError::EnvFileLoad {
91 path: path.to_path_buf(),
92 source: e,
93 }),
94 }
95 }
96
97 pub fn set_server_config(&mut self, config_name: String, vars: HashMap<String, String>) {
98 for (key, value) in &vars {
99 self.variables.insert(key.clone(), value.clone());
100 }
101 self.sources
102 .push(EnvironmentSource::ServerConfig(config_name));
103 }
104
105 pub fn set_custom(&mut self, vars: HashMap<String, String>) {
106 for (key, value) in &vars {
107 self.variables.insert(key.clone(), value.clone());
108 }
109 self.sources.push(EnvironmentSource::Custom(vars));
110 }
111
112 pub fn get(&self, key: &str) -> Option<&String> {
113 self.variables.get(key)
114 }
115
116 pub fn set(&mut self, key: String, value: String) {
117 self.variables.insert(key, value);
118 }
119
120 pub fn create_snapshot(&mut self) {
122 self.snapshot = Some(self.variables.clone());
123 }
124
125 pub fn restore_from_snapshot(&mut self) -> Result<()> {
127 match &self.snapshot {
128 Some(snapshot) => {
129 self.variables = snapshot.clone();
130 Ok(())
131 }
132 None => Err(EnvironmentError::SnapshotError {
133 message: "No snapshot available to restore from".to_string(),
134 }),
135 }
136 }
137
138 pub fn clear_snapshot(&mut self) {
140 self.snapshot = None;
141 }
142
143 pub fn get_environment_info(&self) -> HashMap<String, String> {
145 let mut info = HashMap::new();
146
147 info.insert("source_count".to_string(), self.sources.len().to_string());
149 info.insert(
150 "variable_count".to_string(),
151 self.variables.len().to_string(),
152 );
153 info.insert(
154 "has_snapshot".to_string(),
155 self.snapshot.is_some().to_string(),
156 );
157
158 let source_types: Vec<String> = self
160 .sources
161 .iter()
162 .map(|s| match s {
163 EnvironmentSource::ProcessEnv => "ProcessEnv".to_string(),
164 EnvironmentSource::EnvFile(path) => format!("EnvFile({})", path.display()),
165 EnvironmentSource::ServerConfig(name) => format!("ServerConfig({name})"),
166 EnvironmentSource::Custom(_) => "Custom".to_string(),
167 })
168 .collect();
169 info.insert("sources".to_string(), source_types.join(", "));
170
171 for mcp_var in MCP_ENVIRONMENT_VARIABLES {
173 if let Some(value) = self.variables.get(*mcp_var) {
174 info.insert(format!("mcp_{}", mcp_var.to_lowercase()), value.clone());
175 }
176 }
177
178 info
179 }
180
181 pub fn expand_variables(&self, text: &str) -> Result<String> {
182 let mut result = text.to_string();
183
184 let mut changed = true;
186 let mut iterations = 0;
187 const MAX_ITERATIONS: usize = 10; while changed && iterations < MAX_ITERATIONS {
190 changed = false;
191 iterations += 1;
192
193 while let Some(start) = result.find("${") {
195 if let Some(end) = result[start..].find('}') {
196 let var_name = &result[start + 2..start + end];
197 let replacement = self.variables.get(var_name).cloned().unwrap_or_else(|| {
198 std::env::var(var_name).unwrap_or_default()
200 });
201
202 result.replace_range(start..start + end + 1, &replacement);
203 changed = true;
204 } else {
205 break;
206 }
207 }
208
209 let mut pos = 0;
211 while let Some(dollar_pos) = result[pos..].find('$') {
212 let abs_pos = pos + dollar_pos;
213
214 if result.chars().nth(abs_pos + 1) == Some('{') {
216 pos = abs_pos + 1;
217 continue;
218 }
219
220 let var_start = abs_pos + 1;
222 let var_end = result[var_start..]
223 .chars()
224 .take_while(|c| c.is_alphanumeric() || *c == '_')
225 .count()
226 + var_start;
227
228 if var_end > var_start {
229 let var_name = &result[var_start..var_end];
230 let replacement = self.variables.get(var_name).cloned().unwrap_or_else(|| {
231 std::env::var(var_name).unwrap_or_default()
233 });
234
235 result.replace_range(abs_pos..var_end, &replacement);
236 changed = true;
237 pos = abs_pos + replacement.len();
238 } else {
239 pos = abs_pos + 1;
240 }
241 }
242 }
243
244 if iterations >= MAX_ITERATIONS {
245 return Err(EnvironmentError::InvalidMcpConfig {
246 message: "Too many variable expansion iterations - possible circular reference"
247 .to_string(),
248 });
249 }
250
251 Ok(result)
252 }
253}
254
255impl Default for McpEnvironment {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261pub fn load_mcp_environment(sources: &[EnvironmentSource]) -> Result<McpEnvironment> {
263 let mut env = McpEnvironment::new();
264
265 for source in sources {
266 match source {
267 EnvironmentSource::ProcessEnv => {
268 env.load_process_env();
269 }
270 EnvironmentSource::EnvFile(path) => {
271 env.load_env_file(path)?;
272 }
273 EnvironmentSource::Custom(vars) => {
274 env.set_custom(vars.clone());
275 }
276 EnvironmentSource::ServerConfig(config_name) => {
277 let mut config_vars = HashMap::new();
280 config_vars.insert("MCP_SERVER_CONFIG".to_string(), config_name.clone());
281 env.set_server_config(config_name.clone(), config_vars);
282 }
283 }
284 }
285
286 Ok(env)
287}
288
289pub fn validate_mcp_environment(environment: &McpEnvironment, requirements: &[&str]) -> Result<()> {
291 let missing_vars: Vec<&str> = requirements
292 .iter()
293 .filter(|&var| environment.get(var).is_none())
294 .copied()
295 .collect();
296
297 if !missing_vars.is_empty() {
298 return Err(EnvironmentError::McpValidationFailed {
299 message: format!("Missing required variables: {}", missing_vars.join(", ")),
300 });
301 }
302
303 Ok(())
304}
305
306pub fn get_environment_info() -> HashMap<String, String> {
308 let env = McpEnvironment::with_process_env();
309 env.get_environment_info()
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_mcp_environment_new() {
318 let env = McpEnvironment::new();
319 assert!(env.variables.is_empty());
320 assert!(env.sources.is_empty());
321 assert!(env.snapshot.is_none());
322 }
323
324 #[test]
325 fn test_mcp_environment_set_get() {
326 let mut env = McpEnvironment::new();
327 env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
328
329 assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"just-mcp".to_string()));
330 assert_eq!(env.get("NONEXISTENT"), None);
331 }
332
333 #[test]
334 fn test_mcp_environment_snapshot() {
335 let mut env = McpEnvironment::new();
336 env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
337
338 env.create_snapshot();
340 assert!(env.snapshot.is_some());
341
342 env.set("MCP_SERVER_NAME".to_string(), "modified".to_string());
344 assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"modified".to_string()));
345
346 env.restore_from_snapshot().unwrap();
348 assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"just-mcp".to_string()));
349
350 env.clear_snapshot();
352 assert!(env.snapshot.is_none());
353
354 assert!(env.restore_from_snapshot().is_err());
356 }
357
358 #[test]
359 fn test_mcp_environment_server_config() {
360 let mut env = McpEnvironment::new();
361
362 let mut config_vars = HashMap::new();
363 config_vars.insert("MCP_LOG_LEVEL".to_string(), "debug".to_string());
364 config_vars.insert("MCP_TIMEOUT_SECONDS".to_string(), "30".to_string());
365
366 env.set_server_config("production".to_string(), config_vars);
367
368 assert_eq!(env.get("MCP_LOG_LEVEL"), Some(&"debug".to_string()));
369 assert_eq!(env.get("MCP_TIMEOUT_SECONDS"), Some(&"30".to_string()));
370 assert_eq!(env.sources.len(), 1);
371 }
372
373 #[test]
374 fn test_mcp_environment_info() {
375 let mut env = McpEnvironment::new();
376 env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
377 env.set("MCP_LOG_LEVEL".to_string(), "info".to_string());
378
379 let mut custom_vars = HashMap::new();
380 custom_vars.insert("CUSTOM_VAR".to_string(), "custom_value".to_string());
381 env.set_custom(custom_vars);
382
383 let info = env.get_environment_info();
384
385 assert_eq!(info.get("variable_count"), Some(&"3".to_string()));
386 assert_eq!(info.get("source_count"), Some(&"1".to_string()));
387 assert_eq!(info.get("has_snapshot"), Some(&"false".to_string()));
388 assert_eq!(
389 info.get("mcp_mcp_server_name"),
390 Some(&"just-mcp".to_string())
391 );
392 assert_eq!(info.get("mcp_mcp_log_level"), Some(&"info".to_string()));
393 }
394
395 #[test]
396 fn test_mcp_variable_expansion() {
397 let mut env = McpEnvironment::new();
398 env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
399 env.set("MCP_LOG_LEVEL".to_string(), "debug".to_string());
400
401 let result = env
402 .expand_variables("Server: ${MCP_SERVER_NAME} (${MCP_LOG_LEVEL})")
403 .unwrap();
404 assert_eq!(result, "Server: just-mcp (debug)");
405
406 let result = env.expand_variables("$MCP_SERVER_NAME running").unwrap();
407 assert_eq!(result, "just-mcp running");
408 }
409
410 #[test]
411 fn test_validate_mcp_environment() {
412 let mut env = McpEnvironment::new();
413 env.set("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
414 env.set("MCP_LOG_LEVEL".to_string(), "info".to_string());
415
416 let result = validate_mcp_environment(&env, &["MCP_SERVER_NAME", "MCP_LOG_LEVEL"]);
418 assert!(result.is_ok());
419
420 let result = validate_mcp_environment(&env, &["MCP_SERVER_NAME", "MCP_MISSING_VAR"]);
422 assert!(result.is_err());
423 assert!(result.unwrap_err().to_string().contains("MCP_MISSING_VAR"));
424 }
425
426 #[test]
427 fn test_load_mcp_environment_multiple_sources() {
428 let mut custom_vars = HashMap::new();
429 custom_vars.insert("MCP_SERVER_NAME".to_string(), "just-mcp".to_string());
430 custom_vars.insert("CUSTOM_VAR".to_string(), "custom_value".to_string());
431
432 let sources = vec![
433 EnvironmentSource::ProcessEnv,
434 EnvironmentSource::Custom(custom_vars),
435 EnvironmentSource::ServerConfig("production".to_string()),
436 ];
437
438 let env = load_mcp_environment(&sources).unwrap();
439
440 assert_eq!(env.get("MCP_SERVER_NAME"), Some(&"just-mcp".to_string()));
441 assert_eq!(env.get("CUSTOM_VAR"), Some(&"custom_value".to_string()));
442 assert_eq!(
443 env.get("MCP_SERVER_CONFIG"),
444 Some(&"production".to_string())
445 );
446 assert_eq!(env.sources.len(), 3);
447 }
448
449 #[test]
450 fn test_get_environment_info_function() {
451 let info = get_environment_info();
453
454 assert!(info.contains_key("source_count"));
455 assert!(info.contains_key("variable_count"));
456 assert!(info.contains_key("has_snapshot"));
457 assert!(info.contains_key("sources"));
458 }
459}