mls_rs_crypto_rustcrypto/
aead.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5extern crate aead as rc_aead;
6
7use core::fmt::Debug;
8
9use aes_gcm::{Aes128Gcm, Aes256Gcm, KeyInit};
10use chacha20poly1305::ChaCha20Poly1305;
11use mls_rs_core::{crypto::CipherSuite, error::IntoAnyError};
12use mls_rs_crypto_traits::{AeadId, AeadType, AES_TAG_LEN};
13use rc_aead::{generic_array::GenericArray, Payload};
14
15use alloc::vec::Vec;
16
17#[derive(Debug)]
18#[cfg_attr(feature = "std", derive(thiserror::Error))]
19pub enum AeadError {
20    #[cfg_attr(feature = "std", error("Rc AEAD Error"))]
21    RcAeadError(rc_aead::Error),
22    #[cfg_attr(
23        feature = "std",
24        error("AEAD ciphertext of length {0} is too short to fit the tag")
25    )]
26    InvalidCipherLen(usize),
27    #[cfg_attr(feature = "std", error("encrypted message cannot be empty"))]
28    EmptyPlaintext,
29    #[cfg_attr(
30        feature = "std",
31        error("AEAD key of invalid length {0}. Expected length {1}")
32    )]
33    InvalidKeyLen(usize, usize),
34    #[cfg_attr(feature = "std", error("unsupported cipher suite"))]
35    UnsupportedCipherSuite,
36}
37
38impl From<rc_aead::Error> for AeadError {
39    fn from(value: rc_aead::Error) -> Self {
40        AeadError::RcAeadError(value)
41    }
42}
43
44impl IntoAnyError for AeadError {
45    #[cfg(feature = "std")]
46    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
47        Ok(self.into())
48    }
49}
50#[derive(Clone, Copy, Debug, Eq, PartialEq)]
51pub struct Aead(AeadId);
52
53impl Aead {
54    pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
55        AeadId::new(cipher_suite).map(Self)
56    }
57}
58
59#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
60#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
61#[cfg_attr(
62    all(not(target_arch = "wasm32"), mls_build_async),
63    maybe_async::must_be_async
64)]
65impl AeadType for Aead {
66    type Error = AeadError;
67
68    #[allow(clippy::needless_lifetimes)]
69    async fn seal<'a>(
70        &self,
71        key: &[u8],
72        data: &[u8],
73        aad: Option<&'a [u8]>,
74        nonce: &[u8],
75    ) -> Result<Vec<u8>, AeadError> {
76        (!data.is_empty())
77            .then_some(())
78            .ok_or(AeadError::EmptyPlaintext)?;
79
80        (key.len() == self.key_size())
81            .then_some(())
82            .ok_or_else(|| AeadError::InvalidKeyLen(key.len(), self.key_size()))?;
83
84        match self.0 {
85            AeadId::Aes128Gcm => {
86                let cipher = Aes128Gcm::new(GenericArray::from_slice(key));
87                encrypt_aead_trait(cipher, data, aad, nonce)
88            }
89            AeadId::Aes256Gcm => {
90                let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
91                encrypt_aead_trait(cipher, data, aad, nonce)
92            }
93            AeadId::Chacha20Poly1305 => {
94                let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(key));
95                encrypt_aead_trait(cipher, data, aad, nonce)
96            }
97            _ => Err(AeadError::UnsupportedCipherSuite),
98        }
99    }
100
101    #[allow(clippy::needless_lifetimes)]
102    async fn open<'a>(
103        &self,
104        key: &[u8],
105        ciphertext: &[u8],
106        aad: Option<&'a [u8]>,
107        nonce: &[u8],
108    ) -> Result<Vec<u8>, AeadError> {
109        (ciphertext.len() > AES_TAG_LEN)
110            .then_some(())
111            .ok_or(AeadError::InvalidCipherLen(ciphertext.len()))?;
112
113        (key.len() == self.key_size())
114            .then_some(())
115            .ok_or_else(|| AeadError::InvalidKeyLen(key.len(), self.key_size()))?;
116
117        match self.0 {
118            AeadId::Aes128Gcm => {
119                let cipher = Aes128Gcm::new(GenericArray::from_slice(key));
120                decrypt_aead_trait(cipher, ciphertext, aad, nonce)
121            }
122            AeadId::Aes256Gcm => {
123                let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
124                decrypt_aead_trait(cipher, ciphertext, aad, nonce)
125            }
126            AeadId::Chacha20Poly1305 => {
127                let cipher = ChaCha20Poly1305::new(GenericArray::from_slice(key));
128                decrypt_aead_trait(cipher, ciphertext, aad, nonce)
129            }
130            _ => Err(AeadError::UnsupportedCipherSuite),
131        }
132    }
133
134    #[inline(always)]
135    fn key_size(&self) -> usize {
136        self.0.key_size()
137    }
138
139    fn nonce_size(&self) -> usize {
140        self.0.nonce_size()
141    }
142
143    fn aead_id(&self) -> u16 {
144        self.0 as u16
145    }
146}
147
148fn encrypt_aead_trait(
149    cipher: impl rc_aead::Aead,
150    data: &[u8],
151    aad: Option<&[u8]>,
152    nonce: &[u8],
153) -> Result<Vec<u8>, AeadError> {
154    let payload = Payload {
155        msg: data,
156        aad: aad.unwrap_or_default(),
157    };
158
159    Ok(cipher.encrypt(GenericArray::from_slice(nonce), payload)?)
160}
161
162fn decrypt_aead_trait(
163    cipher: impl rc_aead::Aead,
164    ciphertext: &[u8],
165    aad: Option<&[u8]>,
166    nonce: &[u8],
167) -> Result<Vec<u8>, AeadError> {
168    let payload = Payload {
169        msg: ciphertext,
170        aad: aad.unwrap_or_default(),
171    };
172
173    Ok(cipher.decrypt(GenericArray::from_slice(nonce), payload)?)
174}
175
176#[cfg(all(not(mls_build_async), test))]
177mod test {
178    use mls_rs_core::crypto::CipherSuite;
179    use mls_rs_crypto_traits::{AeadType, AES_TAG_LEN};
180
181    use super::{Aead, AeadError};
182
183    use assert_matches::assert_matches;
184
185    use alloc::vec;
186    use alloc::vec::Vec;
187
188    fn get_aeads() -> Vec<Aead> {
189        [
190            CipherSuite::CURVE25519_AES128,
191            CipherSuite::CURVE25519_CHACHA,
192            CipherSuite::CURVE448_AES256,
193        ]
194        .into_iter()
195        .map(|cs| Aead::new(cs).unwrap())
196        .collect()
197    }
198
199    #[test]
200    fn invalid_key() {
201        for aead in get_aeads() {
202            let nonce = vec![42u8; aead.nonce_size()];
203            let data = b"top secret";
204
205            let too_short = vec![42u8; aead.key_size() - 1];
206
207            assert_matches!(
208                aead.seal(&too_short, data, None, &nonce),
209                Err(AeadError::InvalidKeyLen(_, _))
210            );
211
212            let too_long = vec![42u8; aead.key_size() + 1];
213
214            assert_matches!(
215                aead.seal(&too_long, data, None, &nonce),
216                Err(AeadError::InvalidKeyLen(_, _))
217            );
218        }
219    }
220
221    #[test]
222    fn invalid_ciphertext() {
223        for aead in get_aeads() {
224            let key = vec![42u8; aead.key_size()];
225            let nonce = vec![42u8; aead.nonce_size()];
226
227            let too_short = [0u8; AES_TAG_LEN];
228
229            assert_matches!(
230                aead.open(&key, &too_short, None, &nonce),
231                Err(AeadError::InvalidCipherLen(_))
232            );
233        }
234    }
235
236    #[test]
237    fn aad_mismatch() {
238        for aead in get_aeads() {
239            let key = vec![42u8; aead.key_size()];
240            let nonce = vec![42u8; aead.nonce_size()];
241
242            let ciphertext = aead.seal(&key, b"message", Some(b"foo"), &nonce).unwrap();
243
244            assert_matches!(
245                aead.open(&key, &ciphertext, Some(b"bar"), &nonce),
246                Err(AeadError::RcAeadError(_))
247            );
248
249            assert_matches!(
250                aead.open(&key, &ciphertext, None, &nonce),
251                Err(AeadError::RcAeadError(_))
252            );
253        }
254    }
255}