use crate::{ZoeyError, Result};
use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Nonce,
};
use argon2::password_hash::SaltString;
use argon2::{Argon2, PasswordHasher};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use rand::RngCore;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; 32]> {
let argon2 = Argon2::default();
let salt_string = SaltString::encode_b64(salt)
.map_err(|e| ZoeyError::other(format!("Failed to encode salt: {}", e)))?;
let password_hash = argon2
.hash_password(password.as_bytes(), &salt_string)
.map_err(|e| ZoeyError::other(format!("Failed to derive key: {}", e)))?;
let hash_output = password_hash
.hash
.ok_or_else(|| ZoeyError::other("No hash produced"))?;
let hash_bytes = hash_output.as_bytes();
let mut key = [0u8; 32];
key.copy_from_slice(&hash_bytes[..32]);
Ok(key)
}
pub fn encrypt_secret(value: &str, key: &str) -> Result<String> {
let mut salt = [0u8; 16];
OsRng.fill_bytes(&mut salt);
let derived_key = derive_key(key, &salt)?;
let cipher = Aes256Gcm::new(&derived_key.into());
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, value.as_bytes())
.map_err(|e| ZoeyError::other(format!("Encryption failed: {}", e)))?;
let mut result = Vec::with_capacity(salt.len() + nonce_bytes.len() + ciphertext.len());
result.extend_from_slice(&salt);
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(BASE64.encode(&result))
}
pub fn decrypt_secret(encrypted: &str, key: &str) -> Result<String> {
let decoded = BASE64
.decode(encrypted)
.map_err(|e| ZoeyError::other(format!("Failed to decode base64: {}", e)))?;
if decoded.len() < 28 + 16 {
return Err(ZoeyError::other("Invalid encrypted data: too short"));
}
let salt = &decoded[0..16];
let nonce_bytes = &decoded[16..28];
let ciphertext = &decoded[28..];
let derived_key = derive_key(key, salt)?;
let cipher = Aes256Gcm::new(&derived_key.into());
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| ZoeyError::other(format!("Decryption failed: {}", e)))?;
String::from_utf8(plaintext)
.map_err(|e| ZoeyError::other(format!("Invalid UTF-8 in decrypted data: {}", e)))
}
pub fn validate_input(input: &str, max_length: usize) -> Result<()> {
if input.len() > max_length {
return Err(ZoeyError::validation(format!(
"Input too long: {} > {}",
input.len(),
max_length
)));
}
if input.contains('\0') {
return Err(ZoeyError::validation("Input contains null bytes"));
}
for ch in input.chars() {
if ch.is_control() && ch != '\n' && ch != '\t' && ch != '\r' {
return Err(ZoeyError::validation(
"Input contains invalid control characters",
));
}
}
Ok(())
}
pub fn sanitize_input(input: &str) -> String {
input
.chars()
.filter(|ch| !ch.is_control() || *ch == '\n' || *ch == '\t' || *ch == '\r')
.collect()
}
const MAX_RATE_LIMIT_KEY_LENGTH: usize = 256;
const MAX_TRACKED_KEYS: usize = 100_000;
pub struct RateLimiter {
limits: Arc<RwLock<HashMap<String, Vec<Instant>>>>,
window: Duration,
max_requests: usize,
}
impl RateLimiter {
pub fn new(window: Duration, max_requests: usize) -> Self {
Self {
limits: Arc::new(RwLock::new(HashMap::new())),
window,
max_requests,
}
}
fn get_limits_write(&self) -> std::sync::RwLockWriteGuard<'_, HashMap<String, Vec<Instant>>> {
self.limits.write().unwrap_or_else(|poisoned| {
tracing::error!("RateLimiter lock was poisoned, recovering");
poisoned.into_inner()
})
}
pub fn check(&self, key: &str) -> bool {
if key.len() > MAX_RATE_LIMIT_KEY_LENGTH {
tracing::warn!("Rate limit key too long, rejecting request");
return false;
}
let mut limits = self.get_limits_write();
let now = Instant::now();
if limits.len() >= MAX_TRACKED_KEYS && !limits.contains_key(key) {
let keys_to_remove: Vec<String> = limits
.iter()
.filter(|(_, timestamps)| {
timestamps
.last()
.map(|&t| now.duration_since(t) >= self.window)
.unwrap_or(true)
})
.map(|(k, _)| k.clone())
.take(1000) .collect();
for k in keys_to_remove {
limits.remove(&k);
}
if limits.len() >= MAX_TRACKED_KEYS {
tracing::warn!("Rate limiter at capacity, rejecting new key");
return false;
}
}
let timestamps = limits.entry(key.to_string()).or_insert_with(Vec::new);
timestamps.retain(|&t| now.duration_since(t) < self.window);
if timestamps.len() < self.max_requests {
timestamps.push(now);
true
} else {
false
}
}
pub fn reset(&self, key: &str) {
let mut limits = self.get_limits_write();
limits.remove(key);
}
pub fn remaining(&self, key: &str) -> usize {
let mut limits = self.get_limits_write();
let now = Instant::now();
if let Some(timestamps) = limits.get_mut(key) {
timestamps.retain(|&t| now.duration_since(t) < self.window);
self.max_requests.saturating_sub(timestamps.len())
} else {
self.max_requests
}
}
}
pub fn hash_password(password: &str, salt: &str) -> String {
use argon2::password_hash::SaltString;
use argon2::{Argon2, PasswordHasher};
let salt_string = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let combined = format!("{}:{}", password, salt);
match argon2.hash_password(combined.as_bytes(), &salt_string) {
Ok(hash) => hash.to_string(),
Err(e) => {
tracing::error!("Argon2 hashing failed, using fallback: {}", e);
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
hasher.update(salt.as_bytes());
format!("SHA256:{:x}", hasher.finalize())
}
}
}
pub fn verify_password(password: &str, salt: &str, hash: &str) -> bool {
use argon2::password_hash::PasswordHash;
use argon2::{Argon2, PasswordVerifier};
if hash.starts_with("SHA256:") {
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
hasher.update(salt.as_bytes());
let computed = format!("SHA256:{:x}", hasher.finalize());
return computed == hash;
}
let parsed_hash = match PasswordHash::new(hash) {
Ok(h) => h,
Err(e) => {
tracing::warn!("Failed to parse password hash: {}", e);
return false;
}
};
let combined = format!("{}:{}", password, salt);
Argon2::default()
.verify_password(combined.as_bytes(), &parsed_hash)
.is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt() {
let plaintext = "Hello, World! This is a secret message.";
let key = "my-secret-password";
let encrypted = encrypt_secret(plaintext, key).expect("Encryption should succeed");
assert_ne!(encrypted, plaintext);
let decrypted = decrypt_secret(&encrypted, key).expect("Decryption should succeed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_different_outputs() {
let plaintext = "Same message";
let key = "same-key";
let encrypted1 = encrypt_secret(plaintext, key).unwrap();
let encrypted2 = encrypt_secret(plaintext, key).unwrap();
assert_ne!(encrypted1, encrypted2);
assert_eq!(decrypt_secret(&encrypted1, key).unwrap(), plaintext);
assert_eq!(decrypt_secret(&encrypted2, key).unwrap(), plaintext);
}
#[test]
fn test_decrypt_wrong_key() {
let plaintext = "Secret data";
let key1 = "correct-password";
let key2 = "wrong-password";
let encrypted = encrypt_secret(plaintext, key1).unwrap();
assert!(decrypt_secret(&encrypted, key2).is_err());
}
#[test]
fn test_decrypt_invalid_data() {
let key = "some-key";
assert!(decrypt_secret("dGVzdA==", key).is_err());
assert!(decrypt_secret("not-valid-base64!!!", key).is_err());
let plaintext = "test";
let encrypted = encrypt_secret(plaintext, key).unwrap();
let mut corrupted = encrypted.clone();
corrupted.push('X'); assert!(decrypt_secret(&corrupted, key).is_err());
}
#[test]
fn test_encrypt_empty_string() {
let encrypted = encrypt_secret("", "key").unwrap();
let decrypted = decrypt_secret(&encrypted, "key").unwrap();
assert_eq!(decrypted, "");
}
#[test]
fn test_encrypt_unicode() {
let plaintext = "Hello 世界 🌍 مرحبا";
let key = "unicode-key";
let encrypted = encrypt_secret(plaintext, key).unwrap();
let decrypted = decrypt_secret(&encrypted, key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_validate_input() {
assert!(validate_input("Hello, World!", 100).is_ok());
assert!(validate_input("A".repeat(1000).as_str(), 100).is_err());
assert!(validate_input("Hello\0World", 100).is_err());
}
#[test]
fn test_sanitize_input() {
let input = "Hello\x01World\x02";
let sanitized = sanitize_input(input);
assert_eq!(sanitized, "HelloWorld");
}
#[test]
fn test_rate_limiter() {
let limiter = RateLimiter::new(Duration::from_secs(60), 5);
for _ in 0..5 {
assert!(limiter.check("user1"));
}
assert!(!limiter.check("user1"));
assert!(limiter.check("user2"));
limiter.reset("user1");
assert!(limiter.check("user1"));
}
#[test]
fn test_remaining() {
let limiter = RateLimiter::new(Duration::from_secs(60), 10);
assert_eq!(limiter.remaining("user1"), 10);
limiter.check("user1");
assert_eq!(limiter.remaining("user1"), 9);
}
#[test]
fn test_hash_password() {
let hash1 = hash_password("password123", "salt");
let hash2 = hash_password("password123", "salt");
assert_ne!(hash1, hash2);
assert!(verify_password("password123", "salt", &hash1));
assert!(verify_password("password123", "salt", &hash2));
assert!(!verify_password("different", "salt", &hash1));
}
#[test]
fn test_verify_password() {
let hash = hash_password("secret", "mysalt");
assert!(verify_password("secret", "mysalt", &hash));
assert!(!verify_password("wrong", "mysalt", &hash));
}
}