use zeroize::Zeroizing;
use cachekit_core::ZeroKnowledgeEncryptor;
use crate::error::CachekitError;
const AAD_VERSION: u8 = 0x03;
pub struct EncryptionLayer {
encryptor: ZeroKnowledgeEncryptor,
derived_key: Zeroizing<[u8; 32]>,
tenant_id: String,
}
impl EncryptionLayer {
pub fn new(master_key_bytes: &[u8], tenant_id: &str) -> Result<Self, CachekitError> {
if master_key_bytes.len() < 32 {
return Err(CachekitError::Encryption(format!(
"master key must be at least 32 bytes; got {}",
master_key_bytes.len()
)));
}
if tenant_id.is_empty() {
return Err(CachekitError::Encryption(
"tenant_id must not be empty".to_owned(),
));
}
if tenant_id.len() > 255 {
return Err(CachekitError::Encryption(format!(
"tenant_id must be at most 255 bytes; got {}",
tenant_id.len()
)));
}
let tenant_keys = cachekit_core::encryption::key_derivation::derive_tenant_keys(
master_key_bytes,
tenant_id,
)
.map_err(|e| CachekitError::Encryption(format!("key derivation failed: {e}")))?;
let encryptor = ZeroKnowledgeEncryptor::new()
.map_err(|e| CachekitError::Encryption(format!("encryptor init failed: {e}")))?;
Ok(Self {
encryptor,
derived_key: Zeroizing::new(tenant_keys.encryption_key),
tenant_id: tenant_id.to_owned(),
})
}
pub fn encrypt(&self, plaintext: &[u8], cache_key: &str) -> Result<Vec<u8>, CachekitError> {
let aad = self.build_aad(cache_key, false);
self.encryptor
.encrypt_aes_gcm(plaintext, &*self.derived_key, &aad)
.map_err(|e| CachekitError::Encryption(format!("encrypt failed: {e}")))
}
pub fn decrypt(&self, ciphertext: &[u8], cache_key: &str) -> Result<Vec<u8>, CachekitError> {
let aad = self.build_aad(cache_key, false);
self.encryptor
.decrypt_aes_gcm(ciphertext, &*self.derived_key, &aad)
.map_err(|e| CachekitError::Encryption(format!("decrypt failed: {e}")))
}
pub fn tenant_id(&self) -> &str {
&self.tenant_id
}
pub fn build_aad(&self, cache_key: &str, compressed: bool) -> Vec<u8> {
let format_str = b"msgpack";
let compressed_str = if compressed {
b"True" as &[u8]
} else {
b"False"
};
let tenant_bytes = self.tenant_id.as_bytes();
let key_bytes = cache_key.as_bytes();
let capacity =
1 + 16 + tenant_bytes.len() + key_bytes.len() + format_str.len() + compressed_str.len();
let mut aad = Vec::with_capacity(capacity);
aad.push(AAD_VERSION);
aad.extend_from_slice(&len_u32(tenant_bytes.len()).to_be_bytes());
aad.extend_from_slice(tenant_bytes);
aad.extend_from_slice(&len_u32(key_bytes.len()).to_be_bytes());
aad.extend_from_slice(key_bytes);
aad.extend_from_slice(&len_u32(format_str.len()).to_be_bytes());
aad.extend_from_slice(format_str);
aad.extend_from_slice(&len_u32(compressed_str.len()).to_be_bytes());
aad.extend_from_slice(compressed_str);
aad
}
}
#[allow(clippy::cast_possible_truncation)]
fn len_u32(len: usize) -> u32 {
u32::try_from(len).unwrap_or(u32::MAX)
}
impl std::fmt::Debug for EncryptionLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionLayer")
.field("tenant_id", &self.tenant_id)
.field("derived_key", &"[REDACTED]")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_MASTER_KEY: &[u8] = b"test_master_key_32_bytes_long!!!";
const TEST_TENANT: &str = "test-tenant";
#[test]
fn roundtrip_encrypt_decrypt() {
let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
let plaintext = b"hello, zero-knowledge world";
let ciphertext = layer.encrypt(plaintext, "my:key").unwrap();
let decrypted = layer.decrypt(&ciphertext, "my:key").unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn wrong_cache_key_fails_decryption() {
let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
let ciphertext = layer.encrypt(b"secret", "key:a").unwrap();
let result = layer.decrypt(&ciphertext, "key:b");
assert!(result.is_err(), "decryption with wrong cache key must fail");
}
#[test]
fn different_tenants_produce_different_ciphertext() {
let layer_a = EncryptionLayer::new(TEST_MASTER_KEY, "tenant-a").unwrap();
let layer_b = EncryptionLayer::new(TEST_MASTER_KEY, "tenant-b").unwrap();
let ct_a = layer_a.encrypt(b"same data", "same:key").unwrap();
let ct_b = layer_b.encrypt(b"same data", "same:key").unwrap();
assert_ne!(ct_a, ct_b);
assert!(layer_b.decrypt(&ct_a, "same:key").is_err());
}
#[test]
fn master_key_too_short() {
let result = EncryptionLayer::new(b"short", "tenant");
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("at least 32 bytes"), "got: {msg}");
}
#[test]
fn aad_v03_format() {
let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
let aad = layer.build_aad("user:42", false);
assert_eq!(aad[0], 0x03);
let tenant_len = u32::from_be_bytes(aad[1..5].try_into().unwrap()) as usize;
assert_eq!(tenant_len, TEST_TENANT.len());
assert_eq!(&aad[5..5 + tenant_len], TEST_TENANT.as_bytes());
let offset = 5 + tenant_len;
let key_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
assert_eq!(key_len, 7); assert_eq!(&aad[offset + 4..offset + 4 + key_len], b"user:42");
let offset = offset + 4 + key_len;
let fmt_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
assert_eq!(&aad[offset + 4..offset + 4 + fmt_len], b"msgpack");
let offset = offset + 4 + fmt_len;
let comp_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
assert_eq!(&aad[offset + 4..offset + 4 + comp_len], b"False");
}
#[test]
fn aad_compressed_flag() {
let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
let aad_false = layer.build_aad("k", false);
let aad_true = layer.build_aad("k", true);
assert_ne!(aad_false, aad_true);
assert!(aad_true.ends_with(b"True"));
assert!(aad_false.ends_with(b"False"));
}
#[test]
fn debug_redacts_key() {
let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
let debug = format!("{layer:?}");
assert!(debug.contains("[REDACTED]"));
assert!(!debug.contains("test_master_key"));
}
}