cortexai_encryption/
envelope.rs1use crate::error::{CryptoError, CryptoResult};
8use crate::key::{EncryptionKey, KeyRing, VersionedKey};
9use crate::traits::{Cipher, DataEncryptor};
10
11#[cfg(feature = "aes")]
12use crate::aes_cipher::Aes256GcmCipher;
13
14use serde::{de::DeserializeOwned, Serialize};
15
16const ENVELOPE_VERSION: u8 = 1;
18
19const ENVELOPE_HEADER_SIZE: usize = 5;
21
22pub struct EnvelopeEncryptor {
33 key_ring: KeyRing,
34}
35
36impl EnvelopeEncryptor {
37 pub fn new(key: EncryptionKey) -> Self {
39 let mut key_ring = KeyRing::new();
40 key_ring.add_key(VersionedKey::new(1, key));
41 Self { key_ring }
42 }
43
44 pub fn with_key_ring(key_ring: KeyRing) -> Self {
46 Self { key_ring }
47 }
48
49 pub fn key_ring(&self) -> &KeyRing {
51 &self.key_ring
52 }
53
54 pub fn key_ring_mut(&mut self) -> &mut KeyRing {
56 &mut self.key_ring
57 }
58
59 pub fn rotate_key(&mut self, new_key: EncryptionKey) -> u32 {
61 self.key_ring.rotate(new_key)
62 }
63
64 #[cfg(feature = "aes")]
66 pub fn encrypt(
67 &self,
68 plaintext: &[u8],
69 associated_data: Option<&[u8]>,
70 ) -> CryptoResult<Vec<u8>> {
71 let active = self
72 .key_ring
73 .active_key()
74 .ok_or(CryptoError::KeyNotFound(0))?;
75
76 let cipher = Aes256GcmCipher::new(&active.key)?;
77 let cipher_data = cipher.encrypt(plaintext, associated_data)?;
78
79 let mut envelope = Vec::with_capacity(ENVELOPE_HEADER_SIZE + cipher_data.len());
81 envelope.push(ENVELOPE_VERSION);
82 envelope.extend_from_slice(&active.version.to_le_bytes());
83 envelope.extend_from_slice(&cipher_data);
84
85 Ok(envelope)
86 }
87
88 #[cfg(feature = "aes")]
90 pub fn decrypt(
91 &self,
92 ciphertext: &[u8],
93 associated_data: Option<&[u8]>,
94 ) -> CryptoResult<Vec<u8>> {
95 if ciphertext.len() < ENVELOPE_HEADER_SIZE {
96 return Err(CryptoError::InvalidCiphertext(
97 "envelope too short".to_string(),
98 ));
99 }
100
101 let envelope_version = ciphertext[0];
102 if envelope_version != ENVELOPE_VERSION {
103 return Err(CryptoError::InvalidCiphertext(format!(
104 "unsupported envelope version: {}",
105 envelope_version
106 )));
107 }
108
109 let key_version =
110 u32::from_le_bytes([ciphertext[1], ciphertext[2], ciphertext[3], ciphertext[4]]);
111
112 let versioned_key = self
113 .key_ring
114 .get_key(key_version)
115 .ok_or(CryptoError::KeyNotFound(key_version))?;
116
117 let cipher = Aes256GcmCipher::new(&versioned_key.key)?;
118 let cipher_data = &ciphertext[ENVELOPE_HEADER_SIZE..];
119
120 cipher.decrypt(cipher_data, associated_data)
121 }
122
123 #[cfg(feature = "aes")]
127 pub fn re_encrypt(
128 &self,
129 ciphertext: &[u8],
130 associated_data: Option<&[u8]>,
131 ) -> CryptoResult<Vec<u8>> {
132 let plaintext = self.decrypt(ciphertext, associated_data)?;
133 self.encrypt(&plaintext, associated_data)
134 }
135
136 pub fn uses_active_key(&self, ciphertext: &[u8]) -> CryptoResult<bool> {
138 if ciphertext.len() < ENVELOPE_HEADER_SIZE {
139 return Err(CryptoError::InvalidCiphertext(
140 "envelope too short".to_string(),
141 ));
142 }
143
144 let key_version =
145 u32::from_le_bytes([ciphertext[1], ciphertext[2], ciphertext[3], ciphertext[4]]);
146
147 Ok(self
148 .key_ring
149 .active_key()
150 .map(|k| k.version == key_version)
151 .unwrap_or(false))
152 }
153
154 pub fn get_key_version(&self, ciphertext: &[u8]) -> CryptoResult<u32> {
156 if ciphertext.len() < ENVELOPE_HEADER_SIZE {
157 return Err(CryptoError::InvalidCiphertext(
158 "envelope too short".to_string(),
159 ));
160 }
161
162 Ok(u32::from_le_bytes([
163 ciphertext[1],
164 ciphertext[2],
165 ciphertext[3],
166 ciphertext[4],
167 ]))
168 }
169}
170
171#[cfg(feature = "aes")]
172impl DataEncryptor for EnvelopeEncryptor {
173 fn encrypt_data<T: Serialize>(&self, data: &T) -> CryptoResult<Vec<u8>> {
174 let json = serde_json::to_vec(data)?;
175 self.encrypt(&json, None)
176 }
177
178 fn decrypt_data<T: DeserializeOwned>(&self, ciphertext: &[u8]) -> CryptoResult<T> {
179 let plaintext = self.decrypt(ciphertext, None)?;
180 let data = serde_json::from_slice(&plaintext)?;
181 Ok(data)
182 }
183}
184
185#[cfg(all(test, feature = "aes"))]
186mod tests {
187 use super::*;
188 use crate::aes_cipher::Aes256GcmCipher;
189
190 #[test]
191 fn test_envelope_encrypt_decrypt() {
192 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
193 let encryptor = EnvelopeEncryptor::new(key);
194
195 let plaintext = b"Secret message";
196 let ciphertext = encryptor.encrypt(plaintext, None).unwrap();
197 let decrypted = encryptor.decrypt(&ciphertext, None).unwrap();
198
199 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
200 }
201
202 #[test]
203 fn test_envelope_with_aad() {
204 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
205 let encryptor = EnvelopeEncryptor::new(key);
206
207 let plaintext = b"Secret message";
208 let aad = b"context-data";
209
210 let ciphertext = encryptor.encrypt(plaintext, Some(aad)).unwrap();
211 let decrypted = encryptor.decrypt(&ciphertext, Some(aad)).unwrap();
212
213 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
214 }
215
216 #[test]
217 fn test_key_rotation() {
218 let key1 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
219 let mut encryptor = EnvelopeEncryptor::new(key1);
220
221 let plaintext = b"Secret message";
223 let ciphertext_v1 = encryptor.encrypt(plaintext, None).unwrap();
224
225 assert_eq!(encryptor.get_key_version(&ciphertext_v1).unwrap(), 1);
226 assert!(encryptor.uses_active_key(&ciphertext_v1).unwrap());
227
228 let key2 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
230 let v2 = encryptor.rotate_key(key2);
231 assert_eq!(v2, 2);
232
233 let ciphertext_v2 = encryptor.encrypt(plaintext, None).unwrap();
235
236 assert_eq!(encryptor.get_key_version(&ciphertext_v2).unwrap(), 2);
237 assert!(encryptor.uses_active_key(&ciphertext_v2).unwrap());
238 assert!(!encryptor.uses_active_key(&ciphertext_v1).unwrap());
239
240 let decrypted_v1 = encryptor.decrypt(&ciphertext_v1, None).unwrap();
242 let decrypted_v2 = encryptor.decrypt(&ciphertext_v2, None).unwrap();
243
244 assert_eq!(plaintext.as_slice(), decrypted_v1.as_slice());
245 assert_eq!(plaintext.as_slice(), decrypted_v2.as_slice());
246 }
247
248 #[test]
249 fn test_re_encrypt() {
250 let key1 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
251 let mut encryptor = EnvelopeEncryptor::new(key1);
252
253 let plaintext = b"Secret message";
254 let ciphertext_v1 = encryptor.encrypt(plaintext, None).unwrap();
255
256 let key2 = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
258 encryptor.rotate_key(key2);
259
260 let ciphertext_v2 = encryptor.re_encrypt(&ciphertext_v1, None).unwrap();
262
263 assert_eq!(encryptor.get_key_version(&ciphertext_v2).unwrap(), 2);
264
265 let decrypted = encryptor.decrypt(&ciphertext_v2, None).unwrap();
266 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
267 }
268
269 #[test]
270 fn test_data_encryptor_json() {
271 use serde::{Deserialize, Serialize};
272
273 #[derive(Debug, Serialize, Deserialize, PartialEq)]
274 struct TestData {
275 name: String,
276 value: i32,
277 }
278
279 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
280 let encryptor = EnvelopeEncryptor::new(key);
281
282 let data = TestData {
283 name: "test".to_string(),
284 value: 42,
285 };
286
287 let ciphertext = encryptor.encrypt_data(&data).unwrap();
288 let decrypted: TestData = encryptor.decrypt_data(&ciphertext).unwrap();
289
290 assert_eq!(data, decrypted);
291 }
292
293 #[test]
294 fn test_envelope_header_format() {
295 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
296 let encryptor = EnvelopeEncryptor::new(key);
297
298 let plaintext = b"Test";
299 let ciphertext = encryptor.encrypt(plaintext, None).unwrap();
300
301 assert_eq!(ciphertext[0], ENVELOPE_VERSION);
303
304 let key_version =
306 u32::from_le_bytes([ciphertext[1], ciphertext[2], ciphertext[3], ciphertext[4]]);
307 assert_eq!(key_version, 1);
308 }
309
310 #[test]
311 fn test_missing_key_version() {
312 let key = EncryptionKey::generate(Aes256GcmCipher::KEY_SIZE);
313 let encryptor = EnvelopeEncryptor::new(key);
314
315 let plaintext = b"Test";
316 let mut ciphertext = encryptor.encrypt(plaintext, None).unwrap();
317
318 ciphertext[1] = 99;
320 ciphertext[2] = 0;
321 ciphertext[3] = 0;
322 ciphertext[4] = 0;
323
324 let result = encryptor.decrypt(&ciphertext, None);
325 assert!(matches!(result, Err(CryptoError::KeyNotFound(99))));
326 }
327}