1use std::marker::PhantomData;
4use std::sync::Mutex;
5
6use aes_gcm::aead::{Aead, Payload};
7use aes_gcm::aes::cipher::consts::U16;
8use aes_gcm::aes::Aes128;
9use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, Key, KeyInit, KeySizeUser};
10use aes_gcm_siv::{Aes128GcmSiv, Aes256GcmSiv};
11use async_trait::async_trait;
12use rand_chacha::ChaChaRng;
13
14use crate::errors::{KeyDecryptionError, KeyGenerationError};
15use crate::key_provider::{DataKey, KeyProvider};
16use crate::safe_rng::SafeRng;
17
18type Nonce = aes_gcm::Nonce<U16>;
21const NONCE_SIZE: usize = 16;
22
23#[derive(Debug)]
33struct EncryptedSimpleKey<'a> {
34 version: u8,
38 nonce: &'a Nonce,
39 key: &'a [u8],
40}
41
42impl<'a> EncryptedSimpleKey<'a> {
43 fn from_slice(bytes: &'a [u8]) -> Result<Self, KeyDecryptionError> {
45 if bytes.len() < 1 + NONCE_SIZE {
46 return Err(KeyDecryptionError::Other(format!(
47 "Slice was too small to load an EncryptedSimpleKey. Received: {}",
48 bytes.len()
49 )));
50 }
51
52 let version = bytes[0];
53
54 let nonce: &'a Nonce = Nonce::from_slice(&bytes[1..1 + NONCE_SIZE]);
55 let key: &'a [u8] = &bytes[1 + NONCE_SIZE..];
56
57 Ok(Self {
58 version,
59 nonce,
60 key,
61 })
62 }
63
64 fn to_vec(&self) -> Vec<u8> {
66 let mut output = Vec::with_capacity(1 + self.nonce.len() + self.key.len());
67 output.push(self.version);
68 output.extend_from_slice(self.nonce);
69 output.extend_from_slice(self.key);
70 output
71 }
72}
73
74pub struct SimpleKeyProvider<S: KeySizeUser = Aes128Gcm, R: SafeRng = ChaChaRng> {
75 cipher: AesGcm<Aes128, U16>,
76 rng: Mutex<R>,
77 phantom_data: PhantomData<S>,
78}
79
80impl<S: KeySizeUser, R: SafeRng> SimpleKeyProvider<S, R> {
81 pub fn init(kek: [u8; 16]) -> Self {
82 let key: &Key<Aes128> = Key::<Aes128>::from_slice(&kek);
83
84 Self {
85 cipher: AesGcm::<Aes128, U16>::new(key),
86 rng: Mutex::new(R::from_entropy()),
87 phantom_data: PhantomData,
88 }
89 }
90}
91
92macro_rules! define_simple_key_provider_impl {
93 ($name:ty) => {
94 #[async_trait]
95 impl<R: SafeRng> KeyProvider for SimpleKeyProvider<$name, R> {
96 type Cipher = $name;
97
98 async fn decrypt_data_key(
99 &self,
100 encrypted_key: &[u8],
101 aad: Option<&str>,
102 ) -> Result<Key<$name>, KeyDecryptionError> {
103 let decoded_key = EncryptedSimpleKey::from_slice(encrypted_key)?;
104
105 let aad = match aad {
106 Some(a) => [&[decoded_key.version], a.as_bytes()].concat(),
107 None => vec![decoded_key.version],
108 };
109 let data_key = self.cipher.decrypt(
110 decoded_key.nonce,
111 Payload {
112 msg: decoded_key.key,
113 aad: &aad,
114 },
115 )?;
116
117 return Ok(*Key::<$name>::from_slice(&data_key));
118 }
119
120 async fn generate_data_key(
121 &self,
122 _bytes: usize,
123 aad: Option<&str>,
124 ) -> Result<DataKey<$name>, KeyGenerationError> {
125 let version = 1;
126
127 let (data_key, nonce) = {
128 let mut data_key: Key<$name> = Default::default();
129 let mut nonce: Nonce = Default::default();
130 let mut rng = self.rng.lock().unwrap_or_else(|e| e.into_inner());
131 rng.try_fill_bytes(&mut data_key)?;
132 rng.try_fill_bytes(&mut nonce)?;
133
134 (data_key, nonce)
135 };
136
137 let aad = match aad {
138 Some(a) => [&[version], a.as_bytes()].concat(),
139 None => vec![version],
140 };
141
142 let payload = Payload {
143 msg: &data_key,
144 aad: &aad,
145 };
146
147 let ciphertext = self.cipher.encrypt(&nonce, payload)?;
148
149 let encrypted_key = EncryptedSimpleKey {
150 version,
151 key: &ciphertext,
152 nonce: &nonce,
153 };
154
155 return Ok(DataKey {
156 key: data_key,
157 encrypted_key: encrypted_key.to_vec(),
158 key_id: String::from("simplekey"),
159 });
160 }
161 }
162 };
163}
164
165define_simple_key_provider_impl!(Aes128Gcm);
166define_simple_key_provider_impl!(Aes256Gcm);
167define_simple_key_provider_impl!(Aes128GcmSiv);
168define_simple_key_provider_impl!(Aes256GcmSiv);
169
170#[cfg(test)]
171mod tests {
172 use aes_gcm::{Aes128Gcm, Aes256Gcm, KeySizeUser};
173 use aes_gcm_siv::{Aes128GcmSiv, Aes256GcmSiv};
174
175 use super::{EncryptedSimpleKey, Nonce};
176 use crate::{KeyProvider, SimpleKeyProvider};
177
178 fn create_provider<S: KeySizeUser>() -> SimpleKeyProvider<S>
179 where
180 SimpleKeyProvider<S>: KeyProvider<Cipher = S>,
181 {
182 SimpleKeyProvider::init([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
183 }
184
185 async fn test_generate_decrypt_data_key<S: KeySizeUser, K: KeyProvider<Cipher = S>>(
186 provider: K,
187 ) {
188 let data_key = provider.generate_data_key(0, None).await.unwrap();
189 let decrypted_data_key = provider
190 .decrypt_data_key(&data_key.encrypted_key, None)
191 .await
192 .unwrap();
193
194 assert_eq!(data_key.key, decrypted_data_key);
195
196 let data_key = provider.generate_data_key(0, Some("abcde")).await.unwrap();
198 let decrypted_data_key = provider
199 .decrypt_data_key(&data_key.encrypted_key, Some("abcde"))
200 .await
201 .unwrap();
202
203 assert_eq!(data_key.key, decrypted_data_key);
204 }
205
206 #[tokio::test]
207 async fn test_generate_decrypt_data_key_128_gcm() {
208 let provider: SimpleKeyProvider<Aes128Gcm> = create_provider();
209 test_generate_decrypt_data_key(provider).await;
210
211 let provider: SimpleKeyProvider<Aes128Gcm> = create_provider();
212 let provider: Box<dyn KeyProvider<Cipher = Aes128Gcm>> = Box::new(provider);
213 test_generate_decrypt_data_key(provider).await;
214 }
215
216 #[tokio::test]
217 async fn test_generate_decrypt_data_key_256_gcm() {
218 let provider: SimpleKeyProvider<Aes256Gcm> = create_provider();
219 test_generate_decrypt_data_key(provider).await;
220
221 let provider: SimpleKeyProvider<Aes256Gcm> = create_provider();
222 let provider: Box<dyn KeyProvider<Cipher = Aes256Gcm>> = Box::new(provider);
223 test_generate_decrypt_data_key(provider).await;
224 }
225
226 #[tokio::test]
227 async fn test_generate_decrypt_data_key_128_gcm_siv() {
228 let provider: SimpleKeyProvider<Aes128GcmSiv> = create_provider();
229 test_generate_decrypt_data_key(provider).await;
230
231 let provider: SimpleKeyProvider<Aes128GcmSiv> = create_provider();
232 let provider: Box<dyn KeyProvider<Cipher = Aes128GcmSiv>> = Box::new(provider);
233 test_generate_decrypt_data_key(provider).await;
234 }
235
236 #[tokio::test]
237 async fn test_generate_decrypt_data_key_256_gcm_siv() {
238 let provider: SimpleKeyProvider<Aes256GcmSiv> = create_provider();
239 test_generate_decrypt_data_key(provider).await;
240
241 let provider: SimpleKeyProvider<Aes256GcmSiv> = create_provider();
242 let provider: Box<dyn KeyProvider<Cipher = Aes256GcmSiv>> = Box::new(provider);
243 test_generate_decrypt_data_key(provider).await;
244 }
245
246 #[tokio::test]
247 async fn test_fails_on_invalid_data_key() {
248 let first: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
249 let second: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([1; 16]);
250
251 let data_key = first.generate_data_key(0, None).await.unwrap();
252
253 assert_eq!(
254 second
255 .decrypt_data_key(&data_key.encrypted_key, None)
256 .await
257 .map_err(|e| e.to_string())
258 .expect_err("Decrypting data key suceeded"),
259 "failed to decrypt key"
260 );
261 }
262
263 #[tokio::test]
264 async fn test_fails_on_invalid_nonce() {
265 let provider: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
266
267 let mut data_key = provider.generate_data_key(0, None).await.unwrap();
268
269 assert!(provider
271 .decrypt_data_key(&data_key.encrypted_key, None)
272 .await
273 .is_ok());
274
275 data_key.encrypted_key[1..17]
277 .clone_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
278
279 assert_eq!(
280 provider
281 .decrypt_data_key(&data_key.encrypted_key, None)
282 .await
283 .map_err(|e| e.to_string())
284 .expect_err("Decrypting data key succeeded"),
285 "failed to decrypt key"
286 );
287 }
288
289 #[tokio::test]
290 async fn test_fails_on_invalid_version() {
291 let provider: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
292
293 let mut data_key = provider.generate_data_key(0, None).await.unwrap();
294
295 assert!(provider
297 .decrypt_data_key(&data_key.encrypted_key, None)
298 .await
299 .is_ok());
300
301 data_key.encrypted_key[0] = 5;
303
304 assert_eq!(
305 provider
306 .decrypt_data_key(&data_key.encrypted_key, None)
307 .await
308 .map_err(|e| e.to_string())
309 .expect_err("Decrypting data key succeeded"),
310 "failed to decrypt key"
311 );
312 }
313
314 #[tokio::test]
315 async fn test_fails_on_invalid_aad() {
316 let provider: SimpleKeyProvider<Aes128Gcm> = SimpleKeyProvider::init([0; 16]);
317
318 let data_key = provider.generate_data_key(0, Some("abcdef")).await.unwrap();
319
320 assert!(provider
322 .decrypt_data_key(&data_key.encrypted_key, Some("abcdef"))
323 .await
324 .is_ok());
325
326 assert_eq!(
328 provider
329 .decrypt_data_key(&data_key.encrypted_key, Some("ghijk"))
330 .await
331 .map_err(|e| e.to_string())
332 .expect_err("Decrypting data key succeeded"),
333 "failed to decrypt key"
334 );
335
336 assert_eq!(
338 provider
339 .decrypt_data_key(&data_key.encrypted_key, None)
340 .await
341 .map_err(|e| e.to_string())
342 .expect_err("Decrypting data key succeeded"),
343 "failed to decrypt key"
344 );
345 }
346
347 #[test]
348 fn test_load_encrypted_key_from_slice() {
349 let slice: Vec<u8> = vec![
350 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, ];
354
355 let key = EncryptedSimpleKey::from_slice(&slice).unwrap();
356
357 assert_eq!(key.version, 1);
358 assert_eq!(
359 key.nonce,
360 Nonce::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
361 );
362 assert_eq!(key.key, &[1, 2, 3, 4, 5, 6]);
363 }
364
365 #[test]
366 fn test_fails_on_tiny_slice() {
367 let slice: Vec<u8> = vec![5, 1, 2, 3, 4, 5, 6];
368
369 let err =
370 EncryptedSimpleKey::from_slice(&slice).expect_err("Encrypted key decode succeeded");
371
372 assert_eq!(
373 err.to_string(),
374 "Slice was too small to load an EncryptedSimpleKey. Received: 7"
375 );
376 }
377
378 #[test]
379 fn test_serialize_key() {
380 let version = 1;
381 let nonce = Nonce::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
382 let key = &[1, 2, 3, 4, 5, 6];
383
384 let encrypted_key = EncryptedSimpleKey {
385 version,
386 nonce,
387 key,
388 };
389
390 let bytes = encrypted_key.to_vec();
391
392 assert_eq!(
393 bytes,
394 vec![
395 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, ]
399 );
400 }
401}