use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct EnvCacheConfig {
pub ttl: Duration,
pub enabled: bool,
pub max_entries: usize,
}
impl Default for EnvCacheConfig {
fn default() -> Self {
Self {
ttl: Duration::from_secs(30), enabled: true,
max_entries: 50, }
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
value: Option<String>,
cached_at: Instant,
access_count: u64,
}
impl CacheEntry {
fn new(value: Option<String>) -> Self {
Self {
value,
cached_at: Instant::now(),
access_count: 0,
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.cached_at.elapsed() > ttl
}
fn access(&mut self) -> &Option<String> {
self.access_count += 1;
&self.value
}
}
#[derive(Debug, Clone, Default)]
pub struct EnvCacheStats {
pub hits: u64,
pub misses: u64,
pub ttl_evictions: u64,
pub current_entries: usize,
#[allow(dead_code)]
pub max_entries: usize,
}
impl EnvCacheStats {
#[allow(dead_code)]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
pub struct EnvironmentCache {
cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
config: EnvCacheConfig,
stats: Arc<RwLock<EnvCacheStats>>,
safe_variables: std::collections::HashSet<&'static str>,
}
impl EnvironmentCache {
pub fn new() -> Self {
Self::with_config(EnvCacheConfig::default())
}
pub fn with_config(config: EnvCacheConfig) -> Self {
let stats = EnvCacheStats {
max_entries: config.max_entries,
..Default::default()
};
let safe_variables = std::collections::HashSet::from([
"HOME",
"USER",
"LOGNAME",
"USERNAME",
"SSH_AUTH_SOCK",
"SSH_CONNECTION",
"SSH_CLIENT",
"SSH_TTY",
"LANG",
"LC_ALL",
"LC_CTYPE",
"LC_MESSAGES",
"TMPDIR",
"TEMP",
"TMP",
"TERM",
"COLORTERM",
]);
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
config,
stats: Arc::new(RwLock::new(stats)),
safe_variables,
}
}
pub fn get_env_var(&self, var_name: &str) -> Result<Option<String>, anyhow::Error> {
if !self.config.enabled {
return if self.safe_variables.contains(var_name) {
Ok(std::env::var(var_name).ok())
} else {
tracing::warn!(
"Blocked access to non-whitelisted environment variable '{}' (cache disabled)",
var_name
);
Ok(None)
};
}
if !self.safe_variables.contains(var_name) {
tracing::warn!(
"Blocked access to non-whitelisted environment variable '{}'",
var_name
);
return Ok(None);
}
if let Some(value) = self.try_get_cached(var_name)? {
return Ok(value);
}
let value = std::env::var(var_name).ok();
self.put(var_name.to_string(), value.clone());
{
let mut stats = self.stats.write().unwrap();
stats.misses += 1;
}
tracing::trace!("Environment variable cache miss: {}", var_name);
Ok(value)
}
fn try_get_cached(&self, var_name: &str) -> Result<Option<Option<String>>, anyhow::Error> {
let mut cache = self.cache.write().unwrap();
if let Some(entry) = cache.get_mut(var_name) {
if entry.is_expired(self.config.ttl) {
tracing::trace!("Environment variable cache entry expired: {}", var_name);
cache.remove(var_name);
let mut stats = self.stats.write().unwrap();
stats.ttl_evictions += 1;
return Ok(None);
}
let value = entry.access().clone();
{
let mut stats = self.stats.write().unwrap();
stats.hits += 1;
}
tracing::trace!("Environment variable cache hit: {}", var_name);
return Ok(Some(value));
}
Ok(None)
}
fn put(&self, var_name: String, value: Option<String>) {
let mut cache = self.cache.write().unwrap();
if cache.len() >= self.config.max_entries {
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, entry)| entry.cached_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
tracing::debug!(
"Evicted environment variable from cache due to size limit: {}",
oldest_key
);
}
}
let entry = CacheEntry::new(value);
cache.insert(var_name.clone(), entry);
{
let mut stats = self.stats.write().unwrap();
stats.current_entries = cache.len();
}
tracing::trace!("Environment variable cached: {}", var_name);
}
#[allow(dead_code)]
pub fn clear(&self) {
let mut cache = self.cache.write().unwrap();
cache.clear();
let mut stats = self.stats.write().unwrap();
stats.current_entries = 0;
}
#[allow(dead_code)]
pub fn remove(&self, var_name: &str) -> Option<String> {
let mut cache = self.cache.write().unwrap();
let entry = cache.remove(var_name)?;
let mut stats = self.stats.write().unwrap();
stats.current_entries = cache.len();
entry.value
}
#[allow(dead_code)]
pub fn stats(&self) -> EnvCacheStats {
self.stats.read().unwrap().clone()
}
#[allow(dead_code)]
pub fn config(&self) -> &EnvCacheConfig {
&self.config
}
#[allow(dead_code)]
pub fn maintain(&self) -> usize {
if !self.config.enabled {
return 0;
}
let mut cache = self.cache.write().unwrap();
let mut expired_keys = Vec::new();
for (key, entry) in cache.iter() {
if entry.is_expired(self.config.ttl) {
expired_keys.push(key.clone());
}
}
for key in &expired_keys {
cache.remove(key);
}
let removed_count = expired_keys.len();
{
let mut stats = self.stats.write().unwrap();
stats.ttl_evictions += removed_count as u64;
stats.current_entries = cache.len();
}
if removed_count > 0 {
tracing::debug!(
"Environment cache maintenance: removed {} expired entries",
removed_count
);
}
removed_count
}
#[allow(dead_code)]
pub fn refresh(&self) {
self.clear();
tracing::debug!("Environment variable cache refreshed");
}
#[allow(dead_code)]
pub fn debug_info(&self) -> HashMap<String, String> {
let cache = self.cache.read().unwrap();
let mut info = HashMap::new();
for (key, entry) in cache.iter() {
let age = entry.cached_at.elapsed();
let is_expired = entry.is_expired(self.config.ttl);
let has_value = entry.value.is_some();
let status = if is_expired { "EXPIRED" } else { "VALID" };
info.insert(
key.clone(),
format!(
"Status: {}, Age: {:?}, Accesses: {}, Has value: {}",
status, age, entry.access_count, has_value
),
);
}
info
}
#[allow(dead_code)]
pub fn is_safe_variable(&self, var_name: &str) -> bool {
self.safe_variables.contains(var_name)
}
#[allow(dead_code)]
pub fn safe_variables(&self) -> Vec<&'static str> {
self.safe_variables.iter().copied().collect()
}
}
impl Default for EnvironmentCache {
fn default() -> Self {
Self::new()
}
}
pub static GLOBAL_ENV_CACHE: Lazy<EnvironmentCache> = Lazy::new(|| {
let config = EnvCacheConfig {
ttl: Duration::from_secs(
std::env::var("BSSH_ENV_CACHE_TTL")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30),
),
enabled: std::env::var("BSSH_ENV_CACHE_ENABLED")
.map(|s| s.to_lowercase() != "false" && s != "0")
.unwrap_or(true),
max_entries: std::env::var("BSSH_ENV_CACHE_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(50),
};
tracing::debug!(
"Initializing environment variable cache with {} max entries, {:?} TTL, enabled: {}",
config.max_entries,
config.ttl,
config.enabled
);
EnvironmentCache::with_config(config)
});
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[test]
fn test_env_cache_config_default() {
let config = EnvCacheConfig::default();
assert_eq!(config.ttl, Duration::from_secs(30));
assert!(config.enabled);
assert_eq!(config.max_entries, 50);
}
#[test]
fn test_cache_entry_expiration() {
let mut entry = CacheEntry::new(Some("test".to_string()));
assert!(!entry.is_expired(Duration::from_secs(60)));
entry.cached_at = Instant::now() - Duration::from_secs(120);
assert!(entry.is_expired(Duration::from_secs(60)));
}
#[test]
fn test_env_cache_basic_operations() {
let cache = EnvironmentCache::new();
if let Ok(Some(value)) = cache.get_env_var("HOME") {
assert!(!value.is_empty());
let cached_value = cache.get_env_var("HOME").unwrap();
assert_eq!(cached_value, Some(value));
}
let stats = cache.stats();
assert!(stats.hits > 0 || stats.misses > 0);
}
#[test]
fn test_env_cache_unsafe_variable_blocked() {
let cache = EnvironmentCache::new();
let result = cache.get_env_var("PATH").unwrap();
assert_eq!(result, None);
assert!(!cache.is_safe_variable("PATH"));
assert!(!cache.is_safe_variable("LD_PRELOAD"));
assert!(cache.is_safe_variable("HOME"));
assert!(cache.is_safe_variable("USER"));
}
#[test]
fn test_env_cache_ttl_expiration() {
let config = EnvCacheConfig {
ttl: Duration::from_millis(50),
enabled: true,
max_entries: 10,
};
let cache = EnvironmentCache::with_config(config);
let _result1 = cache.get_env_var("HOME");
std::thread::sleep(Duration::from_millis(100));
let _result2 = cache.get_env_var("HOME");
let stats = cache.stats();
assert!(stats.ttl_evictions > 0);
}
#[test]
fn test_env_cache_size_limit() {
let config = EnvCacheConfig {
ttl: Duration::from_secs(60),
enabled: true,
max_entries: 2, };
let cache = EnvironmentCache::with_config(config);
let _r1 = cache.get_env_var("HOME");
let _r2 = cache.get_env_var("USER");
let _r3 = cache.get_env_var("TMPDIR");
let stats = cache.stats();
assert!(stats.current_entries <= 2);
}
#[test]
fn test_env_cache_clear_and_refresh() {
let cache = EnvironmentCache::new();
let _r1 = cache.get_env_var("HOME");
assert!(cache.stats().current_entries > 0);
cache.clear();
assert_eq!(cache.stats().current_entries, 0);
let _r2 = cache.get_env_var("HOME");
assert!(cache.stats().current_entries > 0);
cache.refresh();
assert_eq!(cache.stats().current_entries, 0);
}
#[test]
fn test_env_cache_maintenance() {
let config = EnvCacheConfig {
ttl: Duration::from_millis(50),
enabled: true,
max_entries: 10,
};
let cache = EnvironmentCache::with_config(config);
let _result = cache.get_env_var("HOME");
assert!(cache.stats().current_entries > 0);
std::thread::sleep(Duration::from_millis(100));
let removed = cache.maintain();
assert!(removed > 0);
assert_eq!(cache.stats().current_entries, 0);
}
#[test]
fn test_env_cache_disabled() {
let config = EnvCacheConfig {
ttl: Duration::from_secs(60),
enabled: false,
max_entries: 10,
};
let cache = EnvironmentCache::with_config(config);
let _r1 = cache.get_env_var("HOME");
let _r2 = cache.get_env_var("HOME");
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.current_entries, 0);
}
#[test]
fn test_env_cache_stats() {
let cache = EnvironmentCache::new();
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.hit_rate(), 0.0);
assert_eq!(stats.current_entries, 0);
assert_eq!(stats.max_entries, 50);
}
#[test]
fn test_env_cache_safe_variables_list() {
let cache = EnvironmentCache::new();
let safe_vars = cache.safe_variables();
assert!(safe_vars.contains(&"HOME"));
assert!(safe_vars.contains(&"USER"));
assert!(safe_vars.contains(&"SSH_AUTH_SOCK"));
assert!(!safe_vars.contains(&"PATH"));
assert!(!safe_vars.contains(&"LD_PRELOAD"));
}
#[test]
fn test_env_cache_concurrent_access() {
let cache = Arc::new(EnvironmentCache::new());
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let cache_clone = Arc::clone(&cache);
let counter_clone = Arc::clone(&counter);
let handle = std::thread::spawn(move || {
for _ in 0..100 {
if cache_clone.get_env_var("HOME").is_ok() {
counter_clone.fetch_add(1, Ordering::Relaxed);
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(counter.load(Ordering::Relaxed) > 0);
let stats = cache.stats();
assert!(stats.hits + stats.misses > 0);
}
}