mls_rs_crypto_openssl/
aead.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use std::{fmt::Debug, ops::Deref};
6
7use mls_rs_core::{crypto::CipherSuite, error::IntoAnyError};
8use mls_rs_crypto_traits::{AeadId, AeadType, AES_TAG_LEN};
9use openssl::symm::{decrypt_aead, encrypt_aead, Cipher};
10use thiserror::Error;
11
12#[derive(Debug, Error)]
13pub enum AeadError {
14    #[error(transparent)]
15    OpensslError(#[from] openssl::error::ErrorStack),
16    #[error("AEAD ciphertext of length {0} is too short to fit the tag")]
17    InvalidCipherLen(usize),
18    #[error("encrypted message cannot be empty")]
19    EmptyPlaintext,
20    #[error("unsupported cipher suite")]
21    UnsupportedCipherSuite,
22}
23
24impl IntoAnyError for AeadError {
25    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
26        Ok(self.into())
27    }
28}
29
30#[derive(Clone)]
31pub struct Aead {
32    cipher: Cipher,
33    aead_id: AeadId,
34}
35
36impl Debug for Aead {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "Aead with aead_id {:?}", self.aead_id)
39    }
40}
41
42impl Deref for Aead {
43    type Target = Cipher;
44
45    fn deref(&self) -> &Self::Target {
46        &self.cipher
47    }
48}
49
50impl Aead {
51    pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
52        let aead_id = AeadId::new(cipher_suite)?;
53
54        let cipher = match aead_id {
55            AeadId::Aes128Gcm => Some(Cipher::aes_128_gcm()),
56            AeadId::Aes256Gcm => Some(Cipher::aes_256_gcm()),
57            AeadId::Chacha20Poly1305 => Some(Cipher::chacha20_poly1305()),
58            _ => None,
59        };
60
61        cipher.map(|cipher| Self { cipher, aead_id })
62    }
63}
64
65#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
66#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
67#[cfg_attr(
68    all(not(target_arch = "wasm32"), mls_build_async),
69    maybe_async::must_be_async
70)]
71impl AeadType for Aead {
72    type Error = AeadError;
73
74    #[allow(clippy::needless_lifetimes)]
75    async fn seal<'a>(
76        &self,
77        key: &[u8],
78        data: &[u8],
79        aad: Option<&'a [u8]>,
80        nonce: &[u8],
81    ) -> Result<Vec<u8>, AeadError> {
82        (!data.is_empty())
83            .then_some(())
84            .ok_or(AeadError::EmptyPlaintext)?;
85
86        let mut tag = [0u8; AES_TAG_LEN];
87        let aad = aad.unwrap_or_default();
88
89        let ciphertext = encrypt_aead(self.cipher, key, Some(nonce), aad, data, &mut tag)?;
90
91        // Question Is this how this should be done? Or other encodings?
92        Ok([&ciphertext, &tag as &[u8]].concat())
93    }
94
95    #[allow(clippy::needless_lifetimes)]
96    async fn open<'a>(
97        &self,
98        key: &[u8],
99        ciphertext: &[u8],
100        aad: Option<&'a [u8]>,
101        nonce: &[u8],
102    ) -> Result<Vec<u8>, AeadError> {
103        (ciphertext.len() > AES_TAG_LEN)
104            .then_some(())
105            .ok_or(AeadError::InvalidCipherLen(ciphertext.len()))?;
106
107        let (data, tag) = ciphertext.split_at(ciphertext.len() - AES_TAG_LEN);
108        let aad = aad.unwrap_or_default();
109
110        decrypt_aead(self.cipher, key, Some(nonce), aad, data, tag).map_err(Into::into)
111    }
112
113    fn key_size(&self) -> usize {
114        self.key_len()
115    }
116
117    fn nonce_size(&self) -> usize {
118        self.iv_len()
119            .expect("The ciphersuite's AEAD algorithm must support nonce-based encryption.")
120    }
121
122    fn aead_id(&self) -> u16 {
123        self.aead_id as u16
124    }
125}
126
127#[cfg(all(not(mls_build_async), test))]
128mod test {
129    use mls_rs_core::crypto::CipherSuite;
130    use mls_rs_crypto_traits::{AeadType, AES_TAG_LEN};
131
132    use super::{Aead, AeadError};
133
134    use assert_matches::assert_matches;
135
136    fn get_aeads() -> Vec<Aead> {
137        [
138            CipherSuite::CURVE25519_AES128,
139            CipherSuite::CURVE25519_CHACHA,
140            CipherSuite::CURVE448_AES256,
141        ]
142        .into_iter()
143        .map(|v| Aead::new(v).unwrap())
144        .collect()
145    }
146
147    #[test]
148    fn invalid_key() {
149        for aead in get_aeads() {
150            let nonce = vec![42u8; aead.nonce_size()];
151            let data = b"top secret";
152
153            let too_short = vec![42u8; aead.key_size() - 1];
154
155            assert_matches!(
156                aead.seal(&too_short, data, None, &nonce),
157                Err(AeadError::OpensslError(_))
158            );
159
160            let too_long = vec![42u8; aead.key_size() + 1];
161
162            assert_matches!(
163                aead.seal(&too_long, data, None, &nonce),
164                Err(AeadError::OpensslError(_))
165            );
166        }
167    }
168
169    #[test]
170    fn invalid_ciphertext() {
171        for aead in get_aeads() {
172            let key = vec![42u8; aead.key_size()];
173            let nonce = vec![42u8; aead.nonce_size()];
174
175            let too_short = [0u8; AES_TAG_LEN];
176
177            assert_matches!(
178                aead.open(&key, &too_short, None, &nonce),
179                Err(AeadError::InvalidCipherLen(_))
180            );
181        }
182    }
183
184    #[test]
185    fn aad_mismatch() {
186        for aead in get_aeads() {
187            let key = vec![42u8; aead.key_size()];
188            let nonce = vec![42u8; aead.nonce_size()];
189
190            let ciphertext = aead.seal(&key, b"message", Some(b"foo"), &nonce).unwrap();
191
192            assert_matches!(
193                aead.open(&key, &ciphertext, Some(b"bar"), &nonce),
194                Err(AeadError::OpensslError(_))
195            );
196
197            assert_matches!(
198                aead.open(&key, &ciphertext, None, &nonce),
199                Err(AeadError::OpensslError(_))
200            );
201        }
202    }
203}