1use core::fmt::{self, Debug, Formatter};
4
5use aead::{AeadCore, AeadInPlace, KeyInit, KeySizeUser};
6use chacha20poly1305::{ChaCha20Poly1305, XChaCha20Poly1305};
7use serde::{Deserialize, Serialize};
8use zeroize::Zeroize;
9
10use super::{Chacha20Types, HasKeyAlg, HasKeyBackend, KeyAlg};
11use crate::{
12 buffer::{ArrayKey, ResizeBuffer, Writer},
13 encrypt::{KeyAeadInPlace, KeyAeadMeta, KeyAeadParams},
14 error::Error,
15 generic_array::{typenum::Unsigned, GenericArray},
16 jwk::{FromJwk, JwkEncoder, JwkParts, ToJwk},
17 kdf::{FromKeyDerivation, FromKeyExchange, KeyDerivation, KeyExchange},
18 random::KeyMaterial,
19 repr::{KeyGen, KeyMeta, KeySecretBytes},
20};
21
22pub static JWK_KEY_TYPE: &str = "oct";
24
25pub trait Chacha20Type: 'static {
27 type Aead: KeyInit + AeadCore + AeadInPlace;
29
30 const ALG_TYPE: Chacha20Types;
32 const JWK_ALG: &'static str;
34}
35
36#[derive(Debug)]
38pub struct C20P;
39
40impl Chacha20Type for C20P {
41 type Aead = ChaCha20Poly1305;
42
43 const ALG_TYPE: Chacha20Types = Chacha20Types::C20P;
44 const JWK_ALG: &'static str = "C20P";
45}
46
47#[derive(Debug)]
49pub struct XC20P;
50
51impl Chacha20Type for XC20P {
52 type Aead = XChaCha20Poly1305;
53
54 const ALG_TYPE: Chacha20Types = Chacha20Types::XC20P;
55 const JWK_ALG: &'static str = "XC20P";
56}
57
58type KeyType<A> = ArrayKey<<<A as Chacha20Type>::Aead as KeySizeUser>::KeySize>;
59
60type NonceSize<A> = <<A as Chacha20Type>::Aead as AeadCore>::NonceSize;
61
62type TagSize<A> = <<A as Chacha20Type>::Aead as AeadCore>::TagSize;
63
64#[derive(Serialize, Deserialize, Zeroize)]
66#[serde(
67 transparent,
68 bound(
69 deserialize = "KeyType<T>: for<'a> Deserialize<'a>",
70 serialize = "KeyType<T>: Serialize"
71 )
72)]
73pub struct Chacha20Key<T: Chacha20Type>(KeyType<T>);
75
76impl<T: Chacha20Type> Chacha20Key<T> {
77 pub const KEY_LENGTH: usize = KeyType::<T>::SIZE;
79 pub const NONCE_LENGTH: usize = NonceSize::<T>::USIZE;
81 pub const TAG_LENGTH: usize = TagSize::<T>::USIZE;
83}
84
85impl<T: Chacha20Type> Clone for Chacha20Key<T> {
86 fn clone(&self) -> Self {
87 Self(self.0.clone())
88 }
89}
90
91impl<T: Chacha20Type> Debug for Chacha20Key<T> {
92 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
93 f.debug_struct("Chacha20Key")
94 .field("alg", &T::JWK_ALG)
95 .field("key", &self.0)
96 .finish()
97 }
98}
99
100impl<T: Chacha20Type> PartialEq for Chacha20Key<T> {
101 fn eq(&self, other: &Self) -> bool {
102 other.0 == self.0
103 }
104}
105
106impl<T: Chacha20Type> Eq for Chacha20Key<T> {}
107
108impl<T: Chacha20Type> HasKeyBackend for Chacha20Key<T> {}
109
110impl<T: Chacha20Type> HasKeyAlg for Chacha20Key<T> {
111 fn algorithm(&self) -> KeyAlg {
112 KeyAlg::Chacha20(T::ALG_TYPE)
113 }
114}
115
116impl<T: Chacha20Type> KeyMeta for Chacha20Key<T> {
117 type KeySize = <T::Aead as KeySizeUser>::KeySize;
118}
119
120impl<T: Chacha20Type> KeyGen for Chacha20Key<T> {
121 fn generate(rng: impl KeyMaterial) -> Result<Self, Error> {
122 Ok(Chacha20Key(KeyType::<T>::generate(rng)))
123 }
124}
125
126impl<T: Chacha20Type> KeySecretBytes for Chacha20Key<T> {
127 fn from_secret_bytes(key: &[u8]) -> Result<Self, Error> {
128 if key.len() != KeyType::<T>::SIZE {
129 return Err(err_msg!(InvalidKeyData));
130 }
131 Ok(Self(KeyType::<T>::from_slice(key)))
132 }
133
134 fn with_secret_bytes<O>(&self, f: impl FnOnce(Option<&[u8]>) -> O) -> O {
135 f(Some(self.0.as_ref()))
136 }
137}
138
139impl<T: Chacha20Type> FromKeyDerivation for Chacha20Key<T> {
140 fn from_key_derivation<D: KeyDerivation>(mut derive: D) -> Result<Self, Error>
141 where
142 Self: Sized,
143 {
144 Ok(Self(KeyType::<T>::try_new_with(|arr| {
145 derive.derive_key_bytes(arr)
146 })?))
147 }
148}
149
150impl<T: Chacha20Type> KeyAeadMeta for Chacha20Key<T> {
151 type NonceSize = NonceSize<T>;
152 type TagSize = TagSize<T>;
153}
154
155impl<T: Chacha20Type> KeyAeadInPlace for Chacha20Key<T> {
156 fn encrypt_in_place(
158 &self,
159 buffer: &mut dyn ResizeBuffer,
160 nonce: &[u8],
161 aad: &[u8],
162 ) -> Result<usize, Error> {
163 if nonce.len() != NonceSize::<T>::USIZE {
164 return Err(err_msg!(InvalidNonce));
165 }
166 let nonce = GenericArray::from_slice(nonce);
167 let chacha = T::Aead::new(self.0.as_ref());
168 let tag = chacha
169 .encrypt_in_place_detached(nonce, aad, buffer.as_mut())
170 .map_err(|_| err_msg!(Encryption, "AEAD encryption error"))?;
171 let ctext_len = buffer.as_ref().len();
172 buffer.buffer_write(&tag[..])?;
173 Ok(ctext_len)
174 }
175
176 fn decrypt_in_place(
178 &self,
179 buffer: &mut dyn ResizeBuffer,
180 nonce: &[u8],
181 aad: &[u8],
182 ) -> Result<(), Error> {
183 if nonce.len() != NonceSize::<T>::USIZE {
184 return Err(err_msg!(InvalidNonce));
185 }
186 let nonce = GenericArray::from_slice(nonce);
187 let buf_len = buffer.as_ref().len();
188 if buf_len < TagSize::<T>::USIZE {
189 return Err(err_msg!(Invalid, "Invalid size for encrypted data"));
190 }
191 let tag_start = buf_len - TagSize::<T>::USIZE;
192 let mut tag = GenericArray::default();
193 tag.clone_from_slice(&buffer.as_ref()[tag_start..]);
194 let chacha = T::Aead::new(self.0.as_ref());
195 chacha
196 .decrypt_in_place_detached(nonce, aad, &mut buffer.as_mut()[..tag_start], &tag)
197 .map_err(|_| err_msg!(Encryption, "AEAD decryption error"))?;
198 buffer.buffer_resize(tag_start)?;
199 Ok(())
200 }
201
202 fn aead_params(&self) -> KeyAeadParams {
203 KeyAeadParams {
204 nonce_length: NonceSize::<T>::USIZE,
205 tag_length: TagSize::<T>::USIZE,
206 }
207 }
208}
209
210impl<T: Chacha20Type> FromJwk for Chacha20Key<T> {
211 fn from_jwk_parts(jwk: JwkParts<'_>) -> Result<Self, Error> {
212 if jwk.kty != JWK_KEY_TYPE {
213 return Err(err_msg!(InvalidKeyData, "Unsupported key type"));
214 }
215 if jwk.alg.is_some() && jwk.alg != T::JWK_ALG {
216 return Err(err_msg!(InvalidKeyData, "Unsupported key algorithm"));
217 }
218 Ok(Self(ArrayKey::try_new_with(|buf| {
219 if jwk.k.decode_base64(buf)? != buf.len() {
220 Err(err_msg!(InvalidKeyData))
221 } else {
222 Ok(())
223 }
224 })?))
225 }
226}
227
228impl<T: Chacha20Type> ToJwk for Chacha20Key<T> {
229 fn encode_jwk(&self, enc: &mut dyn JwkEncoder) -> Result<(), Error> {
230 if enc.is_public() {
231 return Err(err_msg!(Unsupported, "Cannot export as a public key"));
232 }
233 if !enc.is_thumbprint() {
234 enc.add_str("alg", T::JWK_ALG)?;
235 }
236 enc.add_as_base64("k", self.0.as_ref())?;
237 enc.add_str("kty", JWK_KEY_TYPE)?;
238 Ok(())
239 }
240}
241
242impl<Lhs, Rhs, T> FromKeyExchange<Lhs, Rhs> for Chacha20Key<T>
244where
245 Lhs: KeyExchange<Rhs> + ?Sized,
246 Rhs: ?Sized,
247 T: Chacha20Type,
248{
249 fn from_key_exchange(lhs: &Lhs, rhs: &Rhs) -> Result<Self, Error> {
250 Ok(Self(KeyType::<T>::try_new_with(|arr| {
251 let mut buf = Writer::from_slice(arr);
252 lhs.write_key_exchange(rhs, &mut buf)?;
253 if buf.position() != Self::KEY_LENGTH {
254 return Err(err_msg!(Usage, "Invalid length for key exchange output"));
255 }
256 Ok(())
257 })?))
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::buffer::SecretBytes;
265 use crate::repr::ToSecretBytes;
266
267 #[test]
268 fn encrypt_round_trip() {
269 fn test_encrypt<T: Chacha20Type>() {
270 let input = b"hello";
271 let key = Chacha20Key::<T>::random().unwrap();
272 let mut buffer = SecretBytes::from_slice(input);
273 let nonce = Chacha20Key::<T>::random_nonce();
274 key.encrypt_in_place(&mut buffer, &nonce, &[]).unwrap();
275 assert_eq!(buffer.len(), input.len() + Chacha20Key::<T>::TAG_LENGTH);
276 assert_ne!(&buffer[..], input);
277 key.decrypt_in_place(&mut buffer, &nonce, &[]).unwrap();
278 assert_eq!(&buffer[..], input);
279 }
280 test_encrypt::<C20P>();
281 test_encrypt::<XC20P>();
282 }
283
284 #[cfg(feature = "any_key")]
285 #[test]
286 fn jwk_any_compat() {
287 use crate::alg::{any::AnyKey, Chacha20Types, KeyAlg};
288 use alloc::boxed::Box;
289
290 let test_jwk_compat = r#"
291 {"alg": "XC20P",
292 "k": "IateWalmifmgIAtA6XhbPVKPmjBUiwrs3p0ePHpMivU",
293 "kty": "oct"}
294 "#;
295 let key = Box::<AnyKey>::from_jwk(test_jwk_compat).expect("Error decoding ChaCha key JWK");
296 assert_eq!(key.algorithm(), KeyAlg::Chacha20(Chacha20Types::XC20P));
297 let as_chacha = key
298 .downcast_ref::<Chacha20Key<XC20P>>()
299 .expect("Error downcasting ChaCha key");
300 let _ = as_chacha
301 .to_jwk_secret(None)
302 .expect("Error converting key to JWK");
303 }
304
305 #[test]
306 fn serialize_round_trip() {
307 fn test_serialize<T: Chacha20Type>() {
308 let key = Chacha20Key::<T>::random().unwrap();
309 let sk = key.to_secret_bytes().unwrap();
310 let mut bytes = vec![];
311 ciborium::into_writer(&key, &mut bytes).unwrap();
312 let deser: alloc::vec::Vec<u8> = ciborium::from_reader(&bytes[..]).unwrap();
313 assert_eq!(deser, sk.as_ref());
314 }
315 test_serialize::<C20P>();
316 test_serialize::<XC20P>();
317 }
318}