use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use argon2::{password_hash::SaltString, Argon2, PasswordHasher};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use rand::RngCore;
use super::CloudError;
pub const KEY_SIZE: usize = 32;
pub const NONCE_SIZE: usize = 12;
#[allow(dead_code)]
pub const SALT_SIZE: usize = 16;
pub fn derive_key(passphrase: &str, salt: &[u8]) -> Result<Vec<u8>, CloudError> {
let salt_string = SaltString::encode_b64(salt)
.map_err(|e| CloudError::EncryptionError(format!("Invalid salt: {e}")))?;
let argon2 = Argon2::default();
let hash = argon2
.hash_password(passphrase.as_bytes(), &salt_string)
.map_err(|e| CloudError::EncryptionError(format!("Key derivation failed: {e}")))?;
let hash_output = hash
.hash
.ok_or_else(|| CloudError::EncryptionError("No hash output".to_string()))?;
let key_bytes = hash_output.as_bytes();
if key_bytes.len() < KEY_SIZE {
return Err(CloudError::EncryptionError(
"Derived key too short".to_string(),
));
}
Ok(key_bytes[..KEY_SIZE].to_vec())
}
#[allow(dead_code)]
pub fn generate_salt() -> Vec<u8> {
let mut salt = vec![0u8; SALT_SIZE];
rand::thread_rng().fill_bytes(&mut salt);
salt
}
pub fn encrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, CloudError> {
if key.len() != KEY_SIZE {
return Err(CloudError::EncryptionError(format!(
"Invalid key size: expected {KEY_SIZE}, got {}",
key.len()
)));
}
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| CloudError::EncryptionError(format!("Cipher creation failed: {e}")))?;
let ciphertext = cipher
.encrypt(nonce, data)
.map_err(|e| CloudError::EncryptionError(format!("Encryption failed: {e}")))?;
let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn decrypt_data(data: &[u8], key: &[u8]) -> Result<Vec<u8>, CloudError> {
if key.len() != KEY_SIZE {
return Err(CloudError::EncryptionError(format!(
"Invalid key size: expected {KEY_SIZE}, got {}",
key.len()
)));
}
if data.len() < NONCE_SIZE {
return Err(CloudError::EncryptionError(
"Encrypted data too short".to_string(),
));
}
let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
let nonce = Nonce::from_slice(nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| CloudError::EncryptionError(format!("Cipher creation failed: {e}")))?;
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| CloudError::EncryptionError(format!("Decryption failed: {e}")))?;
Ok(plaintext)
}
pub fn encode_base64(data: &[u8]) -> String {
BASE64.encode(data)
}
pub fn decode_base64(data: &str) -> Result<Vec<u8>, CloudError> {
BASE64
.decode(data)
.map_err(|e| CloudError::EncryptionError(format!("Base64 decode failed: {e}")))
}
pub fn encode_key_hex(key: &[u8]) -> String {
hex::encode(key)
}
pub fn decode_key_hex(hex_str: &str) -> Result<Vec<u8>, CloudError> {
hex::decode(hex_str).map_err(|e| CloudError::EncryptionError(format!("Hex decode failed: {e}")))
}
mod hex {
pub fn encode(data: &[u8]) -> String {
data.iter().map(|b| format!("{:02x}", b)).collect()
}
pub fn decode(s: &str) -> Result<Vec<u8>, String> {
if !s.len().is_multiple_of(2) {
return Err("Hex string has odd length".to_string());
}
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| format!("Invalid hex: {e}")))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_salt_length() {
let salt = generate_salt();
assert_eq!(salt.len(), SALT_SIZE);
}
#[test]
fn test_generate_salt_randomness() {
let salt1 = generate_salt();
let salt2 = generate_salt();
assert_ne!(salt1, salt2);
}
#[test]
fn test_derive_key_consistency() {
let passphrase = "test passphrase";
let salt = generate_salt();
let key1 = derive_key(passphrase, &salt).unwrap();
let key2 = derive_key(passphrase, &salt).unwrap();
assert_eq!(key1, key2);
assert_eq!(key1.len(), KEY_SIZE);
}
#[test]
fn test_derive_key_different_passphrases() {
let salt = generate_salt();
let key1 = derive_key("passphrase1", &salt).unwrap();
let key2 = derive_key("passphrase2", &salt).unwrap();
assert_ne!(key1, key2);
}
#[test]
fn test_derive_key_different_salts() {
let passphrase = "test passphrase";
let salt1 = generate_salt();
let salt2 = generate_salt();
let key1 = derive_key(passphrase, &salt1).unwrap();
let key2 = derive_key(passphrase, &salt2).unwrap();
assert_ne!(key1, key2);
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let passphrase = "test passphrase";
let salt = generate_salt();
let key = derive_key(passphrase, &salt).unwrap();
let plaintext = b"Hello, World! This is a test message.";
let encrypted = encrypt_data(plaintext, &key).unwrap();
let decrypted = decrypt_data(&encrypted, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_produces_different_ciphertext() {
let salt = generate_salt();
let key = derive_key("passphrase", &salt).unwrap();
let plaintext = b"test data";
let encrypted1 = encrypt_data(plaintext, &key).unwrap();
let encrypted2 = encrypt_data(plaintext, &key).unwrap();
assert_ne!(encrypted1, encrypted2);
}
#[test]
fn test_decrypt_with_wrong_key_fails() {
let salt = generate_salt();
let key1 = derive_key("passphrase1", &salt).unwrap();
let key2 = derive_key("passphrase2", &salt).unwrap();
let plaintext = b"secret data";
let encrypted = encrypt_data(plaintext, &key1).unwrap();
let result = decrypt_data(&encrypted, &key2);
assert!(result.is_err());
}
#[test]
fn test_decrypt_with_corrupted_data_fails() {
let salt = generate_salt();
let key = derive_key("passphrase", &salt).unwrap();
let plaintext = b"secret data";
let mut encrypted = encrypt_data(plaintext, &key).unwrap();
if let Some(byte) = encrypted.get_mut(NONCE_SIZE + 5) {
*byte ^= 0xFF;
}
let result = decrypt_data(&encrypted, &key);
assert!(result.is_err());
}
#[test]
fn test_encrypt_data_invalid_key_size() {
let short_key = vec![0u8; 16]; let result = encrypt_data(b"data", &short_key);
assert!(result.is_err());
}
#[test]
fn test_decrypt_data_too_short() {
let salt = generate_salt();
let key = derive_key("passphrase", &salt).unwrap();
let short_data = vec![0u8; 5]; let result = decrypt_data(&short_data, &key);
assert!(result.is_err());
}
#[test]
fn test_base64_roundtrip() {
let data = b"test binary data \x00\x01\x02";
let encoded = encode_base64(data);
let decoded = decode_base64(&encoded).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_hex_roundtrip() {
let data = vec![0u8, 1, 2, 255, 128, 64];
let encoded = encode_key_hex(&data);
let decoded = decode_key_hex(&encoded).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_hex_encode() {
assert_eq!(hex::encode(&[0, 255, 128]), "00ff80");
}
#[test]
fn test_hex_decode_invalid() {
assert!(hex::decode("xyz").is_err());
assert!(hex::decode("abc").is_err()); }
#[test]
fn test_encrypt_empty_data() {
let salt = generate_salt();
let key = derive_key("passphrase", &salt).unwrap();
let plaintext = b"";
let encrypted = encrypt_data(plaintext, &key).unwrap();
let decrypted = decrypt_data(&encrypted, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_large_data() {
let salt = generate_salt();
let key = derive_key("passphrase", &salt).unwrap();
let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
let encrypted = encrypt_data(&plaintext, &key).unwrap();
let decrypted = decrypt_data(&encrypted, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
}