envelopers/
simple_key_provider.rs

1//! Trait for a KeyProvider
2
3use std::marker::PhantomData;
4use std::sync::Mutex;
5
6use aes_gcm::aead::{Aead, Payload};
7use aes_gcm::aes::cipher::consts::U16;
8use aes_gcm::aes::Aes128;
9use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, Key, KeyInit, KeySizeUser};
10use aes_gcm_siv::{Aes128GcmSiv, Aes256GcmSiv};
11use async_trait::async_trait;
12use rand_chacha::ChaChaRng;
13
14use crate::errors::{KeyDecryptionError, KeyGenerationError};
15use crate::key_provider::{DataKey, KeyProvider};
16use crate::safe_rng::SafeRng;
17
18// EncryptedSimpleKey relies on this size being constant. If this ever needs to be changed a new
19// version of EncryptedSimpleKey needs to be created.
20type Nonce = aes_gcm::Nonce<U16>;
21const NONCE_SIZE: usize = 16;
22
23/// A decoded intermediate representation of an encrypted simple key
24///
25/// The encoded version of the encrypted simple key looks like so:
26///
27/// | Pos  | Data                   |
28/// | -----|------------------------|
29/// | 0    | Version tag (1 byte)   |
30/// | 1-17 | Nonce       (16 bytes) |
31/// | 17-  | Encrypted key          |
32#[derive(Debug)]
33struct EncryptedSimpleKey<'a> {
34    // Keep a version tag on the key just in case the format gets changed
35    // This could have other header information - but it should be safe since we're only expected
36    // keys created by the SimpleKeyProvider.
37    version: u8,
38    nonce: &'a Nonce,
39    key: &'a [u8],
40}
41
42impl<'a> EncryptedSimpleKey<'a> {
43    /// Decode an [`EncryptedSimpleKey`] from a slice following its encoded representation
44    fn from_slice(bytes: &'a [u8]) -> Result<Self, KeyDecryptionError> {
45        if bytes.len() < 1 + NONCE_SIZE {
46            return Err(KeyDecryptionError::Other(format!(
47                "Slice was too small to load an EncryptedSimpleKey. Received: {}",
48                bytes.len()
49            )));
50        }
51
52        let version = bytes[0];
53
54        let nonce: &'a Nonce = Nonce::from_slice(&bytes[1..1 + NONCE_SIZE]);
55        let key: &'a [u8] = &bytes[1 + NONCE_SIZE..];
56
57        Ok(Self {
58            version,
59            nonce,
60            key,
61        })
62    }
63
64    /// Encode an [`EncryptedSimpleKey`] as bytes
65    fn to_vec(&self) -> Vec<u8> {
66        let mut output = Vec::with_capacity(1 + self.nonce.len() + self.key.len());
67        output.push(self.version);
68        output.extend_from_slice(self.nonce);
69        output.extend_from_slice(self.key);
70        output
71    }
72}
73
74pub struct SimpleKeyProvider<S: KeySizeUser = Aes128Gcm, R: SafeRng = ChaChaRng> {
75    cipher: AesGcm<Aes128, U16>,
76    rng: Mutex<R>,
77    phantom_data: PhantomData<S>,
78}
79
80impl<S: KeySizeUser, R: SafeRng> SimpleKeyProvider<S, R> {
81    pub fn init(kek: [u8; 16]) -> Self {
82        let key: &Key<Aes128> = Key::<Aes128>::from_slice(&kek);
83
84        Self {
85            cipher: AesGcm::<Aes128, U16>::new(key),
86            rng: Mutex::new(R::from_entropy()),
87            phantom_data: PhantomData,
88        }
89    }
90}
91
92macro_rules! define_simple_key_provider_impl {
93    ($name:ty) => {
94        #[async_trait]
95        impl<R: SafeRng> KeyProvider for SimpleKeyProvider<$name, R> {
96            type Cipher = $name;
97
98            async fn decrypt_data_key(
99                &self,
100                encrypted_key: &[u8],
101                aad: Option<&str>,
102            ) -> Result<Key<$name>, KeyDecryptionError> {
103                let decoded_key = EncryptedSimpleKey::from_slice(encrypted_key)?;
104
105                let aad = match aad {
106                    Some(a) => [&[decoded_key.version], a.as_bytes()].concat(),
107                    None => vec![decoded_key.version],
108                };
109                let data_key = self.cipher.decrypt(
110                    decoded_key.nonce,
111                    Payload {
112                        msg: decoded_key.key,
113                        aad: &aad,
114                    },
115                )?;
116
117                return Ok(*Key::<$name>::from_slice(&data_key));
118            }
119
120            async fn generate_data_key(
121                &self,
122                _bytes: usize,
123                aad: Option<&str>,
124            ) -> Result<DataKey<$name>, KeyGenerationError> {
125                let version = 1;
126
127                let (data_key, nonce) = {
128                    let mut data_key: Key<$name> = Default::default();
129                    let mut nonce: Nonce = Default::default();
130                    let mut rng = self.rng.lock().unwrap_or_else(|e| e.into_inner());
131                    rng.try_fill_bytes(&mut data_key)?;
132                    rng.try_fill_bytes(&mut nonce)?;
133
134                    (data_key, nonce)
135                };
136
137                let aad = match aad {
138                    Some(a) => [&[version], a.as_bytes()].concat(),
139                    None => vec![version],
140                };
141
142                let payload = Payload {
143                    msg: &data_key,
144                    aad: &aad,
145                };
146
147                let ciphertext = self.cipher.encrypt(&nonce, payload)?;
148
149                let encrypted_key = EncryptedSimpleKey {
150                    version,
151                    key: &ciphertext,
152                    nonce: &nonce,
153                };
154
155                return Ok(DataKey {
156                    key: data_key,
157                    encrypted_key: encrypted_key.to_vec(),
158                    key_id: String::from("simplekey"),
159                });
160            }
161        }
162    };
163}
164
165define_simple_key_provider_impl!(Aes128Gcm);
166define_simple_key_provider_impl!(Aes256Gcm);
167define_simple_key_provider_impl!(Aes128GcmSiv);
168define_simple_key_provider_impl!(Aes256GcmSiv);
169
170#[cfg(test)]
171mod tests {
172    use aes_gcm::{Aes128Gcm, Aes256Gcm, KeySizeUser};
173    use aes_gcm_siv::{Aes128GcmSiv, Aes256GcmSiv};
174
175    use super::{EncryptedSimpleKey, Nonce};
176    use crate::{KeyProvider, SimpleKeyProvider};
177
178    fn create_provider<S: KeySizeUser>() -> SimpleKeyProvider<S>
179    where
180        SimpleKeyProvider<S>: KeyProvider<Cipher = S>,
181    {
182        SimpleKeyProvider::init([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
183    }
184
185    async fn test_generate_decrypt_data_key<S: KeySizeUser, K: KeyProvider<Cipher = S>>(
186        provider: K,
187    ) {
188        let data_key = provider.generate_data_key(0, None).await.unwrap();
189        let decrypted_data_key = provider
190            .decrypt_data_key(&data_key.encrypted_key, None)
191            .await
192            .unwrap();
193
194        assert_eq!(data_key.key, decrypted_data_key);
195
196        // with aad
197        let data_key = provider.generate_data_key(0, Some("abcde")).await.unwrap();
198        let decrypted_data_key = provider
199            .decrypt_data_key(&data_key.encrypted_key, Some("abcde"))
200            .await
201            .unwrap();
202
203        assert_eq!(data_key.key, decrypted_data_key);
204    }
205
206    #[tokio::test]
207    async fn test_generate_decrypt_data_key_128_gcm() {
208        let provider: SimpleKeyProvider<Aes128Gcm> = create_provider();
209        test_generate_decrypt_data_key(provider).await;
210
211        let provider: SimpleKeyProvider<Aes128Gcm> = create_provider();
212        let provider: Box<dyn KeyProvider<Cipher = Aes128Gcm>> = Box::new(provider);
213        test_generate_decrypt_data_key(provider).await;
214    }
215
216    #[tokio::test]
217    async fn test_generate_decrypt_data_key_256_gcm() {
218        let provider: SimpleKeyProvider<Aes256Gcm> = create_provider();
219        test_generate_decrypt_data_key(provider).await;
220
221        let provider: SimpleKeyProvider<Aes256Gcm> = create_provider();
222        let provider: Box<dyn KeyProvider<Cipher = Aes256Gcm>> = Box::new(provider);
223        test_generate_decrypt_data_key(provider).await;
224    }
225
226    #[tokio::test]
227    async fn test_generate_decrypt_data_key_128_gcm_siv() {
228        let provider: SimpleKeyProvider<Aes128GcmSiv> = create_provider();
229        test_generate_decrypt_data_key(provider).await;
230
231        let provider: SimpleKeyProvider<Aes128GcmSiv> = create_provider();
232        let provider: Box<dyn KeyProvider<Cipher = Aes128GcmSiv>> = Box::new(provider);
233        test_generate_decrypt_data_key(provider).await;
234    }
235
236    #[tokio::test]
237    async fn test_generate_decrypt_data_key_256_gcm_siv() {
238        let provider: SimpleKeyProvider<Aes256GcmSiv> = create_provider();
239        test_generate_decrypt_data_key(provider).await;
240
241        let provider: SimpleKeyProvider<Aes256GcmSiv> = create_provider();
242        let provider: Box<dyn KeyProvider<Cipher = Aes256GcmSiv>> = Box::new(provider);
243        test_generate_decrypt_data_key(provider).await;
244    }
245
246    #[tokio::test]
247    async fn test_fails_on_invalid_data_key() {
248        let first: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
249        let second: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([1; 16]);
250
251        let data_key = first.generate_data_key(0, None).await.unwrap();
252
253        assert_eq!(
254            second
255                .decrypt_data_key(&data_key.encrypted_key, None)
256                .await
257                .map_err(|e| e.to_string())
258                .expect_err("Decrypting data key suceeded"),
259            "failed to decrypt key"
260        );
261    }
262
263    #[tokio::test]
264    async fn test_fails_on_invalid_nonce() {
265        let provider: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
266
267        let mut data_key = provider.generate_data_key(0, None).await.unwrap();
268
269        // Decrypts data key fine
270        assert!(provider
271            .decrypt_data_key(&data_key.encrypted_key, None)
272            .await
273            .is_ok());
274
275        // Replace the nonce with a nonsense one
276        data_key.encrypted_key[1..17]
277            .clone_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
278
279        assert_eq!(
280            provider
281                .decrypt_data_key(&data_key.encrypted_key, None)
282                .await
283                .map_err(|e| e.to_string())
284                .expect_err("Decrypting data key succeeded"),
285            "failed to decrypt key"
286        );
287    }
288
289    #[tokio::test]
290    async fn test_fails_on_invalid_version() {
291        let provider: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
292
293        let mut data_key = provider.generate_data_key(0, None).await.unwrap();
294
295        // Decrypts data key fine
296        assert!(provider
297            .decrypt_data_key(&data_key.encrypted_key, None)
298            .await
299            .is_ok());
300
301        // Replace key version with invalid one
302        data_key.encrypted_key[0] = 5;
303
304        assert_eq!(
305            provider
306                .decrypt_data_key(&data_key.encrypted_key, None)
307                .await
308                .map_err(|e| e.to_string())
309                .expect_err("Decrypting data key succeeded"),
310            "failed to decrypt key"
311        );
312    }
313
314    #[tokio::test]
315    async fn test_fails_on_invalid_aad() {
316        let provider: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
317
318        let data_key = provider.generate_data_key(0, Some("abcdef")).await.unwrap();
319
320        // Decrypts data key fine
321        assert!(provider
322            .decrypt_data_key(&data_key.encrypted_key, Some("abcdef"))
323            .await
324            .is_ok());
325
326        // Fails on invalid aad
327        assert_eq!(
328            provider
329                .decrypt_data_key(&data_key.encrypted_key, Some("ghijk"))
330                .await
331                .map_err(|e| e.to_string())
332                .expect_err("Decrypting data key succeeded"),
333            "failed to decrypt key"
334        );
335
336        // Fails on missing aad
337        assert_eq!(
338            provider
339                .decrypt_data_key(&data_key.encrypted_key, None)
340                .await
341                .map_err(|e| e.to_string())
342                .expect_err("Decrypting data key succeeded"),
343            "failed to decrypt key"
344        );
345    }
346
347    #[test]
348    fn test_load_encrypted_key_from_slice() {
349        let slice: Vec<u8> = vec![
350            1, // version
351            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // nonce
352            1, 2, 3, 4, 5, 6, // encrypted key (size is unknown)
353        ];
354
355        let key = EncryptedSimpleKey::from_slice(&slice).unwrap();
356
357        assert_eq!(key.version, 1);
358        assert_eq!(
359            key.nonce,
360            Nonce::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
361        );
362        assert_eq!(key.key, &[1, 2, 3, 4, 5, 6]);
363    }
364
365    #[test]
366    fn test_fails_on_tiny_slice() {
367        let slice: Vec<u8> = vec![5, 1, 2, 3, 4, 5, 6];
368
369        let err =
370            EncryptedSimpleKey::from_slice(&slice).expect_err("Encrypted key decode succeeded");
371
372        assert_eq!(
373            err.to_string(),
374            "Slice was too small to load an EncryptedSimpleKey. Received: 7"
375        );
376    }
377
378    #[test]
379    fn test_serialize_key() {
380        let version = 1;
381        let nonce = Nonce::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
382        let key = &[1, 2, 3, 4, 5, 6];
383
384        let encrypted_key = EncryptedSimpleKey {
385            version,
386            nonce,
387            key,
388        };
389
390        let bytes = encrypted_key.to_vec();
391
392        assert_eq!(
393            bytes,
394            vec![
395                1, // version
396                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, // nonce
397                1, 2, 3, 4, 5, 6, // encrypted key (size is unknown)
398            ]
399        );
400    }
401}