askar_crypto/alg/aes/
mod.rs

1//! AES key representations with AEAD support
2
3use core::fmt::{self, Debug, Formatter};
4
5use aead::{generic_array::ArrayLength, AeadCore, AeadInPlace, KeyInit, KeySizeUser};
6use aes_gcm::{Aes128Gcm, Aes256Gcm};
7use serde::{Deserialize, Serialize};
8use zeroize::Zeroize;
9
10use super::{AesTypes, HasKeyAlg, HasKeyBackend, KeyAlg};
11use crate::{
12    buffer::{ArrayKey, ResizeBuffer, Writer},
13    encrypt::{KeyAeadInPlace, KeyAeadMeta, KeyAeadParams},
14    error::Error,
15    generic_array::{typenum::Unsigned, GenericArray},
16    jwk::{FromJwk, JwkEncoder, JwkParts, ToJwk},
17    kdf::{FromKeyDerivation, FromKeyExchange, KeyDerivation, KeyExchange},
18    random::KeyMaterial,
19    repr::{KeyGen, KeyMeta, KeySecretBytes},
20};
21
22mod cbc_hmac;
23pub use cbc_hmac::{A128CbcHs256, A256CbcHs512};
24
25mod key_wrap;
26pub use key_wrap::{A128Kw, A256Kw};
27
28/// The 'kty' value of a symmetric key JWK
29pub static JWK_KEY_TYPE: &str = "oct";
30
31/// Trait implemented by supported AES authenticated encryption algorithms
32pub trait AesType: 'static {
33    /// The size of the key secret bytes
34    type KeySize: ArrayLength<u8>;
35
36    /// The associated algorithm type
37    const ALG_TYPE: AesTypes;
38    /// The associated JWK algorithm name
39    const JWK_ALG: &'static str;
40}
41
42type KeyType<A> = ArrayKey<<A as AesType>::KeySize>;
43
44type NonceSize<A> = <A as KeyAeadMeta>::NonceSize;
45
46type TagSize<A> = <A as KeyAeadMeta>::TagSize;
47
48/// An AES symmetric encryption key
49#[derive(Serialize, Deserialize, Zeroize)]
50#[serde(
51    transparent,
52    bound(
53        deserialize = "KeyType<T>: for<'a> Deserialize<'a>",
54        serialize = "KeyType<T>: Serialize"
55    )
56)]
57// SECURITY: ArrayKey is zeroized on drop
58pub struct AesKey<T: AesType>(KeyType<T>);
59
60impl<T: AesType> Clone for AesKey<T> {
61    fn clone(&self) -> Self {
62        Self(self.0.clone())
63    }
64}
65
66impl<T: AesType> Debug for AesKey<T> {
67    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68        f.debug_struct("AesKey")
69            .field("alg", &T::JWK_ALG)
70            .field("key", &self.0)
71            .finish()
72    }
73}
74
75impl<T: AesType> PartialEq for AesKey<T> {
76    fn eq(&self, other: &Self) -> bool {
77        other.0 == self.0
78    }
79}
80
81impl<T: AesType> Eq for AesKey<T> {}
82
83impl<T: AesType> HasKeyBackend for AesKey<T> {}
84
85impl<T: AesType> HasKeyAlg for AesKey<T> {
86    fn algorithm(&self) -> KeyAlg {
87        KeyAlg::Aes(T::ALG_TYPE)
88    }
89}
90
91impl<T: AesType> KeyMeta for AesKey<T> {
92    type KeySize = T::KeySize;
93}
94
95impl<T: AesType> KeyGen for AesKey<T> {
96    fn generate(rng: impl KeyMaterial) -> Result<Self, Error> {
97        Ok(AesKey(KeyType::<T>::generate(rng)))
98    }
99}
100
101impl<T: AesType> KeySecretBytes for AesKey<T> {
102    fn from_secret_bytes(key: &[u8]) -> Result<Self, Error> {
103        if key.len() != KeyType::<T>::SIZE {
104            return Err(err_msg!(InvalidKeyData));
105        }
106        Ok(Self(KeyType::<T>::from_slice(key)))
107    }
108
109    fn with_secret_bytes<O>(&self, f: impl FnOnce(Option<&[u8]>) -> O) -> O {
110        f(Some(self.0.as_ref()))
111    }
112}
113
114impl<T: AesType> FromKeyDerivation for AesKey<T> {
115    fn from_key_derivation<D: KeyDerivation>(mut derive: D) -> Result<Self, Error>
116    where
117        Self: Sized,
118    {
119        Ok(Self(KeyType::<T>::try_new_with(|arr| {
120            derive.derive_key_bytes(arr)
121        })?))
122    }
123}
124
125impl<T: AesType> FromJwk for AesKey<T> {
126    fn from_jwk_parts(jwk: JwkParts<'_>) -> Result<Self, Error> {
127        if jwk.kty != JWK_KEY_TYPE {
128            return Err(err_msg!(InvalidKeyData, "Unsupported key type"));
129        }
130        if jwk.alg.is_some() && jwk.alg != T::JWK_ALG {
131            return Err(err_msg!(InvalidKeyData, "Unsupported key algorithm"));
132        }
133        Ok(Self(ArrayKey::try_new_with(|buf| {
134            if jwk.k.decode_base64(buf)? != buf.len() {
135                Err(err_msg!(InvalidKeyData))
136            } else {
137                Ok(())
138            }
139        })?))
140    }
141}
142
143impl<T: AesType> ToJwk for AesKey<T> {
144    fn encode_jwk(&self, enc: &mut dyn JwkEncoder) -> Result<(), Error> {
145        if enc.is_public() {
146            return Err(err_msg!(Unsupported, "Cannot export as a public key"));
147        }
148        if !enc.is_thumbprint() {
149            enc.add_str("alg", T::JWK_ALG)?;
150        }
151        enc.add_as_base64("k", self.0.as_ref())?;
152        enc.add_str("kty", JWK_KEY_TYPE)?;
153        Ok(())
154    }
155}
156
157// for direct key agreement (not used currently)
158impl<Lhs, Rhs, T> FromKeyExchange<Lhs, Rhs> for AesKey<T>
159where
160    Lhs: KeyExchange<Rhs> + ?Sized,
161    Rhs: ?Sized,
162    T: AesType,
163{
164    fn from_key_exchange(lhs: &Lhs, rhs: &Rhs) -> Result<Self, Error> {
165        Ok(Self(KeyType::<T>::try_new_with(|arr| {
166            let mut buf = Writer::from_slice(arr);
167            lhs.write_key_exchange(rhs, &mut buf)?;
168            if buf.position() != T::KeySize::USIZE {
169                return Err(err_msg!(Usage, "Invalid length for key exchange output"));
170            }
171            Ok(())
172        })?))
173    }
174}
175
176/// 128 bit AES-GCM
177pub type A128Gcm = Aes128Gcm;
178
179impl AesType for A128Gcm {
180    type KeySize = <Self as KeySizeUser>::KeySize;
181
182    const ALG_TYPE: AesTypes = AesTypes::A128Gcm;
183    const JWK_ALG: &'static str = "A128GCM";
184}
185
186/// 256 bit AES-GCM
187pub type A256Gcm = Aes256Gcm;
188
189impl AesType for A256Gcm {
190    type KeySize = <Self as KeySizeUser>::KeySize;
191
192    const ALG_TYPE: AesTypes = AesTypes::A256Gcm;
193    const JWK_ALG: &'static str = "A256GCM";
194}
195
196// generic implementation applying to AesGcm
197impl<T: AeadCore + AesType> KeyAeadMeta for AesKey<T> {
198    type NonceSize = <T as AeadCore>::NonceSize;
199    type TagSize = <T as AeadCore>::TagSize;
200}
201
202// generic implementation applying to AesGcm
203impl<T> KeyAeadInPlace for AesKey<T>
204where
205    T: KeyInit + AeadInPlace + AesType<KeySize = <T as KeySizeUser>::KeySize>,
206{
207    /// Encrypt a secret value in place, appending the verification tag
208    fn encrypt_in_place(
209        &self,
210        buffer: &mut dyn ResizeBuffer,
211        nonce: &[u8],
212        aad: &[u8],
213    ) -> Result<usize, Error> {
214        if nonce.len() != T::NonceSize::USIZE {
215            return Err(err_msg!(InvalidNonce));
216        }
217        let enc = <T as KeyInit>::new(self.0.as_ref());
218        let tag = enc
219            .encrypt_in_place_detached(GenericArray::from_slice(nonce), aad, buffer.as_mut())
220            .map_err(|_| err_msg!(Encryption, "AEAD encryption error"))?;
221        let ctext_len = buffer.as_ref().len();
222        buffer.buffer_write(&tag[..])?;
223        Ok(ctext_len)
224    }
225
226    /// Decrypt an encrypted (verification tag appended) value in place
227    fn decrypt_in_place(
228        &self,
229        buffer: &mut dyn ResizeBuffer,
230        nonce: &[u8],
231        aad: &[u8],
232    ) -> Result<(), Error> {
233        if nonce.len() != T::NonceSize::USIZE {
234            return Err(err_msg!(InvalidNonce));
235        }
236        let buf_len = buffer.as_ref().len();
237        if buf_len < T::TagSize::USIZE {
238            return Err(err_msg!(Encryption, "Invalid size for encrypted data"));
239        }
240        let tag_start = buf_len - T::TagSize::USIZE;
241        let mut tag = GenericArray::default();
242        tag.clone_from_slice(&buffer.as_ref()[tag_start..]);
243        let enc = <T as KeyInit>::new(self.0.as_ref());
244        enc.decrypt_in_place_detached(
245            GenericArray::from_slice(nonce),
246            aad,
247            &mut buffer.as_mut()[..tag_start],
248            &tag,
249        )
250        .map_err(|_| err_msg!(Encryption, "AEAD decryption error"))?;
251        buffer.buffer_resize(tag_start)?;
252        Ok(())
253    }
254
255    fn aead_params(&self) -> KeyAeadParams {
256        KeyAeadParams {
257            nonce_length: T::NonceSize::USIZE,
258            tag_length: T::TagSize::USIZE,
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::buffer::SecretBytes;
267    use crate::repr::ToSecretBytes;
268
269    #[test]
270    fn encrypt_round_trip() {
271        fn test_encrypt<T>()
272        where
273            T: AesType,
274            AesKey<T>: KeyAeadInPlace + KeyAeadMeta,
275        {
276            let input = b"hello";
277            let aad = b"additional data";
278            let key = AesKey::<T>::random().unwrap();
279            let mut buffer = SecretBytes::from_slice(input);
280            let params = key.aead_params();
281            let pad_len = key.aead_padding(input.len());
282            let nonce = AesKey::<T>::random_nonce();
283            key.encrypt_in_place(&mut buffer, &nonce, aad).unwrap();
284            let enc_len = buffer.len();
285            assert_eq!(enc_len, input.len() + pad_len + params.tag_length);
286            assert_ne!(&buffer[..], input);
287            let mut dec = buffer.clone();
288            key.decrypt_in_place(&mut dec, &nonce, aad).unwrap();
289            assert_eq!(&dec[..], input);
290
291            // test tag validation
292            buffer.as_mut()[enc_len - 1] = buffer.as_mut()[enc_len - 1].wrapping_add(1);
293            assert!(key.decrypt_in_place(&mut buffer, &nonce, aad).is_err());
294        }
295        test_encrypt::<A128Gcm>();
296        test_encrypt::<A256Gcm>();
297        test_encrypt::<A128CbcHs256>();
298        test_encrypt::<A256CbcHs512>();
299    }
300
301    #[test]
302    fn test_random() {
303        let key = AesKey::<A128CbcHs256>::random().unwrap();
304        let nonce = AesKey::<A128CbcHs256>::random_nonce();
305        let message = b"hello there";
306        let mut buffer = [0u8; 255];
307        buffer[0..message.len()].copy_from_slice(&message[..]);
308        let mut writer = Writer::from_slice_position(&mut buffer, message.len());
309        key.encrypt_in_place(&mut writer, nonce.as_slice(), &[])
310            .unwrap();
311    }
312
313    #[cfg(feature = "any_key")]
314    #[test]
315    fn jwk_any_compat() {
316        use crate::alg::{any::AnyKey, AesTypes, KeyAlg};
317        use alloc::boxed::Box;
318
319        let test_jwk_compat = r#"
320            {"alg": "A128CBC-HS256",
321            "k": "6scajSsnjo2fI-wjCCvBC2xNSYyErNyN93CAsyzVVGI",
322            "kty": "oct"}
323        "#;
324        let key = Box::<AnyKey>::from_jwk(test_jwk_compat).expect("Error decoding AES key JWK");
325        assert_eq!(key.algorithm(), KeyAlg::Aes(AesTypes::A128CbcHs256));
326        let as_aes = key
327            .downcast_ref::<AesKey<A128CbcHs256>>()
328            .expect("Error downcasting AES key");
329        let _ = as_aes
330            .to_jwk_secret(None)
331            .expect("Error converting key to JWK");
332    }
333
334    #[test]
335    fn serialize_round_trip() {
336        fn test_serialize<T: AesType>() {
337            let key = AesKey::<T>::random().unwrap();
338            let sk = key.to_secret_bytes().unwrap();
339            let mut bytes = alloc::vec::Vec::new();
340            ciborium::into_writer(&key, &mut bytes).unwrap();
341            let deser: alloc::vec::Vec<u8> = ciborium::from_reader(&bytes[..]).unwrap();
342            assert_eq!(deser, sk.as_ref());
343        }
344        test_serialize::<A128Gcm>();
345        test_serialize::<A256Gcm>();
346        test_serialize::<A128CbcHs256>();
347        test_serialize::<A256CbcHs512>();
348        test_serialize::<A128Kw>();
349        test_serialize::<A256Kw>();
350    }
351}