1use zeroize::Zeroizing;
17
18use cachekit_core::ZeroKnowledgeEncryptor;
19
20use crate::error::CachekitError;
21
22const AAD_VERSION: u8 = 0x03;
24
25pub struct EncryptionLayer {
33 encryptor: ZeroKnowledgeEncryptor,
34 derived_key: Zeroizing<[u8; 32]>,
35 tenant_id: String,
36}
37
38impl EncryptionLayer {
39 pub fn new(master_key_bytes: &[u8], tenant_id: &str) -> Result<Self, CachekitError> {
50 if master_key_bytes.len() < 32 {
51 return Err(CachekitError::Encryption(format!(
52 "master key must be at least 32 bytes; got {}",
53 master_key_bytes.len()
54 )));
55 }
56 if tenant_id.is_empty() {
57 return Err(CachekitError::Encryption(
58 "tenant_id must not be empty".to_owned(),
59 ));
60 }
61 if tenant_id.len() > 255 {
62 return Err(CachekitError::Encryption(format!(
63 "tenant_id must be at most 255 bytes; got {}",
64 tenant_id.len()
65 )));
66 }
67
68 let tenant_keys = cachekit_core::encryption::key_derivation::derive_tenant_keys(
69 master_key_bytes,
70 tenant_id,
71 )
72 .map_err(|e| CachekitError::Encryption(format!("key derivation failed: {e}")))?;
73
74 let encryptor = ZeroKnowledgeEncryptor::new()
75 .map_err(|e| CachekitError::Encryption(format!("encryptor init failed: {e}")))?;
76
77 Ok(Self {
78 encryptor,
79 derived_key: Zeroizing::new(tenant_keys.encryption_key),
80 tenant_id: tenant_id.to_owned(),
81 })
82 }
83
84 pub fn encrypt(&self, plaintext: &[u8], cache_key: &str) -> Result<Vec<u8>, CachekitError> {
88 let aad = self.build_aad(cache_key, false);
89 self.encryptor
90 .encrypt_aes_gcm(plaintext, &*self.derived_key, &aad)
91 .map_err(|e| CachekitError::Encryption(format!("encrypt failed: {e}")))
92 }
93
94 pub fn decrypt(&self, ciphertext: &[u8], cache_key: &str) -> Result<Vec<u8>, CachekitError> {
99 let aad = self.build_aad(cache_key, false);
100 self.encryptor
101 .decrypt_aes_gcm(ciphertext, &*self.derived_key, &aad)
102 .map_err(|e| CachekitError::Encryption(format!("decrypt failed: {e}")))
103 }
104
105 pub fn tenant_id(&self) -> &str {
107 &self.tenant_id
108 }
109
110 pub fn build_aad(&self, cache_key: &str, compressed: bool) -> Vec<u8> {
116 let format_str = b"msgpack";
117 let compressed_str = if compressed {
118 b"True" as &[u8]
119 } else {
120 b"False"
121 };
122
123 let tenant_bytes = self.tenant_id.as_bytes();
124 let key_bytes = cache_key.as_bytes();
125
126 let capacity =
128 1 + 16 + tenant_bytes.len() + key_bytes.len() + format_str.len() + compressed_str.len();
129 let mut aad = Vec::with_capacity(capacity);
130
131 aad.push(AAD_VERSION);
132
133 aad.extend_from_slice(&len_u32(tenant_bytes.len()).to_be_bytes());
139 aad.extend_from_slice(tenant_bytes);
140
141 aad.extend_from_slice(&len_u32(key_bytes.len()).to_be_bytes());
143 aad.extend_from_slice(key_bytes);
144
145 aad.extend_from_slice(&len_u32(format_str.len()).to_be_bytes());
147 aad.extend_from_slice(format_str);
148
149 aad.extend_from_slice(&len_u32(compressed_str.len()).to_be_bytes());
151 aad.extend_from_slice(compressed_str);
152
153 aad
154 }
155}
156
157#[allow(clippy::cast_possible_truncation)]
160fn len_u32(len: usize) -> u32 {
161 u32::try_from(len).unwrap_or(u32::MAX)
162}
163
164impl std::fmt::Debug for EncryptionLayer {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 f.debug_struct("EncryptionLayer")
167 .field("tenant_id", &self.tenant_id)
168 .field("derived_key", &"[REDACTED]")
169 .finish()
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 const TEST_MASTER_KEY: &[u8] = b"test_master_key_32_bytes_long!!!";
178 const TEST_TENANT: &str = "test-tenant";
179
180 #[test]
181 fn roundtrip_encrypt_decrypt() {
182 let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
183 let plaintext = b"hello, zero-knowledge world";
184
185 let ciphertext = layer.encrypt(plaintext, "my:key").unwrap();
186 let decrypted = layer.decrypt(&ciphertext, "my:key").unwrap();
187
188 assert_eq!(decrypted, plaintext);
189 }
190
191 #[test]
192 fn wrong_cache_key_fails_decryption() {
193 let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
194 let ciphertext = layer.encrypt(b"secret", "key:a").unwrap();
195
196 let result = layer.decrypt(&ciphertext, "key:b");
197 assert!(result.is_err(), "decryption with wrong cache key must fail");
198 }
199
200 #[test]
201 fn different_tenants_produce_different_ciphertext() {
202 let layer_a = EncryptionLayer::new(TEST_MASTER_KEY, "tenant-a").unwrap();
203 let layer_b = EncryptionLayer::new(TEST_MASTER_KEY, "tenant-b").unwrap();
204
205 let ct_a = layer_a.encrypt(b"same data", "same:key").unwrap();
206 let ct_b = layer_b.encrypt(b"same data", "same:key").unwrap();
207
208 assert_ne!(ct_a, ct_b);
210
211 assert!(layer_b.decrypt(&ct_a, "same:key").is_err());
213 }
214
215 #[test]
216 fn master_key_too_short() {
217 let result = EncryptionLayer::new(b"short", "tenant");
218 assert!(result.is_err());
219 let msg = result.unwrap_err().to_string();
220 assert!(msg.contains("at least 32 bytes"), "got: {msg}");
221 }
222
223 #[test]
224 fn aad_v03_format() {
225 let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
226 let aad = layer.build_aad("user:42", false);
227
228 assert_eq!(aad[0], 0x03);
230
231 let tenant_len = u32::from_be_bytes(aad[1..5].try_into().unwrap()) as usize;
233 assert_eq!(tenant_len, TEST_TENANT.len());
234 assert_eq!(&aad[5..5 + tenant_len], TEST_TENANT.as_bytes());
235
236 let offset = 5 + tenant_len;
238 let key_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
239 assert_eq!(key_len, 7); assert_eq!(&aad[offset + 4..offset + 4 + key_len], b"user:42");
241
242 let offset = offset + 4 + key_len;
244 let fmt_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
245 assert_eq!(&aad[offset + 4..offset + 4 + fmt_len], b"msgpack");
246
247 let offset = offset + 4 + fmt_len;
249 let comp_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
250 assert_eq!(&aad[offset + 4..offset + 4 + comp_len], b"False");
251 }
252
253 #[test]
254 fn aad_compressed_flag() {
255 let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
256 let aad_false = layer.build_aad("k", false);
257 let aad_true = layer.build_aad("k", true);
258
259 assert_ne!(aad_false, aad_true);
260 assert!(aad_true.ends_with(b"True"));
262 assert!(aad_false.ends_with(b"False"));
263 }
264
265 #[test]
266 fn debug_redacts_key() {
267 let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
268 let debug = format!("{layer:?}");
269 assert!(debug.contains("[REDACTED]"));
270 assert!(!debug.contains("test_master_key"));
271 }
272}