jwt_simple/
jwe_token.rs

1//! JWE token building and parsing.
2
3use ct_codecs::{Base64UrlSafeNoPadding, Decoder, Encoder};
4use serde::{de::DeserializeOwned, Serialize};
5
6use crate::algorithms::jwe::content::{ContentEncryption, CEK};
7use crate::claims::*;
8use crate::common::VerificationOptions;
9use crate::error::*;
10use crate::jwe_header::JWEHeader;
11
12pub const MAX_JWE_HEADER_LENGTH: usize = 8192;
13
14/// Options for JWE encryption.
15#[derive(Clone, Debug, Default)]
16pub struct EncryptionOptions {
17    /// Content encryption algorithm (default: A256GCM)
18    pub content_encryption: ContentEncryption,
19    /// Content type header
20    pub content_type: Option<String>,
21    /// Key ID
22    pub key_id: Option<String>,
23}
24
25/// Options for JWE decryption.
26#[derive(Clone, Debug, Default)]
27pub struct DecryptionOptions {
28    /// Maximum token length to accept
29    pub max_token_length: Option<usize>,
30    /// Maximum header length to accept
31    pub max_header_length: Option<usize>,
32    /// Required key ID
33    pub required_key_id: Option<String>,
34    /// Options for validating claims after decryption
35    pub claim_options: Option<VerificationOptions>,
36}
37
38/// JWE token metadata extracted from the header (before decryption).
39#[derive(Debug, Clone)]
40pub struct JWETokenMetadata {
41    header: JWEHeader,
42}
43
44impl JWETokenMetadata {
45    /// The key management algorithm.
46    pub fn algorithm(&self) -> &str {
47        &self.header.algorithm
48    }
49
50    /// The content encryption algorithm.
51    pub fn encryption(&self) -> &str {
52        &self.header.encryption
53    }
54
55    /// The key ID (if present).
56    pub fn key_id(&self) -> Option<&str> {
57        self.header.key_id.as_deref()
58    }
59
60    /// The content type (if present).
61    pub fn content_type(&self) -> Option<&str> {
62        self.header.content_type.as_deref()
63    }
64
65    /// Get the full header.
66    pub fn header(&self) -> &JWEHeader {
67        &self.header
68    }
69}
70
71/// Utilities for working with JWE tokens.
72pub struct JWEToken;
73
74impl JWEToken {
75    /// Build a JWE token.
76    ///
77    /// This function is called by key management implementations to create
78    /// the final JWE compact serialization.
79    ///
80    /// # Arguments
81    /// * `header` - The JWE header
82    /// * `encrypted_key` - The encrypted CEK (or empty for direct key agreement)
83    /// * `iv` - The initialization vector
84    /// * `ciphertext` - The encrypted content
85    /// * `tag` - The authentication tag
86    pub fn build(
87        header: &JWEHeader,
88        encrypted_key: &[u8],
89        iv: &[u8],
90        ciphertext: &[u8],
91        tag: &[u8],
92    ) -> Result<String, Error> {
93        let header_json = serde_json::to_string(header)?;
94        let header_b64 = Base64UrlSafeNoPadding::encode_to_string(&header_json)?;
95        let encrypted_key_b64 = Base64UrlSafeNoPadding::encode_to_string(encrypted_key)?;
96        let iv_b64 = Base64UrlSafeNoPadding::encode_to_string(iv)?;
97        let ciphertext_b64 = Base64UrlSafeNoPadding::encode_to_string(ciphertext)?;
98        let tag_b64 = Base64UrlSafeNoPadding::encode_to_string(tag)?;
99
100        Ok(format!(
101            "{}.{}.{}.{}.{}",
102            header_b64, encrypted_key_b64, iv_b64, ciphertext_b64, tag_b64
103        ))
104    }
105
106    /// Build a JWE token from claims.
107    ///
108    /// This is a helper that serializes claims to JSON before encryption.
109    pub fn build_from_claims<KeyWrapFn, CustomClaims: Serialize>(
110        header: &JWEHeader,
111        claims: &JWTClaims<CustomClaims>,
112        content_encryption: ContentEncryption,
113        key_wrap_fn: KeyWrapFn,
114    ) -> Result<String, Error>
115    where
116        KeyWrapFn: FnOnce(&[u8]) -> Result<Vec<u8>, Error>,
117    {
118        // Serialize claims to JSON
119        let claims_json = serde_json::to_string(claims)?;
120        let plaintext = claims_json.as_bytes();
121
122        // Generate CEK and IV
123        let cek = CEK::new(content_encryption.generate_cek());
124        let iv = content_encryption.generate_iv();
125
126        // Wrap the CEK
127        let encrypted_key = key_wrap_fn(cek.as_bytes())?;
128
129        // Build the AAD (ASCII bytes of the base64url-encoded header)
130        let header_json = serde_json::to_string(header)?;
131        let header_b64 = Base64UrlSafeNoPadding::encode_to_string(&header_json)?;
132        let aad = header_b64.as_bytes();
133
134        // Encrypt the plaintext
135        let (ciphertext, tag) = content_encryption.encrypt(cek.as_bytes(), &iv, aad, plaintext)?;
136        drop(cek); // Zeroize CEK immediately after use
137
138        // Build the final token
139        let encrypted_key_b64 = Base64UrlSafeNoPadding::encode_to_string(&encrypted_key)?;
140        let iv_b64 = Base64UrlSafeNoPadding::encode_to_string(&iv)?;
141        let ciphertext_b64 = Base64UrlSafeNoPadding::encode_to_string(&ciphertext)?;
142        let tag_b64 = Base64UrlSafeNoPadding::encode_to_string(&tag)?;
143
144        Ok(format!(
145            "{}.{}.{}.{}.{}",
146            header_b64, encrypted_key_b64, iv_b64, ciphertext_b64, tag_b64
147        ))
148    }
149
150    /// Parse and decrypt a JWE token.
151    ///
152    /// This function is called by key management implementations to decrypt
153    /// a JWE token and return the claims.
154    pub fn decrypt<KeyUnwrapFn, CustomClaims: DeserializeOwned>(
155        expected_alg: &str,
156        token: &str,
157        options: Option<DecryptionOptions>,
158        key_unwrap_fn: KeyUnwrapFn,
159    ) -> Result<JWTClaims<CustomClaims>, Error>
160    where
161        KeyUnwrapFn: FnOnce(&JWEHeader, &[u8]) -> Result<Vec<u8>, Error>,
162    {
163        let options = options.unwrap_or_default();
164
165        // Check token length
166        if let Some(max_len) = options.max_token_length {
167            ensure!(token.len() <= max_len, JWTError::TokenTooLong);
168        }
169
170        // Split into 5 parts
171        let parts: Vec<&str> = token.split('.').collect();
172        ensure!(parts.len() == 5, JWTError::InvalidJWEFormat);
173
174        let header_b64 = parts[0];
175        let encrypted_key_b64 = parts[1];
176        let iv_b64 = parts[2];
177        let ciphertext_b64 = parts[3];
178        let tag_b64 = parts[4];
179
180        // Check header length
181        let max_header_len = options.max_header_length.unwrap_or(MAX_JWE_HEADER_LENGTH);
182        ensure!(header_b64.len() <= max_header_len, JWTError::HeaderTooLarge);
183
184        // Decode header
185        let header_bytes = Base64UrlSafeNoPadding::decode_to_vec(header_b64, None)?;
186        let header: JWEHeader = serde_json::from_slice(&header_bytes)?;
187
188        // Validate critical header - RFC 7516 requires rejecting tokens with
189        // unrecognized critical extensions
190        if let Some(ref crit) = header.critical {
191            if !crit.is_empty() {
192                // We don't support any critical extensions
193                bail!(JWTError::UnknownCriticalExtension);
194            }
195        }
196
197        // Validate algorithm
198        ensure!(
199            header.algorithm == expected_alg,
200            JWTError::AlgorithmMismatch
201        );
202
203        // Validate key ID if required
204        if let Some(required_key_id) = &options.required_key_id {
205            if let Some(key_id) = &header.key_id {
206                ensure!(key_id == required_key_id, JWTError::KeyIdentifierMismatch);
207            } else {
208                bail!(JWTError::MissingJWTKeyIdentifier);
209            }
210        }
211
212        // Decode the encrypted key, IV, ciphertext, and tag
213        let encrypted_key = Base64UrlSafeNoPadding::decode_to_vec(encrypted_key_b64, None)?;
214        let iv = Base64UrlSafeNoPadding::decode_to_vec(iv_b64, None)?;
215        let ciphertext = Base64UrlSafeNoPadding::decode_to_vec(ciphertext_b64, None)?;
216        let tag = Base64UrlSafeNoPadding::decode_to_vec(tag_b64, None)?;
217
218        // Get the content encryption algorithm
219        let content_encryption = ContentEncryption::from_alg_name(&header.encryption)?;
220
221        // Unwrap the CEK
222        let cek = CEK::new(key_unwrap_fn(&header, &encrypted_key)?);
223
224        // The AAD is the ASCII bytes of the base64url-encoded header
225        let aad = header_b64.as_bytes();
226
227        // Decrypt the ciphertext
228        let plaintext = content_encryption.decrypt(cek.as_bytes(), &iv, aad, &ciphertext, &tag)?;
229        drop(cek); // Zeroize CEK immediately after use
230
231        // Parse the claims
232        let claims: JWTClaims<CustomClaims> = serde_json::from_slice(&plaintext)?;
233
234        // Validate claims if options provided
235        if let Some(claim_options) = &options.claim_options {
236            claims.validate(claim_options)?;
237        }
238
239        Ok(claims)
240    }
241
242    /// Decode JWE token metadata without decrypting.
243    ///
244    /// This allows inspection of the header to determine which key to use
245    /// for decryption.
246    pub fn decode_metadata(token: &str) -> Result<JWETokenMetadata, Error> {
247        let mut parts = token.split('.');
248        let header_b64 = parts.next().ok_or(JWTError::InvalidJWEFormat)?;
249
250        ensure!(
251            header_b64.len() <= MAX_JWE_HEADER_LENGTH,
252            JWTError::HeaderTooLarge
253        );
254
255        let header_bytes = Base64UrlSafeNoPadding::decode_to_vec(header_b64, None)?;
256        let header: JWEHeader = serde_json::from_slice(&header_bytes)?;
257
258        Ok(JWETokenMetadata { header })
259    }
260}