use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::CryptoError;
pub const KEY_LENGTH: usize = 32;
pub const NONCE_LENGTH: usize = 12;
pub const TAG_LENGTH: usize = 16;
pub const WRAPPED_KEY_LENGTH: usize = NONCE_LENGTH + KEY_LENGTH + TAG_LENGTH;
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct EncryptionKey([u8; KEY_LENGTH]);
impl EncryptionKey {
pub(crate) fn from_random_bytes(bytes: [u8; KEY_LENGTH]) -> Self {
assert!(
bytes.iter().any(|&b| b != 0),
"EncryptionKey random bytes are all zeros - RNG failure or bug"
);
Self(bytes)
}
pub fn from_bytes(bytes: &[u8; KEY_LENGTH]) -> Self {
assert!(
bytes.iter().any(|&b| b != 0),
"EncryptionKey bytes are all zeros - corrupted or uninitialized key material"
);
Self(*bytes)
}
pub fn to_bytes(&self) -> [u8; KEY_LENGTH] {
self.0
}
#[allow(clippy::let_and_return)]
pub fn generate() -> Self {
let random_bytes: [u8; KEY_LENGTH] = generate_random();
let key = Self::from_random_bytes(random_bytes);
kimberlite_properties::always!(
key.0.iter().any(|&b| b != 0),
"crypto.encryption_key_not_all_zeros",
"EncryptionKey must never be all-zeros after generation"
);
key
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct Nonce([u8; NONCE_LENGTH]);
impl Nonce {
pub(crate) fn from_random_bytes(bytes: [u8; NONCE_LENGTH]) -> Self {
assert!(
bytes.iter().any(|&b| b != 0),
"Nonce random bytes are all zeros - RNG failure or bug"
);
Self(bytes)
}
pub fn from_position(position: u64) -> Self {
let mut nonce = [0u8; NONCE_LENGTH];
nonce[..8].copy_from_slice(&position.to_le_bytes());
Self(nonce)
}
pub fn from_bytes(bytes: [u8; NONCE_LENGTH]) -> Self {
Self(bytes)
}
pub fn to_bytes(&self) -> [u8; NONCE_LENGTH] {
self.0
}
pub fn generate_random() -> Self {
let random_bytes: [u8; NONCE_LENGTH] = generate_random();
Self::from_random_bytes(random_bytes)
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct Ciphertext(Vec<u8>);
impl Ciphertext {
pub fn from_bytes(bytes: Vec<u8>) -> Self {
assert!(
bytes.len() >= TAG_LENGTH,
"ciphertext too short: must be at least {} bytes for auth tag, got {}",
TAG_LENGTH,
bytes.len()
);
Self(bytes)
}
pub fn to_bytes(&self) -> &[u8] {
&self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
pub struct WrappedKey {
nonce: Nonce,
ciphertext: Ciphertext,
}
impl WrappedKey {
pub fn new(wrapping_key: &EncryptionKey, key_to_wrap: &[u8; KEY_LENGTH]) -> Self {
assert!(
key_to_wrap.iter().any(|&b| b != 0),
"key_to_wrap is all zeros - corrupted or uninitialized key material"
);
let nonce = Nonce::generate_random();
let ciphertext = encrypt(wrapping_key, &nonce, key_to_wrap);
assert_eq!(
ciphertext.len(),
KEY_LENGTH + TAG_LENGTH,
"wrapped ciphertext has unexpected length: expected {}, got {}",
KEY_LENGTH + TAG_LENGTH,
ciphertext.len()
);
Self { nonce, ciphertext }
}
pub fn unwrap_key(
&self,
wrapping_key: &EncryptionKey,
) -> Result<[u8; KEY_LENGTH], CryptoError> {
let decrypted = decrypt(wrapping_key, &self.nonce, &self.ciphertext)?;
assert_eq!(
decrypted.len(),
KEY_LENGTH,
"unwrapped key has unexpected length: expected {}, got {}",
KEY_LENGTH,
decrypted.len()
);
decrypted
.try_into()
.map_err(|_| CryptoError::DecryptionError)
}
pub fn to_bytes(&self) -> [u8; WRAPPED_KEY_LENGTH] {
let mut bytes = [0u8; WRAPPED_KEY_LENGTH];
bytes[..NONCE_LENGTH].copy_from_slice(&self.nonce.to_bytes());
bytes[NONCE_LENGTH..].copy_from_slice(self.ciphertext.to_bytes());
assert!(
bytes.iter().any(|&b| b != 0),
"serialized wrapped key is all zeros - encryption bug"
);
bytes
}
pub fn from_bytes(bytes: &[u8; WRAPPED_KEY_LENGTH]) -> Self {
assert!(
bytes.iter().any(|&b| b != 0),
"wrapped key bytes are all zeros - corrupted or uninitialized storage"
);
let mut nonce_bytes = [0u8; NONCE_LENGTH];
nonce_bytes.copy_from_slice(&bytes[..NONCE_LENGTH]);
let ciphertext = Ciphertext::from_bytes(bytes[NONCE_LENGTH..].to_vec());
Self {
nonce: Nonce::from_bytes(nonce_bytes),
ciphertext,
}
}
}
pub trait MasterKeyProvider {
fn wrap_kek(&self, kek_bytes: &[u8; KEY_LENGTH]) -> WrappedKey;
fn unwrap_kek(&self, wrapped: &WrappedKey) -> Result<[u8; KEY_LENGTH], CryptoError>;
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct InMemoryMasterKey(EncryptionKey);
impl InMemoryMasterKey {
pub(crate) fn from_random_bytes(bytes: [u8; KEY_LENGTH]) -> Self {
Self(EncryptionKey::from_random_bytes(bytes))
}
pub fn from_bytes(bytes: &[u8; KEY_LENGTH]) -> Self {
assert!(
bytes.iter().any(|&b| b != 0),
"master key bytes are all zeros - corrupted or uninitialized key material"
);
Self(EncryptionKey::from_bytes(bytes))
}
pub fn to_bytes(&self) -> [u8; KEY_LENGTH] {
self.0.to_bytes()
}
pub fn generate() -> Self {
let random_bytes: [u8; KEY_LENGTH] = generate_random();
Self::from_random_bytes(random_bytes)
}
}
impl MasterKeyProvider for InMemoryMasterKey {
fn wrap_kek(&self, kek_bytes: &[u8; KEY_LENGTH]) -> WrappedKey {
assert!(
kek_bytes.iter().any(|&b| b != 0),
"KEK bytes are all zeros - corrupted or uninitialized key material"
);
WrappedKey::new(&self.0, kek_bytes)
}
fn unwrap_kek(&self, wrapped: &WrappedKey) -> Result<[u8; KEY_LENGTH], CryptoError> {
let kek_bytes = wrapped.unwrap_key(&self.0)?;
assert!(
kek_bytes.iter().any(|&b| b != 0),
"unwrapped KEK is all zeros - decryption produced invalid key"
);
Ok(kek_bytes)
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct KeyEncryptionKey(EncryptionKey);
impl KeyEncryptionKey {
pub(crate) fn from_random_bytes_and_wrap(
random_bytes: [u8; KEY_LENGTH],
master: &impl MasterKeyProvider,
) -> (Self, WrappedKey) {
let key = EncryptionKey::from_random_bytes(random_bytes);
let wrapped = master.wrap_kek(&key.to_bytes());
(Self(key), wrapped)
}
pub fn restore(
master: &impl MasterKeyProvider,
wrapped: &WrappedKey,
) -> Result<Self, CryptoError> {
let key_bytes = master.unwrap_kek(wrapped)?;
assert!(
key_bytes.iter().any(|&b| b != 0),
"restored KEK is all zeros - decryption produced invalid key"
);
Ok(Self(EncryptionKey::from_bytes(&key_bytes)))
}
pub fn wrap_dek(&self, dek_bytes: &[u8; KEY_LENGTH]) -> WrappedKey {
assert!(
dek_bytes.iter().any(|&b| b != 0),
"DEK bytes are all zeros - corrupted or uninitialized key material"
);
WrappedKey::new(&self.0, dek_bytes)
}
pub fn unwrap_dek(&self, wrapped: &WrappedKey) -> Result<[u8; KEY_LENGTH], CryptoError> {
let dek_bytes = wrapped.unwrap_key(&self.0)?;
assert!(
dek_bytes.iter().any(|&b| b != 0),
"unwrapped DEK is all zeros - decryption produced invalid key"
);
Ok(dek_bytes)
}
pub fn generate_and_wrap(master: &impl MasterKeyProvider) -> (Self, WrappedKey) {
let random_bytes: [u8; KEY_LENGTH] = generate_random();
Self::from_random_bytes_and_wrap(random_bytes, master)
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct DataEncryptionKey(EncryptionKey);
impl DataEncryptionKey {
pub(crate) fn from_random_bytes_and_wrap(
random_bytes: [u8; KEY_LENGTH],
kek: &KeyEncryptionKey,
) -> (Self, WrappedKey) {
let encryption_key = EncryptionKey::from_random_bytes(random_bytes);
let wrapped = kek.wrap_dek(&encryption_key.to_bytes());
(Self(encryption_key), wrapped)
}
pub fn restore(kek: &KeyEncryptionKey, wrapped: &WrappedKey) -> Result<Self, CryptoError> {
let key_bytes = kek.unwrap_dek(wrapped)?;
assert!(
key_bytes.iter().any(|&b| b != 0),
"restored DEK is all zeros - decryption produced invalid key"
);
Ok(Self(EncryptionKey::from_bytes(&key_bytes)))
}
pub fn encryption_key(&self) -> &EncryptionKey {
&self.0
}
pub fn generate_and_wrap(kek: &KeyEncryptionKey) -> (Self, WrappedKey) {
let random_bytes: [u8; KEY_LENGTH] = generate_random();
Self::from_random_bytes_and_wrap(random_bytes, kek)
}
pub fn shred(self, shred_nonce: &[u8; 32]) -> [u8; 32] {
use sha2::{Digest, Sha256};
let mut hasher = <Sha256 as Digest>::new();
let key_bytes = self.0.to_bytes();
hasher.update(key_bytes);
hasher.update(shred_nonce);
let digest: [u8; 32] = hasher.finalize().into();
drop(self);
digest
}
}
pub struct CachedCipher {
cipher: Aes256Gcm,
}
impl CachedCipher {
pub fn new(key: &EncryptionKey) -> Self {
Self {
cipher: Aes256Gcm::new_from_slice(&key.0).expect("KEY_LENGTH is always valid"),
}
}
pub fn encrypt(&self, nonce: &Nonce, plaintext: &[u8]) -> Ciphertext {
assert!(
plaintext.len() <= MAX_PLAINTEXT_LENGTH,
"plaintext exceeds {} byte sanity limit, got {} bytes",
MAX_PLAINTEXT_LENGTH,
plaintext.len()
);
let nonce_array = nonce.0.into();
let data = self
.cipher
.encrypt(&nonce_array, plaintext)
.expect("AES-GCM encryption cannot fail with valid inputs");
assert_eq!(
data.len(),
plaintext.len() + TAG_LENGTH,
"ciphertext length mismatch"
);
Ciphertext(data)
}
pub fn decrypt(&self, nonce: &Nonce, ciphertext: &Ciphertext) -> Result<Vec<u8>, CryptoError> {
let ciphertext_len = ciphertext.0.len();
assert!(
ciphertext_len >= TAG_LENGTH,
"ciphertext too short: {ciphertext_len} bytes, need at least {TAG_LENGTH}"
);
let nonce_array = nonce.0.into();
let plaintext = self
.cipher
.decrypt(&nonce_array, ciphertext.0.as_slice())
.map_err(|_| CryptoError::DecryptionError)?;
assert_eq!(
plaintext.len(),
ciphertext.0.len() - TAG_LENGTH,
"plaintext length mismatch"
);
Ok(plaintext)
}
}
#[allow(dead_code)]
const MAX_PLAINTEXT_LENGTH: usize = 64 * 1024 * 1024;
pub fn encrypt(key: &EncryptionKey, nonce: &Nonce, plaintext: &[u8]) -> Ciphertext {
assert!(
plaintext.len() <= MAX_PLAINTEXT_LENGTH,
"plaintext exceeds {} byte sanity limit, got {} bytes",
MAX_PLAINTEXT_LENGTH,
plaintext.len()
);
let cipher = Aes256Gcm::new_from_slice(&key.0).expect("KEY_LENGTH is always valid");
let nonce_array = nonce.0.into();
let data = cipher
.encrypt(&nonce_array, plaintext)
.expect("AES-GCM encryption cannot fail with valid inputs");
assert_eq!(
data.len(),
plaintext.len() + TAG_LENGTH,
"ciphertext length mismatch: expected {}, got {}",
plaintext.len() + TAG_LENGTH,
data.len()
);
Ciphertext(data)
}
pub fn decrypt(
key: &EncryptionKey,
nonce: &Nonce,
ciphertext: &Ciphertext,
) -> Result<Vec<u8>, CryptoError> {
let ciphertext_len = ciphertext.0.len();
assert!(
ciphertext_len >= TAG_LENGTH,
"ciphertext too short: {ciphertext_len} bytes, need at least {TAG_LENGTH}"
);
let cipher = Aes256Gcm::new_from_slice(&key.0).expect("KEY_LENGTH is always valid");
let nonce_array = nonce.0.into();
let plaintext = cipher
.decrypt(&nonce_array, ciphertext.0.as_slice())
.map_err(|_| CryptoError::DecryptionError)?;
assert_eq!(
plaintext.len(),
ciphertext.0.len() - TAG_LENGTH,
"plaintext length mismatch: expected {}, got {}",
ciphertext.0.len() - TAG_LENGTH,
plaintext.len()
);
kimberlite_properties::always!(
{
let re_encrypted = encrypt(key, nonce, &plaintext);
re_encrypted.to_bytes() == ciphertext.to_bytes()
},
"crypto.encrypt_decrypt_roundtrip",
"re-encrypting decrypted plaintext must produce identical ciphertext"
);
Ok(plaintext)
}
fn generate_random<const N: usize>() -> [u8; N] {
let mut bytes = [0u8; N];
getrandom::fill(&mut bytes).expect("CSPRNG failure");
bytes
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_decrypt_roundtrip() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let plaintext = b"sensitive tenant data";
let ciphertext = encrypt(&key, &nonce, plaintext);
let decrypted = decrypt(&key, &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn encrypt_decrypt_empty_plaintext() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(0);
let plaintext = b"";
let ciphertext = encrypt(&key, &nonce, plaintext);
let decrypted = decrypt(&key, &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
assert_eq!(ciphertext.len(), TAG_LENGTH); }
#[test]
fn ciphertext_length_is_plaintext_plus_tag() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(100);
let plaintext = b"hello world";
let ciphertext = encrypt(&key, &nonce, plaintext);
assert_eq!(ciphertext.len(), plaintext.len() + TAG_LENGTH);
}
#[test]
fn wrong_key_fails_decryption() {
let key1 = EncryptionKey::generate();
let key2 = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let plaintext = b"secret message";
let ciphertext = encrypt(&key1, &nonce, plaintext);
let result = decrypt(&key2, &nonce, &ciphertext);
assert!(result.is_err());
}
#[test]
fn wrong_nonce_fails_decryption() {
let key = EncryptionKey::generate();
let nonce1 = Nonce::from_position(42);
let nonce2 = Nonce::from_position(43);
let plaintext = b"secret message";
let ciphertext = encrypt(&key, &nonce1, plaintext);
let result = decrypt(&key, &nonce2, &ciphertext);
assert!(result.is_err());
}
#[test]
fn tampered_ciphertext_fails_decryption() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let plaintext = b"secret message";
let ciphertext = encrypt(&key, &nonce, plaintext);
let mut tampered_bytes = ciphertext.to_bytes().to_vec();
tampered_bytes[0] ^= 0x01; let tampered = Ciphertext::from_bytes(tampered_bytes);
let result = decrypt(&key, &nonce, &tampered);
assert!(result.is_err());
}
#[test]
fn tampered_tag_fails_decryption() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let plaintext = b"secret message";
let ciphertext = encrypt(&key, &nonce, plaintext);
let mut tampered_bytes = ciphertext.to_bytes().to_vec();
let len = tampered_bytes.len();
tampered_bytes[len - 1] ^= 0x01; let tampered = Ciphertext::from_bytes(tampered_bytes);
let result = decrypt(&key, &nonce, &tampered);
assert!(result.is_err());
}
#[test]
fn nonce_from_position_layout() {
let nonce = Nonce::from_position(0x0102_0304_0506_0708);
let bytes = nonce.to_bytes();
assert_eq!(bytes[0], 0x08);
assert_eq!(bytes[1], 0x07);
assert_eq!(bytes[2], 0x06);
assert_eq!(bytes[3], 0x05);
assert_eq!(bytes[4], 0x04);
assert_eq!(bytes[5], 0x03);
assert_eq!(bytes[6], 0x02);
assert_eq!(bytes[7], 0x01);
assert_eq!(bytes[8], 0x00);
assert_eq!(bytes[9], 0x00);
assert_eq!(bytes[10], 0x00);
assert_eq!(bytes[11], 0x00);
}
#[test]
fn nonce_position_zero_is_valid() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(0);
let plaintext = b"first record";
let ciphertext = encrypt(&key, &nonce, plaintext);
let decrypted = decrypt(&key, &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn encryption_key_roundtrip() {
let original = EncryptionKey::generate();
let bytes = original.to_bytes();
let restored = EncryptionKey::from_bytes(&bytes);
let nonce = Nonce::from_position(1);
let plaintext = b"test";
let ct1 = encrypt(&original, &nonce, plaintext);
let ct2 = encrypt(&restored, &nonce, plaintext);
assert_eq!(ct1.to_bytes(), ct2.to_bytes());
}
#[test]
fn ciphertext_roundtrip() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(999);
let plaintext = b"data to serialize";
let ciphertext = encrypt(&key, &nonce, plaintext);
let bytes = ciphertext.to_bytes().to_vec();
let restored = Ciphertext::from_bytes(bytes);
let decrypted = decrypt(&key, &nonce, &restored).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn different_positions_produce_different_ciphertexts() {
let key = EncryptionKey::generate();
let plaintext = b"same plaintext";
let ct1 = encrypt(&key, &Nonce::from_position(1), plaintext);
let ct2 = encrypt(&key, &Nonce::from_position(2), plaintext);
assert_ne!(ct1.to_bytes(), ct2.to_bytes());
}
#[test]
fn encryption_is_deterministic() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let plaintext = b"deterministic test";
let ct1 = encrypt(&key, &nonce, plaintext);
let ct2 = encrypt(&key, &nonce, plaintext);
assert_eq!(ct1.to_bytes(), ct2.to_bytes());
}
#[test]
fn wrap_unwrap_roundtrip() {
let wrapping_key = EncryptionKey::generate();
let original_key: [u8; KEY_LENGTH] = generate_random();
let wrapped = WrappedKey::new(&wrapping_key, &original_key);
let unwrapped = wrapped.unwrap_key(&wrapping_key).unwrap();
assert_eq!(original_key, unwrapped);
}
#[test]
fn wrapped_key_serialization_roundtrip() {
let wrapping_key = EncryptionKey::generate();
let original_key: [u8; KEY_LENGTH] = generate_random();
let wrapped = WrappedKey::new(&wrapping_key, &original_key);
let bytes = wrapped.to_bytes();
let restored = WrappedKey::from_bytes(&bytes);
let unwrapped = restored.unwrap_key(&wrapping_key).unwrap();
assert_eq!(original_key, unwrapped);
}
#[test]
fn wrapped_key_has_correct_length() {
let wrapping_key = EncryptionKey::generate();
let key_to_wrap: [u8; KEY_LENGTH] = generate_random();
let wrapped = WrappedKey::new(&wrapping_key, &key_to_wrap);
let bytes = wrapped.to_bytes();
assert_eq!(bytes.len(), WRAPPED_KEY_LENGTH);
assert_eq!(bytes.len(), NONCE_LENGTH + KEY_LENGTH + TAG_LENGTH);
}
#[test]
fn wrong_wrapping_key_fails_unwrap() {
let key1 = EncryptionKey::generate();
let key2 = EncryptionKey::generate();
let original: [u8; KEY_LENGTH] = generate_random();
let wrapped = WrappedKey::new(&key1, &original);
let result = wrapped.unwrap_key(&key2);
assert!(result.is_err());
}
#[test]
fn tampered_wrapped_key_fails_unwrap() {
let wrapping_key = EncryptionKey::generate();
let original: [u8; KEY_LENGTH] = generate_random();
let wrapped = WrappedKey::new(&wrapping_key, &original);
let mut bytes = wrapped.to_bytes();
bytes[NONCE_LENGTH] ^= 0x01;
let tampered = WrappedKey::from_bytes(&bytes);
let result = tampered.unwrap_key(&wrapping_key);
assert!(result.is_err());
}
#[test]
fn different_keys_produce_different_wrapped_output() {
let wrapping_key = EncryptionKey::generate();
let key1: [u8; KEY_LENGTH] = generate_random();
let key2: [u8; KEY_LENGTH] = generate_random();
let wrapped1 = WrappedKey::new(&wrapping_key, &key1);
let wrapped2 = WrappedKey::new(&wrapping_key, &key2);
assert_ne!(wrapped1.to_bytes(), wrapped2.to_bytes());
}
#[test]
fn same_key_wrapped_twice_differs_due_to_random_nonce() {
let wrapping_key = EncryptionKey::generate();
let key_to_wrap: [u8; KEY_LENGTH] = generate_random();
let wrapped1 = WrappedKey::new(&wrapping_key, &key_to_wrap);
let wrapped2 = WrappedKey::new(&wrapping_key, &key_to_wrap);
assert_ne!(wrapped1.to_bytes(), wrapped2.to_bytes());
let unwrapped1 = wrapped1.unwrap_key(&wrapping_key).unwrap();
let unwrapped2 = wrapped2.unwrap_key(&wrapping_key).unwrap();
assert_eq!(unwrapped1, unwrapped2);
assert_eq!(unwrapped1, key_to_wrap);
}
#[test]
fn master_key_generate_and_restore() {
let master = InMemoryMasterKey::generate();
let bytes = master.to_bytes();
let restored = InMemoryMasterKey::from_bytes(&bytes);
let kek_bytes: [u8; KEY_LENGTH] = generate_random();
let wrapped1 = master.wrap_kek(&kek_bytes);
let wrapped2 = restored.wrap_kek(&kek_bytes);
let unwrapped1 = master.unwrap_kek(&wrapped1).unwrap();
let unwrapped2 = restored.unwrap_kek(&wrapped2).unwrap();
assert_eq!(unwrapped1, kek_bytes);
assert_eq!(unwrapped2, kek_bytes);
}
#[test]
fn kek_generate_and_restore() {
let master = InMemoryMasterKey::generate();
let (kek, wrapped_kek) = KeyEncryptionKey::generate_and_wrap(&master);
let restored_kek = KeyEncryptionKey::restore(&master, &wrapped_kek).unwrap();
let dek_bytes: [u8; KEY_LENGTH] = generate_random();
let wrapped1 = kek.wrap_dek(&dek_bytes);
let wrapped2 = restored_kek.wrap_dek(&dek_bytes);
let unwrapped1 = kek.unwrap_dek(&wrapped2).unwrap();
let unwrapped2 = restored_kek.unwrap_dek(&wrapped1).unwrap();
assert_eq!(unwrapped1, dek_bytes);
assert_eq!(unwrapped2, dek_bytes);
}
#[test]
fn dek_generate_and_restore() {
let master = InMemoryMasterKey::generate();
let (kek, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (dek, wrapped_dek) = DataEncryptionKey::generate_and_wrap(&kek);
let restored_dek = DataEncryptionKey::restore(&kek, &wrapped_dek).unwrap();
let nonce = Nonce::from_position(42);
let plaintext = b"secret tenant data";
let ciphertext = encrypt(dek.encryption_key(), &nonce, plaintext);
let decrypted = decrypt(restored_dek.encryption_key(), &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn full_key_hierarchy_roundtrip() {
let master = InMemoryMasterKey::generate();
let (kek_acme, wrapped_kek_acme) = KeyEncryptionKey::generate_and_wrap(&master);
let (dek_seg0, wrapped_dek_seg0) = DataEncryptionKey::generate_and_wrap(&kek_acme);
let nonce = Nonce::from_position(0);
let plaintext = b"acme's sensitive record";
let ciphertext = encrypt(dek_seg0.encryption_key(), &nonce, plaintext);
let kek = KeyEncryptionKey::restore(&master, &wrapped_kek_acme).unwrap();
let dek = DataEncryptionKey::restore(&kek, &wrapped_dek_seg0).unwrap();
let decrypted = decrypt(dek.encryption_key(), &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn wrong_master_key_fails_kek_restore() {
let master1 = InMemoryMasterKey::generate();
let master2 = InMemoryMasterKey::generate();
let (_, wrapped_kek) = KeyEncryptionKey::generate_and_wrap(&master1);
let result = KeyEncryptionKey::restore(&master2, &wrapped_kek);
assert!(result.is_err());
}
#[test]
fn wrong_kek_fails_dek_restore() {
let master = InMemoryMasterKey::generate();
let (kek1, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (kek2, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (_, wrapped_dek) = DataEncryptionKey::generate_and_wrap(&kek1);
let result = DataEncryptionKey::restore(&kek2, &wrapped_dek);
assert!(result.is_err());
}
#[test]
fn tenant_isolation_via_kek() {
let master = InMemoryMasterKey::generate();
let (kek_tenant_a, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (kek_tenant_b, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (dek_a, wrapped_dek_a) = DataEncryptionKey::generate_and_wrap(&kek_tenant_a);
let nonce = Nonce::from_position(0);
let _ciphertext_a = encrypt(dek_a.encryption_key(), &nonce, b"tenant A secret");
let result = DataEncryptionKey::restore(&kek_tenant_b, &wrapped_dek_a);
assert!(result.is_err());
}
#[test]
fn wrapped_kek_serialization_roundtrip() {
let master = InMemoryMasterKey::generate();
let (_, wrapped_kek) = KeyEncryptionKey::generate_and_wrap(&master);
let bytes = wrapped_kek.to_bytes();
let restored_wrapped = WrappedKey::from_bytes(&bytes);
let kek = KeyEncryptionKey::restore(&master, &restored_wrapped).unwrap();
let dek_bytes: [u8; KEY_LENGTH] = generate_random();
let dek_wrapped = kek.wrap_dek(&dek_bytes);
let unwrapped = kek.unwrap_dek(&dek_wrapped).unwrap();
assert_eq!(unwrapped, dek_bytes);
}
#[test]
fn dek_encryption_key_reference() {
let master = InMemoryMasterKey::generate();
let (kek, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (dek, _) = DataEncryptionKey::generate_and_wrap(&kek);
let key_ref = dek.encryption_key();
let nonce = Nonce::from_position(1);
let ciphertext = encrypt(key_ref, &nonce, b"test");
let plaintext = decrypt(key_ref, &nonce, &ciphertext).unwrap();
assert_eq!(plaintext, b"test");
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_encrypt_decrypt_roundtrip(plaintext in prop::collection::vec(any::<u8>(), 0..10000)) {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let ciphertext = encrypt(&key, &nonce, &plaintext);
let decrypted = decrypt(&key, &nonce, &ciphertext)
.expect("decryption should succeed for valid ciphertext");
prop_assert_eq!(decrypted, plaintext);
}
#[test]
fn prop_different_nonces_different_ciphertext(
positions in (any::<u64>(), any::<u64>()).prop_filter("positions must differ", |(p1, p2)| p1 != p2),
plaintext in prop::collection::vec(any::<u8>(), 1..1000),
) {
let (position1, position2) = positions;
let key = EncryptionKey::generate();
let nonce1 = Nonce::from_position(position1);
let nonce2 = Nonce::from_position(position2);
let ct1 = encrypt(&key, &nonce1, &plaintext);
let ct2 = encrypt(&key, &nonce2, &plaintext);
prop_assert_ne!(ct1.to_bytes(), ct2.to_bytes());
}
#[test]
fn prop_ciphertext_length(plaintext in prop::collection::vec(any::<u8>(), 0..10000)) {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(100);
let ciphertext = encrypt(&key, &nonce, &plaintext);
prop_assert_eq!(ciphertext.to_bytes().len(), plaintext.len() + TAG_LENGTH);
}
#[test]
fn prop_wrong_key_fails(plaintext in prop::collection::vec(any::<u8>(), 1..1000)) {
let key1 = EncryptionKey::generate();
let key2 = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let ciphertext = encrypt(&key1, &nonce, &plaintext);
let result = decrypt(&key2, &nonce, &ciphertext);
prop_assert!(result.is_err(), "decryption with wrong key must fail");
}
#[test]
fn prop_wrong_nonce_fails(
positions in (any::<u64>(), any::<u64>()).prop_filter("positions must differ", |(p1, p2)| p1 != p2),
plaintext in prop::collection::vec(any::<u8>(), 1..1000),
) {
let (position1, position2) = positions;
let key = EncryptionKey::generate();
let nonce1 = Nonce::from_position(position1);
let nonce2 = Nonce::from_position(position2);
let ciphertext = encrypt(&key, &nonce1, &plaintext);
let result = decrypt(&key, &nonce2, &ciphertext);
prop_assert!(result.is_err(), "decryption with wrong nonce must fail");
}
#[test]
fn prop_tampered_ciphertext_fails(
plaintext in prop::collection::vec(any::<u8>(), 1..1000),
bit_position in 0usize..8000, ) {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let ciphertext = encrypt(&key, &nonce, &plaintext);
let ct_bytes = ciphertext.to_bytes();
if bit_position / 8 < ct_bytes.len() {
let mut tampered = ct_bytes.to_vec();
let byte_idx = bit_position / 8;
let bit_idx = bit_position % 8;
tampered[byte_idx] ^= 1 << bit_idx;
let tampered_ct = Ciphertext::from_bytes(tampered);
let result = decrypt(&key, &nonce, &tampered_ct);
prop_assert!(result.is_err(), "tampered ciphertext must fail authentication");
}
}
#[test]
fn prop_wrap_unwrap_roundtrip(key_bytes in prop::array::uniform32(any::<u8>())) {
let wrapping_key = EncryptionKey::generate();
let wrapped = WrappedKey::new(&wrapping_key, &key_bytes);
let unwrapped = wrapped.unwrap_key(&wrapping_key)
.expect("unwrapping with correct key should succeed");
prop_assert_eq!(unwrapped, key_bytes);
}
#[test]
fn prop_encryption_deterministic(
position in any::<u64>(),
plaintext in prop::collection::vec(any::<u8>(), 0..1000),
) {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(position);
let ct1 = encrypt(&key, &nonce, &plaintext);
let ct2 = encrypt(&key, &nonce, &plaintext);
prop_assert_eq!(ct1.to_bytes(), ct2.to_bytes());
}
#[test]
fn prop_key_serialization_roundtrip(plaintext in prop::collection::vec(any::<u8>(), 1..1000)) {
let original = EncryptionKey::generate();
let bytes = original.to_bytes();
let restored = EncryptionKey::from_bytes(&bytes);
let nonce = Nonce::from_position(1);
let ct1 = encrypt(&original, &nonce, &plaintext);
let ct2 = encrypt(&restored, &nonce, &plaintext);
prop_assert_eq!(ct1.to_bytes(), ct2.to_bytes());
let decrypted1 = decrypt(&original, &nonce, &ct1).unwrap();
let decrypted2 = decrypt(&restored, &nonce, &ct2).unwrap();
prop_assert_eq!(&decrypted1[..], &plaintext[..]);
prop_assert_eq!(&decrypted2[..], &plaintext[..]);
}
#[test]
fn prop_nonce_position_injective(
positions in (any::<u64>(), any::<u64>()).prop_filter("positions must differ", |(p1, p2)| p1 != p2),
) {
let (pos1, pos2) = positions;
let nonce1 = Nonce::from_position(pos1);
let nonce2 = Nonce::from_position(pos2);
prop_assert_ne!(nonce1.to_bytes(), nonce2.to_bytes());
}
}
#[test]
fn truncated_ciphertext_fails() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let plaintext = b"hello world";
let ciphertext = encrypt(&key, &nonce, plaintext);
let ct_bytes = ciphertext.to_bytes();
if ct_bytes.len() > TAG_LENGTH {
let truncated = Ciphertext::from_bytes(ct_bytes[..TAG_LENGTH].to_vec());
let result = decrypt(&key, &nonce, &truncated);
assert!(result.is_err(), "truncated ciphertext must fail");
}
}
#[test]
fn corrupted_tag_fails() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(42);
let plaintext = b"authenticated encryption test";
let ciphertext = encrypt(&key, &nonce, plaintext);
let mut ct_bytes = ciphertext.to_bytes().to_vec();
let last_idx = ct_bytes.len() - 1;
ct_bytes[last_idx] = ct_bytes[last_idx].wrapping_add(1);
let corrupted = Ciphertext::from_bytes(ct_bytes);
let result = decrypt(&key, &nonce, &corrupted);
assert!(
result.is_err(),
"corrupted tag must cause authentication failure"
);
}
#[test]
fn maximum_position_nonce() {
let key = EncryptionKey::generate();
let max_position = u64::MAX;
let nonce = Nonce::from_position(max_position);
let plaintext = b"test at max position";
let ciphertext = encrypt(&key, &nonce, plaintext);
let decrypted = decrypt(&key, &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn large_plaintext_encryption() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(1);
let plaintext = vec![0xAB; 1_024 * 1_024];
let ciphertext = encrypt(&key, &nonce, &plaintext);
let decrypted = decrypt(&key, &nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
assert_eq!(ciphertext.to_bytes().len(), plaintext.len() + TAG_LENGTH);
}
#[test]
fn wrapped_key_wrong_length_from_bytes() {
let mut too_short = [0u8; WRAPPED_KEY_LENGTH];
too_short[WRAPPED_KEY_LENGTH - 1] = 0xFF;
let wrapped = WrappedKey::from_bytes(&too_short);
let wrapping_key = EncryptionKey::generate();
let result = wrapped.unwrap_key(&wrapping_key);
assert!(result.is_err(), "corrupted wrapped key should fail unwrap");
}
#[test]
#[should_panic(expected = "EncryptionKey bytes are all zeros")]
fn encryption_key_all_zeros_panics() {
let _key = EncryptionKey::from_bytes(&[0u8; KEY_LENGTH]);
}
#[test]
#[should_panic(expected = "EncryptionKey random bytes are all zeros")]
fn encryption_key_from_random_bytes_all_zeros_panics() {
let key_bytes = [0u8; KEY_LENGTH];
let key = EncryptionKey::from_random_bytes(key_bytes);
drop(key);
}
#[test]
fn ciphertext_serialization_preserves_authentication() {
let key = EncryptionKey::generate();
let nonce = Nonce::from_position(123);
let plaintext = b"serialization test";
let ciphertext = encrypt(&key, &nonce, plaintext);
let bytes = ciphertext.to_bytes().to_vec();
let restored = Ciphertext::from_bytes(bytes);
let decrypted = decrypt(&key, &nonce, &restored).unwrap();
assert_eq!(decrypted, plaintext);
let mut tampered_bytes = restored.to_bytes().to_vec();
tampered_bytes[0] ^= 0x01;
let tampered = Ciphertext::from_bytes(tampered_bytes);
let result = decrypt(&key, &nonce, &tampered);
assert!(
result.is_err(),
"tampered deserialized ciphertext must fail"
);
}
#[test]
fn multiple_wrapped_keys_independent() {
let wrapping_key = EncryptionKey::generate();
let key1: [u8; KEY_LENGTH] = generate_random();
let key2: [u8; KEY_LENGTH] = generate_random();
let wrapped1 = WrappedKey::new(&wrapping_key, &key1);
let wrapped2 = WrappedKey::new(&wrapping_key, &key2);
assert_eq!(wrapped1.unwrap_key(&wrapping_key).unwrap(), key1);
assert_eq!(wrapped2.unwrap_key(&wrapping_key).unwrap(), key2);
assert_ne!(wrapped1.to_bytes(), wrapped2.to_bytes());
}
#[test]
fn nonce_reserves_upper_bytes() {
for position in [0u64, 1, 42, u64::MAX / 2, u64::MAX] {
let nonce = Nonce::from_position(position);
let bytes = nonce.to_bytes();
assert_eq!(bytes[8], 0, "byte 8 must be reserved (zero)");
assert_eq!(bytes[9], 0, "byte 9 must be reserved (zero)");
assert_eq!(bytes[10], 0, "byte 10 must be reserved (zero)");
assert_eq!(bytes[11], 0, "byte 11 must be reserved (zero)");
}
}
#[test]
fn kek_dek_hierarchy_isolation() {
let master = InMemoryMasterKey::generate();
let (kek1, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (kek2, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (_dek1, wrapped_dek1) = DataEncryptionKey::generate_and_wrap(&kek1);
let (_dek2, wrapped_dek2) = DataEncryptionKey::generate_and_wrap(&kek2);
assert!(DataEncryptionKey::restore(&kek1, &wrapped_dek1).is_ok());
assert!(DataEncryptionKey::restore(&kek1, &wrapped_dek2).is_err());
assert!(DataEncryptionKey::restore(&kek2, &wrapped_dek2).is_ok());
assert!(DataEncryptionKey::restore(&kek2, &wrapped_dek1).is_err());
}
#[test]
fn encryption_key_zeroize_on_drop() {
let key = EncryptionKey::generate();
let _bytes = key.to_bytes();
drop(key);
}
#[test]
fn dek_shred_is_deterministic_in_key_and_nonce() {
let master = InMemoryMasterKey::generate();
let (kek, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (dek1, wrapped) = DataEncryptionKey::generate_and_wrap(&kek);
let dek2 = DataEncryptionKey::restore(&kek, &wrapped).unwrap();
let nonce = [0xAB; 32];
let d1 = dek1.shred(&nonce);
let d2 = dek2.shred(&nonce);
assert_eq!(d1, d2, "same key + same nonce → same shred digest");
}
#[test]
fn dek_shred_differs_on_different_keys() {
let master = InMemoryMasterKey::generate();
let (kek, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (dek_a, _) = DataEncryptionKey::generate_and_wrap(&kek);
let (dek_b, _) = DataEncryptionKey::generate_and_wrap(&kek);
let nonce = [0xCD; 32];
let da = dek_a.shred(&nonce);
let db = dek_b.shred(&nonce);
assert_ne!(da, db, "different keys must produce different digests");
}
#[test]
fn dek_shred_differs_on_different_nonces() {
let master = InMemoryMasterKey::generate();
let (kek, wrapped) =
(KeyEncryptionKey::generate_and_wrap(&master).0, None::<WrappedKey>);
let _ = wrapped;
let (dek_a, wrapped_a) = DataEncryptionKey::generate_and_wrap(&kek);
let dek_b = DataEncryptionKey::restore(&kek, &wrapped_a).unwrap();
let d1 = dek_a.shred(&[1u8; 32]);
let d2 = dek_b.shred(&[2u8; 32]);
assert_ne!(d1, d2, "same key + different nonces → different digests");
}
#[test]
fn dek_shred_digest_differs_from_key_bytes() {
let master = InMemoryMasterKey::generate();
let (kek, _) = KeyEncryptionKey::generate_and_wrap(&master);
let (dek, wrapped) = DataEncryptionKey::generate_and_wrap(&kek);
let key_bytes = dek.encryption_key().to_bytes();
let digest = DataEncryptionKey::restore(&kek, &wrapped).unwrap().shred(&[0u8; 32]);
assert_ne!(digest, key_bytes, "digest must not leak key bytes");
}
}