use crate::security::{EnvSecurityError, EnvSecurityValidator};
use regex::Regex;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock, RwLock};
use std::time::Instant;
const RATE_LIMIT_MAX_REQUESTS: usize = 100;
const RATE_LIMIT_WINDOW_SECONDS: u64 = 60;
#[derive(Debug, Clone)]
pub struct InjectionRateLimiter {
window_counter: Arc<AtomicU64>,
window_start: Arc<AtomicU64>,
max_requests: usize,
window_seconds: u64,
rate_limiting_enabled: bool,
}
impl Default for InjectionRateLimiter {
fn default() -> Self {
Self::new()
}
}
impl InjectionRateLimiter {
pub fn new() -> Self {
Self {
window_counter: Arc::new(AtomicU64::new(0)),
window_start: Arc::new(AtomicU64::new(0)),
max_requests: RATE_LIMIT_MAX_REQUESTS,
window_seconds: RATE_LIMIT_WINDOW_SECONDS,
rate_limiting_enabled: true,
}
}
pub fn disabled() -> Self {
Self {
window_counter: Arc::new(AtomicU64::new(0)),
window_start: Arc::new(AtomicU64::new(0)),
max_requests: usize::MAX,
window_seconds: u64::MAX,
rate_limiting_enabled: false,
}
}
pub fn with_limits(max_requests: usize, window_seconds: u64) -> Self {
Self {
window_counter: Arc::new(AtomicU64::new(0)),
window_start: Arc::new(AtomicU64::new(0)),
max_requests,
window_seconds,
rate_limiting_enabled: true,
}
}
pub fn check_rate_limit(&self) -> Result<(), u64> {
if !self.rate_limiting_enabled {
return Ok(());
}
let now = Instant::now();
let now_secs = now.elapsed().as_secs();
let window_start = self.window_start.load(Ordering::SeqCst);
if window_start == 0 {
if self
.window_start
.compare_exchange(0, now_secs, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
return Ok(());
}
}
if now_secs - window_start < self.window_seconds {
let current = self.window_counter.fetch_add(1, Ordering::SeqCst);
if current as usize >= self.max_requests {
self.window_counter.fetch_sub(1, Ordering::SeqCst);
let retry_after = self.window_seconds - (now_secs - window_start);
return Err(retry_after);
}
Ok(())
} else {
if self
.window_start
.compare_exchange(window_start, now_secs, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
self.window_counter.store(1, Ordering::SeqCst);
Ok(())
} else {
self.check_rate_limit()
}
}
}
pub fn usage_stats(&self) -> (usize, u64, f64) {
let now_secs = Instant::now().elapsed().as_secs();
let window_start = self.window_start.load(Ordering::SeqCst);
let counter = self.window_counter.load(Ordering::SeqCst);
let elapsed = now_secs.saturating_sub(window_start);
let remaining = self.window_seconds.saturating_sub(elapsed);
let usage_percent = (counter as f64 / self.max_requests as f64) * 100.0;
(counter as usize, remaining, usage_percent)
}
}
pub static GLOBAL_RATE_LIMITER: OnceLock<InjectionRateLimiter> = OnceLock::new();
#[cfg(test)]
pub static TEST_RATE_LIMITER: OnceLock<InjectionRateLimiter> = OnceLock::new();
pub static GLOBAL_INJECTOR: OnceLock<Arc<RwLock<ConfigInjector>>> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct ConfigInjector {
values: Arc<RwLock<HashMap<String, String>>>,
validator: EnvSecurityValidator,
sensitive_patterns: Vec<Regex>,
injection_history: Arc<RwLock<Vec<InjectionRecord>>>,
}
impl Default for ConfigInjector {
fn default() -> Self {
Self::new()
}
}
impl ConfigInjector {
pub fn new() -> Self {
Self::with_validator(EnvSecurityValidator::new())
}
pub fn with_validator(validator: EnvSecurityValidator) -> Self {
Self {
values: Arc::new(RwLock::new(HashMap::new())),
validator,
sensitive_patterns: Self::default_sensitive_patterns(),
injection_history: Arc::new(RwLock::new(Vec::new())),
}
}
fn default_sensitive_patterns() -> Vec<Regex> {
vec![
Regex::new(r"(?i)(secret|password|token|key|auth|credential)").unwrap(),
Regex::new(r"(?i)(api_key|access_token|refresh_token)").unwrap(),
Regex::new(r"(?i)(private_key|public_key)").unwrap(),
Regex::new(r"(?i)(database_url|connection_string)").unwrap(),
]
}
pub fn inject(&self, name: &str, value: &str) -> Result<(), ConfigInjectionError> {
#[cfg(test)]
if let Err(retry_after) = TEST_RATE_LIMITER
.get_or_init(InjectionRateLimiter::disabled)
.check_rate_limit()
{
return Err(ConfigInjectionError::RateLimited {
retry_after_seconds: retry_after,
});
}
#[cfg(not(test))]
if let Err(retry_after) = GLOBAL_RATE_LIMITER
.get_or_init(InjectionRateLimiter::new)
.check_rate_limit()
{
return Err(ConfigInjectionError::RateLimited {
retry_after_seconds: retry_after,
});
}
self.validator.validate_env_name(name, Some(value))?;
self.validator.validate_env_value(value)?;
let is_sensitive = self.is_sensitive_field(name);
{
let mut values = self
.values
.write()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
values.insert(name.to_string(), value.to_string());
}
{
let mut history = self
.injection_history
.write()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
history.push(InjectionRecord {
name: name.to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
is_sensitive,
});
}
Ok(())
}
#[allow(clippy::type_complexity)]
pub fn inject_all(
&self,
config: &HashMap<String, String>,
) -> Result<(Vec<String>, Vec<(String, String)>), ConfigInjectionError> {
let mut success = Vec::new();
let mut failures = Vec::new();
let mut valid_injections: Vec<(String, String, bool)> = Vec::new();
for (name, value) in config {
match self.validator.validate_env_name(name, Some(value)) {
Ok(_) => match self.validator.validate_env_value(value) {
Ok(_) => {
let is_sensitive = self.is_sensitive_field(name);
valid_injections.push((name.clone(), value.clone(), is_sensitive));
}
Err(e) => failures.push((name.clone(), e.to_string())),
},
Err(e) => failures.push((name.clone(), e.to_string())),
}
}
if !valid_injections.is_empty() {
let mut values = self
.values
.write()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
let mut history = self
.injection_history
.write()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
for (name, value, is_sensitive) in valid_injections {
values.insert(name.clone(), value);
history.push(InjectionRecord {
name: name.clone(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
is_sensitive,
});
success.push(name);
}
}
Ok((success, failures))
}
pub fn get(&self, name: &str) -> Option<String> {
let values = self.values.read().ok()?;
values.get(name).cloned()
}
pub fn get_safe(&self, name: &str) -> Option<String> {
let value = self.get(name)?;
if self.is_sensitive_field(name) {
Some(Self::mask_value(&value))
} else {
Some(value)
}
}
pub fn get_all(&self) -> Result<HashMap<String, String>, ConfigInjectionError> {
let values = self
.values
.read()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
Ok(values.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
}
pub fn get_all_safe(&self) -> Result<HashMap<String, String>, ConfigInjectionError> {
let values = self
.values
.read()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
Ok(values
.iter()
.map(|(k, v)| {
let safe_value = if self.is_sensitive_field(k) {
Self::mask_value(v)
} else {
v.clone()
};
(k.clone(), safe_value)
})
.collect())
}
pub fn contains(&self, name: &str) -> Result<bool, ConfigInjectionError> {
let values = self
.values
.read()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
Ok(values.contains_key(name))
}
pub fn remove(&self, name: &str) -> Result<Option<String>, ConfigInjectionError> {
let mut values = self
.values
.write()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
Ok(values.remove(name))
}
pub fn clear(&self) -> Result<(), ConfigInjectionError> {
let mut values = self
.values
.write()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
values.clear();
let mut history = self
.injection_history
.write()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
history.clear();
Ok(())
}
pub fn get_injection_history(&self) -> Result<Vec<InjectionRecord>, ConfigInjectionError> {
let history = self
.injection_history
.read()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
Ok(history.clone())
}
fn is_sensitive_field(&self, name: &str) -> bool {
let name_lower = name.to_lowercase();
self.sensitive_patterns
.iter()
.any(|pattern| pattern.is_match(&name_lower))
}
fn mask_value(value: &str) -> String {
if value.len() <= 4 {
"*".repeat(value.len())
} else {
let visible = std::cmp::min(2, value.len() / 4);
format!("{}{}", &value[..visible], "*".repeat(value.len() - visible))
}
}
pub fn len(&self) -> Result<usize, ConfigInjectionError> {
let values = self
.values
.read()
.map_err(|_| ConfigInjectionError::PoisonedLock)?;
Ok(values.len())
}
pub fn is_empty(&self) -> Result<bool, ConfigInjectionError> {
Ok(self.len()? == 0)
}
pub fn validator(&self) -> &EnvSecurityValidator {
&self.validator
}
}
#[derive(Debug, Clone)]
pub struct InjectionRecord {
pub name: String,
pub timestamp: u64,
pub is_sensitive: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConfigInjectionError {
SecurityValidation(EnvSecurityError),
PoisonedLock,
InvalidName(String),
InvalidValue(String),
RateLimited { retry_after_seconds: u64 },
}
impl std::fmt::Display for ConfigInjectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigInjectionError::SecurityValidation(e) => {
write!(f, "Security validation failed: {}", e)
}
ConfigInjectionError::PoisonedLock => {
write!(f, "Configuration lock poisoned")
}
ConfigInjectionError::InvalidName(name) => {
write!(f, "Invalid configuration name: {}", name)
}
ConfigInjectionError::InvalidValue(value) => {
write!(f, "Invalid configuration value: {}", value)
}
ConfigInjectionError::RateLimited {
retry_after_seconds,
} => {
write!(
f,
"Rate limited. Retry after {} seconds",
retry_after_seconds
)
}
}
}
}
impl std::error::Error for ConfigInjectionError {}
impl From<EnvSecurityError> for ConfigInjectionError {
fn from(e: EnvSecurityError) -> Self {
ConfigInjectionError::SecurityValidation(e)
}
}
#[derive(Debug, Clone)]
pub struct EnvironmentConfig<'a> {
injector: &'a ConfigInjector,
}
impl<'a> EnvironmentConfig<'a> {
pub fn from_injector(injector: &'a ConfigInjector) -> Self {
Self { injector }
}
pub fn get<T>(&self, name: &str, default: T) -> T
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
let value = self.injector.get(name);
match value {
Some(v) => v.parse().ok().unwrap_or(default),
None => default,
}
}
pub fn get_required<T>(&self, name: &str) -> Result<T, ConfigInjectionError>
where
T: std::str::FromStr,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
let value = self
.injector
.get(name)
.ok_or_else(|| ConfigInjectionError::InvalidName(name.to_string()))?;
value
.parse()
.map_err(|e| ConfigInjectionError::InvalidValue(format!("{}: {}", name, e)))
}
pub fn get_string(&self, name: &str, default: &str) -> String {
self.injector
.get(name)
.unwrap_or_else(|| default.to_string())
}
pub fn get_bool(&self, name: &str, default: bool) -> bool {
self.get::<bool>(name, default)
}
pub fn get_number<T>(&self, name: &str, default: T) -> T
where
T: std::str::FromStr + std::clone::Clone,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
self.get::<T>(name, default)
}
}
pub mod macros {
#[macro_export]
macro_rules! safe_inject {
($injector:expr, { $($name:expr => $value:expr),+ }) => {
$(
let _ = $injector.inject($name, $value);
)+
};
}
#[macro_export]
macro_rules! inject_from_env {
($injector:expr, $prefix:expr, [$($name:expr),+]) => {
$(
if let Ok(value) = std::env::var(format!("{}{}", $prefix, $name)) {
let _ = $injector.inject(&format!("{}{}", $prefix, $name), &value);
}
)+
};
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_injector_basic() {
let injector = ConfigInjector::new();
assert!(injector.inject("APP_PORT", "8080").is_ok());
assert!(injector.inject("APP_DEBUG", "true").is_ok());
assert_eq!(injector.get("APP_PORT"), Some("8080".to_string()));
assert_eq!(injector.get("APP_DEBUG"), Some("true".to_string()));
assert!(injector.contains("APP_PORT").unwrap());
assert!(!injector.contains("APP_NONEXISTENT").unwrap());
}
#[test]
fn test_sensitive_field_detection() {
let injector = ConfigInjector::new();
assert!(injector.is_sensitive_field("APP_SECRET"));
assert!(injector.is_sensitive_field("API_TOKEN"));
assert!(injector.is_sensitive_field("DATABASE_PASSWORD"));
assert!(!injector.is_sensitive_field("APP_PORT"));
assert!(!injector.is_sensitive_field("APP_HOST"));
assert!(!injector.is_sensitive_field("DEBUG_MODE"));
}
#[test]
fn test_safe_retrieval() {
let validator = EnvSecurityValidator::lenient();
let injector = ConfigInjector::with_validator(validator);
injector.inject("APP_SECRET", "my-secret-value").unwrap();
injector.inject("APP_PORT", "8080").unwrap();
let secret = injector.get_safe("APP_SECRET").unwrap();
assert!(secret.contains('*'));
assert_ne!(secret, "my-secret-value");
let port = injector.get_safe("APP_PORT").unwrap();
assert_eq!(port, "8080");
}
#[test]
fn test_batch_injection() {
let validator = EnvSecurityValidator::lenient();
let injector = ConfigInjector::with_validator(validator);
let mut config = HashMap::new();
config.insert("APP_PORT".to_string(), "8080".to_string());
config.insert("APP_HOST".to_string(), "localhost".to_string());
config.insert("APP_SECRET".to_string(), "secret".to_string());
let (success, failures) = injector.inject_all(&config).unwrap();
assert_eq!(success.len(), 3);
assert!(failures.is_empty());
}
#[test]
fn test_injection_history() {
let validator = EnvSecurityValidator::lenient();
let injector = ConfigInjector::with_validator(validator);
injector.inject("APP_PORT", "8080").unwrap();
injector.inject("APP_SECRET", "secret").unwrap();
let history = injector.get_injection_history().unwrap();
assert_eq!(history.len(), 2);
assert!(!history[0].is_sensitive);
assert!(history[1].is_sensitive);
}
#[test]
fn test_environment_config() {
let injector = ConfigInjector::new();
injector.inject("APP_PORT", "8080").unwrap();
injector.inject("APP_DEBUG", "true").unwrap();
injector.inject("APP_NAME", "test-app").unwrap();
let config = EnvironmentConfig::from_injector(&injector);
assert_eq!(config.get::<u16>("APP_PORT", 8080), 8080);
assert!(config.get::<bool>("APP_DEBUG", false));
assert_eq!(config.get_string("APP_NAME", "default"), "test-app");
}
#[test]
fn test_clear_and_remove() {
let injector = ConfigInjector::new();
injector.inject("APP_PORT", "8080").unwrap();
injector
.inject("APP_CONFIG_VALUE", "secret-key-value")
.unwrap();
assert_eq!(injector.len().unwrap(), 2);
let removed = injector.remove("APP_PORT").unwrap();
assert_eq!(removed, Some("8080".to_string()));
assert_eq!(injector.len().unwrap(), 1);
injector.clear().unwrap();
assert!(injector.is_empty().unwrap());
}
#[test]
fn test_validation_failure() {
let injector = ConfigInjector::new();
assert!(injector.inject("path", "value").is_err());
assert!(injector.inject("HOME", "value").is_err());
assert!(injector.inject("APP_TEST", "hello;world").is_err());
assert!(injector.inject("APP_TEST", "hello${world}").is_err());
}
}