use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use anyhow::{anyhow, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::path::Path;
const NONCE_SIZE: usize = 12;
const KEY_SIZE: usize = 32;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedData {
pub ciphertext: String,
pub nonce: String,
pub algorithm: String,
pub kdf: KeyDerivation,
pub version: u8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyDerivation {
pub algorithm: String,
pub salt: String,
pub iterations: Option<u32>,
pub memory: Option<u32>,
pub time: Option<u32>,
}
pub struct EncryptionManager {
key: Key<Aes256Gcm>,
enabled: bool,
}
impl EncryptionManager {
pub fn new(password: &str, salt: &[u8]) -> Result<Self> {
let key = Self::derive_key(password, salt)?;
Ok(Self { key, enabled: true })
}
pub fn disabled() -> Self {
Self {
key: Key::<Aes256Gcm>::default(),
enabled: false,
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
fn derive_key(password: &str, salt: &[u8]) -> Result<Key<Aes256Gcm>> {
let mut key = [0u8; KEY_SIZE];
pbkdf2::pbkdf2_hmac::<sha2::Sha256>(password.as_bytes(), salt, 100_000, &mut key);
Ok(*Key::<Aes256Gcm>::from_slice(&key))
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData> {
if !self.enabled {
return Err(anyhow!("Encryption is not enabled"));
}
let cipher = Aes256Gcm::new(&self.key);
let nonce_bytes: [u8; NONCE_SIZE] = rand::random();
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
Ok(EncryptedData {
ciphertext: BASE64.encode(&ciphertext),
nonce: BASE64.encode(nonce_bytes),
algorithm: "AES-256-GCM".to_string(),
kdf: KeyDerivation {
algorithm: "PBKDF2-HMAC-SHA256".to_string(),
salt: String::new(), iterations: Some(100_000),
memory: None,
time: None,
},
version: 1,
})
}
pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>> {
if !self.enabled {
return Err(anyhow!("Encryption is not enabled"));
}
if encrypted.version != 1 {
return Err(anyhow!(
"Unsupported encryption version: {}",
encrypted.version
));
}
let cipher = Aes256Gcm::new(&self.key);
let ciphertext = BASE64
.decode(&encrypted.ciphertext)
.map_err(|e| anyhow!("Invalid ciphertext encoding: {}", e))?;
let nonce_bytes = BASE64
.decode(&encrypted.nonce)
.map_err(|e| anyhow!("Invalid nonce encoding: {}", e))?;
if nonce_bytes.len() != NONCE_SIZE {
return Err(anyhow!("Invalid nonce size"));
}
let nonce = Nonce::from_slice(&nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|e| anyhow!("Decryption failed: {}", e))?;
Ok(plaintext)
}
pub fn encrypt_string(&self, plaintext: &str) -> Result<String> {
let encrypted = self.encrypt(plaintext.as_bytes())?;
Ok(serde_json::to_string(&encrypted)?)
}
pub fn decrypt_string(&self, encrypted_json: &str) -> Result<String> {
let encrypted: EncryptedData = serde_json::from_str(encrypted_json)?;
let plaintext = self.decrypt(&encrypted)?;
String::from_utf8(plaintext).map_err(|e| anyhow!("Invalid UTF-8: {}", e))
}
}
pub fn encrypt_messages(manager: &EncryptionManager, messages_json: &str) -> Result<String> {
if !manager.is_enabled() {
return Ok(messages_json.to_string());
}
manager.encrypt_string(messages_json)
}
pub fn decrypt_messages(manager: &EncryptionManager, encrypted_messages: &str) -> Result<String> {
if !manager.is_enabled() {
return Ok(encrypted_messages.to_string());
}
if encrypted_messages.starts_with('{') && encrypted_messages.contains("\"ciphertext\"") {
manager.decrypt_string(encrypted_messages)
} else {
Ok(encrypted_messages.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptionConfig {
pub enabled: bool,
pub salt: String,
pub password_hash: String,
}
impl EncryptionConfig {
pub fn new(password: &str) -> Self {
let salt: [u8; 32] = rand::random();
let password_hash = Self::hash_password(password, &salt);
Self {
enabled: true,
salt: BASE64.encode(salt),
password_hash,
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
salt: String::new(),
password_hash: String::new(),
}
}
fn hash_password(password: &str, salt: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
hasher.update(salt);
hasher.update(b"verification");
BASE64.encode(hasher.finalize())
}
pub fn verify_password(&self, password: &str) -> bool {
if !self.enabled {
return true;
}
if let Ok(salt) = BASE64.decode(&self.salt) {
let hash = Self::hash_password(password, &salt);
hash == self.password_hash
} else {
false
}
}
pub fn get_salt(&self) -> Result<Vec<u8>> {
BASE64
.decode(&self.salt)
.map_err(|e| anyhow!("Invalid salt: {}", e))
}
pub fn load(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&content)?)
}
pub fn save(&self, path: &Path) -> Result<()> {
let content = serde_json::to_string_pretty(self)?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, content)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt() {
let password = "test_password_123";
let salt = b"test_salt_12345678901234";
let manager = EncryptionManager::new(password, salt).unwrap();
let plaintext = "Hello, encrypted world!";
let encrypted = manager.encrypt(plaintext.as_bytes()).unwrap();
assert!(!encrypted.ciphertext.is_empty());
assert!(!encrypted.nonce.is_empty());
assert_eq!(encrypted.algorithm, "AES-256-GCM");
let decrypted = manager.decrypt(&encrypted).unwrap();
assert_eq!(String::from_utf8(decrypted).unwrap(), plaintext);
}
#[test]
fn test_encrypt_decrypt_string() {
let password = "secure_password";
let salt = b"random_salt_value_here";
let manager = EncryptionManager::new(password, salt).unwrap();
let original = r#"{"role": "user", "content": "Secret message"}"#;
let encrypted = manager.encrypt_string(original).unwrap();
let decrypted = manager.decrypt_string(&encrypted).unwrap();
assert_eq!(decrypted, original);
}
#[test]
fn test_password_verification() {
let config = EncryptionConfig::new("my_password");
assert!(config.verify_password("my_password"));
assert!(!config.verify_password("wrong_password"));
}
#[test]
fn test_disabled_encryption() {
let manager = EncryptionManager::disabled();
assert!(!manager.is_enabled());
let config = EncryptionConfig::disabled();
assert!(!config.enabled);
assert!(config.verify_password("any_password"));
}
}