1use core::fmt::{self, Debug, Formatter};
4
5use aead::{generic_array::ArrayLength, AeadCore, AeadInPlace, KeyInit, KeySizeUser};
6use aes_gcm::{Aes128Gcm, Aes256Gcm};
7use serde::{Deserialize, Serialize};
8use zeroize::Zeroize;
9
10use super::{AesTypes, HasKeyAlg, HasKeyBackend, KeyAlg};
11use crate::{
12 buffer::{ArrayKey, ResizeBuffer, Writer},
13 encrypt::{KeyAeadInPlace, KeyAeadMeta, KeyAeadParams},
14 error::Error,
15 generic_array::{typenum::Unsigned, GenericArray},
16 jwk::{FromJwk, JwkEncoder, JwkParts, ToJwk},
17 kdf::{FromKeyDerivation, FromKeyExchange, KeyDerivation, KeyExchange},
18 random::KeyMaterial,
19 repr::{KeyGen, KeyMeta, KeySecretBytes},
20};
21
22mod cbc_hmac;
23pub use cbc_hmac::{A128CbcHs256, A256CbcHs512};
24
25mod key_wrap;
26pub use key_wrap::{A128Kw, A256Kw};
27
28pub static JWK_KEY_TYPE: &str = "oct";
30
31pub trait AesType: 'static {
33 type KeySize: ArrayLength<u8>;
35
36 const ALG_TYPE: AesTypes;
38 const JWK_ALG: &'static str;
40}
41
42type KeyType<A> = ArrayKey<<A as AesType>::KeySize>;
43
44type NonceSize<A> = <A as KeyAeadMeta>::NonceSize;
45
46type TagSize<A> = <A as KeyAeadMeta>::TagSize;
47
48#[derive(Serialize, Deserialize, Zeroize)]
50#[serde(
51 transparent,
52 bound(
53 deserialize = "KeyType<T>: for<'a> Deserialize<'a>",
54 serialize = "KeyType<T>: Serialize"
55 )
56)]
57pub struct AesKey<T: AesType>(KeyType<T>);
59
60impl<T: AesType> Clone for AesKey<T> {
61 fn clone(&self) -> Self {
62 Self(self.0.clone())
63 }
64}
65
66impl<T: AesType> Debug for AesKey<T> {
67 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68 f.debug_struct("AesKey")
69 .field("alg", &T::JWK_ALG)
70 .field("key", &self.0)
71 .finish()
72 }
73}
74
75impl<T: AesType> PartialEq for AesKey<T> {
76 fn eq(&self, other: &Self) -> bool {
77 other.0 == self.0
78 }
79}
80
81impl<T: AesType> Eq for AesKey<T> {}
82
83impl<T: AesType> HasKeyBackend for AesKey<T> {}
84
85impl<T: AesType> HasKeyAlg for AesKey<T> {
86 fn algorithm(&self) -> KeyAlg {
87 KeyAlg::Aes(T::ALG_TYPE)
88 }
89}
90
91impl<T: AesType> KeyMeta for AesKey<T> {
92 type KeySize = T::KeySize;
93}
94
95impl<T: AesType> KeyGen for AesKey<T> {
96 fn generate(rng: impl KeyMaterial) -> Result<Self, Error> {
97 Ok(AesKey(KeyType::<T>::generate(rng)))
98 }
99}
100
101impl<T: AesType> KeySecretBytes for AesKey<T> {
102 fn from_secret_bytes(key: &[u8]) -> Result<Self, Error> {
103 if key.len() != KeyType::<T>::SIZE {
104 return Err(err_msg!(InvalidKeyData));
105 }
106 Ok(Self(KeyType::<T>::from_slice(key)))
107 }
108
109 fn with_secret_bytes<O>(&self, f: impl FnOnce(Option<&[u8]>) -> O) -> O {
110 f(Some(self.0.as_ref()))
111 }
112}
113
114impl<T: AesType> FromKeyDerivation for AesKey<T> {
115 fn from_key_derivation<D: KeyDerivation>(mut derive: D) -> Result<Self, Error>
116 where
117 Self: Sized,
118 {
119 Ok(Self(KeyType::<T>::try_new_with(|arr| {
120 derive.derive_key_bytes(arr)
121 })?))
122 }
123}
124
125impl<T: AesType> FromJwk for AesKey<T> {
126 fn from_jwk_parts(jwk: JwkParts<'_>) -> Result<Self, Error> {
127 if jwk.kty != JWK_KEY_TYPE {
128 return Err(err_msg!(InvalidKeyData, "Unsupported key type"));
129 }
130 if jwk.alg.is_some() && jwk.alg != T::JWK_ALG {
131 return Err(err_msg!(InvalidKeyData, "Unsupported key algorithm"));
132 }
133 Ok(Self(ArrayKey::try_new_with(|buf| {
134 if jwk.k.decode_base64(buf)? != buf.len() {
135 Err(err_msg!(InvalidKeyData))
136 } else {
137 Ok(())
138 }
139 })?))
140 }
141}
142
143impl<T: AesType> ToJwk for AesKey<T> {
144 fn encode_jwk(&self, enc: &mut dyn JwkEncoder) -> Result<(), Error> {
145 if enc.is_public() {
146 return Err(err_msg!(Unsupported, "Cannot export as a public key"));
147 }
148 if !enc.is_thumbprint() {
149 enc.add_str("alg", T::JWK_ALG)?;
150 }
151 enc.add_as_base64("k", self.0.as_ref())?;
152 enc.add_str("kty", JWK_KEY_TYPE)?;
153 Ok(())
154 }
155}
156
157impl<Lhs, Rhs, T> FromKeyExchange<Lhs, Rhs> for AesKey<T>
159where
160 Lhs: KeyExchange<Rhs> + ?Sized,
161 Rhs: ?Sized,
162 T: AesType,
163{
164 fn from_key_exchange(lhs: &Lhs, rhs: &Rhs) -> Result<Self, Error> {
165 Ok(Self(KeyType::<T>::try_new_with(|arr| {
166 let mut buf = Writer::from_slice(arr);
167 lhs.write_key_exchange(rhs, &mut buf)?;
168 if buf.position() != T::KeySize::USIZE {
169 return Err(err_msg!(Usage, "Invalid length for key exchange output"));
170 }
171 Ok(())
172 })?))
173 }
174}
175
176pub type A128Gcm = Aes128Gcm;
178
179impl AesType for A128Gcm {
180 type KeySize = <Self as KeySizeUser>::KeySize;
181
182 const ALG_TYPE: AesTypes = AesTypes::A128Gcm;
183 const JWK_ALG: &'static str = "A128GCM";
184}
185
186pub type A256Gcm = Aes256Gcm;
188
189impl AesType for A256Gcm {
190 type KeySize = <Self as KeySizeUser>::KeySize;
191
192 const ALG_TYPE: AesTypes = AesTypes::A256Gcm;
193 const JWK_ALG: &'static str = "A256GCM";
194}
195
196impl<T: AeadCore + AesType> KeyAeadMeta for AesKey<T> {
198 type NonceSize = <T as AeadCore>::NonceSize;
199 type TagSize = <T as AeadCore>::TagSize;
200}
201
202impl<T> KeyAeadInPlace for AesKey<T>
204where
205 T: KeyInit + AeadInPlace + AesType<KeySize = <T as KeySizeUser>::KeySize>,
206{
207 fn encrypt_in_place(
209 &self,
210 buffer: &mut dyn ResizeBuffer,
211 nonce: &[u8],
212 aad: &[u8],
213 ) -> Result<usize, Error> {
214 if nonce.len() != T::NonceSize::USIZE {
215 return Err(err_msg!(InvalidNonce));
216 }
217 let enc = <T as KeyInit>::new(self.0.as_ref());
218 let tag = enc
219 .encrypt_in_place_detached(GenericArray::from_slice(nonce), aad, buffer.as_mut())
220 .map_err(|_| err_msg!(Encryption, "AEAD encryption error"))?;
221 let ctext_len = buffer.as_ref().len();
222 buffer.buffer_write(&tag[..])?;
223 Ok(ctext_len)
224 }
225
226 fn decrypt_in_place(
228 &self,
229 buffer: &mut dyn ResizeBuffer,
230 nonce: &[u8],
231 aad: &[u8],
232 ) -> Result<(), Error> {
233 if nonce.len() != T::NonceSize::USIZE {
234 return Err(err_msg!(InvalidNonce));
235 }
236 let buf_len = buffer.as_ref().len();
237 if buf_len < T::TagSize::USIZE {
238 return Err(err_msg!(Encryption, "Invalid size for encrypted data"));
239 }
240 let tag_start = buf_len - T::TagSize::USIZE;
241 let mut tag = GenericArray::default();
242 tag.clone_from_slice(&buffer.as_ref()[tag_start..]);
243 let enc = <T as KeyInit>::new(self.0.as_ref());
244 enc.decrypt_in_place_detached(
245 GenericArray::from_slice(nonce),
246 aad,
247 &mut buffer.as_mut()[..tag_start],
248 &tag,
249 )
250 .map_err(|_| err_msg!(Encryption, "AEAD decryption error"))?;
251 buffer.buffer_resize(tag_start)?;
252 Ok(())
253 }
254
255 fn aead_params(&self) -> KeyAeadParams {
256 KeyAeadParams {
257 nonce_length: T::NonceSize::USIZE,
258 tag_length: T::TagSize::USIZE,
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::buffer::SecretBytes;
267 use crate::repr::ToSecretBytes;
268
269 #[test]
270 fn encrypt_round_trip() {
271 fn test_encrypt<T>()
272 where
273 T: AesType,
274 AesKey<T>: KeyAeadInPlace + KeyAeadMeta,
275 {
276 let input = b"hello";
277 let aad = b"additional data";
278 let key = AesKey::<T>::random().unwrap();
279 let mut buffer = SecretBytes::from_slice(input);
280 let params = key.aead_params();
281 let pad_len = key.aead_padding(input.len());
282 let nonce = AesKey::<T>::random_nonce();
283 key.encrypt_in_place(&mut buffer, &nonce, aad).unwrap();
284 let enc_len = buffer.len();
285 assert_eq!(enc_len, input.len() + pad_len + params.tag_length);
286 assert_ne!(&buffer[..], input);
287 let mut dec = buffer.clone();
288 key.decrypt_in_place(&mut dec, &nonce, aad).unwrap();
289 assert_eq!(&dec[..], input);
290
291 buffer.as_mut()[enc_len - 1] = buffer.as_mut()[enc_len - 1].wrapping_add(1);
293 assert!(key.decrypt_in_place(&mut buffer, &nonce, aad).is_err());
294 }
295 test_encrypt::<A128Gcm>();
296 test_encrypt::<A256Gcm>();
297 test_encrypt::<A128CbcHs256>();
298 test_encrypt::<A256CbcHs512>();
299 }
300
301 #[test]
302 fn test_random() {
303 let key = AesKey::<A128CbcHs256>::random().unwrap();
304 let nonce = AesKey::<A128CbcHs256>::random_nonce();
305 let message = b"hello there";
306 let mut buffer = [0u8; 255];
307 buffer[0..message.len()].copy_from_slice(&message[..]);
308 let mut writer = Writer::from_slice_position(&mut buffer, message.len());
309 key.encrypt_in_place(&mut writer, nonce.as_slice(), &[])
310 .unwrap();
311 }
312
313 #[cfg(feature = "any_key")]
314 #[test]
315 fn jwk_any_compat() {
316 use crate::alg::{any::AnyKey, AesTypes, KeyAlg};
317 use alloc::boxed::Box;
318
319 let test_jwk_compat = r#"
320 {"alg": "A128CBC-HS256",
321 "k": "6scajSsnjo2fI-wjCCvBC2xNSYyErNyN93CAsyzVVGI",
322 "kty": "oct"}
323 "#;
324 let key = Box::<AnyKey>::from_jwk(test_jwk_compat).expect("Error decoding AES key JWK");
325 assert_eq!(key.algorithm(), KeyAlg::Aes(AesTypes::A128CbcHs256));
326 let as_aes = key
327 .downcast_ref::<AesKey<A128CbcHs256>>()
328 .expect("Error downcasting AES key");
329 let _ = as_aes
330 .to_jwk_secret(None)
331 .expect("Error converting key to JWK");
332 }
333
334 #[test]
335 fn serialize_round_trip() {
336 fn test_serialize<T: AesType>() {
337 let key = AesKey::<T>::random().unwrap();
338 let sk = key.to_secret_bytes().unwrap();
339 let mut bytes = alloc::vec::Vec::new();
340 ciborium::into_writer(&key, &mut bytes).unwrap();
341 let deser: alloc::vec::Vec<u8> = ciborium::from_reader(&bytes[..]).unwrap();
342 assert_eq!(deser, sk.as_ref());
343 }
344 test_serialize::<A128Gcm>();
345 test_serialize::<A256Gcm>();
346 test_serialize::<A128CbcHs256>();
347 test_serialize::<A256CbcHs512>();
348 test_serialize::<A128Kw>();
349 test_serialize::<A256Kw>();
350 }
351}