1use crate::{
5 jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
6 Error, ErrorKind, Result,
7};
8use azure_core::{base64, Bytes};
9use azure_security_keyvault_keys::models::KeyOperationResult;
10use openssl::{
11 rand,
12 symm::{self, Cipher},
13};
14use std::marker::PhantomData;
15
16#[derive(Debug)]
18pub struct Jwe {
19 header: Header,
20 cek: Bytes,
21 iv: Bytes,
22 ciphertext: Bytes,
23 tag: Bytes,
24}
25
26impl Jwe {
27 pub fn encryptor() -> JweEncryptor<Unset, Unset> {
28 JweEncryptor::default()
29 }
30
31 pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
32 where
33 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
34 {
35 if self.header.typ != Type::JWE {
36 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
37 format!("expected JWE, got {}", self.header.typ)
38 }));
39 }
40
41 let key_id = self
43 .header
44 .kid
45 .as_deref()
46 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
47 let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
48
49 let enc = self
50 .header
51 .enc
52 .as_ref()
53 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
54 let cipher: Cipher = enc.try_into()?;
55 let aad = self.header.encode()?;
56
57 let plaintext: Bytes = symm::decrypt_aead(
58 cipher,
59 &result.cek,
60 Some(&self.iv),
61 aad.as_bytes(),
62 &self.ciphertext,
63 &self.tag,
64 )?
65 .into();
66
67 Ok(plaintext)
68 }
69
70 pub fn kid(&self) -> Option<&str> {
71 self.header.kid.as_deref()
72 }
73}
74
75impl Encode for Jwe {
76 fn decode(value: &str) -> Result<Self> {
77 let parts: Vec<_> = value.split(".").collect();
78 if parts.len() != 5 {
79 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
80 format!("invalid compact JWE: expected 5 parts, got {}", parts.len())
81 }));
82 }
83
84 Ok(Self {
85 header: Header::decode(parts[0])?,
86 cek: base64::decode_url_safe(parts[1])?.into(),
87 iv: base64::decode_url_safe(parts[2])?.into(),
88 ciphertext: base64::decode_url_safe(parts[3])?.into(),
89 tag: base64::decode_url_safe(parts[4])?.into(),
90 })
91 }
92
93 fn encode(&self) -> Result<String> {
94 Ok([
95 self.header.encode()?,
96 base64::encode_url_safe(&self.cek),
97 base64::encode_url_safe(&self.iv),
98 base64::encode_url_safe(&self.ciphertext),
99 base64::encode_url_safe(&self.tag),
100 ]
101 .join("."))
102 }
103}
104
105#[derive(Debug)]
106pub struct JweEncryptor<C, K> {
107 alg: Option<Algorithm>,
108 enc: Option<EncryptionAlgorithm>,
109 kid: Option<String>,
110 cek: Option<Bytes>,
111 iv: Option<Bytes>,
112 plaintext: Option<Bytes>,
113 phantom: PhantomData<(C, K)>,
114}
115
116impl<C, K> JweEncryptor<C, K> {
117 pub fn alg(self, alg: Algorithm) -> Self {
118 Self {
119 alg: Some(alg),
120 ..self
121 }
122 }
123
124 pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
125 Self {
126 enc: Some(enc),
127 ..self
128 }
129 }
130
131 pub fn cek(self, cek: &[u8]) -> Self {
132 Self {
133 cek: Some(Bytes::copy_from_slice(cek)),
134 ..self
135 }
136 }
137
138 pub fn iv(self, iv: &[u8]) -> Self {
139 Self {
140 iv: Some(Bytes::copy_from_slice(iv)),
141 ..self
142 }
143 }
144}
145
146impl<K> JweEncryptor<Unset, K> {
147 pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
148 JweEncryptor::<Set, K> {
149 plaintext: Some(Bytes::copy_from_slice(plaintext)),
150 alg: self.alg,
151 enc: self.enc,
152 kid: self.kid,
153 cek: self.cek,
154 iv: self.iv,
155 phantom: PhantomData,
156 }
157 }
158
159 pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
160 JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
161 }
162}
163
164impl<C> JweEncryptor<C, Unset> {
165 pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
166 JweEncryptor::<C, Set> {
167 kid: Some(kid.into()),
168 alg: self.alg,
169 enc: self.enc,
170 cek: self.cek,
171 iv: self.iv,
172 plaintext: self.plaintext,
173 phantom: PhantomData,
174 }
175 }
176}
177
178impl JweEncryptor<Set, Set> {
179 pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
180 where
181 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
182 {
183 let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
185 let cipher: Cipher = enc.try_into()?;
186
187 let cek = match self.cek {
189 Some(v) if v.len() == cipher.key_len() => v,
190 Some(v) => {
191 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
192 format!(
193 "require key size of {} bytes, got {}",
194 cipher.key_len(),
195 v.len()
196 )
197 }));
198 }
199 None => {
200 let mut buf = [0; 32];
202 rand::rand_bytes(&mut buf)?;
203 Bytes::copy_from_slice(&buf[0..cipher.key_len()])
204 }
205 };
206
207 let kid = self
208 .kid
209 .as_deref()
210 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
211 let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
212
213 let result = wrap_key(kid, &alg, &cek).await?;
215
216 let header = Header {
217 alg,
218 enc: Some(enc.clone()),
219 kid: Some(result.kid),
220 typ: super::Type::JWE,
221 };
222 let aad = header.encode()?;
223
224 let iv_len = cipher.iv_len().ok_or_else(|| {
226 Error::with_message(
227 ErrorKind::InvalidData,
228 format!("expected iv length for cipher {}", &enc),
229 )
230 })?;
231 let iv = match self.iv {
232 Some(v) if v.len() == iv_len => v,
233 Some(v) => {
234 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
235 format!("require iv size of {} bytes, got {}", iv_len, v.len())
236 }));
237 }
238 None => {
239 let mut buf = [0; 12];
241 rand::rand_bytes(&mut buf)?;
242 Bytes::copy_from_slice(&buf[0..iv_len])
243 }
244 };
245
246 let plaintext = self.plaintext.expect("expected plaintext");
247 let mut tag = [0; 16];
248 let ciphertext: Bytes = symm::encrypt_aead(
249 cipher,
250 &cek,
251 Some(&iv),
252 aad.as_bytes(),
253 &plaintext,
254 &mut tag,
255 )?
256 .into();
257
258 Ok(Jwe {
259 header,
260 cek: result.cek,
261 iv,
262 ciphertext,
263 tag: Bytes::copy_from_slice(&tag),
264 })
265 }
266}
267
268impl<C, K> Default for JweEncryptor<C, K> {
269 fn default() -> Self {
270 Self {
271 alg: None,
272 enc: None,
273 kid: None,
274 cek: None,
275 iv: None,
276 plaintext: None,
277 phantom: PhantomData,
278 }
279 }
280}
281
282impl TryFrom<EncryptionAlgorithm> for Cipher {
283 type Error = Error;
284 fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
285 (&value).try_into()
286 }
287}
288
289impl TryFrom<&EncryptionAlgorithm> for Cipher {
290 type Error = Error;
291 fn try_from(value: &EncryptionAlgorithm) -> Result<Cipher> {
292 match value {
293 EncryptionAlgorithm::A128GCM => Ok(Cipher::aes_128_gcm()),
294 EncryptionAlgorithm::A192GCM => Ok(Cipher::aes_192_gcm()),
295 EncryptionAlgorithm::A256GCM => Ok(Cipher::aes_256_gcm()),
296 EncryptionAlgorithm::Other(value) => {
297 Err(Error::with_message_fn(ErrorKind::InvalidData, || {
298 format!("unsupported encryption algorithm {value}")
299 }))
300 }
301 }
302 }
303}
304
305impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
306 type Error = Error;
307 fn try_from(value: &Algorithm) -> Result<Self> {
308 match value {
309 Algorithm::RSA1_5 => Ok(Self::RSA1_5),
310 Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
311 Algorithm::RSA_OAEP_256 => Ok(Self::RsaOAEP256),
312 Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
313 format!("unsupported algorithm {s}")
314 })),
315 }
316 }
317}
318
319#[derive(Debug)]
320pub struct WrapKeyResult {
321 pub kid: String,
322 pub cek: Bytes,
323}
324
325impl TryFrom<KeyOperationResult> for WrapKeyResult {
326 type Error = Error;
327 fn try_from(value: KeyOperationResult) -> Result<Self> {
328 Ok(Self {
329 kid: value
330 .kid
331 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
332 cek: value
333 .result
334 .map(Into::into)
335 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
336 })
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use azure_core::Bytes;
344
345 #[test]
346 fn decode_invalid() {
347 assert!(
348 matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("invalid compact JWE: expected 5 parts, got 4"))
349 );
350 assert!(
351 matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("invalid compact JWE: expected 5 parts, got 6"))
352 );
353 }
354
355 #[test]
356 fn encode_decode_roundtrip() {
357 let jwe = Jwe {
358 header: Header {
359 alg: crate::jose::Algorithm::RSA_OAEP_256,
360 enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
361 kid: Some("test-key-id".to_string()),
362 typ: crate::jose::Type::JWE,
363 },
364 cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
365 iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
366 ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
367 tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
368 };
369
370 const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
372
373 let encoded = jwe.encode().expect("encode should succeed");
374 assert_eq!(encoded, EXPECTED);
375
376 let decoded = Jwe::decode(&encoded).expect("decode should succeed");
377 assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
378 assert_eq!(
379 decoded.header.enc,
380 Some(crate::jose::EncryptionAlgorithm::A128GCM)
381 );
382 assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
383 assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
384 assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
385 assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
386 assert_eq!(
387 decoded.ciphertext,
388 Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
389 );
390 assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
391 }
392
393 #[test]
394 fn encryption_algorithm_cipher() {
395 let cipher: Cipher = EncryptionAlgorithm::A128GCM
396 .try_into()
397 .expect("try_into should succeed");
398 assert_eq!(cipher.iv_len(), Some(12));
399 assert_eq!(cipher.key_len(), 16);
400
401 let cipher: Cipher = EncryptionAlgorithm::A192GCM
402 .try_into()
403 .expect("try_into should succeed");
404 assert_eq!(cipher.iv_len(), Some(12));
405 assert_eq!(cipher.key_len(), 24);
406
407 let cipher: Cipher = EncryptionAlgorithm::A256GCM
408 .try_into()
409 .expect("try_into should succeed");
410 assert_eq!(cipher.iv_len(), Some(12));
411 assert_eq!(cipher.key_len(), 32);
412 }
413
414 #[tokio::test]
415 async fn encrypt_decrypt_roundtrip() {
416 let kid = "key-name";
417 let alg = Algorithm::RSA_OAEP;
418 let enc = EncryptionAlgorithm::A128GCM;
419 let plaintext = b"Hello, world!";
420
421 let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
423 assert_eq!(key_id, kid);
424 assert_eq!(wrap_alg, &alg);
425 Ok(crate::jose::jwe::WrapKeyResult {
426 kid: "key-name/key-version".into(),
427 cek: Bytes::copy_from_slice(cek),
428 })
429 };
430
431 let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
433 assert_eq!(key_id, "key-name/key-version");
434 assert_eq!(wrap_alg, &alg);
435 Ok(crate::jose::jwe::WrapKeyResult {
436 kid: "key-name/key-version".into(),
437 cek: Bytes::copy_from_slice(cek),
438 })
439 };
440
441 let jwe = Jwe::encryptor()
442 .alg(alg.clone())
443 .enc(enc)
444 .kid(kid)
445 .plaintext(plaintext)
446 .encrypt(wrap_key)
447 .await
448 .expect("encryption should succeed");
449
450 let decrypted = jwe
451 .decrypt(unwrap_key)
452 .await
453 .expect("decryption should succeed");
454 assert_eq!(decrypted, plaintext.as_ref());
455 }
456}