Skip to main content

lexe_crypto/
aes.rs

1//! Securely encrypt and decrypted blobs, usually for remote storage.
2//!
3//! ## Design Considerations
4//!
5//! * AES-256-GCM uses 12-byte nonces (2^96 bits).
6//! * For a given key, any nonce reuse is catastrophic w/ AES-GCM.
7//! * Synthetic nonce / nonce reuse resistant schemes like AES-SIV-GCM aren't
8//!   available in [`ring`] or have undesirable properties (multiple passes, max
9//!   2^32 encryptions).
10//! * [`ring`] doesn't support XChaCha20-Poly1305, which would let us use a
11//!   larger nonce.
12//! * We need to use [`ring`] b/c TLS. We don't want to depend on other crypto
13//!   libraries b/c attack surface and binary bloat.
14//! * For our particular use case, we don't particularly care about
15//!   single-message (key, nonce) wear-out, since our messages aren't
16//!   particularly large (at most a few MiB).
17//! * We also don't have access to a consistent monotonic counter b/c
18//!   distributed system and adversarial host, so we use random nonces.
19//! * If we had a reliable counter, we could safely encrypt ~2^64 messages
20//!   before key wear-out, which would be sufficient for us to simplify and just
21//!   use one key for all encryptions.
22//! * For a given key and perfectly random nonces, with nonce collision
23//!   probability = 2^-32 (standard NIST bound), we can expect key wear-out
24//!   after 2^32 encryptions.
25//!
26//! ## Design
27//!
28//! This scheme is inspired by "Derive Key Mode" described in
29//! [(2017) GueronLindel](https://eprint.iacr.org/2017/702.pdf).
30//! "Derive Key Mode" uses a long-term "master key" (see `AesMasterKey`), which
31//! isn't used to encrypt data; rather, it's used to derive per-message keys
32//! from a large random key-id, sampled per message (see `KeyId`).
33//!
34//! In our case, we use a 32-byte (2^256 bit) key id to derive each per-message
35//! `EncryptKey`/`DecryptKey`, which gives us plenty of breathing room as far as
36//! safety bounds are concerned.
37//!
38//! For the AAD, taking a single `&[u8]` would require the caller to allocate
39//! and canonically serialize (length-prefixes, etc...) when there are multiple
40//! things to bind. Then, we would need to copy+allocate again in order to bind
41//! the `version`, `key-id`, and user AAD. To avoid this the user passes the AAD
42//! as a list of segments (like fields of a struct). For more info, see `Aad`.
43//!
44//! We use an AES-256-GCM nonce of all zeroes, since keys are single-use and
45//! 256 bits of security are Good Enough^tm.
46//!
47//! The scheme in simplified pseudo-code, encryption only:
48//!
49//! ```text
50//! master-key := (secret derived from user's root seed)
51//!
52//! Aad(version, key-id, user-aad: &[&[u8]]) :=
53//! 1. return bcs::to_bytes({ version, key-id, user-aad })
54//!
55//! Encrypt(master-key, user-aad: &[&[u8]], plaintext) :=
56//! 1. version := 0_u8
57//! 2. key-id := random 32-byte value
58//! 3. aad := Aad(version, key-id, user-aad)
59//! 4. encrypt-key := HKDF-Extract-Expand(
60//!         ikm=master-key,
61//!         salt=array::pad::<32>("LEXE-REALM::AesMasterKey"),
62//!         info=key-id,
63//!         out-len=32 bytes,
64//!    )
65//! 5. (ciphertext, tag) := AES-256-GCM(encrypt-key, nonce=[0; 12], aad, plaintext)
66//! 6. output := version || key-id || ciphertext || tag
67//! 7. return output
68//! ```
69//!
70//! ## References
71//!
72//! * [(2017) GueronLindel](https://eprint.iacr.org/2017/702.pdf) ([video](https://www.youtube.com/watch?v=WEJ451rmhk4))
73//!
74//! This paper, "Better Bounds for Block Cipher Modes of Operation via
75//! Nonce-Based Key Derivation", shows how "Derive Key Mode" significantly
76//! improves the security bounds over the standard long-lived key approach.
77//!
78//! * [(2020) Cryptographic Wear-out for Symmetric Encryption](https://soatok.blog/2020/12/24/cryptographic-wear-out-for-symmetric-encryption/)
79//!
80//! This article describes symmetric security bounds nicely.
81
82use std::fmt;
83
84use lexe_std::array;
85use ref_cast::RefCast;
86use ring::{
87    aead::{self, BoundKey},
88    hkdf,
89};
90use serde_core::ser::{Serialize, SerializeStruct, Serializer};
91
92use crate::rng::{Crng, RngExt};
93
94/// serialized version length
95const VERSION_LEN: usize = 1;
96
97/// serialized [`KeyId`] length
98const KEY_ID_LEN: usize = 32;
99
100/// serialized AES-256-GCM tag length
101const TAG_LEN: usize = 16;
102
103/// The length of the final encrypted ciphertext + version byte + key_id + tag
104/// given an input plaintext length.
105pub const fn encrypted_len(plaintext_len: usize) -> usize {
106    VERSION_LEN + KEY_ID_LEN + plaintext_len + TAG_LEN
107}
108
109/// The `AesMasterKey` is used to derive unique single-use encrypt keys for
110/// encrypting or decrypting a blob.
111///
112/// `RootSeed` -- derive("vfs master key") --> `AesMasterKey`
113// We store the salted+extracted PRK directly to avoid recomputing it every
114// time we encrypt something.
115pub struct AesMasterKey(hkdf::Prk);
116
117/// `KeyId` is the value used to derive the single-use message
118/// encryption/decryption key from the [`AesMasterKey`] HKDF.
119///
120/// As explained in the module docs, AES-GCM nonces are too small (12-bytes), so
121/// we use what is effectively a synthetic nonce scheme by deriving single-use
122/// keys from a larger pool of entropy (2^32 bits) for each separate encryption.
123#[derive(RefCast)]
124#[repr(transparent)]
125struct KeyId([u8; 32]);
126
127/// `Aad` is canonically serialized and then passed to AES-256-GCM as the `aad`
128/// (additional authenticated data) parameter.
129///
130/// It serves to:
131///
132/// 1. bind the protocol version
133/// 2. bind the encryption key (via the key id)
134/// 3. bind the user-provided additional authenticated data segments, including
135///    the number of segments, and the lengths of each segment.
136struct Aad<'data, 'aad> {
137    version: u8,
138    key_id: &'data KeyId,
139    aad: &'aad [&'aad [u8]],
140}
141
142struct EncryptKey(aead::SealingKey<ZeroNonce>);
143
144struct DecryptKey(aead::OpeningKey<ZeroNonce>);
145
146/// A single-use, all-zero nonce that panics if used to encrypt or decrypt data
147/// more than once (for a particular instance).
148struct ZeroNonce(Option<aead::Nonce>);
149
150#[derive(Clone, Debug)]
151pub struct DecryptError;
152
153impl std::error::Error for DecryptError {}
154
155impl fmt::Display for DecryptError {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        f.write_str("decrypt error: ciphertext or metadata may be corrupted")
158    }
159}
160
161impl fmt::Debug for AesMasterKey {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.write_str("AesMasterKey(..)")
164    }
165}
166
167impl AesMasterKey {
168    const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::AesMasterKey");
169
170    pub fn new(root_seed_derived_secret: &[u8; 32]) -> Self {
171        Self(
172            hkdf::Salt::new(hkdf::HKDF_SHA256, &Self::HKDF_SALT)
173                .extract(root_seed_derived_secret),
174        )
175    }
176
177    fn derive_unbound_key(&self, key_id: &KeyId) -> aead::UnboundKey {
178        aead::UnboundKey::from(
179            self.0
180                .expand(&[key_id.as_slice()], &aead::AES_256_GCM)
181                .expect("This should never fail"),
182        )
183    }
184
185    fn derive_encrypt_key(&self, key_id: &KeyId) -> EncryptKey {
186        let nonce = ZeroNonce::new();
187        let key = aead::SealingKey::new(self.derive_unbound_key(key_id), nonce);
188        EncryptKey(key)
189    }
190
191    fn derive_decrypt_key(&self, key_id: &KeyId) -> DecryptKey {
192        let nonce = ZeroNonce::new();
193        let key = aead::OpeningKey::new(self.derive_unbound_key(key_id), nonce);
194        DecryptKey(key)
195    }
196
197    pub fn encrypt<R: Crng>(
198        &self,
199        rng: &mut R,
200        aad: &[&[u8]],
201        // A size hint so we can possibly avoid reallocing. If you don't know
202        // how long the plaintext will be, just set this to None.
203        data_size_hint: Option<usize>,
204        // This closure should write the object to the provided &mut Vec<u8>.
205        // See tests as well as node / lsp `encrypt_*` for examples.
206        write_data_cb: &dyn Fn(&mut Vec<u8>),
207    ) -> Vec<u8> {
208        let version = 0;
209        let key_id = KeyId::from_rng(rng);
210
211        let aad = Aad {
212            version,
213            key_id: &key_id,
214            aad,
215        }
216        .serialize();
217
218        // reserve enough capacity for at least version, key_id, and tag
219        let approx_encrypted_len = encrypted_len(data_size_hint.unwrap_or(0));
220        let mut data = Vec::with_capacity(approx_encrypted_len);
221
222        // data := ""
223
224        data.push(version);
225        data.extend_from_slice(key_id.as_slice());
226        let plaintext_offset = data.len();
227
228        // data := [version] || [key_id]
229
230        write_data_cb(&mut data);
231
232        // data := [version] || [key_id] || [plaintext]
233
234        self.derive_encrypt_key(&key_id).encrypt_in_place(
235            aad.as_slice(),
236            &mut data,
237            plaintext_offset,
238        );
239
240        // data := [version] || [key_id] || [ciphertext] || [tag]
241
242        data
243    }
244
245    pub fn decrypt(
246        &self,
247        aad: &[&[u8]],
248        mut data: Vec<u8>,
249    ) -> Result<Vec<u8>, DecryptError> {
250        // data := [version] || [key_id] || [ciphertext] || [tag]
251
252        const MIN_DATA_LEN: usize = encrypted_len(0 /* plaintext len */);
253        if data.len() < MIN_DATA_LEN {
254            return Err(DecryptError);
255        }
256
257        // parse out version and key_id w/o advancing `data`
258        let (version, key_id) = {
259            let (version, data) = data
260                .split_first_chunk::<VERSION_LEN>()
261                .expect("data.len() checked above");
262            let (key_id, _) = data
263                .split_first_chunk::<KEY_ID_LEN>()
264                .expect("data.len() checked above");
265            (version[0], key_id)
266        };
267
268        if version != 0 {
269            return Err(DecryptError);
270        }
271        let key_id = KeyId::from_ref(key_id);
272        let decrypt_key = self.derive_decrypt_key(key_id);
273
274        let aad = Aad {
275            version,
276            key_id,
277            aad,
278        }
279        .serialize();
280
281        let ciphertext_and_tag_offset = VERSION_LEN + KEY_ID_LEN;
282        decrypt_key.decrypt_in_place(
283            &aad,
284            &mut data,
285            ciphertext_and_tag_offset,
286        )?;
287
288        // data := [plaintext]
289
290        Ok(data)
291    }
292}
293
294impl EncryptKey {
295    // aad := additional authenticated data (e.g. protocol transcripts)
296    // data := [version] || [key_id] || [plaintext]
297    // plaintext_offset := starting index of `[plaintext]` in `data`
298    fn encrypt_in_place(
299        mut self,
300        aad: &[u8],
301        data: &mut Vec<u8>,
302        plaintext_offset: usize,
303    ) {
304        assert!(plaintext_offset <= data.len());
305
306        let aad = aead::Aad::from(aad);
307        let tag = self
308            .0
309            .seal_in_place_separate_tag(aad, &mut data[plaintext_offset..])
310            .expect(
311                "Cannot encrypt more than ~4 GiB at once (should never happen)",
312            );
313        data.extend_from_slice(tag.as_ref());
314    }
315}
316
317impl DecryptKey {
318    // aad := additional authenticated data (e.g. protocol transcripts)
319    // data := [version] || [key_id] || [ciphertext] || [tag]
320    // ciphertext_and_tag_offset := starting index of `[ciphertext] || [tag]`
321    fn decrypt_in_place(
322        mut self,
323        aad: &[u8],
324        data: &mut Vec<u8>,
325        ciphertext_and_tag_offset: usize,
326    ) -> Result<(), DecryptError> {
327        // `open_within` will shift the decrypted plaintext to the start of
328        // `data`.
329        let aad = aead::Aad::from(aad);
330
331        let plaintext_ref = self
332            .0
333            .open_within(aad, data, ciphertext_and_tag_offset..)
334            .map_err(|_| DecryptError)?;
335        let plaintext_len = plaintext_ref.len();
336
337        // decrypting happens in-place. set the length of the now decrypted
338        // plaintext blob.
339        data.truncate(plaintext_len);
340
341        Ok(())
342    }
343}
344
345impl KeyId {
346    #[inline]
347    const fn from_ref(arr: &[u8; 32]) -> &Self {
348        lexe_std::const_utils::const_ref_cast(arr)
349    }
350
351    #[inline]
352    fn as_slice(&self) -> &[u8] {
353        self.0.as_slice()
354    }
355
356    fn from_rng<R: Crng>(rng: &mut R) -> Self {
357        Self(rng.gen_bytes())
358    }
359}
360
361impl Serialize for KeyId {
362    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
363    where
364        S: Serializer,
365    {
366        self.0.serialize(serializer)
367    }
368}
369
370impl Aad<'_, '_> {
371    fn serialize(&self) -> Vec<u8> {
372        let len = bcs::serialized_size(self)
373            .expect("Serializing the AAD should never fail");
374
375        let mut out = Vec::with_capacity(len);
376        bcs::serialize_into(&mut out, self)
377            .expect("Serializing the AAD should never fail");
378        out
379    }
380}
381
382impl Serialize for Aad<'_, '_> {
383    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
384    where
385        S: Serializer,
386    {
387        let mut fields = serializer.serialize_struct("Aad", 3)?;
388        fields.serialize_field("version", &self.version)?;
389        fields.serialize_field("key_id", self.key_id)?;
390        fields.serialize_field("aad", self.aad)?;
391        fields.end()
392    }
393}
394
395impl ZeroNonce {
396    fn new() -> Self {
397        Self(Some(aead::Nonce::assume_unique_for_key([0u8; 12])))
398    }
399}
400
401impl aead::NonceSequence for ZeroNonce {
402    fn advance(&mut self) -> Result<aead::Nonce, ring::error::Unspecified> {
403        Ok(self.0.take().expect(
404            "We somehow encrypted / decrypted more than once with the same key",
405        ))
406    }
407}
408
409// See `lexe_common::root_seed::RootSeed::derive_vfs_master_key`
410#[cfg(any(test, feature = "test-utils"))]
411pub(crate) fn derive_key(rng: &mut crate::rng::FastRng) -> AesMasterKey {
412    struct OkmLength;
413    impl hkdf::KeyType for OkmLength {
414        fn len(&self) -> usize {
415            32
416        }
417    }
418
419    const HKDF_SALT: [u8; 32] = array::pad(*b"LEXE-REALM::RootSeed");
420    let seed: [u8; 32] = rng.gen_bytes();
421    let mut key_seed = [0u8; 32];
422    hkdf::Salt::new(hkdf::HKDF_SHA256, HKDF_SALT.as_slice())
423        .extract(&seed)
424        .expand(&[b"vfs master key"], OkmLength)
425        .unwrap()
426        .fill(key_seed.as_mut_slice())
427        .unwrap();
428    AesMasterKey::new(&key_seed)
429}
430
431#[cfg(any(test, feature = "test-utils"))]
432mod arbitrary_impl {
433    use proptest::{
434        arbitrary::{Arbitrary, any},
435        strategy::{BoxedStrategy, Strategy},
436    };
437
438    use super::*;
439    use crate::rng::FastRng;
440
441    impl Arbitrary for AesMasterKey {
442        type Parameters = ();
443        type Strategy = BoxedStrategy<Self>;
444        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
445            any::<FastRng>()
446                .prop_map(|mut rng| derive_key(&mut rng))
447                .boxed()
448        }
449    }
450}
451
452#[cfg(test)]
453mod test {
454    use lexe_hex::hex;
455    use proptest::{
456        arbitrary::any, collection::vec, prop_assert, prop_assert_eq, proptest,
457    };
458
459    use super::*;
460    use crate::rng::FastRng;
461
462    #[test]
463    fn test_aad_compat() {
464        let aad = Aad {
465            version: 0,
466            key_id: KeyId::from_ref(&[0x69; 32]),
467            aad: &[],
468        }
469        .serialize();
470
471        let expected_aad = hex::decode(
472            "00\
473             6969696969696969696969696969696969696969696969696969696969696969\
474             00",
475        )
476        .unwrap();
477
478        assert_eq!(&aad, &expected_aad);
479
480        let aad = Aad {
481            version: 0,
482            key_id: KeyId::from_ref(&[0x42; 32]),
483            aad: &[b"aaaaaaaa".as_slice(), b"0123456789".as_slice()],
484        }
485        .serialize();
486
487        let expected_aad = hex::decode(
488            "00\
489             4242424242424242424242424242424242424242424242424242424242424242\
490             02\
491                08\
492                    6161616161616161\
493                0a\
494                    30313233343536373839",
495        )
496        .unwrap();
497        assert_eq!(&aad, &expected_aad);
498    }
499
500    #[test]
501    fn test_decrypt_compat() {
502        let mut rng = FastRng::from_u64(123);
503        let vfs_key = derive_key(&mut rng);
504
505        // aad = [], plaintext = ""
506
507        // uncomment to regen
508        // let encrypted = vfs_key.encrypt(&mut rng, &[], None, &|_| ());
509        // println!("encrypted: {}", hex::display(&encrypted));
510
511        let encrypted = hex::decode(
512            // [version] || [key_id] || [ciphertext] || [tag]
513            "00\
514             b0abd2beab31c1d925c5d8059cf90068eece2c41a3a6e4454d84e36ad6858a01\
515             \
516             0e2d1f6d16e9bb5738de28b4f180f07f",
517        )
518        .unwrap();
519
520        let decrypted = vfs_key.decrypt(&[], encrypted).unwrap();
521        assert_eq!(decrypted.as_slice(), b"");
522
523        // aad = ["my context"], plaintext = "my cool message"
524
525        let aad = b"my context".as_slice();
526        let plaintext = b"my cool message".as_slice();
527
528        // // uncomment to regen
529        // #[rustfmt::skip]
530        // let encrypted = vfs_key
531        //     .encrypt(&mut rng, &[aad], None, &|out| out.put(plaintext));
532        // println!("encrypted: {}", hex::display(&encrypted));
533
534        let encrypted = hex::decode(
535            // [version] || [key_id] || [ciphertext] || [tag]
536            "00\
537             c87fea5c4db8c16d3dae5a6ead5ee5985fa7c38721b9624e37772adea6a48aae\
538             22f52c6f08440092338d16e3402eaf\
539             c3972d357e56dad4cc42c6a80da4ac35",
540        )
541        .unwrap();
542
543        let decrypted = vfs_key.decrypt(&[aad], encrypted).unwrap();
544
545        assert_eq!(decrypted.as_slice(), plaintext);
546    }
547
548    #[test]
549    fn test_encrypt_decrypt_roundtrip() {
550        proptest!(|(
551            mut rng in any::<FastRng>(),
552            aad in vec(vec(any::<u8>(), 0..=16), 0..=4),
553            plaintext in vec(any::<u8>(), 0..=256),
554        )| {
555            let vfs_key = derive_key(&mut rng);
556
557            let aad_ref = aad
558                .iter()
559                .map(|x| x.as_slice())
560                .collect::<Vec<_>>();
561
562            let encrypted = vfs_key.encrypt(&mut rng, &aad_ref, Some(plaintext.len()), &|out: &mut Vec<u8>| {
563                out.extend_from_slice(&plaintext);
564            });
565
566            let decrypted = vfs_key.decrypt(&aad_ref, encrypted.clone()).unwrap();
567            prop_assert_eq!(&plaintext, &decrypted);
568
569            let encrypted2 = vfs_key.encrypt(&mut rng, &aad_ref, None, &|out: &mut Vec<u8>| {
570                out.extend_from_slice(&plaintext);
571            });
572
573            prop_assert!(encrypted != encrypted2);
574        });
575    }
576}