1use crate::{
5 jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
6 Error, ErrorKind, Result, ResultExt as _,
7};
8use azure_core::{base64, Bytes};
9use azure_security_keyvault_keys::models::KeyOperationResult;
10use openssl::{
11 rand,
12 symm::{self, Cipher},
13};
14use std::{marker::PhantomData, str::FromStr};
15
16#[derive(Debug)]
18pub struct Jwe {
19 header: Header,
20 cek: Bytes,
21 iv: Bytes,
22 ciphertext: Bytes,
23 tag: Bytes,
24}
25
26impl Jwe {
27 pub fn encryptor() -> JweEncryptor<Unset, Unset> {
28 JweEncryptor::default()
29 }
30
31 pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
32 where
33 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
34 {
35 if self.header.typ != Type::JWE {
36 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
37 format!("expected JWE, got {}", self.header.typ)
38 }));
39 }
40
41 let key_id = self
43 .header
44 .kid
45 .as_deref()
46 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
47 let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
48
49 let enc = self
50 .header
51 .enc
52 .as_ref()
53 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
54 let cipher: Cipher = enc.try_into()?;
55 let aad = self.header.encode()?;
56
57 let plaintext: Bytes = symm::decrypt_aead(
58 cipher,
59 &result.cek,
60 Some(&self.iv),
61 aad.as_bytes(),
62 &self.ciphertext,
63 &self.tag,
64 )?
65 .into();
66
67 Ok(plaintext)
68 }
69
70 pub fn kid(&self) -> Option<&str> {
71 self.header.kid.as_deref()
72 }
73}
74
75impl Encode for Jwe {
76 fn decode(value: &str) -> Result<Self> {
77 value.parse()
78 }
79
80 fn encode(&self) -> Result<String> {
81 Ok([
82 self.header.encode()?,
83 base64::encode_url_safe(&self.cek),
84 base64::encode_url_safe(&self.iv),
85 base64::encode_url_safe(&self.ciphertext),
86 base64::encode_url_safe(&self.tag),
87 ]
88 .join("."))
89 }
90}
91
92impl FromStr for Jwe {
93 type Err = Error;
94 fn from_str(s: &str) -> Result<Self> {
95 const PARTS_ERROR: &str = "JWE must have exactly 5 parts separated by periods";
96
97 fn is_base64url_char(c: char) -> bool {
98 c.is_ascii_alphanumeric() || c == '-' || c == '_'
99 }
100
101 let mut parts = [0usize; 6];
102 let mut current_part_start = 0;
103 for (i, c) in s.char_indices() {
104 if c == '.' {
105 if current_part_start >= 5 {
106 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
107 PARTS_ERROR
108 }));
109 }
110
111 parts[current_part_start + 1] = i + 1;
112 current_part_start += 1;
113 } else if !is_base64url_char(c) {
114 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
115 "invalid character in JWE compact serialization"
116 }));
117 }
118 }
119
120 if current_part_start != 4 {
121 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
122 PARTS_ERROR
123 }));
124 }
125
126 parts[5] = s.len() + 1;
127 let header = &s[parts[0]..parts[1] - 1];
128 let cek = &s[parts[1]..parts[2] - 1];
129 let iv = &s[parts[2]..parts[3] - 1];
130 let ciphertext = &s[parts[3]..parts[4] - 1];
131 let tag = &s[parts[4]..parts[5] - 1];
132
133 let header =
134 Header::decode(header).with_context_fn(ErrorKind::InvalidData, || "invalid header")?;
135 let cek = base64::decode_url_safe(cek)
136 .with_context_fn(ErrorKind::InvalidData, || "invalid cek")?
137 .into();
138 let iv = base64::decode_url_safe(iv)
139 .with_context_fn(ErrorKind::InvalidData, || "invalid iv")?
140 .into();
141 let ciphertext = base64::decode_url_safe(ciphertext)
142 .with_context_fn(ErrorKind::InvalidData, || "invalid ciphertext")?
143 .into();
144 let tag = base64::decode_url_safe(tag)
145 .with_context_fn(ErrorKind::InvalidData, || "invalid tag")?
146 .into();
147
148 Ok(Jwe {
149 header,
150 cek,
151 iv,
152 ciphertext,
153 tag,
154 })
155 }
156}
157
158#[derive(Debug)]
159pub struct JweEncryptor<C, K> {
160 alg: Option<Algorithm>,
161 enc: Option<EncryptionAlgorithm>,
162 kid: Option<String>,
163 cek: Option<Bytes>,
164 iv: Option<Bytes>,
165 plaintext: Option<Bytes>,
166 phantom: PhantomData<(C, K)>,
167}
168
169impl<C, K> JweEncryptor<C, K> {
170 pub fn alg(self, alg: Algorithm) -> Self {
171 Self {
172 alg: Some(alg),
173 ..self
174 }
175 }
176
177 pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
178 Self {
179 enc: Some(enc),
180 ..self
181 }
182 }
183
184 pub fn cek(self, cek: &[u8]) -> Self {
185 Self {
186 cek: Some(Bytes::copy_from_slice(cek)),
187 ..self
188 }
189 }
190
191 pub fn iv(self, iv: &[u8]) -> Self {
192 Self {
193 iv: Some(Bytes::copy_from_slice(iv)),
194 ..self
195 }
196 }
197}
198
199impl<K> JweEncryptor<Unset, K> {
200 pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
201 JweEncryptor::<Set, K> {
202 plaintext: Some(Bytes::copy_from_slice(plaintext)),
203 alg: self.alg,
204 enc: self.enc,
205 kid: self.kid,
206 cek: self.cek,
207 iv: self.iv,
208 phantom: PhantomData,
209 }
210 }
211
212 pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
213 JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
214 }
215}
216
217impl<C> JweEncryptor<C, Unset> {
218 pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
219 JweEncryptor::<C, Set> {
220 kid: Some(kid.into()),
221 alg: self.alg,
222 enc: self.enc,
223 cek: self.cek,
224 iv: self.iv,
225 plaintext: self.plaintext,
226 phantom: PhantomData,
227 }
228 }
229}
230
231impl JweEncryptor<Set, Set> {
232 pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
233 where
234 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
235 {
236 let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
238 let cipher: Cipher = enc.try_into()?;
239
240 let cek = match self.cek {
242 Some(v) if v.len() == cipher.key_len() => v,
243 Some(v) => {
244 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
245 format!(
246 "require key size of {} bytes, got {}",
247 cipher.key_len(),
248 v.len()
249 )
250 }));
251 }
252 None => {
253 let mut buf = [0; 32];
255 rand::rand_bytes(&mut buf)?;
256 Bytes::copy_from_slice(&buf[0..cipher.key_len()])
257 }
258 };
259
260 let kid = self
261 .kid
262 .as_deref()
263 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
264 let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
265
266 let result = wrap_key(kid, &alg, &cek).await?;
268
269 let header = Header {
270 alg,
271 enc: Some(enc.clone()),
272 kid: Some(result.kid),
273 typ: super::Type::JWE,
274 };
275 let aad = header.encode()?;
276
277 let iv_len = cipher.iv_len().ok_or_else(|| {
279 Error::with_message(
280 ErrorKind::InvalidData,
281 format!("expected iv length for cipher {}", &enc),
282 )
283 })?;
284 let iv = match self.iv {
285 Some(v) if v.len() == iv_len => v,
286 Some(v) => {
287 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
288 format!("require iv size of {} bytes, got {}", iv_len, v.len())
289 }));
290 }
291 None => {
292 let mut buf = [0; 12];
294 rand::rand_bytes(&mut buf)?;
295 Bytes::copy_from_slice(&buf[0..iv_len])
296 }
297 };
298
299 let plaintext = self.plaintext.expect("expected plaintext");
300 let mut tag = [0; 16];
301 let ciphertext: Bytes = symm::encrypt_aead(
302 cipher,
303 &cek,
304 Some(&iv),
305 aad.as_bytes(),
306 &plaintext,
307 &mut tag,
308 )?
309 .into();
310
311 Ok(Jwe {
312 header,
313 cek: result.cek,
314 iv,
315 ciphertext,
316 tag: Bytes::copy_from_slice(&tag),
317 })
318 }
319}
320
321impl<C, K> Default for JweEncryptor<C, K> {
322 fn default() -> Self {
323 Self {
324 alg: None,
325 enc: None,
326 kid: None,
327 cek: None,
328 iv: None,
329 plaintext: None,
330 phantom: PhantomData,
331 }
332 }
333}
334
335impl TryFrom<EncryptionAlgorithm> for Cipher {
336 type Error = Error;
337 fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
338 (&value).try_into()
339 }
340}
341
342impl TryFrom<&EncryptionAlgorithm> for Cipher {
343 type Error = Error;
344 fn try_from(value: &EncryptionAlgorithm) -> Result<Cipher> {
345 match value {
346 EncryptionAlgorithm::A128GCM => Ok(Cipher::aes_128_gcm()),
347 EncryptionAlgorithm::A192GCM => Ok(Cipher::aes_192_gcm()),
348 EncryptionAlgorithm::A256GCM => Ok(Cipher::aes_256_gcm()),
349 EncryptionAlgorithm::Other(value) => {
350 Err(Error::with_message_fn(ErrorKind::InvalidData, || {
351 format!("unsupported encryption algorithm {value}")
352 }))
353 }
354 }
355 }
356}
357
358impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
359 type Error = Error;
360 fn try_from(value: &Algorithm) -> Result<Self> {
361 match value {
362 Algorithm::RSA1_5 => Ok(Self::RSA1_5),
363 Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
364 Algorithm::RSA_OAEP_256 => Ok(Self::RsaOAEP256),
365 Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
366 format!("unsupported algorithm {s}")
367 })),
368 }
369 }
370}
371
372#[derive(Debug)]
373pub struct WrapKeyResult {
374 pub kid: String,
375 pub cek: Bytes,
376}
377
378impl TryFrom<KeyOperationResult> for WrapKeyResult {
379 type Error = Error;
380 fn try_from(value: KeyOperationResult) -> Result<Self> {
381 Ok(Self {
382 kid: value
383 .kid
384 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
385 cek: value
386 .result
387 .map(Into::into)
388 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
389 })
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use azure_core::Bytes;
397
398 #[test]
399 fn decode_invalid() {
400 assert!(
401 matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
402 );
403 assert!(
404 matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
405 );
406 }
407
408 #[test]
409 fn encode_decode_roundtrip() {
410 let jwe = Jwe {
411 header: Header {
412 alg: crate::jose::Algorithm::RSA_OAEP_256,
413 enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
414 kid: Some("test-key-id".to_string()),
415 typ: crate::jose::Type::JWE,
416 },
417 cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
418 iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
419 ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
420 tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
421 };
422
423 const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
425
426 let encoded = jwe.encode().expect("encode should succeed");
427 assert_eq!(encoded, EXPECTED);
428
429 let decoded = Jwe::decode(&encoded).expect("decode should succeed");
430 assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
431 assert_eq!(
432 decoded.header.enc,
433 Some(crate::jose::EncryptionAlgorithm::A128GCM)
434 );
435 assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
436 assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
437 assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
438 assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
439 assert_eq!(
440 decoded.ciphertext,
441 Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
442 );
443 assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
444 }
445
446 #[test]
447 fn from_str_success() {
448 let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
450 let jwe = Jwe::from_str(s).expect("should parse valid JWE");
451 assert_eq!(jwe.header.alg, Algorithm::RSA_OAEP_256);
452 assert_eq!(jwe.header.enc, Some(EncryptionAlgorithm::A128GCM));
453 assert_eq!(jwe.header.kid, Some("test-key-id".to_string()));
454 assert_eq!(jwe.header.typ, Type::JWE);
455 }
456
457 #[test]
458 fn from_str_invalid_character() {
459 let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRW!eA.mrze8A.ASNFZw.iavN7w";
462 let err = Jwe::from_str(s).unwrap_err();
463 assert!(matches!(err.kind(), ErrorKind::InvalidData));
464 assert_eq!(
465 err.message(),
466 Some("invalid character in JWE compact serialization")
467 );
468 }
469
470 #[test]
471 fn from_str_too_few_periods() {
472 let s = "a.b.c.d";
474 let err = Jwe::from_str(s).unwrap_err();
475 assert!(matches!(err.kind(), ErrorKind::InvalidData));
476 assert_eq!(
477 err.message(),
478 Some("JWE must have exactly 5 parts separated by periods")
479 );
480 }
481
482 #[test]
483 fn from_str_too_many_periods() {
484 let s = "a.b.c.d.e.f";
486 let err = Jwe::from_str(s).unwrap_err();
487 assert!(matches!(err.kind(), ErrorKind::InvalidData));
488 assert_eq!(
489 err.message(),
490 Some("JWE must have exactly 5 parts separated by periods")
491 );
492 }
493
494 #[test]
495 fn from_str_invalid_header() {
496 let s = "Zm9vYmFy.EjRWeA.mrze8A.ASNFZw.iavN7w";
499 let err = Jwe::from_str(s).unwrap_err();
500 assert!(matches!(err.kind(), ErrorKind::InvalidData));
501 assert_eq!(err.message(), Some("invalid header"));
502 }
503
504 #[test]
505 fn encryption_algorithm_cipher() {
506 let cipher: Cipher = EncryptionAlgorithm::A128GCM
507 .try_into()
508 .expect("try_into should succeed");
509 assert_eq!(cipher.iv_len(), Some(12));
510 assert_eq!(cipher.key_len(), 16);
511
512 let cipher: Cipher = EncryptionAlgorithm::A192GCM
513 .try_into()
514 .expect("try_into should succeed");
515 assert_eq!(cipher.iv_len(), Some(12));
516 assert_eq!(cipher.key_len(), 24);
517
518 let cipher: Cipher = EncryptionAlgorithm::A256GCM
519 .try_into()
520 .expect("try_into should succeed");
521 assert_eq!(cipher.iv_len(), Some(12));
522 assert_eq!(cipher.key_len(), 32);
523 }
524
525 #[tokio::test]
526 async fn encrypt_decrypt_roundtrip() {
527 let kid = "key-name";
528 let alg = Algorithm::RSA_OAEP;
529 let enc = EncryptionAlgorithm::A128GCM;
530 let plaintext = b"Hello, world!";
531
532 let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
534 assert_eq!(key_id, kid);
535 assert_eq!(wrap_alg, &alg);
536 Ok(crate::jose::jwe::WrapKeyResult {
537 kid: "key-name/key-version".into(),
538 cek: Bytes::copy_from_slice(cek),
539 })
540 };
541
542 let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
544 assert_eq!(key_id, "key-name/key-version");
545 assert_eq!(wrap_alg, &alg);
546 Ok(crate::jose::jwe::WrapKeyResult {
547 kid: "key-name/key-version".into(),
548 cek: Bytes::copy_from_slice(cek),
549 })
550 };
551
552 let jwe = Jwe::encryptor()
553 .alg(alg.clone())
554 .enc(enc)
555 .kid(kid)
556 .plaintext(plaintext)
557 .encrypt(wrap_key)
558 .await
559 .expect("encryption should succeed");
560
561 let decrypted = jwe
562 .decrypt(unwrap_key)
563 .await
564 .expect("decryption should succeed");
565 assert_eq!(decrypted, plaintext.as_ref());
566 }
567}