1use std::{borrow::Cow, fmt};
83
84use lexe_std::array;
85use ref_cast::RefCast;
86use ring::{
87 aead::{self, BoundKey},
88 hkdf,
89};
90use serde_core::ser::{Serialize, SerializeStruct, Serializer};
91
92use crate::rng::{Crng, RngExt};
93
94const VERSION_LEN: usize = 1;
96
97const KEY_ID_LEN: usize = 32;
99
100const TAG_LEN: usize = 16;
102
103pub const fn encrypted_len(plaintext_len: usize) -> usize {
106 VERSION_LEN + KEY_ID_LEN + plaintext_len + TAG_LEN
107}
108
109pub struct AesMasterKey(hkdf::Prk);
116
117#[derive(RefCast)]
124#[repr(transparent)]
125struct KeyId([u8; 32]);
126
127struct Aad<'data, 'aad> {
137 version: u8,
138 key_id: &'data KeyId,
139 aad: &'aad [&'aad [u8]],
140}
141
142struct EncryptKey(aead::SealingKey<ZeroNonce>);
143
144struct DecryptKey(aead::OpeningKey<ZeroNonce>);
145
146struct ZeroNonce(Option<aead::Nonce>);
149
150#[derive(Clone, Debug)]
151pub struct DecryptError(Cow<'static, str>);
152
153impl std::error::Error for DecryptError {}
154
155impl fmt::Display for DecryptError {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "Decrypt error: {}", self.0)
158 }
159}
160
161impl fmt::Debug for AesMasterKey {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.write_str("AesMasterKey(..)")
164 }
165}
166
167impl AesMasterKey {
168 const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::AesMasterKey");
169
170 pub fn new(root_seed_derived_secret: &[u8; 32]) -> Self {
171 Self(
172 hkdf::Salt::new(hkdf::HKDF_SHA256, &Self::HKDF_SALT)
173 .extract(root_seed_derived_secret),
174 )
175 }
176
177 fn derive_unbound_key(&self, key_id: &KeyId) -> aead::UnboundKey {
178 aead::UnboundKey::from(
179 self.0
180 .expand(&[key_id.as_slice()], &aead::AES_256_GCM)
181 .expect("This should never fail"),
182 )
183 }
184
185 fn derive_encrypt_key(&self, key_id: &KeyId) -> EncryptKey {
186 let nonce = ZeroNonce::new();
187 let key = aead::SealingKey::new(self.derive_unbound_key(key_id), nonce);
188 EncryptKey(key)
189 }
190
191 fn derive_decrypt_key(&self, key_id: &KeyId) -> DecryptKey {
192 let nonce = ZeroNonce::new();
193 let key = aead::OpeningKey::new(self.derive_unbound_key(key_id), nonce);
194 DecryptKey(key)
195 }
196
197 pub fn encrypt<R: Crng>(
198 &self,
199 rng: &mut R,
200 aad: &[&[u8]],
201 data_size_hint: Option<usize>,
204 write_data_cb: &dyn Fn(&mut Vec<u8>),
207 ) -> Vec<u8> {
208 let version = 0;
209 let key_id = KeyId::from_rng(rng);
210
211 let aad = Aad {
212 version,
213 key_id: &key_id,
214 aad,
215 }
216 .serialize();
217
218 let approx_encrypted_len = encrypted_len(data_size_hint.unwrap_or(0));
220 let mut data = Vec::with_capacity(approx_encrypted_len);
221
222 data.push(version);
225 data.extend_from_slice(key_id.as_slice());
226 let plaintext_offset = data.len();
227
228 write_data_cb(&mut data);
231
232 self.derive_encrypt_key(&key_id).encrypt_in_place(
235 aad.as_slice(),
236 &mut data,
237 plaintext_offset,
238 );
239
240 data
243 }
244
245 pub fn decrypt(
246 &self,
247 aad: &[&[u8]],
248 mut data: Vec<u8>,
249 ) -> Result<Vec<u8>, DecryptError> {
250 const MIN_DATA_LEN: usize = encrypted_len(0 );
253 if data.len() < MIN_DATA_LEN {
254 return Err(DecryptError(Cow::Borrowed(
255 "ciphertext too short to contain version, key_id, and tag",
256 )));
257 }
258
259 let (version, key_id) = {
261 let (version, data) = data
262 .split_first_chunk::<VERSION_LEN>()
263 .expect("data.len() checked above");
264 let (key_id, _) = data
265 .split_first_chunk::<KEY_ID_LEN>()
266 .expect("data.len() checked above");
267 (version[0], key_id)
268 };
269
270 if version != 0 {
271 return Err(DecryptError(Cow::Owned(format!(
272 "unsupported version: {version}"
273 ))));
274 }
275 let key_id = KeyId::from_ref(key_id);
276 let decrypt_key = self.derive_decrypt_key(key_id);
277
278 let aad = Aad {
279 version,
280 key_id,
281 aad,
282 }
283 .serialize();
284
285 let ciphertext_and_tag_offset = VERSION_LEN + KEY_ID_LEN;
286 decrypt_key.decrypt_in_place(
287 &aad,
288 &mut data,
289 ciphertext_and_tag_offset,
290 )?;
291
292 Ok(data)
295 }
296}
297
298impl EncryptKey {
299 fn encrypt_in_place(
303 mut self,
304 aad: &[u8],
305 data: &mut Vec<u8>,
306 plaintext_offset: usize,
307 ) {
308 assert!(plaintext_offset <= data.len());
309
310 let aad = aead::Aad::from(aad);
311 let tag = self
312 .0
313 .seal_in_place_separate_tag(aad, &mut data[plaintext_offset..])
314 .expect(
315 "Cannot encrypt more than ~4 GiB at once (should never happen)",
316 );
317 data.extend_from_slice(tag.as_ref());
318 }
319}
320
321impl DecryptKey {
322 fn decrypt_in_place(
326 mut self,
327 aad: &[u8],
328 data: &mut Vec<u8>,
329 ciphertext_and_tag_offset: usize,
330 ) -> Result<(), DecryptError> {
331 let aad = aead::Aad::from(aad);
334
335 let plaintext_ref = self
336 .0
337 .open_within(aad, data, ciphertext_and_tag_offset..)
338 .map_err(|_| "AEAD open failed: ciphertext, tag, or AAD corrupted")
339 .map_err(|msg| DecryptError(Cow::Borrowed(msg)))?;
340 let plaintext_len = plaintext_ref.len();
341
342 data.truncate(plaintext_len);
345
346 Ok(())
347 }
348}
349
350impl KeyId {
351 #[inline]
352 const fn from_ref(arr: &[u8; 32]) -> &Self {
353 lexe_std::const_utils::const_ref_cast(arr)
354 }
355
356 #[inline]
357 fn as_slice(&self) -> &[u8] {
358 self.0.as_slice()
359 }
360
361 fn from_rng<R: Crng>(rng: &mut R) -> Self {
362 Self(rng.gen_bytes())
363 }
364}
365
366impl Serialize for KeyId {
367 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
368 where
369 S: Serializer,
370 {
371 self.0.serialize(serializer)
372 }
373}
374
375impl Aad<'_, '_> {
376 fn serialize(&self) -> Vec<u8> {
377 let len = bcs::serialized_size(self)
378 .expect("Serializing the AAD should never fail");
379
380 let mut out = Vec::with_capacity(len);
381 bcs::serialize_into(&mut out, self)
382 .expect("Serializing the AAD should never fail");
383 out
384 }
385}
386
387impl Serialize for Aad<'_, '_> {
388 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
389 where
390 S: Serializer,
391 {
392 let mut fields = serializer.serialize_struct("Aad", 3)?;
393 fields.serialize_field("version", &self.version)?;
394 fields.serialize_field("key_id", self.key_id)?;
395 fields.serialize_field("aad", self.aad)?;
396 fields.end()
397 }
398}
399
400impl ZeroNonce {
401 fn new() -> Self {
402 Self(Some(aead::Nonce::assume_unique_for_key([0u8; 12])))
403 }
404}
405
406impl aead::NonceSequence for ZeroNonce {
407 fn advance(&mut self) -> Result<aead::Nonce, ring::error::Unspecified> {
408 Ok(self.0.take().expect(
409 "We somehow encrypted / decrypted more than once with the same key",
410 ))
411 }
412}
413
414#[cfg(any(test, feature = "test-utils"))]
416pub(crate) fn derive_key(rng: &mut crate::rng::FastRng) -> AesMasterKey {
417 struct OkmLength;
418 impl hkdf::KeyType for OkmLength {
419 fn len(&self) -> usize {
420 32
421 }
422 }
423
424 const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::RootSeed");
425 let seed: [u8; 32] = rng.gen_bytes();
426 let mut key_seed = [0u8; 32];
427 hkdf::Salt::new(hkdf::HKDF_SHA256, HKDF_SALT.as_slice())
428 .extract(&seed)
429 .expand(&[b"vfs master key"], OkmLength)
430 .unwrap()
431 .fill(key_seed.as_mut_slice())
432 .unwrap();
433 AesMasterKey::new(&key_seed)
434}
435
436#[cfg(any(test, feature = "test-utils"))]
437mod arbitrary_impl {
438 use proptest::{
439 arbitrary::{Arbitrary, any},
440 strategy::{BoxedStrategy, Strategy},
441 };
442
443 use super::*;
444 use crate::rng::FastRng;
445
446 impl Arbitrary for AesMasterKey {
447 type Parameters = ();
448 type Strategy = BoxedStrategy<Self>;
449 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
450 any::<FastRng>()
451 .prop_map(|mut rng| derive_key(&mut rng))
452 .boxed()
453 }
454 }
455}
456
457#[cfg(test)]
458mod test {
459 use lexe_hex::hex;
460 use proptest::{
461 arbitrary::any, collection::vec, prop_assert, prop_assert_eq, proptest,
462 };
463
464 use super::*;
465 use crate::rng::FastRng;
466
467 #[test]
468 fn test_aad_compat() {
469 let aad = Aad {
470 version: 0,
471 key_id: KeyId::from_ref(&[0x69; 32]),
472 aad: &[],
473 }
474 .serialize();
475
476 let expected_aad = hex::decode(
477 "00\
478 6969696969696969696969696969696969696969696969696969696969696969\
479 00",
480 )
481 .unwrap();
482
483 assert_eq!(&aad, &expected_aad);
484
485 let aad = Aad {
486 version: 0,
487 key_id: KeyId::from_ref(&[0x42; 32]),
488 aad: &[b"aaaaaaaa".as_slice(), b"0123456789".as_slice()],
489 }
490 .serialize();
491
492 let expected_aad = hex::decode(
493 "00\
494 4242424242424242424242424242424242424242424242424242424242424242\
495 02\
496 08\
497 6161616161616161\
498 0a\
499 30313233343536373839",
500 )
501 .unwrap();
502 assert_eq!(&aad, &expected_aad);
503 }
504
505 #[test]
506 fn test_decrypt_compat() {
507 let mut rng = FastRng::from_u64(123);
508 let vfs_key = derive_key(&mut rng);
509
510 let encrypted = hex::decode(
517 "00\
519 b0abd2beab31c1d925c5d8059cf90068eece2c41a3a6e4454d84e36ad6858a01\
520 \
521 0e2d1f6d16e9bb5738de28b4f180f07f",
522 )
523 .unwrap();
524
525 let decrypted = vfs_key.decrypt(&[], encrypted).unwrap();
526 assert_eq!(decrypted.as_slice(), b"");
527
528 let aad = b"my context".as_slice();
531 let plaintext = b"my cool message".as_slice();
532
533 let encrypted = hex::decode(
540 "00\
542 c87fea5c4db8c16d3dae5a6ead5ee5985fa7c38721b9624e37772adea6a48aae\
543 22f52c6f08440092338d16e3402eaf\
544 c3972d357e56dad4cc42c6a80da4ac35",
545 )
546 .unwrap();
547
548 let decrypted = vfs_key.decrypt(&[aad], encrypted).unwrap();
549
550 assert_eq!(decrypted.as_slice(), plaintext);
551 }
552
553 #[test]
554 fn test_encrypt_decrypt_roundtrip() {
555 proptest!(|(
556 mut rng in any::<FastRng>(),
557 aad in vec(vec(any::<u8>(), 0..=16), 0..=4),
558 plaintext in vec(any::<u8>(), 0..=256),
559 )| {
560 let vfs_key = derive_key(&mut rng);
561
562 let aad_ref = aad
563 .iter()
564 .map(|x| x.as_slice())
565 .collect::<Vec<_>>();
566
567 let encrypted = vfs_key.encrypt(&mut rng, &aad_ref, Some(plaintext.len()), &|out: &mut Vec<u8>| {
568 out.extend_from_slice(&plaintext);
569 });
570
571 let decrypted = vfs_key.decrypt(&aad_ref, encrypted.clone()).unwrap();
572 prop_assert_eq!(&plaintext, &decrypted);
573
574 let encrypted2 = vfs_key.encrypt(&mut rng, &aad_ref, None, &|out: &mut Vec<u8>| {
575 out.extend_from_slice(&plaintext);
576 });
577
578 prop_assert!(encrypted != encrypted2);
579 });
580 }
581}