Skip to main content

lexe_crypto/
aes.rs

1//! Securely encrypt and decrypted blobs, usually for remote storage.
2//!
3//! ## Design Considerations
4//!
5//! * AES-256-GCM uses 12-byte nonces (2^96 bits).
6//! * For a given key, any nonce reuse is catastrophic w/ AES-GCM.
7//! * Synthetic nonce / nonce reuse resistant schemes like AES-SIV-GCM aren't
8//!   available in [`ring`] or have undesirable properties (multiple passes, max
9//!   2^32 encryptions).
10//! * [`ring`] doesn't support XChaCha20-Poly1305, which would let us use a
11//!   larger nonce.
12//! * We need to use [`ring`] b/c TLS. We don't want to depend on other crypto
13//!   libraries b/c attack surface and binary bloat.
14//! * For our particular use case, we don't particularly care about
15//!   single-message (key, nonce) wear-out, since our messages aren't
16//!   particularly large (at most a few MiB).
17//! * We also don't have access to a consistent monotonic counter b/c
18//!   distributed system and adversarial host, so we use random nonces.
19//! * If we had a reliable counter, we could safely encrypt ~2^64 messages
20//!   before key wear-out, which would be sufficient for us to simplify and just
21//!   use one key for all encryptions.
22//! * For a given key and perfectly random nonces, with nonce collision
23//!   probability = 2^-32 (standard NIST bound), we can expect key wear-out
24//!   after 2^32 encryptions.
25//!
26//! ## Design
27//!
28//! This scheme is inspired by "Derive Key Mode" described in
29//! [(2017) GueronLindel](https://eprint.iacr.org/2017/702.pdf).
30//! "Derive Key Mode" uses a long-term "master key" (see `AesMasterKey`), which
31//! isn't used to encrypt data; rather, it's used to derive per-message keys
32//! from a large random key-id, sampled per message (see `KeyId`).
33//!
34//! In our case, we use a 32-byte (2^256 bit) key id to derive each per-message
35//! `EncryptKey`/`DecryptKey`, which gives us plenty of breathing room as far as
36//! safety bounds are concerned.
37//!
38//! For the AAD, taking a single `&[u8]` would require the caller to allocate
39//! and canonically serialize (length-prefixes, etc...) when there are multiple
40//! things to bind. Then, we would need to copy+allocate again in order to bind
41//! the `version`, `key-id`, and user AAD. To avoid this the user passes the AAD
42//! as a list of segments (like fields of a struct). For more info, see `Aad`.
43//!
44//! We use an AES-256-GCM nonce of all zeroes, since keys are single-use and
45//! 256 bits of security are Good Enough^tm.
46//!
47//! The scheme in simplified pseudo-code, encryption only:
48//!
49//! ```text
50//! master-key := (secret derived from user's root seed)
51//!
52//! Aad(version, key-id, user-aad: &[&[u8]]) :=
53//! 1. return bcs::to_bytes({ version, key-id, user-aad })
54//!
55//! Encrypt(master-key, user-aad: &[&[u8]], plaintext) :=
56//! 1. version := 0_u8
57//! 2. key-id := random 32-byte value
58//! 3. aad := Aad(version, key-id, user-aad)
59//! 4. encrypt-key := HKDF-Extract-Expand(
60//!         ikm=master-key,
61//!         salt=array::pad::<32>("LEXE-REALM::AesMasterKey"),
62//!         info=key-id,
63//!         out-len=32 bytes,
64//!    )
65//! 5. (ciphertext, tag) := AES-256-GCM(encrypt-key, nonce=[0; 12], aad, plaintext)
66//! 6. output := version || key-id || ciphertext || tag
67//! 7. return output
68//! ```
69//!
70//! ## References
71//!
72//! * [(2017) GueronLindel](https://eprint.iacr.org/2017/702.pdf) ([video](https://www.youtube.com/watch?v=WEJ451rmhk4))
73//!
74//! This paper, "Better Bounds for Block Cipher Modes of Operation via
75//! Nonce-Based Key Derivation", shows how "Derive Key Mode" significantly
76//! improves the security bounds over the standard long-lived key approach.
77//!
78//! * [(2020) Cryptographic Wear-out for Symmetric Encryption](https://soatok.blog/2020/12/24/cryptographic-wear-out-for-symmetric-encryption/)
79//!
80//! This article describes symmetric security bounds nicely.
81
82use std::{borrow::Cow, fmt};
83
84use lexe_std::array;
85use ref_cast::RefCast;
86use ring::{
87    aead::{self, BoundKey},
88    hkdf,
89};
90use serde_core::ser::{Serialize, SerializeStruct, Serializer};
91
92use crate::rng::{Crng, RngExt};
93
94/// serialized version length
95const VERSION_LEN: usize = 1;
96
97/// serialized [`KeyId`] length
98const KEY_ID_LEN: usize = 32;
99
100/// serialized AES-256-GCM tag length
101const TAG_LEN: usize = 16;
102
103/// The length of the final encrypted ciphertext + version byte + key_id + tag
104/// given an input plaintext length.
105pub const fn encrypted_len(plaintext_len: usize) -> usize {
106    VERSION_LEN + KEY_ID_LEN + plaintext_len + TAG_LEN
107}
108
109/// The `AesMasterKey` is used to derive unique single-use encrypt keys for
110/// encrypting or decrypting a blob.
111///
112/// `RootSeed` -- derive("vfs master key") --> `AesMasterKey`
113// We store the salted+extracted PRK directly to avoid recomputing it every
114// time we encrypt something.
115pub struct AesMasterKey(hkdf::Prk);
116
117/// `KeyId` is the value used to derive the single-use message
118/// encryption/decryption key from the [`AesMasterKey`] HKDF.
119///
120/// As explained in the module docs, AES-GCM nonces are too small (12-bytes), so
121/// we use what is effectively a synthetic nonce scheme by deriving single-use
122/// keys from a larger pool of entropy (2^32 bits) for each separate encryption.
123#[derive(RefCast)]
124#[repr(transparent)]
125struct KeyId([u8; 32]);
126
127/// `Aad` is canonically serialized and then passed to AES-256-GCM as the `aad`
128/// (additional authenticated data) parameter.
129///
130/// It serves to:
131///
132/// 1. bind the protocol version
133/// 2. bind the encryption key (via the key id)
134/// 3. bind the user-provided additional authenticated data segments, including
135///    the number of segments, and the lengths of each segment.
136struct Aad<'data, 'aad> {
137    version: u8,
138    key_id: &'data KeyId,
139    aad: &'aad [&'aad [u8]],
140}
141
142struct EncryptKey(aead::SealingKey<ZeroNonce>);
143
144struct DecryptKey(aead::OpeningKey<ZeroNonce>);
145
146/// A single-use, all-zero nonce that panics if used to encrypt or decrypt data
147/// more than once (for a particular instance).
148struct ZeroNonce(Option<aead::Nonce>);
149
150#[derive(Clone, Debug)]
151pub struct DecryptError(Cow<'static, str>);
152
153impl std::error::Error for DecryptError {}
154
155impl fmt::Display for DecryptError {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "Decrypt error: {}", self.0)
158    }
159}
160
161impl fmt::Debug for AesMasterKey {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.write_str("AesMasterKey(..)")
164    }
165}
166
167impl AesMasterKey {
168    const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::AesMasterKey");
169
170    pub fn new(root_seed_derived_secret: &[u8; 32]) -> Self {
171        Self(
172            hkdf::Salt::new(hkdf::HKDF_SHA256, &Self::HKDF_SALT)
173                .extract(root_seed_derived_secret),
174        )
175    }
176
177    fn derive_unbound_key(&self, key_id: &KeyId) -> aead::UnboundKey {
178        aead::UnboundKey::from(
179            self.0
180                .expand(&[key_id.as_slice()], &aead::AES_256_GCM)
181                .expect("This should never fail"),
182        )
183    }
184
185    fn derive_encrypt_key(&self, key_id: &KeyId) -> EncryptKey {
186        let nonce = ZeroNonce::new();
187        let key = aead::SealingKey::new(self.derive_unbound_key(key_id), nonce);
188        EncryptKey(key)
189    }
190
191    fn derive_decrypt_key(&self, key_id: &KeyId) -> DecryptKey {
192        let nonce = ZeroNonce::new();
193        let key = aead::OpeningKey::new(self.derive_unbound_key(key_id), nonce);
194        DecryptKey(key)
195    }
196
197    pub fn encrypt<R: Crng>(
198        &self,
199        rng: &mut R,
200        aad: &[&[u8]],
201        // A size hint so we can possibly avoid reallocing. If you don't know
202        // how long the plaintext will be, just set this to None.
203        data_size_hint: Option<usize>,
204        // This closure should write the object to the provided &mut Vec<u8>.
205        // See tests as well as node / lsp `encrypt_*` for examples.
206        write_data_cb: &dyn Fn(&mut Vec<u8>),
207    ) -> Vec<u8> {
208        let version = 0;
209        let key_id = KeyId::from_rng(rng);
210
211        let aad = Aad {
212            version,
213            key_id: &key_id,
214            aad,
215        }
216        .serialize();
217
218        // reserve enough capacity for at least version, key_id, and tag
219        let approx_encrypted_len = encrypted_len(data_size_hint.unwrap_or(0));
220        let mut data = Vec::with_capacity(approx_encrypted_len);
221
222        // data := ""
223
224        data.push(version);
225        data.extend_from_slice(key_id.as_slice());
226        let plaintext_offset = data.len();
227
228        // data := [version] || [key_id]
229
230        write_data_cb(&mut data);
231
232        // data := [version] || [key_id] || [plaintext]
233
234        self.derive_encrypt_key(&key_id).encrypt_in_place(
235            aad.as_slice(),
236            &mut data,
237            plaintext_offset,
238        );
239
240        // data := [version] || [key_id] || [ciphertext] || [tag]
241
242        data
243    }
244
245    pub fn decrypt(
246        &self,
247        aad: &[&[u8]],
248        mut data: Vec<u8>,
249    ) -> Result<Vec<u8>, DecryptError> {
250        // data := [version] || [key_id] || [ciphertext] || [tag]
251
252        const MIN_DATA_LEN: usize = encrypted_len(0 /* plaintext len */);
253        if data.len() < MIN_DATA_LEN {
254            return Err(DecryptError(Cow::Borrowed(
255                "ciphertext too short to contain version, key_id, and tag",
256            )));
257        }
258
259        // parse out version and key_id w/o advancing `data`
260        let (version, key_id) = {
261            let (version, data) = data
262                .split_first_chunk::<VERSION_LEN>()
263                .expect("data.len() checked above");
264            let (key_id, _) = data
265                .split_first_chunk::<KEY_ID_LEN>()
266                .expect("data.len() checked above");
267            (version[0], key_id)
268        };
269
270        if version != 0 {
271            return Err(DecryptError(Cow::Owned(format!(
272                "unsupported version: {version}"
273            ))));
274        }
275        let key_id = KeyId::from_ref(key_id);
276        let decrypt_key = self.derive_decrypt_key(key_id);
277
278        let aad = Aad {
279            version,
280            key_id,
281            aad,
282        }
283        .serialize();
284
285        let ciphertext_and_tag_offset = VERSION_LEN + KEY_ID_LEN;
286        decrypt_key.decrypt_in_place(
287            &aad,
288            &mut data,
289            ciphertext_and_tag_offset,
290        )?;
291
292        // data := [plaintext]
293
294        Ok(data)
295    }
296}
297
298impl EncryptKey {
299    // aad := additional authenticated data (e.g. protocol transcripts)
300    // data := [version] || [key_id] || [plaintext]
301    // plaintext_offset := starting index of `[plaintext]` in `data`
302    fn encrypt_in_place(
303        mut self,
304        aad: &[u8],
305        data: &mut Vec<u8>,
306        plaintext_offset: usize,
307    ) {
308        assert!(plaintext_offset <= data.len());
309
310        let aad = aead::Aad::from(aad);
311        let tag = self
312            .0
313            .seal_in_place_separate_tag(aad, &mut data[plaintext_offset..])
314            .expect(
315                "Cannot encrypt more than ~4 GiB at once (should never happen)",
316            );
317        data.extend_from_slice(tag.as_ref());
318    }
319}
320
321impl DecryptKey {
322    // aad := additional authenticated data (e.g. protocol transcripts)
323    // data := [version] || [key_id] || [ciphertext] || [tag]
324    // ciphertext_and_tag_offset := starting index of `[ciphertext] || [tag]`
325    fn decrypt_in_place(
326        mut self,
327        aad: &[u8],
328        data: &mut Vec<u8>,
329        ciphertext_and_tag_offset: usize,
330    ) -> Result<(), DecryptError> {
331        // `open_within` will shift the decrypted plaintext to the start of
332        // `data`.
333        let aad = aead::Aad::from(aad);
334
335        let plaintext_ref = self
336            .0
337            .open_within(aad, data, ciphertext_and_tag_offset..)
338            .map_err(|_| "AEAD open failed: ciphertext, tag, or AAD corrupted")
339            .map_err(|msg| DecryptError(Cow::Borrowed(msg)))?;
340        let plaintext_len = plaintext_ref.len();
341
342        // decrypting happens in-place. set the length of the now decrypted
343        // plaintext blob.
344        data.truncate(plaintext_len);
345
346        Ok(())
347    }
348}
349
350impl KeyId {
351    #[inline]
352    const fn from_ref(arr: &[u8; 32]) -> &Self {
353        lexe_std::const_utils::const_ref_cast(arr)
354    }
355
356    #[inline]
357    fn as_slice(&self) -> &[u8] {
358        self.0.as_slice()
359    }
360
361    fn from_rng<R: Crng>(rng: &mut R) -> Self {
362        Self(rng.gen_bytes())
363    }
364}
365
366impl Serialize for KeyId {
367    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
368    where
369        S: Serializer,
370    {
371        self.0.serialize(serializer)
372    }
373}
374
375impl Aad<'_, '_> {
376    fn serialize(&self) -> Vec<u8> {
377        let len = bcs::serialized_size(self)
378            .expect("Serializing the AAD should never fail");
379
380        let mut out = Vec::with_capacity(len);
381        bcs::serialize_into(&mut out, self)
382            .expect("Serializing the AAD should never fail");
383        out
384    }
385}
386
387impl Serialize for Aad<'_, '_> {
388    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
389    where
390        S: Serializer,
391    {
392        let mut fields = serializer.serialize_struct("Aad", 3)?;
393        fields.serialize_field("version", &self.version)?;
394        fields.serialize_field("key_id", self.key_id)?;
395        fields.serialize_field("aad", self.aad)?;
396        fields.end()
397    }
398}
399
400impl ZeroNonce {
401    fn new() -> Self {
402        Self(Some(aead::Nonce::assume_unique_for_key([0u8; 12])))
403    }
404}
405
406impl aead::NonceSequence for ZeroNonce {
407    fn advance(&mut self) -> Result<aead::Nonce, ring::error::Unspecified> {
408        Ok(self.0.take().expect(
409            "We somehow encrypted / decrypted more than once with the same key",
410        ))
411    }
412}
413
414// See `lexe_common::root_seed::RootSeed::derive_vfs_master_key`
415#[cfg(any(test, feature = "test-utils"))]
416pub(crate) fn derive_key(rng: &mut crate::rng::FastRng) -> AesMasterKey {
417    struct OkmLength;
418    impl hkdf::KeyType for OkmLength {
419        fn len(&self) -> usize {
420            32
421        }
422    }
423
424    const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::RootSeed");
425    let seed: [u8; 32] = rng.gen_bytes();
426    let mut key_seed = [0u8; 32];
427    hkdf::Salt::new(hkdf::HKDF_SHA256, HKDF_SALT.as_slice())
428        .extract(&seed)
429        .expand(&[b"vfs master key"], OkmLength)
430        .unwrap()
431        .fill(key_seed.as_mut_slice())
432        .unwrap();
433    AesMasterKey::new(&key_seed)
434}
435
436#[cfg(any(test, feature = "test-utils"))]
437mod arbitrary_impl {
438    use proptest::{
439        arbitrary::{Arbitrary, any},
440        strategy::{BoxedStrategy, Strategy},
441    };
442
443    use super::*;
444    use crate::rng::FastRng;
445
446    impl Arbitrary for AesMasterKey {
447        type Parameters = ();
448        type Strategy = BoxedStrategy<Self>;
449        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
450            any::<FastRng>()
451                .prop_map(|mut rng| derive_key(&mut rng))
452                .boxed()
453        }
454    }
455}
456
457#[cfg(test)]
458mod test {
459    use lexe_hex::hex;
460    use proptest::{
461        arbitrary::any, collection::vec, prop_assert, prop_assert_eq, proptest,
462    };
463
464    use super::*;
465    use crate::rng::FastRng;
466
467    #[test]
468    fn test_aad_compat() {
469        let aad = Aad {
470            version: 0,
471            key_id: KeyId::from_ref(&[0x69; 32]),
472            aad: &[],
473        }
474        .serialize();
475
476        let expected_aad = hex::decode(
477            "00\
478             6969696969696969696969696969696969696969696969696969696969696969\
479             00",
480        )
481        .unwrap();
482
483        assert_eq!(&aad, &expected_aad);
484
485        let aad = Aad {
486            version: 0,
487            key_id: KeyId::from_ref(&[0x42; 32]),
488            aad: &[b"aaaaaaaa".as_slice(), b"0123456789".as_slice()],
489        }
490        .serialize();
491
492        let expected_aad = hex::decode(
493            "00\
494             4242424242424242424242424242424242424242424242424242424242424242\
495             02\
496                08\
497                    6161616161616161\
498                0a\
499                    30313233343536373839",
500        )
501        .unwrap();
502        assert_eq!(&aad, &expected_aad);
503    }
504
505    #[test]
506    fn test_decrypt_compat() {
507        let mut rng = FastRng::from_u64(123);
508        let vfs_key = derive_key(&mut rng);
509
510        // aad = [], plaintext = ""
511
512        // uncomment to regen
513        // let encrypted = vfs_key.encrypt(&mut rng, &[], None, &|_| ());
514        // println!("encrypted: {}", hex::display(&encrypted));
515
516        let encrypted = hex::decode(
517            // [version] || [key_id] || [ciphertext] || [tag]
518            "00\
519             b0abd2beab31c1d925c5d8059cf90068eece2c41a3a6e4454d84e36ad6858a01\
520             \
521             0e2d1f6d16e9bb5738de28b4f180f07f",
522        )
523        .unwrap();
524
525        let decrypted = vfs_key.decrypt(&[], encrypted).unwrap();
526        assert_eq!(decrypted.as_slice(), b"");
527
528        // aad = ["my context"], plaintext = "my cool message"
529
530        let aad = b"my context".as_slice();
531        let plaintext = b"my cool message".as_slice();
532
533        // // uncomment to regen
534        // #[rustfmt::skip]
535        // let encrypted = vfs_key
536        //     .encrypt(&mut rng, &[aad], None, &|out| out.put(plaintext));
537        // println!("encrypted: {}", hex::display(&encrypted));
538
539        let encrypted = hex::decode(
540            // [version] || [key_id] || [ciphertext] || [tag]
541            "00\
542             c87fea5c4db8c16d3dae5a6ead5ee5985fa7c38721b9624e37772adea6a48aae\
543             22f52c6f08440092338d16e3402eaf\
544             c3972d357e56dad4cc42c6a80da4ac35",
545        )
546        .unwrap();
547
548        let decrypted = vfs_key.decrypt(&[aad], encrypted).unwrap();
549
550        assert_eq!(decrypted.as_slice(), plaintext);
551    }
552
553    #[test]
554    fn test_encrypt_decrypt_roundtrip() {
555        proptest!(|(
556            mut rng in any::<FastRng>(),
557            aad in vec(vec(any::<u8>(), 0..=16), 0..=4),
558            plaintext in vec(any::<u8>(), 0..=256),
559        )| {
560            let vfs_key = derive_key(&mut rng);
561
562            let aad_ref = aad
563                .iter()
564                .map(|x| x.as_slice())
565                .collect::<Vec<_>>();
566
567            let encrypted = vfs_key.encrypt(&mut rng, &aad_ref, Some(plaintext.len()), &|out: &mut Vec<u8>| {
568                out.extend_from_slice(&plaintext);
569            });
570
571            let decrypted = vfs_key.decrypt(&aad_ref, encrypted.clone()).unwrap();
572            prop_assert_eq!(&plaintext, &decrypted);
573
574            let encrypted2 = vfs_key.encrypt(&mut rng, &aad_ref, None, &|out: &mut Vec<u8>| {
575                out.extend_from_slice(&plaintext);
576            });
577
578            prop_assert!(encrypted != encrypted2);
579        });
580    }
581}