1use std::collections::HashMap;
8use std::str::FromStr;
9use std::sync::{Arc, RwLock};
10
11use crate::{ForgeError, Result};
12
13pub trait EnvProvider: Send + Sync {
15 fn get(&self, key: &str) -> Option<String>;
16
17 fn contains(&self, key: &str) -> bool {
18 self.get(key).is_some()
19 }
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct RealEnvProvider;
25
26impl RealEnvProvider {
27 pub fn new() -> Self {
28 Self
29 }
30
31 pub fn shared() -> Arc<dyn EnvProvider> {
33 static INSTANCE: std::sync::OnceLock<Arc<dyn EnvProvider>> = std::sync::OnceLock::new();
34 Arc::clone(INSTANCE.get_or_init(|| Arc::new(Self)))
35 }
36}
37
38impl EnvProvider for RealEnvProvider {
39 fn get(&self, key: &str) -> Option<String> {
40 std::env::var(key).ok()
41 }
42}
43
44#[derive(Debug, Clone, Default)]
49pub struct MockEnvProvider {
50 vars: HashMap<String, String>,
51 accessed: Arc<RwLock<Vec<String>>>,
52}
53
54impl MockEnvProvider {
55 pub fn new() -> Self {
56 Self {
57 vars: HashMap::new(),
58 accessed: Arc::new(RwLock::new(Vec::new())),
59 }
60 }
61
62 pub fn with_vars(vars: HashMap<String, String>) -> Self {
63 Self {
64 vars,
65 accessed: Arc::new(RwLock::new(Vec::new())),
66 }
67 }
68
69 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
70 self.vars.insert(key.into(), value.into());
71 }
72
73 pub fn remove(&mut self, key: &str) {
74 self.vars.remove(key);
75 }
76
77 pub fn all(&self) -> &HashMap<String, String> {
78 &self.vars
79 }
80
81 pub fn accessed_keys(&self) -> Vec<String> {
82 self.accessed
83 .read()
84 .expect("env accessed lock poisoned")
85 .clone()
86 }
87
88 pub fn was_accessed(&self, key: &str) -> bool {
89 self.accessed
90 .read()
91 .expect("env accessed lock poisoned")
92 .contains(&key.to_string())
93 }
94
95 pub fn clear_accessed(&self) {
96 self.accessed
97 .write()
98 .expect("env accessed lock poisoned")
99 .clear();
100 }
101
102 pub fn assert_accessed(&self, key: &str) {
103 assert!(
104 self.was_accessed(key),
105 "Expected env var '{}' to be accessed, but it wasn't. Accessed keys: {:?}",
106 key,
107 self.accessed_keys()
108 );
109 }
110
111 pub fn assert_not_accessed(&self, key: &str) {
112 assert!(
113 !self.was_accessed(key),
114 "Expected env var '{}' to NOT be accessed, but it was",
115 key
116 );
117 }
118}
119
120impl EnvProvider for MockEnvProvider {
121 fn get(&self, key: &str) -> Option<String> {
122 self.accessed
123 .write()
124 .expect("env accessed lock poisoned")
125 .push(key.to_string());
126 self.vars.get(key).cloned()
127 }
128}
129
130pub trait EnvAccess {
132 fn env_provider(&self) -> &dyn EnvProvider;
133
134 fn env(&self, key: &str) -> Option<String> {
135 self.env_provider().get(key)
136 }
137
138 fn env_or(&self, key: &str, default: &str) -> String {
139 self.env_provider()
140 .get(key)
141 .unwrap_or_else(|| default.to_string())
142 }
143
144 fn env_require(&self, key: &str) -> Result<String> {
145 self.env_provider().get(key).ok_or_else(|| {
146 ForgeError::config(format!("Required environment variable '{}' not set", key))
147 })
148 }
149
150 fn env_parse<T: FromStr>(&self, key: &str) -> Result<T>
151 where
152 T::Err: std::fmt::Display,
153 {
154 let value = self.env_require(key)?;
155 value.parse().map_err(|e: T::Err| {
156 ForgeError::config(format!(
157 "Failed to parse env var '{}' value '{}': {}",
158 key, value, e
159 ))
160 })
161 }
162
163 fn env_parse_or<T: FromStr>(&self, key: &str, default: T) -> Result<T>
165 where
166 T::Err: std::fmt::Display,
167 {
168 match self.env_provider().get(key) {
169 Some(value) => value.parse().map_err(|e: T::Err| {
170 ForgeError::config(format!(
171 "Failed to parse env var '{}' value '{}': {}",
172 key, value, e
173 ))
174 }),
175 None => Ok(default),
176 }
177 }
178
179 fn env_contains(&self, key: &str) -> bool {
180 self.env_provider().contains(key)
181 }
182}
183
184#[cfg(test)]
185#[allow(
186 clippy::unwrap_used,
187 clippy::indexing_slicing,
188 clippy::panic,
189 unsafe_code
190)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_real_env_provider() {
196 unsafe {
197 std::env::set_var("FORGE_TEST_VAR", "test_value");
198 }
199
200 let provider = RealEnvProvider::new();
201 assert_eq!(
202 provider.get("FORGE_TEST_VAR"),
203 Some("test_value".to_string())
204 );
205 assert!(provider.contains("FORGE_TEST_VAR"));
206 assert!(provider.get("FORGE_NONEXISTENT_VAR").is_none());
207
208 unsafe {
209 std::env::remove_var("FORGE_TEST_VAR");
210 }
211 }
212
213 #[test]
214 fn test_mock_env_provider() {
215 let mut provider = MockEnvProvider::new();
216 provider.set("API_KEY", "secret123");
217 provider.set("TIMEOUT", "30");
218
219 assert_eq!(provider.get("API_KEY"), Some("secret123".to_string()));
220 assert_eq!(provider.get("TIMEOUT"), Some("30".to_string()));
221 assert!(provider.get("MISSING").is_none());
222
223 assert!(provider.was_accessed("API_KEY"));
224 assert!(provider.was_accessed("TIMEOUT"));
225 assert!(provider.was_accessed("MISSING"));
226
227 provider.assert_accessed("API_KEY");
228 }
229
230 #[test]
231 fn test_mock_provider_with_vars() {
232 let vars = HashMap::from([
233 ("KEY1".to_string(), "value1".to_string()),
234 ("KEY2".to_string(), "value2".to_string()),
235 ]);
236 let provider = MockEnvProvider::with_vars(vars);
237
238 assert_eq!(provider.get("KEY1"), Some("value1".to_string()));
239 assert_eq!(provider.get("KEY2"), Some("value2".to_string()));
240 }
241
242 #[test]
243 fn test_clear_accessed() {
244 let mut provider = MockEnvProvider::new();
245 provider.set("KEY", "value");
246
247 provider.get("KEY");
248 assert!(!provider.accessed_keys().is_empty());
249
250 provider.clear_accessed();
251 assert!(provider.accessed_keys().is_empty());
252 }
253
254 struct TestEnvContext {
255 provider: MockEnvProvider,
256 }
257
258 impl EnvAccess for TestEnvContext {
259 fn env_provider(&self) -> &dyn EnvProvider {
260 &self.provider
261 }
262 }
263
264 #[test]
265 fn test_env_access_methods() {
266 let mut provider = MockEnvProvider::new();
267 provider.set("PORT", "8080");
268 provider.set("DEBUG", "true");
269 provider.set("BAD_NUMBER", "not_a_number");
270
271 let ctx = TestEnvContext { provider };
272
273 assert_eq!(ctx.env("PORT"), Some("8080".to_string()));
274 assert!(ctx.env("MISSING").is_none());
275
276 assert_eq!(ctx.env_or("PORT", "3000"), "8080");
277 assert_eq!(ctx.env_or("MISSING", "default"), "default");
278
279 assert_eq!(ctx.env_require("PORT").unwrap(), "8080");
280 assert!(ctx.env_require("MISSING").is_err());
281
282 let port: u16 = ctx.env_parse("PORT").unwrap();
283 assert_eq!(port, 8080);
284
285 let debug: bool = ctx.env_parse("DEBUG").unwrap();
286 assert!(debug);
287
288 let bad: Result<u32> = ctx.env_parse("BAD_NUMBER");
289 assert!(bad.is_err());
290
291 let port: u16 = ctx.env_parse_or("MISSING", 3000).unwrap();
292 assert_eq!(port, 3000);
293
294 assert!(ctx.env_contains("PORT"));
295 assert!(!ctx.env_contains("MISSING"));
296 }
297
298 #[test]
299 fn mock_remove_drops_var_but_does_not_clear_access_history() {
300 let mut provider = MockEnvProvider::new();
301 provider.set("TOKEN", "abc");
302 let _ = provider.get("TOKEN");
303 provider.remove("TOKEN");
304
305 assert!(provider.get("TOKEN").is_none());
307 assert!(provider.was_accessed("TOKEN"));
310 }
311
312 #[test]
313 fn mock_all_returns_currently_configured_vars() {
314 let mut provider = MockEnvProvider::new();
315 provider.set("A", "1");
316 provider.set("B", "2");
317 provider.remove("B");
318
319 let all = provider.all();
320 assert_eq!(all.len(), 1);
321 assert_eq!(all.get("A"), Some(&"1".to_string()));
322 assert!(!all.contains_key("B"));
323 }
324
325 #[test]
326 fn mock_access_log_preserves_duplicate_reads_in_order() {
327 let mut provider = MockEnvProvider::new();
328 provider.set("X", "1");
329 let _ = provider.get("X");
330 let _ = provider.get("Y"); let _ = provider.get("X");
332
333 assert_eq!(
336 provider.accessed_keys(),
337 vec!["X".to_string(), "Y".to_string(), "X".to_string()]
338 );
339 }
340
341 #[test]
342 fn mock_assert_not_accessed_passes_when_untouched() {
343 let provider = MockEnvProvider::new();
344 provider.assert_not_accessed("NEVER_READ");
345 }
346
347 #[test]
348 fn env_require_error_is_config_variant_with_key_name() {
349 let provider = MockEnvProvider::new();
350 let ctx = TestEnvContext { provider };
351
352 let err = ctx.env_require("STRIPE_API_KEY").unwrap_err();
353 match err {
354 ForgeError::Config { context: msg, .. } => {
355 assert!(
356 msg.contains("STRIPE_API_KEY"),
357 "msg should name the key: {msg}"
358 );
359 assert!(
360 msg.contains("not set"),
361 "msg should describe failure: {msg}"
362 );
363 }
364 other => panic!("expected ForgeError::Config, got {other:?}"),
365 }
366 }
367
368 #[test]
369 fn env_parse_error_quotes_key_and_value_in_message() {
370 let mut provider = MockEnvProvider::new();
371 provider.set("PORT", "not_a_port");
372 let ctx = TestEnvContext { provider };
373
374 let err: ForgeError = ctx.env_parse::<u16>("PORT").unwrap_err();
375 match err {
376 ForgeError::Config { context: msg, .. } => {
377 assert!(msg.contains("PORT"), "msg should name the key: {msg}");
378 assert!(
379 msg.contains("not_a_port"),
380 "msg should show the bad value: {msg}"
381 );
382 }
383 other => panic!("expected ForgeError::Config, got {other:?}"),
384 }
385 }
386
387 #[test]
388 fn env_parse_or_returns_default_when_unset() {
389 let provider = MockEnvProvider::new();
390 let ctx = TestEnvContext { provider };
391
392 let port: u16 = ctx.env_parse_or("MISSING_PORT", 8080).unwrap();
393 assert_eq!(port, 8080);
394 }
395
396 #[test]
397 fn env_parse_or_propagates_parse_error_when_var_is_set() {
398 let mut provider = MockEnvProvider::new();
401 provider.set("RETRIES", "lots");
402 let ctx = TestEnvContext { provider };
403
404 let err = ctx.env_parse_or::<u32>("RETRIES", 5).unwrap_err();
405 match err {
406 ForgeError::Config { context: msg, .. } => {
407 assert!(msg.contains("RETRIES"));
408 assert!(msg.contains("lots"));
409 }
410 other => panic!("expected ForgeError::Config, got {other:?}"),
411 }
412 }
413
414 #[test]
415 fn real_provider_contains_delegates_to_get() {
416 unsafe {
419 std::env::set_var("FORGE_CONTAINS_PROBE", "x");
420 }
421 let p = RealEnvProvider::new();
422 assert!(p.contains("FORGE_CONTAINS_PROBE"));
423 assert!(!p.contains("FORGE_DEFINITELY_NOT_SET_XYZ_42"));
424 unsafe {
425 std::env::remove_var("FORGE_CONTAINS_PROBE");
426 }
427 }
428}