jwt_simple/algorithms/jwe/
content.rs1#[cfg(any(feature = "pure-rust", target_arch = "wasm32", target_arch = "wasm64"))]
7use superboring as boring;
8
9use boring::symm::{Cipher, Crypter, Mode};
10use rand::RngCore;
11use zeroize::Zeroize;
12
13use crate::error::*;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum ContentEncryption {
18 #[default]
20 A256GCM,
21 A128GCM,
23}
24
25impl ContentEncryption {
26 pub fn alg_name(&self) -> &'static str {
28 match self {
29 ContentEncryption::A256GCM => "A256GCM",
30 ContentEncryption::A128GCM => "A128GCM",
31 }
32 }
33
34 pub fn from_alg_name(name: &str) -> Result<Self, Error> {
36 match name {
37 "A256GCM" => Ok(ContentEncryption::A256GCM),
38 "A128GCM" => Ok(ContentEncryption::A128GCM),
39 _ => bail!(JWTError::UnsupportedContentEncryption(name.to_string())),
40 }
41 }
42
43 pub fn key_size(&self) -> usize {
45 match self {
46 ContentEncryption::A256GCM => 32,
47 ContentEncryption::A128GCM => 16,
48 }
49 }
50
51 pub fn iv_size(&self) -> usize {
53 12 }
55
56 pub fn tag_size(&self) -> usize {
58 16 }
60
61 pub fn generate_cek(&self) -> Vec<u8> {
63 let mut cek = vec![0u8; self.key_size()];
64 rand::thread_rng().fill_bytes(&mut cek);
65 cek
66 }
67
68 pub fn generate_iv(&self) -> Vec<u8> {
70 let mut iv = vec![0u8; self.iv_size()];
71 rand::thread_rng().fill_bytes(&mut iv);
72 iv
73 }
74
75 fn cipher(&self) -> Cipher {
76 match self {
77 ContentEncryption::A256GCM => Cipher::aes_256_gcm(),
78 ContentEncryption::A128GCM => Cipher::aes_128_gcm(),
79 }
80 }
81
82 pub fn encrypt(
86 &self,
87 cek: &[u8],
88 iv: &[u8],
89 aad: &[u8],
90 plaintext: &[u8],
91 ) -> Result<(Vec<u8>, Vec<u8>), Error> {
92 ensure!(cek.len() == self.key_size(), JWTError::InvalidEncryptionKey);
93 ensure!(iv.len() == self.iv_size(), JWTError::InvalidIV);
94
95 let cipher = self.cipher();
96
97 let mut crypter = Crypter::new(cipher, Mode::Encrypt, cek, Some(iv))?;
98 crypter.aad_update(aad)?;
99
100 let mut ciphertext = vec![0u8; plaintext.len() + cipher.block_size()];
101 let mut count = crypter.update(plaintext, &mut ciphertext)?;
102 count += crypter.finalize(&mut ciphertext[count..])?;
103 ciphertext.truncate(count);
104
105 let mut tag = vec![0u8; self.tag_size()];
106 crypter.get_tag(&mut tag)?;
107
108 Ok((ciphertext, tag))
109 }
110
111 pub fn decrypt(
115 &self,
116 cek: &[u8],
117 iv: &[u8],
118 aad: &[u8],
119 ciphertext: &[u8],
120 tag: &[u8],
121 ) -> Result<Vec<u8>, Error> {
122 ensure!(cek.len() == self.key_size(), JWTError::InvalidEncryptionKey);
123 ensure!(iv.len() == self.iv_size(), JWTError::InvalidIV);
124 ensure!(tag.len() == self.tag_size(), JWTError::InvalidJWEAuthTag);
125
126 let cipher = self.cipher();
127
128 let mut crypter = Crypter::new(cipher, Mode::Decrypt, cek, Some(iv))?;
129 crypter.aad_update(aad)?;
130 crypter.set_tag(tag)?;
131
132 let mut plaintext = vec![0u8; ciphertext.len() + cipher.block_size()];
133 let mut count = crypter.update(ciphertext, &mut plaintext)?;
134 count += crypter
135 .finalize(&mut plaintext[count..])
136 .map_err(|_| JWTError::DecryptionFailed)?;
137 plaintext.truncate(count);
138
139 Ok(plaintext)
140 }
141}
142
143#[derive(Clone)]
145pub struct CEK {
146 key: Vec<u8>,
147}
148
149impl CEK {
150 pub fn new(key: Vec<u8>) -> Self {
152 CEK { key }
153 }
154
155 pub fn as_bytes(&self) -> &[u8] {
157 &self.key
158 }
159}
160
161impl Drop for CEK {
162 fn drop(&mut self) {
163 self.key.zeroize();
164 }
165}
166
167impl AsRef<[u8]> for CEK {
168 fn as_ref(&self) -> &[u8] {
169 &self.key
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_a256gcm_roundtrip() {
179 let enc = ContentEncryption::A256GCM;
180 let cek = enc.generate_cek();
181 let iv = enc.generate_iv();
182 let aad = b"additional authenticated data";
183 let plaintext = b"Hello, World!";
184
185 let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
186 let decrypted = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag).unwrap();
187
188 assert_eq!(decrypted, plaintext);
189 }
190
191 #[test]
192 fn test_a128gcm_roundtrip() {
193 let enc = ContentEncryption::A128GCM;
194 let cek = enc.generate_cek();
195 let iv = enc.generate_iv();
196 let aad = b"additional authenticated data";
197 let plaintext = b"Hello, World!";
198
199 let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
200 let decrypted = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag).unwrap();
201
202 assert_eq!(decrypted, plaintext);
203 }
204
205 #[test]
206 fn test_tampered_ciphertext_fails() {
207 let enc = ContentEncryption::A256GCM;
208 let cek = enc.generate_cek();
209 let iv = enc.generate_iv();
210 let aad = b"additional authenticated data";
211 let plaintext = b"Hello, World!";
212
213 let (mut ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
214
215 ciphertext[0] ^= 0xff;
217
218 let result = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag);
219 assert!(result.is_err());
220 }
221
222 #[test]
223 fn test_tampered_aad_fails() {
224 let enc = ContentEncryption::A256GCM;
225 let cek = enc.generate_cek();
226 let iv = enc.generate_iv();
227 let aad = b"additional authenticated data";
228 let plaintext = b"Hello, World!";
229
230 let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
231
232 let wrong_aad = b"wrong aad";
234 let result = enc.decrypt(&cek, &iv, wrong_aad, &ciphertext, &tag);
235 assert!(result.is_err());
236 }
237
238 #[test]
239 fn test_wrong_key_fails() {
240 let enc = ContentEncryption::A256GCM;
241 let cek = enc.generate_cek();
242 let wrong_cek = enc.generate_cek();
243 let iv = enc.generate_iv();
244 let aad = b"additional authenticated data";
245 let plaintext = b"Hello, World!";
246
247 let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
248 let result = enc.decrypt(&wrong_cek, &iv, aad, &ciphertext, &tag);
249 assert!(result.is_err());
250 }
251}