Skip to main content

cachekit/
encryption.rs

1//! Zero-knowledge encryption layer using AES-256-GCM with AAD v0x03 format.
2//!
3//! Wraps `cachekit_core::ZeroKnowledgeEncryptor` with tenant key derivation
4//! and cache-key-bound Additional Authenticated Data (AAD). The AAD binding
5//! prevents ciphertext substitution attacks within the same tenant (CVSS 8.5).
6//!
7//! # AAD v0x03 Format
8//!
9//! ```text
10//! [version(0x03)][len(4)][tenant_id][len(4)][cache_key][len(4)][format][len(4)][compressed]
11//! ```
12//!
13//! Each component is length-prefixed with a 4-byte big-endian u32 to prevent
14//! collision attacks from boundary confusion.
15
16use zeroize::Zeroizing;
17
18use cachekit_core::ZeroKnowledgeEncryptor;
19
20use crate::error::CachekitError;
21
22/// AAD protocol version byte.
23const AAD_VERSION: u8 = 0x03;
24
25/// Zero-knowledge encryption layer with per-tenant key derivation.
26///
27/// Holds a derived encryption key (zeroized on drop) and the
28/// `ZeroKnowledgeEncryptor` from cachekit-core for AES-256-GCM operations.
29///
30/// L1 stores **ciphertext**, not plaintext — the zero-knowledge property
31/// is preserved across all cache layers.
32pub struct EncryptionLayer {
33    encryptor: ZeroKnowledgeEncryptor,
34    derived_key: Zeroizing<[u8; 32]>,
35    tenant_id: String,
36}
37
38impl EncryptionLayer {
39    /// Create a new encryption layer with HKDF-derived tenant keys.
40    ///
41    /// # Arguments
42    /// * `master_key_bytes` — Raw master key (minimum 32 bytes for AES-256)
43    /// * `tenant_id` — Tenant identifier for cryptographic isolation
44    ///
45    /// # Errors
46    /// - Master key too short (< 32 bytes)
47    /// - HKDF derivation failure
48    /// - Encryptor initialization failure
49    pub fn new(master_key_bytes: &[u8], tenant_id: &str) -> Result<Self, CachekitError> {
50        if master_key_bytes.len() < 32 {
51            return Err(CachekitError::Encryption(format!(
52                "master key must be at least 32 bytes; got {}",
53                master_key_bytes.len()
54            )));
55        }
56        if tenant_id.is_empty() {
57            return Err(CachekitError::Encryption(
58                "tenant_id must not be empty".to_owned(),
59            ));
60        }
61        if tenant_id.len() > 255 {
62            return Err(CachekitError::Encryption(format!(
63                "tenant_id must be at most 255 bytes; got {}",
64                tenant_id.len()
65            )));
66        }
67
68        let tenant_keys = cachekit_core::encryption::key_derivation::derive_tenant_keys(
69            master_key_bytes,
70            tenant_id,
71        )
72        .map_err(|e| CachekitError::Encryption(format!("key derivation failed: {e}")))?;
73
74        let encryptor = ZeroKnowledgeEncryptor::new()
75            .map_err(|e| CachekitError::Encryption(format!("encryptor init failed: {e}")))?;
76
77        Ok(Self {
78            encryptor,
79            derived_key: Zeroizing::new(tenant_keys.encryption_key),
80            tenant_id: tenant_id.to_owned(),
81        })
82    }
83
84    /// Encrypt plaintext with AAD bound to the cache key.
85    ///
86    /// Output format: `[nonce(12)][ciphertext + auth_tag(16)]`
87    pub fn encrypt(&self, plaintext: &[u8], cache_key: &str) -> Result<Vec<u8>, CachekitError> {
88        let aad = self.build_aad(cache_key, false);
89        self.encryptor
90            .encrypt_aes_gcm(plaintext, &*self.derived_key, &aad)
91            .map_err(|e| CachekitError::Encryption(format!("encrypt failed: {e}")))
92    }
93
94    /// Decrypt ciphertext with AAD bound to the cache key.
95    ///
96    /// Returns the original plaintext. Fails if the cache key does not match
97    /// the one used during encryption (ciphertext substitution protection).
98    pub fn decrypt(&self, ciphertext: &[u8], cache_key: &str) -> Result<Vec<u8>, CachekitError> {
99        let aad = self.build_aad(cache_key, false);
100        self.encryptor
101            .decrypt_aes_gcm(ciphertext, &*self.derived_key, &aad)
102            .map_err(|e| CachekitError::Encryption(format!("decrypt failed: {e}")))
103    }
104
105    /// Return the tenant ID used for key derivation.
106    pub fn tenant_id(&self) -> &str {
107        &self.tenant_id
108    }
109
110    /// Build AAD v0x03 for a given cache key and compression flag.
111    ///
112    /// Format: `[0x03][len][tenant_id][len][cache_key][len]["msgpack"][len]["True"/"False"]`
113    ///
114    /// All lengths are 4-byte big-endian u32 to prevent boundary-confusion attacks.
115    pub fn build_aad(&self, cache_key: &str, compressed: bool) -> Vec<u8> {
116        let format_str = b"msgpack";
117        let compressed_str = if compressed {
118            b"True" as &[u8]
119        } else {
120            b"False"
121        };
122
123        let tenant_bytes = self.tenant_id.as_bytes();
124        let key_bytes = cache_key.as_bytes();
125
126        // Pre-allocate: version(1) + 4 length fields(16) + data
127        let capacity =
128            1 + 16 + tenant_bytes.len() + key_bytes.len() + format_str.len() + compressed_str.len();
129        let mut aad = Vec::with_capacity(capacity);
130
131        aad.push(AAD_VERSION);
132
133        // All components are bounded: tenant_id <= 255 (validated in new()),
134        // cache_key <= 1024 (validated by client), format/compressed are constants.
135        // Safe to use len_u32 helper which saturates on overflow.
136
137        // tenant_id
138        aad.extend_from_slice(&len_u32(tenant_bytes.len()).to_be_bytes());
139        aad.extend_from_slice(tenant_bytes);
140
141        // cache_key
142        aad.extend_from_slice(&len_u32(key_bytes.len()).to_be_bytes());
143        aad.extend_from_slice(key_bytes);
144
145        // format
146        aad.extend_from_slice(&len_u32(format_str.len()).to_be_bytes());
147        aad.extend_from_slice(format_str);
148
149        // compressed flag
150        aad.extend_from_slice(&len_u32(compressed_str.len()).to_be_bytes());
151        aad.extend_from_slice(compressed_str);
152
153        aad
154    }
155}
156
157/// Convert a usize length to u32 for AAD encoding, saturating on overflow.
158/// In practice all inputs are validated to fit (tenant_id <= 255, cache_key <= 1024).
159#[allow(clippy::cast_possible_truncation)]
160fn len_u32(len: usize) -> u32 {
161    u32::try_from(len).unwrap_or(u32::MAX)
162}
163
164impl std::fmt::Debug for EncryptionLayer {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        f.debug_struct("EncryptionLayer")
167            .field("tenant_id", &self.tenant_id)
168            .field("derived_key", &"[REDACTED]")
169            .finish()
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    const TEST_MASTER_KEY: &[u8] = b"test_master_key_32_bytes_long!!!";
178    const TEST_TENANT: &str = "test-tenant";
179
180    #[test]
181    fn roundtrip_encrypt_decrypt() {
182        let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
183        let plaintext = b"hello, zero-knowledge world";
184
185        let ciphertext = layer.encrypt(plaintext, "my:key").unwrap();
186        let decrypted = layer.decrypt(&ciphertext, "my:key").unwrap();
187
188        assert_eq!(decrypted, plaintext);
189    }
190
191    #[test]
192    fn wrong_cache_key_fails_decryption() {
193        let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
194        let ciphertext = layer.encrypt(b"secret", "key:a").unwrap();
195
196        let result = layer.decrypt(&ciphertext, "key:b");
197        assert!(result.is_err(), "decryption with wrong cache key must fail");
198    }
199
200    #[test]
201    fn different_tenants_produce_different_ciphertext() {
202        let layer_a = EncryptionLayer::new(TEST_MASTER_KEY, "tenant-a").unwrap();
203        let layer_b = EncryptionLayer::new(TEST_MASTER_KEY, "tenant-b").unwrap();
204
205        let ct_a = layer_a.encrypt(b"same data", "same:key").unwrap();
206        let ct_b = layer_b.encrypt(b"same data", "same:key").unwrap();
207
208        // Nonces differ, so ciphertext differs, but also keys differ
209        assert_ne!(ct_a, ct_b);
210
211        // Cross-tenant decryption must fail
212        assert!(layer_b.decrypt(&ct_a, "same:key").is_err());
213    }
214
215    #[test]
216    fn master_key_too_short() {
217        let result = EncryptionLayer::new(b"short", "tenant");
218        assert!(result.is_err());
219        let msg = result.unwrap_err().to_string();
220        assert!(msg.contains("at least 32 bytes"), "got: {msg}");
221    }
222
223    #[test]
224    fn aad_v03_format() {
225        let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
226        let aad = layer.build_aad("user:42", false);
227
228        // Version byte
229        assert_eq!(aad[0], 0x03);
230
231        // tenant_id length (4 bytes BE) + tenant_id
232        let tenant_len = u32::from_be_bytes(aad[1..5].try_into().unwrap()) as usize;
233        assert_eq!(tenant_len, TEST_TENANT.len());
234        assert_eq!(&aad[5..5 + tenant_len], TEST_TENANT.as_bytes());
235
236        // cache_key length + cache_key
237        let offset = 5 + tenant_len;
238        let key_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
239        assert_eq!(key_len, 7); // "user:42"
240        assert_eq!(&aad[offset + 4..offset + 4 + key_len], b"user:42");
241
242        // format length + format
243        let offset = offset + 4 + key_len;
244        let fmt_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
245        assert_eq!(&aad[offset + 4..offset + 4 + fmt_len], b"msgpack");
246
247        // compressed length + compressed
248        let offset = offset + 4 + fmt_len;
249        let comp_len = u32::from_be_bytes(aad[offset..offset + 4].try_into().unwrap()) as usize;
250        assert_eq!(&aad[offset + 4..offset + 4 + comp_len], b"False");
251    }
252
253    #[test]
254    fn aad_compressed_flag() {
255        let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
256        let aad_false = layer.build_aad("k", false);
257        let aad_true = layer.build_aad("k", true);
258
259        assert_ne!(aad_false, aad_true);
260        // "True" is at the end
261        assert!(aad_true.ends_with(b"True"));
262        assert!(aad_false.ends_with(b"False"));
263    }
264
265    #[test]
266    fn debug_redacts_key() {
267        let layer = EncryptionLayer::new(TEST_MASTER_KEY, TEST_TENANT).unwrap();
268        let debug = format!("{layer:?}");
269        assert!(debug.contains("[REDACTED]"));
270        assert!(!debug.contains("test_master_key"));
271    }
272}