1use std::collections::HashMap;
48use std::str::FromStr;
49use std::sync::{Arc, RwLock};
50
51use crate::{ForgeError, Result};
52
53pub trait EnvProvider: Send + Sync {
58 fn get(&self, key: &str) -> Option<String>;
60
61 fn contains(&self, key: &str) -> bool {
63 self.get(key).is_some()
64 }
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct RealEnvProvider;
70
71impl RealEnvProvider {
72 pub fn new() -> Self {
74 Self
75 }
76}
77
78impl EnvProvider for RealEnvProvider {
79 fn get(&self, key: &str) -> Option<String> {
80 std::env::var(key).ok()
81 }
82}
83
84#[derive(Debug, Clone, Default)]
89pub struct MockEnvProvider {
90 vars: HashMap<String, String>,
92 accessed: Arc<RwLock<Vec<String>>>,
94}
95
96impl MockEnvProvider {
97 pub fn new() -> Self {
99 Self {
100 vars: HashMap::new(),
101 accessed: Arc::new(RwLock::new(Vec::new())),
102 }
103 }
104
105 pub fn with_vars(vars: HashMap<String, String>) -> Self {
107 Self {
108 vars,
109 accessed: Arc::new(RwLock::new(Vec::new())),
110 }
111 }
112
113 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
115 self.vars.insert(key.into(), value.into());
116 }
117
118 pub fn remove(&mut self, key: &str) {
120 self.vars.remove(key);
121 }
122
123 pub fn all(&self) -> &HashMap<String, String> {
125 &self.vars
126 }
127
128 pub fn accessed_keys(&self) -> Vec<String> {
130 self.accessed
131 .read()
132 .expect("env accessed lock poisoned")
133 .clone()
134 }
135
136 pub fn was_accessed(&self, key: &str) -> bool {
138 self.accessed
139 .read()
140 .expect("env accessed lock poisoned")
141 .contains(&key.to_string())
142 }
143
144 pub fn clear_accessed(&self) {
146 self.accessed
147 .write()
148 .expect("env accessed lock poisoned")
149 .clear();
150 }
151
152 pub fn assert_accessed(&self, key: &str) {
154 assert!(
155 self.was_accessed(key),
156 "Expected env var '{}' to be accessed, but it wasn't. Accessed keys: {:?}",
157 key,
158 self.accessed_keys()
159 );
160 }
161
162 pub fn assert_not_accessed(&self, key: &str) {
164 assert!(
165 !self.was_accessed(key),
166 "Expected env var '{}' to NOT be accessed, but it was",
167 key
168 );
169 }
170}
171
172impl EnvProvider for MockEnvProvider {
173 fn get(&self, key: &str) -> Option<String> {
174 self.accessed
176 .write()
177 .expect("env accessed lock poisoned")
178 .push(key.to_string());
179 self.vars.get(key).cloned()
180 }
181}
182
183pub trait EnvAccess {
188 fn env_provider(&self) -> &dyn EnvProvider;
190
191 fn env(&self, key: &str) -> Option<String> {
195 self.env_provider().get(key)
196 }
197
198 fn env_or(&self, key: &str, default: &str) -> String {
202 self.env_provider()
203 .get(key)
204 .unwrap_or_else(|| default.to_string())
205 }
206
207 fn env_require(&self, key: &str) -> Result<String> {
211 self.env_provider().get(key).ok_or_else(|| {
212 ForgeError::Config(format!("Required environment variable '{}' not set", key))
213 })
214 }
215
216 fn env_parse<T: FromStr>(&self, key: &str) -> Result<T>
222 where
223 T::Err: std::fmt::Display,
224 {
225 let value = self.env_require(key)?;
226 value.parse().map_err(|e: T::Err| {
227 ForgeError::Config(format!(
228 "Failed to parse env var '{}' value '{}': {}",
229 key, value, e
230 ))
231 })
232 }
233
234 fn env_parse_or<T: FromStr>(&self, key: &str, default: T) -> Result<T>
239 where
240 T::Err: std::fmt::Display,
241 {
242 match self.env_provider().get(key) {
243 Some(value) => value.parse().map_err(|e: T::Err| {
244 ForgeError::Config(format!(
245 "Failed to parse env var '{}' value '{}': {}",
246 key, value, e
247 ))
248 }),
249 None => Ok(default),
250 }
251 }
252
253 fn env_contains(&self, key: &str) -> bool {
255 self.env_provider().contains(key)
256 }
257}
258
259#[cfg(test)]
260#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_real_env_provider() {
266 unsafe {
268 std::env::set_var("FORGE_TEST_VAR", "test_value");
269 }
270
271 let provider = RealEnvProvider::new();
272 assert_eq!(
273 provider.get("FORGE_TEST_VAR"),
274 Some("test_value".to_string())
275 );
276 assert!(provider.contains("FORGE_TEST_VAR"));
277 assert!(provider.get("FORGE_NONEXISTENT_VAR").is_none());
278
279 unsafe {
281 std::env::remove_var("FORGE_TEST_VAR");
282 }
283 }
284
285 #[test]
286 fn test_mock_env_provider() {
287 let mut provider = MockEnvProvider::new();
288 provider.set("API_KEY", "secret123");
289 provider.set("TIMEOUT", "30");
290
291 assert_eq!(provider.get("API_KEY"), Some("secret123".to_string()));
292 assert_eq!(provider.get("TIMEOUT"), Some("30".to_string()));
293 assert!(provider.get("MISSING").is_none());
294
295 assert!(provider.was_accessed("API_KEY"));
297 assert!(provider.was_accessed("TIMEOUT"));
298 assert!(provider.was_accessed("MISSING")); provider.assert_accessed("API_KEY");
301 }
302
303 #[test]
304 fn test_mock_provider_with_vars() {
305 let vars = HashMap::from([
306 ("KEY1".to_string(), "value1".to_string()),
307 ("KEY2".to_string(), "value2".to_string()),
308 ]);
309 let provider = MockEnvProvider::with_vars(vars);
310
311 assert_eq!(provider.get("KEY1"), Some("value1".to_string()));
312 assert_eq!(provider.get("KEY2"), Some("value2".to_string()));
313 }
314
315 #[test]
316 fn test_clear_accessed() {
317 let mut provider = MockEnvProvider::new();
318 provider.set("KEY", "value");
319
320 provider.get("KEY");
321 assert!(!provider.accessed_keys().is_empty());
322
323 provider.clear_accessed();
324 assert!(provider.accessed_keys().is_empty());
325 }
326
327 struct TestEnvContext {
329 provider: MockEnvProvider,
330 }
331
332 impl EnvAccess for TestEnvContext {
333 fn env_provider(&self) -> &dyn EnvProvider {
334 &self.provider
335 }
336 }
337
338 #[test]
339 fn test_env_access_methods() {
340 let mut provider = MockEnvProvider::new();
341 provider.set("PORT", "8080");
342 provider.set("DEBUG", "true");
343 provider.set("BAD_NUMBER", "not_a_number");
344
345 let ctx = TestEnvContext { provider };
346
347 assert_eq!(ctx.env("PORT"), Some("8080".to_string()));
349 assert!(ctx.env("MISSING").is_none());
350
351 assert_eq!(ctx.env_or("PORT", "3000"), "8080");
353 assert_eq!(ctx.env_or("MISSING", "default"), "default");
354
355 assert_eq!(ctx.env_require("PORT").unwrap(), "8080");
357 assert!(ctx.env_require("MISSING").is_err());
358
359 let port: u16 = ctx.env_parse("PORT").unwrap();
361 assert_eq!(port, 8080);
362
363 let debug: bool = ctx.env_parse("DEBUG").unwrap();
364 assert!(debug);
365
366 let bad: Result<u32> = ctx.env_parse("BAD_NUMBER");
368 assert!(bad.is_err());
369
370 let port: u16 = ctx.env_parse_or("MISSING", 3000).unwrap();
372 assert_eq!(port, 3000);
373
374 assert!(ctx.env_contains("PORT"));
376 assert!(!ctx.env_contains("MISSING"));
377 }
378}