1use std::fmt;
83
84use lexe_std::array;
85use ref_cast::RefCast;
86use ring::{
87 aead::{self, BoundKey},
88 hkdf,
89};
90use serde_core::ser::{Serialize, SerializeStruct, Serializer};
91
92use crate::rng::{Crng, RngExt};
93
94const VERSION_LEN: usize = 1;
96
97const KEY_ID_LEN: usize = 32;
99
100const TAG_LEN: usize = 16;
102
103pub const fn encrypted_len(plaintext_len: usize) -> usize {
106 VERSION_LEN + KEY_ID_LEN + plaintext_len + TAG_LEN
107}
108
109pub struct AesMasterKey(hkdf::Prk);
116
117#[derive(RefCast)]
124#[repr(transparent)]
125struct KeyId([u8; 32]);
126
127struct Aad<'data, 'aad> {
137 version: u8,
138 key_id: &'data KeyId,
139 aad: &'aad [&'aad [u8]],
140}
141
142struct EncryptKey(aead::SealingKey<ZeroNonce>);
143
144struct DecryptKey(aead::OpeningKey<ZeroNonce>);
145
146struct ZeroNonce(Option<aead::Nonce>);
149
150#[derive(Clone, Debug)]
151pub struct DecryptError;
152
153impl std::error::Error for DecryptError {}
154
155impl fmt::Display for DecryptError {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 f.write_str("decrypt error: ciphertext or metadata may be corrupted")
158 }
159}
160
161impl fmt::Debug for AesMasterKey {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.write_str("AesMasterKey(..)")
164 }
165}
166
167impl AesMasterKey {
168 const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::AesMasterKey");
169
170 pub fn new(root_seed_derived_secret: &[u8; 32]) -> Self {
171 Self(
172 hkdf::Salt::new(hkdf::HKDF_SHA256, &Self::HKDF_SALT)
173 .extract(root_seed_derived_secret),
174 )
175 }
176
177 fn derive_unbound_key(&self, key_id: &KeyId) -> aead::UnboundKey {
178 aead::UnboundKey::from(
179 self.0
180 .expand(&[key_id.as_slice()], &aead::AES_256_GCM)
181 .expect("This should never fail"),
182 )
183 }
184
185 fn derive_encrypt_key(&self, key_id: &KeyId) -> EncryptKey {
186 let nonce = ZeroNonce::new();
187 let key = aead::SealingKey::new(self.derive_unbound_key(key_id), nonce);
188 EncryptKey(key)
189 }
190
191 fn derive_decrypt_key(&self, key_id: &KeyId) -> DecryptKey {
192 let nonce = ZeroNonce::new();
193 let key = aead::OpeningKey::new(self.derive_unbound_key(key_id), nonce);
194 DecryptKey(key)
195 }
196
197 pub fn encrypt<R: Crng>(
198 &self,
199 rng: &mut R,
200 aad: &[&[u8]],
201 data_size_hint: Option<usize>,
204 write_data_cb: &dyn Fn(&mut Vec<u8>),
207 ) -> Vec<u8> {
208 let version = 0;
209 let key_id = KeyId::from_rng(rng);
210
211 let aad = Aad {
212 version,
213 key_id: &key_id,
214 aad,
215 }
216 .serialize();
217
218 let approx_encrypted_len = encrypted_len(data_size_hint.unwrap_or(0));
220 let mut data = Vec::with_capacity(approx_encrypted_len);
221
222 data.push(version);
225 data.extend_from_slice(key_id.as_slice());
226 let plaintext_offset = data.len();
227
228 write_data_cb(&mut data);
231
232 self.derive_encrypt_key(&key_id).encrypt_in_place(
235 aad.as_slice(),
236 &mut data,
237 plaintext_offset,
238 );
239
240 data
243 }
244
245 pub fn decrypt(
246 &self,
247 aad: &[&[u8]],
248 mut data: Vec<u8>,
249 ) -> Result<Vec<u8>, DecryptError> {
250 const MIN_DATA_LEN: usize = encrypted_len(0 );
253 if data.len() < MIN_DATA_LEN {
254 return Err(DecryptError);
255 }
256
257 let (version, key_id) = {
259 let (version, data) = data
260 .split_first_chunk::<VERSION_LEN>()
261 .expect("data.len() checked above");
262 let (key_id, _) = data
263 .split_first_chunk::<KEY_ID_LEN>()
264 .expect("data.len() checked above");
265 (version[0], key_id)
266 };
267
268 if version != 0 {
269 return Err(DecryptError);
270 }
271 let key_id = KeyId::from_ref(key_id);
272 let decrypt_key = self.derive_decrypt_key(key_id);
273
274 let aad = Aad {
275 version,
276 key_id,
277 aad,
278 }
279 .serialize();
280
281 let ciphertext_and_tag_offset = VERSION_LEN + KEY_ID_LEN;
282 decrypt_key.decrypt_in_place(
283 &aad,
284 &mut data,
285 ciphertext_and_tag_offset,
286 )?;
287
288 Ok(data)
291 }
292}
293
294impl EncryptKey {
295 fn encrypt_in_place(
299 mut self,
300 aad: &[u8],
301 data: &mut Vec<u8>,
302 plaintext_offset: usize,
303 ) {
304 assert!(plaintext_offset <= data.len());
305
306 let aad = aead::Aad::from(aad);
307 let tag = self
308 .0
309 .seal_in_place_separate_tag(aad, &mut data[plaintext_offset..])
310 .expect(
311 "Cannot encrypt more than ~4 GiB at once (should never happen)",
312 );
313 data.extend_from_slice(tag.as_ref());
314 }
315}
316
317impl DecryptKey {
318 fn decrypt_in_place(
322 mut self,
323 aad: &[u8],
324 data: &mut Vec<u8>,
325 ciphertext_and_tag_offset: usize,
326 ) -> Result<(), DecryptError> {
327 let aad = aead::Aad::from(aad);
330
331 let plaintext_ref = self
332 .0
333 .open_within(aad, data, ciphertext_and_tag_offset..)
334 .map_err(|_| DecryptError)?;
335 let plaintext_len = plaintext_ref.len();
336
337 data.truncate(plaintext_len);
340
341 Ok(())
342 }
343}
344
345impl KeyId {
346 #[inline]
347 const fn from_ref(arr: &[u8; 32]) -> &Self {
348 lexe_std::const_utils::const_ref_cast(arr)
349 }
350
351 #[inline]
352 fn as_slice(&self) -> &[u8] {
353 self.0.as_slice()
354 }
355
356 fn from_rng<R: Crng>(rng: &mut R) -> Self {
357 Self(rng.gen_bytes())
358 }
359}
360
361impl Serialize for KeyId {
362 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
363 where
364 S: Serializer,
365 {
366 self.0.serialize(serializer)
367 }
368}
369
370impl Aad<'_, '_> {
371 fn serialize(&self) -> Vec<u8> {
372 let len = bcs::serialized_size(self)
373 .expect("Serializing the AAD should never fail");
374
375 let mut out = Vec::with_capacity(len);
376 bcs::serialize_into(&mut out, self)
377 .expect("Serializing the AAD should never fail");
378 out
379 }
380}
381
382impl Serialize for Aad<'_, '_> {
383 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
384 where
385 S: Serializer,
386 {
387 let mut fields = serializer.serialize_struct("Aad", 3)?;
388 fields.serialize_field("version", &self.version)?;
389 fields.serialize_field("key_id", self.key_id)?;
390 fields.serialize_field("aad", self.aad)?;
391 fields.end()
392 }
393}
394
395impl ZeroNonce {
396 fn new() -> Self {
397 Self(Some(aead::Nonce::assume_unique_for_key([0u8; 12])))
398 }
399}
400
401impl aead::NonceSequence for ZeroNonce {
402 fn advance(&mut self) -> Result<aead::Nonce, ring::error::Unspecified> {
403 Ok(self.0.take().expect(
404 "We somehow encrypted / decrypted more than once with the same key",
405 ))
406 }
407}
408
409#[cfg(any(test, feature = "test-utils"))]
411pub(crate) fn derive_key(rng: &mut crate::rng::FastRng) -> AesMasterKey {
412 struct OkmLength;
413 impl hkdf::KeyType for OkmLength {
414 fn len(&self) -> usize {
415 32
416 }
417 }
418
419 const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::RootSeed");
420 let seed: [u8; 32] = rng.gen_bytes();
421 let mut key_seed = [0u8; 32];
422 hkdf::Salt::new(hkdf::HKDF_SHA256, HKDF_SALT.as_slice())
423 .extract(&seed)
424 .expand(&[b"vfs master key"], OkmLength)
425 .unwrap()
426 .fill(key_seed.as_mut_slice())
427 .unwrap();
428 AesMasterKey::new(&key_seed)
429}
430
431#[cfg(any(test, feature = "test-utils"))]
432mod arbitrary_impl {
433 use proptest::{
434 arbitrary::{Arbitrary, any},
435 strategy::{BoxedStrategy, Strategy},
436 };
437
438 use super::*;
439 use crate::rng::FastRng;
440
441 impl Arbitrary for AesMasterKey {
442 type Parameters = ();
443 type Strategy = BoxedStrategy<Self>;
444 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
445 any::<FastRng>()
446 .prop_map(|mut rng| derive_key(&mut rng))
447 .boxed()
448 }
449 }
450}
451
452#[cfg(test)]
453mod test {
454 use lexe_hex::hex;
455 use proptest::{
456 arbitrary::any, collection::vec, prop_assert, prop_assert_eq, proptest,
457 };
458
459 use super::*;
460 use crate::rng::FastRng;
461
462 #[test]
463 fn test_aad_compat() {
464 let aad = Aad {
465 version: 0,
466 key_id: KeyId::from_ref(&[0x69; 32]),
467 aad: &[],
468 }
469 .serialize();
470
471 let expected_aad = hex::decode(
472 "00\
473 6969696969696969696969696969696969696969696969696969696969696969\
474 00",
475 )
476 .unwrap();
477
478 assert_eq!(&aad, &expected_aad);
479
480 let aad = Aad {
481 version: 0,
482 key_id: KeyId::from_ref(&[0x42; 32]),
483 aad: &[b"aaaaaaaa".as_slice(), b"0123456789".as_slice()],
484 }
485 .serialize();
486
487 let expected_aad = hex::decode(
488 "00\
489 4242424242424242424242424242424242424242424242424242424242424242\
490 02\
491 08\
492 6161616161616161\
493 0a\
494 30313233343536373839",
495 )
496 .unwrap();
497 assert_eq!(&aad, &expected_aad);
498 }
499
500 #[test]
501 fn test_decrypt_compat() {
502 let mut rng = FastRng::from_u64(123);
503 let vfs_key = derive_key(&mut rng);
504
505 let encrypted = hex::decode(
512 "00\
514 b0abd2beab31c1d925c5d8059cf90068eece2c41a3a6e4454d84e36ad6858a01\
515 \
516 0e2d1f6d16e9bb5738de28b4f180f07f",
517 )
518 .unwrap();
519
520 let decrypted = vfs_key.decrypt(&[], encrypted).unwrap();
521 assert_eq!(decrypted.as_slice(), b"");
522
523 let aad = b"my context".as_slice();
526 let plaintext = b"my cool message".as_slice();
527
528 let encrypted = hex::decode(
535 "00\
537 c87fea5c4db8c16d3dae5a6ead5ee5985fa7c38721b9624e37772adea6a48aae\
538 22f52c6f08440092338d16e3402eaf\
539 c3972d357e56dad4cc42c6a80da4ac35",
540 )
541 .unwrap();
542
543 let decrypted = vfs_key.decrypt(&[aad], encrypted).unwrap();
544
545 assert_eq!(decrypted.as_slice(), plaintext);
546 }
547
548 #[test]
549 fn test_encrypt_decrypt_roundtrip() {
550 proptest!(|(
551 mut rng in any::<FastRng>(),
552 aad in vec(vec(any::<u8>(), 0..=16), 0..=4),
553 plaintext in vec(any::<u8>(), 0..=256),
554 )| {
555 let vfs_key = derive_key(&mut rng);
556
557 let aad_ref = aad
558 .iter()
559 .map(|x| x.as_slice())
560 .collect::<Vec<_>>();
561
562 let encrypted = vfs_key.encrypt(&mut rng, &aad_ref, Some(plaintext.len()), &|out: &mut Vec<u8>| {
563 out.extend_from_slice(&plaintext);
564 });
565
566 let decrypted = vfs_key.decrypt(&aad_ref, encrypted.clone()).unwrap();
567 prop_assert_eq!(&plaintext, &decrypted);
568
569 let encrypted2 = vfs_key.encrypt(&mut rng, &aad_ref, None, &|out: &mut Vec<u8>| {
570 out.extend_from_slice(&plaintext);
571 });
572
573 prop_assert!(encrypted != encrypted2);
574 });
575 }
576}