use std::collections::HashMap;
use std::str::FromStr;
use std::sync::{Arc, RwLock};
use crate::{ForgeError, Result};
pub trait EnvProvider: Send + Sync {
fn get(&self, key: &str) -> Option<String>;
fn contains(&self, key: &str) -> bool {
self.get(key).is_some()
}
}
#[derive(Debug, Clone, Default)]
pub struct RealEnvProvider;
impl RealEnvProvider {
pub fn new() -> Self {
Self
}
}
impl EnvProvider for RealEnvProvider {
fn get(&self, key: &str) -> Option<String> {
std::env::var(key).ok()
}
}
#[derive(Debug, Clone, Default)]
pub struct MockEnvProvider {
vars: HashMap<String, String>,
accessed: Arc<RwLock<Vec<String>>>,
}
impl MockEnvProvider {
pub fn new() -> Self {
Self {
vars: HashMap::new(),
accessed: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn with_vars(vars: HashMap<String, String>) -> Self {
Self {
vars,
accessed: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.vars.insert(key.into(), value.into());
}
pub fn remove(&mut self, key: &str) {
self.vars.remove(key);
}
pub fn all(&self) -> &HashMap<String, String> {
&self.vars
}
pub fn accessed_keys(&self) -> Vec<String> {
self.accessed
.read()
.expect("env accessed lock poisoned")
.clone()
}
pub fn was_accessed(&self, key: &str) -> bool {
self.accessed
.read()
.expect("env accessed lock poisoned")
.contains(&key.to_string())
}
pub fn clear_accessed(&self) {
self.accessed
.write()
.expect("env accessed lock poisoned")
.clear();
}
pub fn assert_accessed(&self, key: &str) {
assert!(
self.was_accessed(key),
"Expected env var '{}' to be accessed, but it wasn't. Accessed keys: {:?}",
key,
self.accessed_keys()
);
}
pub fn assert_not_accessed(&self, key: &str) {
assert!(
!self.was_accessed(key),
"Expected env var '{}' to NOT be accessed, but it was",
key
);
}
}
impl EnvProvider for MockEnvProvider {
fn get(&self, key: &str) -> Option<String> {
self.accessed
.write()
.expect("env accessed lock poisoned")
.push(key.to_string());
self.vars.get(key).cloned()
}
}
pub trait EnvAccess {
fn env_provider(&self) -> &dyn EnvProvider;
fn env(&self, key: &str) -> Option<String> {
self.env_provider().get(key)
}
fn env_or(&self, key: &str, default: &str) -> String {
self.env_provider()
.get(key)
.unwrap_or_else(|| default.to_string())
}
fn env_require(&self, key: &str) -> Result<String> {
self.env_provider().get(key).ok_or_else(|| {
ForgeError::Config(format!("Required environment variable '{}' not set", key))
})
}
fn env_parse<T: FromStr>(&self, key: &str) -> Result<T>
where
T::Err: std::fmt::Display,
{
let value = self.env_require(key)?;
value.parse().map_err(|e: T::Err| {
ForgeError::Config(format!(
"Failed to parse env var '{}' value '{}': {}",
key, value, e
))
})
}
fn env_parse_or<T: FromStr>(&self, key: &str, default: T) -> Result<T>
where
T::Err: std::fmt::Display,
{
match self.env_provider().get(key) {
Some(value) => value.parse().map_err(|e: T::Err| {
ForgeError::Config(format!(
"Failed to parse env var '{}' value '{}': {}",
key, value, e
))
}),
None => Ok(default),
}
}
fn env_contains(&self, key: &str) -> bool {
self.env_provider().contains(key)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
mod tests {
use super::*;
#[test]
fn test_real_env_provider() {
unsafe {
std::env::set_var("FORGE_TEST_VAR", "test_value");
}
let provider = RealEnvProvider::new();
assert_eq!(
provider.get("FORGE_TEST_VAR"),
Some("test_value".to_string())
);
assert!(provider.contains("FORGE_TEST_VAR"));
assert!(provider.get("FORGE_NONEXISTENT_VAR").is_none());
unsafe {
std::env::remove_var("FORGE_TEST_VAR");
}
}
#[test]
fn test_mock_env_provider() {
let mut provider = MockEnvProvider::new();
provider.set("API_KEY", "secret123");
provider.set("TIMEOUT", "30");
assert_eq!(provider.get("API_KEY"), Some("secret123".to_string()));
assert_eq!(provider.get("TIMEOUT"), Some("30".to_string()));
assert!(provider.get("MISSING").is_none());
assert!(provider.was_accessed("API_KEY"));
assert!(provider.was_accessed("TIMEOUT"));
assert!(provider.was_accessed("MISSING"));
provider.assert_accessed("API_KEY");
}
#[test]
fn test_mock_provider_with_vars() {
let vars = HashMap::from([
("KEY1".to_string(), "value1".to_string()),
("KEY2".to_string(), "value2".to_string()),
]);
let provider = MockEnvProvider::with_vars(vars);
assert_eq!(provider.get("KEY1"), Some("value1".to_string()));
assert_eq!(provider.get("KEY2"), Some("value2".to_string()));
}
#[test]
fn test_clear_accessed() {
let mut provider = MockEnvProvider::new();
provider.set("KEY", "value");
provider.get("KEY");
assert!(!provider.accessed_keys().is_empty());
provider.clear_accessed();
assert!(provider.accessed_keys().is_empty());
}
struct TestEnvContext {
provider: MockEnvProvider,
}
impl EnvAccess for TestEnvContext {
fn env_provider(&self) -> &dyn EnvProvider {
&self.provider
}
}
#[test]
fn test_env_access_methods() {
let mut provider = MockEnvProvider::new();
provider.set("PORT", "8080");
provider.set("DEBUG", "true");
provider.set("BAD_NUMBER", "not_a_number");
let ctx = TestEnvContext { provider };
assert_eq!(ctx.env("PORT"), Some("8080".to_string()));
assert!(ctx.env("MISSING").is_none());
assert_eq!(ctx.env_or("PORT", "3000"), "8080");
assert_eq!(ctx.env_or("MISSING", "default"), "default");
assert_eq!(ctx.env_require("PORT").unwrap(), "8080");
assert!(ctx.env_require("MISSING").is_err());
let port: u16 = ctx.env_parse("PORT").unwrap();
assert_eq!(port, 8080);
let debug: bool = ctx.env_parse("DEBUG").unwrap();
assert!(debug);
let bad: Result<u32> = ctx.env_parse("BAD_NUMBER");
assert!(bad.is_err());
let port: u16 = ctx.env_parse_or("MISSING", 3000).unwrap();
assert_eq!(port, 3000);
assert!(ctx.env_contains("PORT"));
assert!(!ctx.env_contains("MISSING"));
}
}