jwt_simple/algorithms/jwe/
content.rs

1//! Content encryption algorithms for JWE.
2//!
3//! This module implements the content encryption algorithms specified in RFC 7518.
4//! Currently supported: A256GCM, A128GCM.
5
6#[cfg(any(feature = "pure-rust", target_arch = "wasm32", target_arch = "wasm64"))]
7use superboring as boring;
8
9use boring::symm::{Cipher, Crypter, Mode};
10use rand::RngCore;
11use zeroize::Zeroize;
12
13use crate::error::*;
14
15/// Content encryption algorithm identifier.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum ContentEncryption {
18    /// AES-256-GCM (recommended default)
19    #[default]
20    A256GCM,
21    /// AES-128-GCM
22    A128GCM,
23}
24
25impl ContentEncryption {
26    /// Get the JWE "enc" header value for this algorithm.
27    pub fn alg_name(&self) -> &'static str {
28        match self {
29            ContentEncryption::A256GCM => "A256GCM",
30            ContentEncryption::A128GCM => "A128GCM",
31        }
32    }
33
34    /// Parse a content encryption algorithm from its JWE name.
35    pub fn from_alg_name(name: &str) -> Result<Self, Error> {
36        match name {
37            "A256GCM" => Ok(ContentEncryption::A256GCM),
38            "A128GCM" => Ok(ContentEncryption::A128GCM),
39            _ => bail!(JWTError::UnsupportedContentEncryption(name.to_string())),
40        }
41    }
42
43    /// Get the required key size in bytes.
44    pub fn key_size(&self) -> usize {
45        match self {
46            ContentEncryption::A256GCM => 32,
47            ContentEncryption::A128GCM => 16,
48        }
49    }
50
51    /// Get the IV size in bytes.
52    pub fn iv_size(&self) -> usize {
53        12 // GCM uses 96-bit IV
54    }
55
56    /// Get the authentication tag size in bytes.
57    pub fn tag_size(&self) -> usize {
58        16 // GCM uses 128-bit tag
59    }
60
61    /// Generate a random Content Encryption Key (CEK) for this algorithm.
62    pub fn generate_cek(&self) -> Vec<u8> {
63        let mut cek = vec![0u8; self.key_size()];
64        rand::thread_rng().fill_bytes(&mut cek);
65        cek
66    }
67
68    /// Generate a random IV for this algorithm.
69    pub fn generate_iv(&self) -> Vec<u8> {
70        let mut iv = vec![0u8; self.iv_size()];
71        rand::thread_rng().fill_bytes(&mut iv);
72        iv
73    }
74
75    fn cipher(&self) -> Cipher {
76        match self {
77            ContentEncryption::A256GCM => Cipher::aes_256_gcm(),
78            ContentEncryption::A128GCM => Cipher::aes_128_gcm(),
79        }
80    }
81
82    /// Encrypt plaintext using the content encryption algorithm.
83    ///
84    /// Returns (ciphertext, authentication_tag).
85    pub fn encrypt(
86        &self,
87        cek: &[u8],
88        iv: &[u8],
89        aad: &[u8],
90        plaintext: &[u8],
91    ) -> Result<(Vec<u8>, Vec<u8>), Error> {
92        ensure!(cek.len() == self.key_size(), JWTError::InvalidEncryptionKey);
93        ensure!(iv.len() == self.iv_size(), JWTError::InvalidIV);
94
95        let cipher = self.cipher();
96
97        let mut crypter = Crypter::new(cipher, Mode::Encrypt, cek, Some(iv))?;
98        crypter.aad_update(aad)?;
99
100        let mut ciphertext = vec![0u8; plaintext.len() + cipher.block_size()];
101        let mut count = crypter.update(plaintext, &mut ciphertext)?;
102        count += crypter.finalize(&mut ciphertext[count..])?;
103        ciphertext.truncate(count);
104
105        let mut tag = vec![0u8; self.tag_size()];
106        crypter.get_tag(&mut tag)?;
107
108        Ok((ciphertext, tag))
109    }
110
111    /// Decrypt ciphertext using the content encryption algorithm.
112    ///
113    /// Returns the plaintext.
114    pub fn decrypt(
115        &self,
116        cek: &[u8],
117        iv: &[u8],
118        aad: &[u8],
119        ciphertext: &[u8],
120        tag: &[u8],
121    ) -> Result<Vec<u8>, Error> {
122        ensure!(cek.len() == self.key_size(), JWTError::InvalidEncryptionKey);
123        ensure!(iv.len() == self.iv_size(), JWTError::InvalidIV);
124        ensure!(tag.len() == self.tag_size(), JWTError::InvalidJWEAuthTag);
125
126        let cipher = self.cipher();
127
128        let mut crypter = Crypter::new(cipher, Mode::Decrypt, cek, Some(iv))?;
129        crypter.aad_update(aad)?;
130        crypter.set_tag(tag)?;
131
132        let mut plaintext = vec![0u8; ciphertext.len() + cipher.block_size()];
133        let mut count = crypter.update(ciphertext, &mut plaintext)?;
134        count += crypter
135            .finalize(&mut plaintext[count..])
136            .map_err(|_| JWTError::DecryptionFailed)?;
137        plaintext.truncate(count);
138
139        Ok(plaintext)
140    }
141}
142
143/// A Content Encryption Key (CEK) that is zeroized on drop.
144#[derive(Clone)]
145pub struct CEK {
146    key: Vec<u8>,
147}
148
149impl CEK {
150    /// Create a new CEK from bytes.
151    pub fn new(key: Vec<u8>) -> Self {
152        CEK { key }
153    }
154
155    /// Get the key bytes.
156    pub fn as_bytes(&self) -> &[u8] {
157        &self.key
158    }
159}
160
161impl Drop for CEK {
162    fn drop(&mut self) {
163        self.key.zeroize();
164    }
165}
166
167impl AsRef<[u8]> for CEK {
168    fn as_ref(&self) -> &[u8] {
169        &self.key
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_a256gcm_roundtrip() {
179        let enc = ContentEncryption::A256GCM;
180        let cek = enc.generate_cek();
181        let iv = enc.generate_iv();
182        let aad = b"additional authenticated data";
183        let plaintext = b"Hello, World!";
184
185        let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
186        let decrypted = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag).unwrap();
187
188        assert_eq!(decrypted, plaintext);
189    }
190
191    #[test]
192    fn test_a128gcm_roundtrip() {
193        let enc = ContentEncryption::A128GCM;
194        let cek = enc.generate_cek();
195        let iv = enc.generate_iv();
196        let aad = b"additional authenticated data";
197        let plaintext = b"Hello, World!";
198
199        let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
200        let decrypted = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag).unwrap();
201
202        assert_eq!(decrypted, plaintext);
203    }
204
205    #[test]
206    fn test_tampered_ciphertext_fails() {
207        let enc = ContentEncryption::A256GCM;
208        let cek = enc.generate_cek();
209        let iv = enc.generate_iv();
210        let aad = b"additional authenticated data";
211        let plaintext = b"Hello, World!";
212
213        let (mut ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
214
215        // Tamper with ciphertext
216        ciphertext[0] ^= 0xff;
217
218        let result = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag);
219        assert!(result.is_err());
220    }
221
222    #[test]
223    fn test_tampered_aad_fails() {
224        let enc = ContentEncryption::A256GCM;
225        let cek = enc.generate_cek();
226        let iv = enc.generate_iv();
227        let aad = b"additional authenticated data";
228        let plaintext = b"Hello, World!";
229
230        let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
231
232        // Use different AAD for decryption
233        let wrong_aad = b"wrong aad";
234        let result = enc.decrypt(&cek, &iv, wrong_aad, &ciphertext, &tag);
235        assert!(result.is_err());
236    }
237
238    #[test]
239    fn test_wrong_key_fails() {
240        let enc = ContentEncryption::A256GCM;
241        let cek = enc.generate_cek();
242        let wrong_cek = enc.generate_cek();
243        let iv = enc.generate_iv();
244        let aad = b"additional authenticated data";
245        let plaintext = b"Hello, World!";
246
247        let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
248        let result = enc.decrypt(&wrong_cek, &iv, aad, &ciphertext, &tag);
249        assert!(result.is_err());
250    }
251}