use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use hkdf::Hkdf;
use rand::RngCore;
use sha2::{Digest, Sha256};
pub type EncryptionKey = [u8; 32];
const NONCE_SIZE: usize = 12;
const TAG_SIZE: usize = 16;
const CHK_SALT: &[u8] = b"hashtree-chk";
#[derive(Debug, thiserror::Error)]
pub enum CryptoError {
#[error("Encryption failed: {0}")]
EncryptionFailed(String),
#[error("Decryption failed: {0}")]
DecryptionFailed(String),
#[error("Encrypted data too short")]
DataTooShort,
#[error("Invalid key length")]
InvalidKeyLength,
#[error("Key derivation failed")]
KeyDerivationFailed,
}
fn derive_key(content_hash: &[u8; 32]) -> Result<[u8; 32], CryptoError> {
let hk = Hkdf::<Sha256>::new(Some(CHK_SALT), content_hash);
let mut key = [0u8; 32];
hk.expand(b"encryption-key", &mut key)
.map_err(|_| CryptoError::KeyDerivationFailed)?;
Ok(key)
}
pub fn generate_key() -> EncryptionKey {
let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut key);
key
}
pub fn content_hash(data: &[u8]) -> EncryptionKey {
let hash = Sha256::digest(data);
let mut result = [0u8; 32];
result.copy_from_slice(&hash);
result
}
pub fn encrypt_chk(plaintext: &[u8]) -> Result<(Vec<u8>, EncryptionKey), CryptoError> {
let chash = content_hash(plaintext);
let key = derive_key(&chash)?;
let zero_nonce = [0u8; NONCE_SIZE];
let cipher = Aes256Gcm::new_from_slice(&key)
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
let ciphertext = cipher
.encrypt(Nonce::from_slice(&zero_nonce), plaintext)
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
Ok((ciphertext, chash))
}
pub fn decrypt_chk(ciphertext: &[u8], key: &EncryptionKey) -> Result<Vec<u8>, CryptoError> {
if ciphertext.len() < TAG_SIZE {
return Err(CryptoError::DataTooShort);
}
let enc_key = derive_key(key)?;
let zero_nonce = [0u8; NONCE_SIZE];
let cipher = Aes256Gcm::new_from_slice(&enc_key)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
cipher
.decrypt(Nonce::from_slice(&zero_nonce), ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
pub fn encrypt(plaintext: &[u8], key: &EncryptionKey) -> Result<Vec<u8>, CryptoError> {
let cipher =
Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn decrypt(encrypted: &[u8], key: &EncryptionKey) -> Result<Vec<u8>, CryptoError> {
if encrypted.len() < NONCE_SIZE + TAG_SIZE {
return Err(CryptoError::DataTooShort);
}
let cipher =
Aes256Gcm::new_from_slice(key).map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
let nonce = Nonce::from_slice(&encrypted[..NONCE_SIZE]);
let ciphertext = &encrypted[NONCE_SIZE..];
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
pub fn could_be_encrypted(data: &[u8]) -> bool {
data.len() >= NONCE_SIZE + TAG_SIZE
}
pub fn encrypted_size(plaintext_size: usize) -> usize {
NONCE_SIZE + plaintext_size + TAG_SIZE
}
pub fn encrypted_size_chk(plaintext_size: usize) -> usize {
plaintext_size + TAG_SIZE
}
pub fn plaintext_size(encrypted_size: usize) -> usize {
encrypted_size.saturating_sub(NONCE_SIZE + TAG_SIZE)
}
pub fn key_to_hex(key: &EncryptionKey) -> String {
hex::encode(key)
}
pub fn key_from_hex(hex_str: &str) -> Result<EncryptionKey, CryptoError> {
let bytes = hex::decode(hex_str).map_err(|_| CryptoError::InvalidKeyLength)?;
if bytes.len() != 32 {
return Err(CryptoError::InvalidKeyLength);
}
let mut key = [0u8; 32];
key.copy_from_slice(&bytes);
Ok(key)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chk_encrypt_decrypt() {
let plaintext = b"Hello, World!";
let (ciphertext, key) = encrypt_chk(plaintext).unwrap();
let decrypted = decrypt_chk(&ciphertext, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_chk_deterministic() {
let plaintext = b"Same content produces same ciphertext";
let (ciphertext1, key1) = encrypt_chk(plaintext).unwrap();
let (ciphertext2, key2) = encrypt_chk(plaintext).unwrap();
assert_eq!(key1, key2);
assert_eq!(ciphertext1, ciphertext2);
}
#[test]
fn test_chk_different_content() {
let (ciphertext1, key1) = encrypt_chk(b"Content A").unwrap();
let (ciphertext2, key2) = encrypt_chk(b"Content B").unwrap();
assert_ne!(key1, key2);
assert_ne!(ciphertext1, ciphertext2);
}
#[test]
fn test_chk_wrong_key_fails() {
let (ciphertext, _key) = encrypt_chk(b"Secret data").unwrap();
let wrong_key = generate_key();
let result = decrypt_chk(&ciphertext, &wrong_key);
assert!(result.is_err());
}
#[test]
fn test_non_chk_encrypt_decrypt() {
let key = generate_key();
let plaintext = b"Hello, World!";
let encrypted = encrypt(plaintext, &key).unwrap();
let decrypted = decrypt(&encrypted, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_non_chk_random_nonce() {
let key = generate_key();
let plaintext = b"Same content";
let encrypted1 = encrypt(plaintext, &key).unwrap();
let encrypted2 = encrypt(plaintext, &key).unwrap();
assert_ne!(encrypted1, encrypted2);
assert_eq!(decrypt(&encrypted1, &key).unwrap(), plaintext);
assert_eq!(decrypt(&encrypted2, &key).unwrap(), plaintext);
}
#[test]
fn test_empty_data() {
let (ciphertext, key) = encrypt_chk(b"").unwrap();
let decrypted = decrypt_chk(&ciphertext, &key).unwrap();
assert_eq!(decrypted, b"");
}
#[test]
fn test_large_data() {
let plaintext = vec![0u8; 1024 * 1024];
let (ciphertext, key) = encrypt_chk(&plaintext).unwrap();
let decrypted = decrypt_chk(&ciphertext, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_key_hex_roundtrip() {
let key = generate_key();
let hex_str = key_to_hex(&key);
let key2 = key_from_hex(&hex_str).unwrap();
assert_eq!(key, key2);
}
#[test]
fn test_encrypted_size_chk() {
let plaintext = b"Test data";
let (ciphertext, _) = encrypt_chk(plaintext).unwrap();
assert_eq!(ciphertext.len(), encrypted_size_chk(plaintext.len()));
}
#[test]
fn test_tampered_data_fails() {
let (mut ciphertext, key) = encrypt_chk(b"Important data").unwrap();
if let Some(byte) = ciphertext.last_mut() {
*byte ^= 0xFF;
}
let result = decrypt_chk(&ciphertext, &key);
assert!(result.is_err());
}
}