1use ct_codecs::{Base64UrlSafeNoPadding, Decoder, Encoder};
4use serde::{de::DeserializeOwned, Serialize};
5
6use crate::algorithms::jwe::content::{ContentEncryption, CEK};
7use crate::claims::*;
8use crate::common::VerificationOptions;
9use crate::error::*;
10use crate::jwe_header::JWEHeader;
11
12pub const MAX_JWE_HEADER_LENGTH: usize = 8192;
13
14#[derive(Clone, Debug, Default)]
16pub struct EncryptionOptions {
17 pub content_encryption: ContentEncryption,
19 pub content_type: Option<String>,
21 pub key_id: Option<String>,
23}
24
25#[derive(Clone, Debug, Default)]
27pub struct DecryptionOptions {
28 pub max_token_length: Option<usize>,
30 pub max_header_length: Option<usize>,
32 pub required_key_id: Option<String>,
34 pub claim_options: Option<VerificationOptions>,
36}
37
38#[derive(Debug, Clone)]
40pub struct JWETokenMetadata {
41 header: JWEHeader,
42}
43
44impl JWETokenMetadata {
45 pub fn algorithm(&self) -> &str {
47 &self.header.algorithm
48 }
49
50 pub fn encryption(&self) -> &str {
52 &self.header.encryption
53 }
54
55 pub fn key_id(&self) -> Option<&str> {
57 self.header.key_id.as_deref()
58 }
59
60 pub fn content_type(&self) -> Option<&str> {
62 self.header.content_type.as_deref()
63 }
64
65 pub fn header(&self) -> &JWEHeader {
67 &self.header
68 }
69}
70
71pub struct JWEToken;
73
74impl JWEToken {
75 pub fn build(
87 header: &JWEHeader,
88 encrypted_key: &[u8],
89 iv: &[u8],
90 ciphertext: &[u8],
91 tag: &[u8],
92 ) -> Result<String, Error> {
93 let header_json = serde_json::to_string(header)?;
94 let header_b64 = Base64UrlSafeNoPadding::encode_to_string(&header_json)?;
95 let encrypted_key_b64 = Base64UrlSafeNoPadding::encode_to_string(encrypted_key)?;
96 let iv_b64 = Base64UrlSafeNoPadding::encode_to_string(iv)?;
97 let ciphertext_b64 = Base64UrlSafeNoPadding::encode_to_string(ciphertext)?;
98 let tag_b64 = Base64UrlSafeNoPadding::encode_to_string(tag)?;
99
100 Ok(format!(
101 "{}.{}.{}.{}.{}",
102 header_b64, encrypted_key_b64, iv_b64, ciphertext_b64, tag_b64
103 ))
104 }
105
106 pub fn build_from_claims<KeyWrapFn, CustomClaims: Serialize>(
110 header: &JWEHeader,
111 claims: &JWTClaims<CustomClaims>,
112 content_encryption: ContentEncryption,
113 key_wrap_fn: KeyWrapFn,
114 ) -> Result<String, Error>
115 where
116 KeyWrapFn: FnOnce(&[u8]) -> Result<Vec<u8>, Error>,
117 {
118 let claims_json = serde_json::to_string(claims)?;
120 let plaintext = claims_json.as_bytes();
121
122 let cek = CEK::new(content_encryption.generate_cek());
124 let iv = content_encryption.generate_iv();
125
126 let encrypted_key = key_wrap_fn(cek.as_bytes())?;
128
129 let header_json = serde_json::to_string(header)?;
131 let header_b64 = Base64UrlSafeNoPadding::encode_to_string(&header_json)?;
132 let aad = header_b64.as_bytes();
133
134 let (ciphertext, tag) = content_encryption.encrypt(cek.as_bytes(), &iv, aad, plaintext)?;
136 drop(cek); let encrypted_key_b64 = Base64UrlSafeNoPadding::encode_to_string(&encrypted_key)?;
140 let iv_b64 = Base64UrlSafeNoPadding::encode_to_string(&iv)?;
141 let ciphertext_b64 = Base64UrlSafeNoPadding::encode_to_string(&ciphertext)?;
142 let tag_b64 = Base64UrlSafeNoPadding::encode_to_string(&tag)?;
143
144 Ok(format!(
145 "{}.{}.{}.{}.{}",
146 header_b64, encrypted_key_b64, iv_b64, ciphertext_b64, tag_b64
147 ))
148 }
149
150 pub fn decrypt<KeyUnwrapFn, CustomClaims: DeserializeOwned>(
155 expected_alg: &str,
156 token: &str,
157 options: Option<DecryptionOptions>,
158 key_unwrap_fn: KeyUnwrapFn,
159 ) -> Result<JWTClaims<CustomClaims>, Error>
160 where
161 KeyUnwrapFn: FnOnce(&JWEHeader, &[u8]) -> Result<Vec<u8>, Error>,
162 {
163 let options = options.unwrap_or_default();
164
165 if let Some(max_len) = options.max_token_length {
167 ensure!(token.len() <= max_len, JWTError::TokenTooLong);
168 }
169
170 let parts: Vec<&str> = token.split('.').collect();
172 ensure!(parts.len() == 5, JWTError::InvalidJWEFormat);
173
174 let header_b64 = parts[0];
175 let encrypted_key_b64 = parts[1];
176 let iv_b64 = parts[2];
177 let ciphertext_b64 = parts[3];
178 let tag_b64 = parts[4];
179
180 let max_header_len = options.max_header_length.unwrap_or(MAX_JWE_HEADER_LENGTH);
182 ensure!(header_b64.len() <= max_header_len, JWTError::HeaderTooLarge);
183
184 let header_bytes = Base64UrlSafeNoPadding::decode_to_vec(header_b64, None)?;
186 let header: JWEHeader = serde_json::from_slice(&header_bytes)?;
187
188 if let Some(ref crit) = header.critical {
191 if !crit.is_empty() {
192 bail!(JWTError::UnknownCriticalExtension);
194 }
195 }
196
197 ensure!(
199 header.algorithm == expected_alg,
200 JWTError::AlgorithmMismatch
201 );
202
203 if let Some(required_key_id) = &options.required_key_id {
205 if let Some(key_id) = &header.key_id {
206 ensure!(key_id == required_key_id, JWTError::KeyIdentifierMismatch);
207 } else {
208 bail!(JWTError::MissingJWTKeyIdentifier);
209 }
210 }
211
212 let encrypted_key = Base64UrlSafeNoPadding::decode_to_vec(encrypted_key_b64, None)?;
214 let iv = Base64UrlSafeNoPadding::decode_to_vec(iv_b64, None)?;
215 let ciphertext = Base64UrlSafeNoPadding::decode_to_vec(ciphertext_b64, None)?;
216 let tag = Base64UrlSafeNoPadding::decode_to_vec(tag_b64, None)?;
217
218 let content_encryption = ContentEncryption::from_alg_name(&header.encryption)?;
220
221 let cek = CEK::new(key_unwrap_fn(&header, &encrypted_key)?);
223
224 let aad = header_b64.as_bytes();
226
227 let plaintext = content_encryption.decrypt(cek.as_bytes(), &iv, aad, &ciphertext, &tag)?;
229 drop(cek); let claims: JWTClaims<CustomClaims> = serde_json::from_slice(&plaintext)?;
233
234 if let Some(claim_options) = &options.claim_options {
236 claims.validate(claim_options)?;
237 }
238
239 Ok(claims)
240 }
241
242 pub fn decode_metadata(token: &str) -> Result<JWETokenMetadata, Error> {
247 let mut parts = token.split('.');
248 let header_b64 = parts.next().ok_or(JWTError::InvalidJWEFormat)?;
249
250 ensure!(
251 header_b64.len() <= MAX_JWE_HEADER_LENGTH,
252 JWTError::HeaderTooLarge
253 );
254
255 let header_bytes = Base64UrlSafeNoPadding::decode_to_vec(header_b64, None)?;
256 let header: JWEHeader = serde_json::from_slice(&header_bytes)?;
257
258 Ok(JWETokenMetadata { header })
259 }
260}