use crate::error::ConfigError;
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use lru::LruCache;
use std::env;
use std::num::NonZero;
use std::sync::Mutex;
use zeroize::ZeroizeOnDrop;
const MAX_NONCE_CACHE_SIZE: usize = 50000;
#[derive(ZeroizeOnDrop)]
pub struct SecureKey([u8; 32]);
impl SecureKey {
pub fn new(key_bytes: [u8; 32]) -> Self {
Self(key_bytes)
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
pub fn to_aes_key(&self) -> Key<Aes256Gcm> {
*Key::<Aes256Gcm>::from_slice(&self.0)
}
}
pub struct ConfigEncryption {
key: SecureKey,
nonce_cache: Mutex<LruCache<Vec<u8>, ()>>,
max_nonce_cache_size: usize,
}
impl ConfigEncryption {
pub fn new(key_bytes: [u8; 32]) -> Self {
Self::with_cache_size(key_bytes, MAX_NONCE_CACHE_SIZE)
}
pub fn with_cache_size(key_bytes: [u8; 32], cache_size: usize) -> Self {
let key = SecureKey::new(key_bytes);
Self {
key,
nonce_cache: Mutex::new(LruCache::new(
#[allow(clippy::incompatible_msrv)]
NonZero::new(cache_size).expect("cache_size must be > 0"),
)),
max_nonce_cache_size: cache_size,
}
}
pub fn from_env() -> Result<Self, ConfigError> {
let key_str = env::var("CONFERS_ENCRYPTION_KEY")
.or_else(|_| env::var("CONFERS_KEY"))
.map_err(|_| {
ConfigError::FormatDetectionFailed(
"CONFERS_ENCRYPTION_KEY (or CONFERS_KEY) not found".to_string(),
)
})?;
if !key_str
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=')
{
return Err(ConfigError::FormatDetectionFailed(
"Invalid base64 key format: contains invalid characters".to_string(),
));
}
let key_bytes = BASE64.decode(&key_str).map_err(|e| {
ConfigError::FormatDetectionFailed(format!("Invalid base64 key: {}", e))
})?;
if key_bytes.len() != 32 {
return Err(ConfigError::FormatDetectionFailed(format!(
"Key must be 32 bytes (256 bits), got {} bytes",
key_bytes.len()
)));
}
if key_bytes.iter().all(|&b| b == 0) {
return Err(ConfigError::FormatDetectionFailed(
"Weak key: all zeros".to_string(),
));
}
if key_bytes.windows(2).all(|w| w[0] == w[1]) {
return Err(ConfigError::FormatDetectionFailed(
"Weak key: all bytes are identical".to_string(),
));
}
if is_sequential_pattern(&key_bytes) {
return Err(ConfigError::FormatDetectionFailed(
"Weak key: sequential pattern detected".to_string(),
));
}
if is_repeating_pattern(&key_bytes) {
return Err(ConfigError::FormatDetectionFailed(
"Weak key: repeating pattern detected".to_string(),
));
}
let entropy = calculate_entropy(&key_bytes);
if entropy < 4.0 {
return Err(ConfigError::FormatDetectionFailed(format!(
"Weak key: insufficient entropy ({} bits per byte, minimum 4.0)",
entropy
)));
}
let mut key = [0u8; 32];
key.copy_from_slice(&key_bytes);
Ok(Self::new(key))
}
pub fn encrypt(&self, plaintext: &str) -> Result<String, ConfigError> {
let cipher = Aes256Gcm::new(&self.key.to_aes_key());
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let nonce_bytes: Vec<u8> = nonce.to_vec();
{
let mut cache = self
.nonce_cache
.lock()
.map_err(|_| ConfigError::RuntimeError("Nonce cache lock poisoned".to_string()))?;
if cache.contains(&nonce_bytes) {
return Err(ConfigError::FormatDetectionFailed(
"Nonce reuse detected - cryptographic attack prevented".to_string(),
));
}
cache.put(nonce_bytes.clone(), ());
let usage = cache.len() as f64 / self.max_nonce_cache_size as f64;
if usage > 0.8 {
#[cfg(feature = "tracing")]
tracing::warn!(
"Nonce cache is {:.0}% full ({} entries). Consider increasing cache size or rotating keys more frequently.",
usage * 100.0,
cache.len()
);
}
}
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| ConfigError::FormatDetectionFailed(format!("Encryption error: {}", e)))?;
let nonce_b64 = BASE64.encode(nonce.as_slice());
let ct_b64 = BASE64.encode(ciphertext);
Ok(format!("enc:AES256GCM:{}:{}", nonce_b64, ct_b64))
}
pub fn decrypt(&self, encrypted_value: &str) -> Result<String, ConfigError> {
if !encrypted_value.starts_with("enc:AES256GCM:") {
return Ok(encrypted_value.to_string());
}
let parts: Vec<&str> = encrypted_value.split(':').collect();
if parts.len() != 4 {
return Err(ConfigError::FormatDetectionFailed(
"Invalid encrypted value format".to_string(),
));
}
let nonce_b64 = parts[2];
let ct_b64 = parts[3];
let nonce_bytes = BASE64.decode(nonce_b64).map_err(|e| {
ConfigError::FormatDetectionFailed(format!("Invalid Nonce base64: {}", e))
})?;
{
let cache = self
.nonce_cache
.lock()
.map_err(|_| ConfigError::RuntimeError("Nonce cache lock poisoned".to_string()))?;
if cache.contains(&nonce_bytes) {
return Err(ConfigError::FormatDetectionFailed(
"Nonce reuse detected - cryptographic attack prevented".to_string(),
));
}
}
let ciphertext = BASE64.decode(ct_b64).map_err(|e| {
ConfigError::FormatDetectionFailed(format!("Invalid ciphertext base64: {}", e))
})?;
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = Aes256Gcm::new(&self.key.to_aes_key());
let plaintext_bytes = cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|e| ConfigError::FormatDetectionFailed(format!("Decryption error: {}", e)))?;
let plaintext = String::from_utf8(plaintext_bytes)
.map_err(|e| ConfigError::FormatDetectionFailed(format!("Invalid UTF-8: {}", e)))?;
Ok(plaintext)
}
pub fn nonce_cache_size(&self) -> usize {
self.nonce_cache
.lock()
.map(|cache| cache.len())
.unwrap_or(0)
}
pub fn cache_usage_percent(&self) -> f64 {
let size = self.nonce_cache_size();
(size as f64 / MAX_NONCE_CACHE_SIZE as f64) * 100.0
}
pub fn is_cache_near_full(&self, threshold: f64) -> bool {
self.cache_usage_percent() > threshold
}
pub fn cache_stats(&self) -> CacheStats {
CacheStats {
current_size: self.nonce_cache_size(),
max_size: self.max_nonce_cache_size,
usage_percent: self.cache_usage_percent(),
is_near_full: self.is_cache_near_full(80.0),
}
}
pub fn clear_nonce_cache(&self) {
if let Ok(mut cache) = self.nonce_cache.lock() {
cache.clear();
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub current_size: usize,
pub max_size: usize,
pub usage_percent: f64,
pub is_near_full: bool,
}
fn calculate_entropy(data: &[u8]) -> f64 {
let mut freq = [0usize; 256];
for &byte in data {
freq[byte as usize] += 1;
}
let len = data.len() as f64;
let mut entropy = 0.0;
for &count in &freq {
if count > 0 {
let p = count as f64 / len;
entropy -= p * p.log2();
}
}
entropy
}
fn is_sequential_pattern(key: &[u8]) -> bool {
let mut ascending = 0;
let mut descending = 0;
for i in 0..key.len() - 1 {
if key[i + 1] == key[i].wrapping_add(1) {
ascending += 1;
} else if key[i + 1] == key[i].wrapping_sub(1) {
descending += 1;
}
}
ascending > key.len() / 2 || descending > key.len() / 2
}
fn is_repeating_pattern(key: &[u8]) -> bool {
for period in [2, 4, 8, 16] {
if key.len() % period == 0 {
let mut is_repeating = true;
for i in period..key.len() {
if key[i] != key[i % period] {
is_repeating = false;
break;
}
}
if is_repeating {
return true;
}
}
}
false
}