1use std::collections::HashMap;
48use std::str::FromStr;
49use std::sync::{Arc, RwLock};
50
51use crate::{ForgeError, Result};
52
53pub trait EnvProvider: Send + Sync {
58 fn get(&self, key: &str) -> Option<String>;
60
61 fn contains(&self, key: &str) -> bool {
63 self.get(key).is_some()
64 }
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct RealEnvProvider;
70
71impl RealEnvProvider {
72 pub fn new() -> Self {
74 Self
75 }
76}
77
78impl EnvProvider for RealEnvProvider {
79 fn get(&self, key: &str) -> Option<String> {
80 std::env::var(key).ok()
81 }
82}
83
84#[derive(Debug, Clone, Default)]
89pub struct MockEnvProvider {
90 vars: HashMap<String, String>,
92 accessed: Arc<RwLock<Vec<String>>>,
94}
95
96impl MockEnvProvider {
97 pub fn new() -> Self {
99 Self {
100 vars: HashMap::new(),
101 accessed: Arc::new(RwLock::new(Vec::new())),
102 }
103 }
104
105 pub fn with_vars(vars: HashMap<String, String>) -> Self {
107 Self {
108 vars,
109 accessed: Arc::new(RwLock::new(Vec::new())),
110 }
111 }
112
113 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
115 self.vars.insert(key.into(), value.into());
116 }
117
118 pub fn remove(&mut self, key: &str) {
120 self.vars.remove(key);
121 }
122
123 pub fn all(&self) -> &HashMap<String, String> {
125 &self.vars
126 }
127
128 pub fn accessed_keys(&self) -> Vec<String> {
130 self.accessed.read().unwrap().clone()
131 }
132
133 pub fn was_accessed(&self, key: &str) -> bool {
135 self.accessed.read().unwrap().contains(&key.to_string())
136 }
137
138 pub fn clear_accessed(&self) {
140 self.accessed.write().unwrap().clear();
141 }
142
143 pub fn assert_accessed(&self, key: &str) {
145 assert!(
146 self.was_accessed(key),
147 "Expected env var '{}' to be accessed, but it wasn't. Accessed keys: {:?}",
148 key,
149 self.accessed_keys()
150 );
151 }
152
153 pub fn assert_not_accessed(&self, key: &str) {
155 assert!(
156 !self.was_accessed(key),
157 "Expected env var '{}' to NOT be accessed, but it was",
158 key
159 );
160 }
161}
162
163impl EnvProvider for MockEnvProvider {
164 fn get(&self, key: &str) -> Option<String> {
165 self.accessed.write().unwrap().push(key.to_string());
167 self.vars.get(key).cloned()
168 }
169}
170
171pub trait EnvAccess {
176 fn env_provider(&self) -> &dyn EnvProvider;
178
179 fn env(&self, key: &str) -> Option<String> {
183 self.env_provider().get(key)
184 }
185
186 fn env_or(&self, key: &str, default: &str) -> String {
190 self.env_provider()
191 .get(key)
192 .unwrap_or_else(|| default.to_string())
193 }
194
195 fn env_require(&self, key: &str) -> Result<String> {
199 self.env_provider().get(key).ok_or_else(|| {
200 ForgeError::Config(format!("Required environment variable '{}' not set", key))
201 })
202 }
203
204 fn env_parse<T: FromStr>(&self, key: &str) -> Result<T>
210 where
211 T::Err: std::fmt::Display,
212 {
213 let value = self.env_require(key)?;
214 value.parse().map_err(|e: T::Err| {
215 ForgeError::Config(format!(
216 "Failed to parse env var '{}' value '{}': {}",
217 key, value, e
218 ))
219 })
220 }
221
222 fn env_parse_or<T: FromStr>(&self, key: &str, default: T) -> Result<T>
227 where
228 T::Err: std::fmt::Display,
229 {
230 match self.env_provider().get(key) {
231 Some(value) => value.parse().map_err(|e: T::Err| {
232 ForgeError::Config(format!(
233 "Failed to parse env var '{}' value '{}': {}",
234 key, value, e
235 ))
236 }),
237 None => Ok(default),
238 }
239 }
240
241 fn env_contains(&self, key: &str) -> bool {
243 self.env_provider().contains(key)
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_real_env_provider() {
253 unsafe {
255 std::env::set_var("FORGE_TEST_VAR", "test_value");
256 }
257
258 let provider = RealEnvProvider::new();
259 assert_eq!(
260 provider.get("FORGE_TEST_VAR"),
261 Some("test_value".to_string())
262 );
263 assert!(provider.contains("FORGE_TEST_VAR"));
264 assert!(provider.get("FORGE_NONEXISTENT_VAR").is_none());
265
266 unsafe {
268 std::env::remove_var("FORGE_TEST_VAR");
269 }
270 }
271
272 #[test]
273 fn test_mock_env_provider() {
274 let mut provider = MockEnvProvider::new();
275 provider.set("API_KEY", "secret123");
276 provider.set("TIMEOUT", "30");
277
278 assert_eq!(provider.get("API_KEY"), Some("secret123".to_string()));
279 assert_eq!(provider.get("TIMEOUT"), Some("30".to_string()));
280 assert!(provider.get("MISSING").is_none());
281
282 assert!(provider.was_accessed("API_KEY"));
284 assert!(provider.was_accessed("TIMEOUT"));
285 assert!(provider.was_accessed("MISSING")); provider.assert_accessed("API_KEY");
288 }
289
290 #[test]
291 fn test_mock_provider_with_vars() {
292 let vars = HashMap::from([
293 ("KEY1".to_string(), "value1".to_string()),
294 ("KEY2".to_string(), "value2".to_string()),
295 ]);
296 let provider = MockEnvProvider::with_vars(vars);
297
298 assert_eq!(provider.get("KEY1"), Some("value1".to_string()));
299 assert_eq!(provider.get("KEY2"), Some("value2".to_string()));
300 }
301
302 #[test]
303 fn test_clear_accessed() {
304 let mut provider = MockEnvProvider::new();
305 provider.set("KEY", "value");
306
307 provider.get("KEY");
308 assert!(!provider.accessed_keys().is_empty());
309
310 provider.clear_accessed();
311 assert!(provider.accessed_keys().is_empty());
312 }
313
314 struct TestEnvContext {
316 provider: MockEnvProvider,
317 }
318
319 impl EnvAccess for TestEnvContext {
320 fn env_provider(&self) -> &dyn EnvProvider {
321 &self.provider
322 }
323 }
324
325 #[test]
326 fn test_env_access_methods() {
327 let mut provider = MockEnvProvider::new();
328 provider.set("PORT", "8080");
329 provider.set("DEBUG", "true");
330 provider.set("BAD_NUMBER", "not_a_number");
331
332 let ctx = TestEnvContext { provider };
333
334 assert_eq!(ctx.env("PORT"), Some("8080".to_string()));
336 assert!(ctx.env("MISSING").is_none());
337
338 assert_eq!(ctx.env_or("PORT", "3000"), "8080");
340 assert_eq!(ctx.env_or("MISSING", "default"), "default");
341
342 assert_eq!(ctx.env_require("PORT").unwrap(), "8080");
344 assert!(ctx.env_require("MISSING").is_err());
345
346 let port: u16 = ctx.env_parse("PORT").unwrap();
348 assert_eq!(port, 8080);
349
350 let debug: bool = ctx.env_parse("DEBUG").unwrap();
351 assert!(debug);
352
353 let bad: Result<u32> = ctx.env_parse("BAD_NUMBER");
355 assert!(bad.is_err());
356
357 let port: u16 = ctx.env_parse_or("MISSING", 3000).unwrap();
359 assert_eq!(port, 3000);
360
361 assert!(ctx.env_contains("PORT"));
363 assert!(!ctx.env_contains("MISSING"));
364 }
365}