1use std::collections::HashMap;
44use std::str::FromStr;
45use std::sync::{Arc, RwLock};
46
47use crate::{ForgeError, Result};
48
49pub trait EnvProvider: Send + Sync {
54 fn get(&self, key: &str) -> Option<String>;
56
57 fn contains(&self, key: &str) -> bool {
59 self.get(key).is_some()
60 }
61}
62
63#[derive(Debug, Clone, Default)]
65pub struct RealEnvProvider;
66
67impl RealEnvProvider {
68 pub fn new() -> Self {
70 Self
71 }
72}
73
74impl EnvProvider for RealEnvProvider {
75 fn get(&self, key: &str) -> Option<String> {
76 std::env::var(key).ok()
77 }
78}
79
80#[derive(Debug, Clone, Default)]
85pub struct MockEnvProvider {
86 vars: HashMap<String, String>,
88 accessed: Arc<RwLock<Vec<String>>>,
90}
91
92impl MockEnvProvider {
93 pub fn new() -> Self {
95 Self {
96 vars: HashMap::new(),
97 accessed: Arc::new(RwLock::new(Vec::new())),
98 }
99 }
100
101 pub fn with_vars(vars: HashMap<String, String>) -> Self {
103 Self {
104 vars,
105 accessed: Arc::new(RwLock::new(Vec::new())),
106 }
107 }
108
109 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
111 self.vars.insert(key.into(), value.into());
112 }
113
114 pub fn remove(&mut self, key: &str) {
116 self.vars.remove(key);
117 }
118
119 pub fn all(&self) -> &HashMap<String, String> {
121 &self.vars
122 }
123
124 pub fn accessed_keys(&self) -> Vec<String> {
126 self.accessed.read().unwrap().clone()
127 }
128
129 pub fn was_accessed(&self, key: &str) -> bool {
131 self.accessed.read().unwrap().contains(&key.to_string())
132 }
133
134 pub fn clear_accessed(&self) {
136 self.accessed.write().unwrap().clear();
137 }
138
139 pub fn assert_accessed(&self, key: &str) {
141 assert!(
142 self.was_accessed(key),
143 "Expected env var '{}' to be accessed, but it wasn't. Accessed keys: {:?}",
144 key,
145 self.accessed_keys()
146 );
147 }
148
149 pub fn assert_not_accessed(&self, key: &str) {
151 assert!(
152 !self.was_accessed(key),
153 "Expected env var '{}' to NOT be accessed, but it was",
154 key
155 );
156 }
157}
158
159impl EnvProvider for MockEnvProvider {
160 fn get(&self, key: &str) -> Option<String> {
161 self.accessed.write().unwrap().push(key.to_string());
163 self.vars.get(key).cloned()
164 }
165}
166
167pub trait EnvAccess {
172 fn env_provider(&self) -> &dyn EnvProvider;
174
175 fn env(&self, key: &str) -> Option<String> {
179 self.env_provider().get(key)
180 }
181
182 fn env_or(&self, key: &str, default: &str) -> String {
186 self.env_provider()
187 .get(key)
188 .unwrap_or_else(|| default.to_string())
189 }
190
191 fn env_require(&self, key: &str) -> Result<String> {
195 self.env_provider().get(key).ok_or_else(|| {
196 ForgeError::Config(format!("Required environment variable '{}' not set", key))
197 })
198 }
199
200 fn env_parse<T: FromStr>(&self, key: &str) -> Result<T>
206 where
207 T::Err: std::fmt::Display,
208 {
209 let value = self.env_require(key)?;
210 value.parse().map_err(|e: T::Err| {
211 ForgeError::Config(format!(
212 "Failed to parse env var '{}' value '{}': {}",
213 key, value, e
214 ))
215 })
216 }
217
218 fn env_parse_or<T: FromStr>(&self, key: &str, default: T) -> Result<T>
223 where
224 T::Err: std::fmt::Display,
225 {
226 match self.env_provider().get(key) {
227 Some(value) => value.parse().map_err(|e: T::Err| {
228 ForgeError::Config(format!(
229 "Failed to parse env var '{}' value '{}': {}",
230 key, value, e
231 ))
232 }),
233 None => Ok(default),
234 }
235 }
236
237 fn env_contains(&self, key: &str) -> bool {
239 self.env_provider().contains(key)
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_real_env_provider() {
249 unsafe {
251 std::env::set_var("FORGE_TEST_VAR", "test_value");
252 }
253
254 let provider = RealEnvProvider::new();
255 assert_eq!(
256 provider.get("FORGE_TEST_VAR"),
257 Some("test_value".to_string())
258 );
259 assert!(provider.contains("FORGE_TEST_VAR"));
260 assert!(provider.get("FORGE_NONEXISTENT_VAR").is_none());
261
262 unsafe {
264 std::env::remove_var("FORGE_TEST_VAR");
265 }
266 }
267
268 #[test]
269 fn test_mock_env_provider() {
270 let mut provider = MockEnvProvider::new();
271 provider.set("API_KEY", "secret123");
272 provider.set("TIMEOUT", "30");
273
274 assert_eq!(provider.get("API_KEY"), Some("secret123".to_string()));
275 assert_eq!(provider.get("TIMEOUT"), Some("30".to_string()));
276 assert!(provider.get("MISSING").is_none());
277
278 assert!(provider.was_accessed("API_KEY"));
280 assert!(provider.was_accessed("TIMEOUT"));
281 assert!(provider.was_accessed("MISSING")); provider.assert_accessed("API_KEY");
284 }
285
286 #[test]
287 fn test_mock_provider_with_vars() {
288 let vars = HashMap::from([
289 ("KEY1".to_string(), "value1".to_string()),
290 ("KEY2".to_string(), "value2".to_string()),
291 ]);
292 let provider = MockEnvProvider::with_vars(vars);
293
294 assert_eq!(provider.get("KEY1"), Some("value1".to_string()));
295 assert_eq!(provider.get("KEY2"), Some("value2".to_string()));
296 }
297
298 #[test]
299 fn test_clear_accessed() {
300 let mut provider = MockEnvProvider::new();
301 provider.set("KEY", "value");
302
303 provider.get("KEY");
304 assert!(!provider.accessed_keys().is_empty());
305
306 provider.clear_accessed();
307 assert!(provider.accessed_keys().is_empty());
308 }
309
310 struct TestEnvContext {
312 provider: MockEnvProvider,
313 }
314
315 impl EnvAccess for TestEnvContext {
316 fn env_provider(&self) -> &dyn EnvProvider {
317 &self.provider
318 }
319 }
320
321 #[test]
322 fn test_env_access_methods() {
323 let mut provider = MockEnvProvider::new();
324 provider.set("PORT", "8080");
325 provider.set("DEBUG", "true");
326 provider.set("BAD_NUMBER", "not_a_number");
327
328 let ctx = TestEnvContext { provider };
329
330 assert_eq!(ctx.env("PORT"), Some("8080".to_string()));
332 assert!(ctx.env("MISSING").is_none());
333
334 assert_eq!(ctx.env_or("PORT", "3000"), "8080");
336 assert_eq!(ctx.env_or("MISSING", "default"), "default");
337
338 assert_eq!(ctx.env_require("PORT").unwrap(), "8080");
340 assert!(ctx.env_require("MISSING").is_err());
341
342 let port: u16 = ctx.env_parse("PORT").unwrap();
344 assert_eq!(port, 8080);
345
346 let debug: bool = ctx.env_parse("DEBUG").unwrap();
347 assert!(debug);
348
349 let bad: Result<u32> = ctx.env_parse("BAD_NUMBER");
351 assert!(bad.is_err());
352
353 let port: u16 = ctx.env_parse_or("MISSING", 3000).unwrap();
355 assert_eq!(port, 3000);
356
357 assert!(ctx.env_contains("PORT"));
359 assert!(!ctx.env_contains("MISSING"));
360 }
361}