mls_rs_crypto_rustcrypto/
aead.rs1extern crate aead as rc_aead;
6
7use core::fmt::Debug;
8
9use aes_gcm::{Aes128Gcm, Aes256Gcm, KeyInit};
10use chacha20poly1305::ChaCha20Poly1305;
11use mls_rs_core::{crypto::CipherSuite, error::IntoAnyError};
12use mls_rs_crypto_traits::{AeadId, AeadType, AES_TAG_LEN};
13use rc_aead::{generic_array::GenericArray, Payload};
14
15use alloc::vec::Vec;
16
17#[derive(Debug)]
18#[cfg_attr(feature = "std", derive(thiserror::Error))]
19pub enum AeadError {
20 #[cfg_attr(feature = "std", error("Rc AEAD Error"))]
21 RcAeadError(rc_aead::Error),
22 #[cfg_attr(
23 feature = "std",
24 error("AEAD ciphertext of length {0} is too short to fit the tag")
25 )]
26 InvalidCipherLen(usize),
27 #[cfg_attr(feature = "std", error("encrypted message cannot be empty"))]
28 EmptyPlaintext,
29 #[cfg_attr(
30 feature = "std",
31 error("AEAD key of invalid length {0}. Expected length {1}")
32 )]
33 InvalidKeyLen(usize, usize),
34 #[cfg_attr(feature = "std", error("unsupported cipher suite"))]
35 UnsupportedCipherSuite,
36}
37
38impl From<rc_aead::Error> for AeadError {
39 fn from(value: rc_aead::Error) -> Self {
40 AeadError::RcAeadError(value)
41 }
42}
43
44impl IntoAnyError for AeadError {
45 #[cfg(feature = "std")]
46 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
47 Ok(self.into())
48 }
49}
50#[derive(Clone, Copy, Debug, Eq, PartialEq)]
51pub struct Aead(AeadId);
52
53impl Aead {
54 pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
55 AeadId::new(cipher_suite).map(Self)
56 }
57}
58
59#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
60#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
61#[cfg_attr(
62 all(not(target_arch = "wasm32"), mls_build_async),
63 maybe_async::must_be_async
64)]
65impl AeadType for Aead {
66 type Error = AeadError;
67
68 #[allow(clippy::needless_lifetimes)]
69 async fn seal<'a>(
70 &self,
71 key: &[u8],
72 data: &[u8],
73 aad: Option<&'a [u8]>,
74 nonce: &[u8],
75 ) -> Result<Vec<u8>, AeadError> {
76 (!data.is_empty())
77 .then_some(())
78 .ok_or(AeadError::EmptyPlaintext)?;
79
80 (key.len() == self.key_size())
81 .then_some(())
82 .ok_or_else(|| AeadError::InvalidKeyLen(key.len(), self.key_size()))?;
83
84 match self.0 {
85 AeadId::Aes128Gcm => {
86 let cipher = Aes128Gcm::new(GenericArray::from_slice(key));
87 encrypt_aead_trait(cipher, data, aad, nonce)
88 }
89 AeadId::Aes256Gcm => {
90 let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
91 encrypt_aead_trait(cipher, data, aad, nonce)
92 }
93 AeadId::Chacha20Poly1305 => {
94 let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(key));
95 encrypt_aead_trait(cipher, data, aad, nonce)
96 }
97 _ => Err(AeadError::UnsupportedCipherSuite),
98 }
99 }
100
101 #[allow(clippy::needless_lifetimes)]
102 async fn open<'a>(
103 &self,
104 key: &[u8],
105 ciphertext: &[u8],
106 aad: Option<&'a [u8]>,
107 nonce: &[u8],
108 ) -> Result<Vec<u8>, AeadError> {
109 (ciphertext.len() > AES_TAG_LEN)
110 .then_some(())
111 .ok_or(AeadError::InvalidCipherLen(ciphertext.len()))?;
112
113 (key.len() == self.key_size())
114 .then_some(())
115 .ok_or_else(|| AeadError::InvalidKeyLen(key.len(), self.key_size()))?;
116
117 match self.0 {
118 AeadId::Aes128Gcm => {
119 let cipher = Aes128Gcm::new(GenericArray::from_slice(key));
120 decrypt_aead_trait(cipher, ciphertext, aad, nonce)
121 }
122 AeadId::Aes256Gcm => {
123 let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
124 decrypt_aead_trait(cipher, ciphertext, aad, nonce)
125 }
126 AeadId::Chacha20Poly1305 => {
127 let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(key));
128 decrypt_aead_trait(cipher, ciphertext, aad, nonce)
129 }
130 _ => Err(AeadError::UnsupportedCipherSuite),
131 }
132 }
133
134 #[inline(always)]
135 fn key_size(&self) -> usize {
136 self.0.key_size()
137 }
138
139 fn nonce_size(&self) -> usize {
140 self.0.nonce_size()
141 }
142
143 fn aead_id(&self) -> u16 {
144 self.0 as u16
145 }
146}
147
148fn encrypt_aead_trait(
149 cipher: impl rc_aead::Aead,
150 data: &[u8],
151 aad: Option<&[u8]>,
152 nonce: &[u8],
153) -> Result<Vec<u8>, AeadError> {
154 let payload = Payload {
155 msg: data,
156 aad: aad.unwrap_or_default(),
157 };
158
159 Ok(cipher.encrypt(GenericArray::from_slice(nonce), payload)?)
160}
161
162fn decrypt_aead_trait(
163 cipher: impl rc_aead::Aead,
164 ciphertext: &[u8],
165 aad: Option<&[u8]>,
166 nonce: &[u8],
167) -> Result<Vec<u8>, AeadError> {
168 let payload = Payload {
169 msg: ciphertext,
170 aad: aad.unwrap_or_default(),
171 };
172
173 Ok(cipher.decrypt(GenericArray::from_slice(nonce), payload)?)
174}
175
176#[cfg(all(not(mls_build_async), test))]
177mod test {
178 use mls_rs_core::crypto::CipherSuite;
179 use mls_rs_crypto_traits::{AeadType, AES_TAG_LEN};
180
181 use super::{Aead, AeadError};
182
183 use assert_matches::assert_matches;
184
185 use alloc::vec;
186 use alloc::vec::Vec;
187
188 fn get_aeads() -> Vec<Aead> {
189 [
190 CipherSuite::CURVE25519_AES128,
191 CipherSuite::CURVE25519_CHACHA,
192 CipherSuite::CURVE448_AES256,
193 ]
194 .into_iter()
195 .map(|cs| Aead::new(cs).unwrap())
196 .collect()
197 }
198
199 #[test]
200 fn invalid_key() {
201 for aead in get_aeads() {
202 let nonce = vec![42u8; aead.nonce_size()];
203 let data = b"top secret";
204
205 let too_short = vec![42u8; aead.key_size() - 1];
206
207 assert_matches!(
208 aead.seal(&too_short, data, None, &nonce),
209 Err(AeadError::InvalidKeyLen(_, _))
210 );
211
212 let too_long = vec![42u8; aead.key_size() + 1];
213
214 assert_matches!(
215 aead.seal(&too_long, data, None, &nonce),
216 Err(AeadError::InvalidKeyLen(_, _))
217 );
218 }
219 }
220
221 #[test]
222 fn invalid_ciphertext() {
223 for aead in get_aeads() {
224 let key = vec![42u8; aead.key_size()];
225 let nonce = vec![42u8; aead.nonce_size()];
226
227 let too_short = [0u8; AES_TAG_LEN];
228
229 assert_matches!(
230 aead.open(&key, &too_short, None, &nonce),
231 Err(AeadError::InvalidCipherLen(_))
232 );
233 }
234 }
235
236 #[test]
237 fn aad_mismatch() {
238 for aead in get_aeads() {
239 let key = vec![42u8; aead.key_size()];
240 let nonce = vec![42u8; aead.nonce_size()];
241
242 let ciphertext = aead.seal(&key, b"message", Some(b"foo"), &nonce).unwrap();
243
244 assert_matches!(
245 aead.open(&key, &ciphertext, Some(b"bar"), &nonce),
246 Err(AeadError::RcAeadError(_))
247 );
248
249 assert_matches!(
250 aead.open(&key, &ciphertext, None, &nonce),
251 Err(AeadError::RcAeadError(_))
252 );
253 }
254 }
255}