mls_rs_crypto_openssl/
aead.rs1use std::{fmt::Debug, ops::Deref};
6
7use mls_rs_core::{crypto::CipherSuite, error::IntoAnyError};
8use mls_rs_crypto_traits::{AeadId, AeadType, AES_TAG_LEN};
9use openssl::symm::{decrypt_aead, encrypt_aead, Cipher};
10use thiserror::Error;
11
12#[derive(Debug, Error)]
13pub enum AeadError {
14 #[error(transparent)]
15 OpensslError(#[from] openssl::error::ErrorStack),
16 #[error("AEAD ciphertext of length {0} is too short to fit the tag")]
17 InvalidCipherLen(usize),
18 #[error("encrypted message cannot be empty")]
19 EmptyPlaintext,
20 #[error("unsupported cipher suite")]
21 UnsupportedCipherSuite,
22}
23
24impl IntoAnyError for AeadError {
25 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
26 Ok(self.into())
27 }
28}
29
30#[derive(Clone)]
31pub struct Aead {
32 cipher: Cipher,
33 aead_id: AeadId,
34}
35
36impl Debug for Aead {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 write!(f, "Aead with aead_id {:?}", self.aead_id)
39 }
40}
41
42impl Deref for Aead {
43 type Target = Cipher;
44
45 fn deref(&self) -> &Self::Target {
46 &self.cipher
47 }
48}
49
50impl Aead {
51 pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
52 let aead_id = AeadId::new(cipher_suite)?;
53
54 let cipher = match aead_id {
55 AeadId::Aes128Gcm => Some(Cipher::aes_128_gcm()),
56 AeadId::Aes256Gcm => Some(Cipher::aes_256_gcm()),
57 AeadId::Chacha20Poly1305 => Some(Cipher::chacha20_poly1305()),
58 _ => None,
59 };
60
61 cipher.map(|cipher| Self { cipher, aead_id })
62 }
63}
64
65#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
66#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
67#[cfg_attr(
68 all(not(target_arch = "wasm32"), mls_build_async),
69 maybe_async::must_be_async
70)]
71impl AeadType for Aead {
72 type Error = AeadError;
73
74 #[allow(clippy::needless_lifetimes)]
75 async fn seal<'a>(
76 &self,
77 key: &[u8],
78 data: &[u8],
79 aad: Option<&'a [u8]>,
80 nonce: &[u8],
81 ) -> Result<Vec<u8>, AeadError> {
82 (!data.is_empty())
83 .then_some(())
84 .ok_or(AeadError::EmptyPlaintext)?;
85
86 let mut tag = [0u8; AES_TAG_LEN];
87 let aad = aad.unwrap_or_default();
88
89 let ciphertext = encrypt_aead(self.cipher, key, Some(nonce), aad, data, &mut tag)?;
90
91 Ok([&ciphertext, &tag as &[u8]].concat())
93 }
94
95 #[allow(clippy::needless_lifetimes)]
96 async fn open<'a>(
97 &self,
98 key: &[u8],
99 ciphertext: &[u8],
100 aad: Option<&'a [u8]>,
101 nonce: &[u8],
102 ) -> Result<Vec<u8>, AeadError> {
103 (ciphertext.len() > AES_TAG_LEN)
104 .then_some(())
105 .ok_or(AeadError::InvalidCipherLen(ciphertext.len()))?;
106
107 let (data, tag) = ciphertext.split_at(ciphertext.len() - AES_TAG_LEN);
108 let aad = aad.unwrap_or_default();
109
110 decrypt_aead(self.cipher, key, Some(nonce), aad, data, tag).map_err(Into::into)
111 }
112
113 fn key_size(&self) -> usize {
114 self.key_len()
115 }
116
117 fn nonce_size(&self) -> usize {
118 self.iv_len()
119 .expect("The ciphersuite's AEAD algorithm must support nonce-based encryption.")
120 }
121
122 fn aead_id(&self) -> u16 {
123 self.aead_id as u16
124 }
125}
126
127#[cfg(all(not(mls_build_async), test))]
128mod test {
129 use mls_rs_core::crypto::CipherSuite;
130 use mls_rs_crypto_traits::{AeadType, AES_TAG_LEN};
131
132 use super::{Aead, AeadError};
133
134 use assert_matches::assert_matches;
135
136 fn get_aeads() -> Vec<Aead> {
137 [
138 CipherSuite::CURVE25519_AES128,
139 CipherSuite::CURVE25519_CHACHA,
140 CipherSuite::CURVE448_AES256,
141 ]
142 .into_iter()
143 .map(|v| Aead::new(v).unwrap())
144 .collect()
145 }
146
147 #[test]
148 fn invalid_key() {
149 for aead in get_aeads() {
150 let nonce = vec![42u8; aead.nonce_size()];
151 let data = b"top secret";
152
153 let too_short = vec![42u8; aead.key_size() - 1];
154
155 assert_matches!(
156 aead.seal(&too_short, data, None, &nonce),
157 Err(AeadError::OpensslError(_))
158 );
159
160 let too_long = vec![42u8; aead.key_size() + 1];
161
162 assert_matches!(
163 aead.seal(&too_long, data, None, &nonce),
164 Err(AeadError::OpensslError(_))
165 );
166 }
167 }
168
169 #[test]
170 fn invalid_ciphertext() {
171 for aead in get_aeads() {
172 let key = vec![42u8; aead.key_size()];
173 let nonce = vec![42u8; aead.nonce_size()];
174
175 let too_short = [0u8; AES_TAG_LEN];
176
177 assert_matches!(
178 aead.open(&key, &too_short, None, &nonce),
179 Err(AeadError::InvalidCipherLen(_))
180 );
181 }
182 }
183
184 #[test]
185 fn aad_mismatch() {
186 for aead in get_aeads() {
187 let key = vec![42u8; aead.key_size()];
188 let nonce = vec![42u8; aead.nonce_size()];
189
190 let ciphertext = aead.seal(&key, b"message", Some(b"foo"), &nonce).unwrap();
191
192 assert_matches!(
193 aead.open(&key, &ciphertext, Some(b"bar"), &nonce),
194 Err(AeadError::OpensslError(_))
195 );
196
197 assert_matches!(
198 aead.open(&key, &ciphertext, None, &nonce),
199 Err(AeadError::OpensslError(_))
200 );
201 }
202 }
203}