1use crate::{
7 jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
8 Error, ErrorKind, Result, ResultExt as _,
9};
10use azure_core::{base64, Bytes};
11use azure_security_keyvault_keys::models::KeyOperationResult;
12use openssl::{
13 rand,
14 symm::{self, Cipher},
15};
16use std::{marker::PhantomData, str::FromStr};
17
18#[derive(Debug)]
20pub struct Jwe {
21 header: Header,
22 cek: Bytes,
23 iv: Bytes,
24 ciphertext: Bytes,
25 tag: Bytes,
26}
27
28impl Jwe {
29 pub fn encryptor() -> JweEncryptor<Unset, Unset> {
31 JweEncryptor::default()
32 }
33
34 pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
36 where
37 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
38 {
39 if self.header.typ != Type::JWE {
40 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
41 format!("expected JWE, got {}", self.header.typ)
42 }));
43 }
44
45 let key_id = self
47 .header
48 .kid
49 .as_deref()
50 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
51 let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
52
53 let enc = self
54 .header
55 .enc
56 .as_ref()
57 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
58 let cipher: Cipher = enc.try_into()?;
59 let aad = self.header.encode()?;
60
61 let plaintext: Bytes = symm::decrypt_aead(
62 cipher,
63 &result.cek,
64 Some(&self.iv),
65 aad.as_bytes(),
66 &self.ciphertext,
67 &self.tag,
68 )?
69 .into();
70
71 Ok(plaintext)
72 }
73
74 pub fn kid(&self) -> Option<&str> {
76 self.header.kid.as_deref()
77 }
78}
79
80impl Encode for Jwe {
81 fn decode(value: &str) -> Result<Self> {
82 value.parse()
83 }
84
85 fn encode(&self) -> Result<String> {
86 Ok([
87 self.header.encode()?,
88 base64::encode_url_safe(&self.cek),
89 base64::encode_url_safe(&self.iv),
90 base64::encode_url_safe(&self.ciphertext),
91 base64::encode_url_safe(&self.tag),
92 ]
93 .join("."))
94 }
95}
96
97impl FromStr for Jwe {
98 type Err = Error;
99 fn from_str(s: &str) -> Result<Self> {
100 const PARTS_ERROR: &str = "JWE must have exactly 5 parts separated by periods";
101
102 fn is_base64url_char(c: char) -> bool {
103 c.is_ascii_alphanumeric() || c == '-' || c == '_'
104 }
105
106 let mut parts = [0usize; 6];
107 let mut current_part_start = 0;
108 for (i, c) in s.char_indices() {
109 if c == '.' {
110 if current_part_start >= 5 {
111 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
112 PARTS_ERROR
113 }));
114 }
115
116 parts[current_part_start + 1] = i + 1;
117 current_part_start += 1;
118 } else if !is_base64url_char(c) {
119 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
120 "invalid character in JWE compact serialization"
121 }));
122 }
123 }
124
125 if current_part_start != 4 {
126 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
127 PARTS_ERROR
128 }));
129 }
130
131 parts[5] = s.len() + 1;
132 let header = &s[parts[0]..parts[1] - 1];
133 let cek = &s[parts[1]..parts[2] - 1];
134 let iv = &s[parts[2]..parts[3] - 1];
135 let ciphertext = &s[parts[3]..parts[4] - 1];
136 let tag = &s[parts[4]..parts[5] - 1];
137
138 let header =
139 Header::decode(header).with_context_fn(ErrorKind::InvalidData, || "invalid header")?;
140 let cek = base64::decode_url_safe(cek)
141 .with_context_fn(ErrorKind::InvalidData, || "invalid cek")?
142 .into();
143 let iv = base64::decode_url_safe(iv)
144 .with_context_fn(ErrorKind::InvalidData, || "invalid iv")?
145 .into();
146 let ciphertext = base64::decode_url_safe(ciphertext)
147 .with_context_fn(ErrorKind::InvalidData, || "invalid ciphertext")?
148 .into();
149 let tag = base64::decode_url_safe(tag)
150 .with_context_fn(ErrorKind::InvalidData, || "invalid tag")?
151 .into();
152
153 Ok(Jwe {
154 header,
155 cek,
156 iv,
157 ciphertext,
158 tag,
159 })
160 }
161}
162
163#[derive(Debug)]
167pub struct JweEncryptor<C, K> {
168 alg: Option<Algorithm>,
169 enc: Option<EncryptionAlgorithm>,
170 kid: Option<String>,
171 cek: Option<Bytes>,
172 iv: Option<Bytes>,
173 plaintext: Option<Bytes>,
174 phantom: PhantomData<(C, K)>,
175}
176
177impl<C, K> JweEncryptor<C, K> {
178 pub fn alg(self, alg: Algorithm) -> Self {
180 Self {
181 alg: Some(alg),
182 ..self
183 }
184 }
185
186 pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
188 Self {
189 enc: Some(enc),
190 ..self
191 }
192 }
193
194 pub fn cek(self, cek: &[u8]) -> Self {
196 Self {
197 cek: Some(Bytes::copy_from_slice(cek)),
198 ..self
199 }
200 }
201
202 pub fn iv(self, iv: &[u8]) -> Self {
204 Self {
205 iv: Some(Bytes::copy_from_slice(iv)),
206 ..self
207 }
208 }
209}
210
211impl<K> JweEncryptor<Unset, K> {
212 pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
214 JweEncryptor::<Set, K> {
215 plaintext: Some(Bytes::copy_from_slice(plaintext)),
216 alg: self.alg,
217 enc: self.enc,
218 kid: self.kid,
219 cek: self.cek,
220 iv: self.iv,
221 phantom: PhantomData,
222 }
223 }
224
225 pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
227 JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
228 }
229}
230
231impl<C> JweEncryptor<C, Unset> {
232 pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
234 JweEncryptor::<C, Set> {
235 kid: Some(kid.into()),
236 alg: self.alg,
237 enc: self.enc,
238 cek: self.cek,
239 iv: self.iv,
240 plaintext: self.plaintext,
241 phantom: PhantomData,
242 }
243 }
244}
245
246impl JweEncryptor<Set, Set> {
247 pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
249 where
250 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
251 {
252 let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
254 let cipher: Cipher = enc.try_into()?;
255
256 let cek = match self.cek {
258 Some(v) if v.len() == cipher.key_len() => v,
259 Some(v) => {
260 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
261 format!(
262 "require key size of {} bytes, got {}",
263 cipher.key_len(),
264 v.len()
265 )
266 }));
267 }
268 None => {
269 let mut buf = [0; 32];
271 rand::rand_bytes(&mut buf)?;
272 Bytes::copy_from_slice(&buf[0..cipher.key_len()])
273 }
274 };
275
276 let kid = self
277 .kid
278 .as_deref()
279 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
280 let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
281
282 let result = wrap_key(kid, &alg, &cek).await?;
284
285 let header = Header {
286 alg,
287 enc: Some(enc.clone()),
288 kid: Some(result.kid),
289 typ: super::Type::JWE,
290 };
291 let aad = header.encode()?;
292
293 let iv_len = cipher.iv_len().ok_or_else(|| {
295 Error::with_message(
296 ErrorKind::InvalidData,
297 format!("expected iv length for cipher {}", &enc),
298 )
299 })?;
300 let iv = match self.iv {
301 Some(v) if v.len() == iv_len => v,
302 Some(v) => {
303 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
304 format!("require iv size of {} bytes, got {}", iv_len, v.len())
305 }));
306 }
307 None => {
308 let mut buf = [0; 12];
310 rand::rand_bytes(&mut buf)?;
311 Bytes::copy_from_slice(&buf[0..iv_len])
312 }
313 };
314
315 let plaintext = self.plaintext.expect("expected plaintext");
316 let mut tag = [0; 16];
317 let ciphertext: Bytes = symm::encrypt_aead(
318 cipher,
319 &cek,
320 Some(&iv),
321 aad.as_bytes(),
322 &plaintext,
323 &mut tag,
324 )?
325 .into();
326
327 Ok(Jwe {
328 header,
329 cek: result.cek,
330 iv,
331 ciphertext,
332 tag: Bytes::copy_from_slice(&tag),
333 })
334 }
335}
336
337impl<C, K> Default for JweEncryptor<C, K> {
338 fn default() -> Self {
339 Self {
340 alg: None,
341 enc: None,
342 kid: None,
343 cek: None,
344 iv: None,
345 plaintext: None,
346 phantom: PhantomData,
347 }
348 }
349}
350
351impl TryFrom<EncryptionAlgorithm> for Cipher {
352 type Error = Error;
353 fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
354 (&value).try_into()
355 }
356}
357
358impl TryFrom<&EncryptionAlgorithm> for Cipher {
359 type Error = Error;
360 fn try_from(value: &EncryptionAlgorithm) -> Result<Cipher> {
361 match value {
362 EncryptionAlgorithm::A128GCM => Ok(Cipher::aes_128_gcm()),
363 EncryptionAlgorithm::A192GCM => Ok(Cipher::aes_192_gcm()),
364 EncryptionAlgorithm::A256GCM => Ok(Cipher::aes_256_gcm()),
365 EncryptionAlgorithm::Other(value) => {
366 Err(Error::with_message_fn(ErrorKind::InvalidData, || {
367 format!("unsupported encryption algorithm {value}")
368 }))
369 }
370 }
371 }
372}
373
374impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
375 type Error = Error;
376 fn try_from(value: &Algorithm) -> Result<Self> {
377 match value {
378 Algorithm::RSA1_5 => Ok(Self::Rsa1_5),
379 Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
380 Algorithm::RSA_OAEP_256 => Ok(Self::RsaOaep256),
381 Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
382 format!("unsupported algorithm {s}")
383 })),
384 }
385 }
386}
387
388#[derive(Debug)]
390pub struct WrapKeyResult {
391 pub kid: String,
393
394 pub cek: Bytes,
396}
397
398impl TryFrom<KeyOperationResult> for WrapKeyResult {
399 type Error = Error;
400 fn try_from(value: KeyOperationResult) -> Result<Self> {
401 Ok(Self {
402 kid: value
403 .kid
404 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
405 cek: value
406 .result
407 .map(Into::into)
408 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
409 })
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use azure_core::Bytes;
417
418 #[test]
419 fn decode_invalid() {
420 assert!(
421 matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
422 );
423 assert!(
424 matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
425 );
426 }
427
428 #[test]
429 fn encode_decode_roundtrip() {
430 let jwe = Jwe {
431 header: Header {
432 alg: crate::jose::Algorithm::RSA_OAEP_256,
433 enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
434 kid: Some("test-key-id".to_string()),
435 typ: crate::jose::Type::JWE,
436 },
437 cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
438 iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
439 ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
440 tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
441 };
442
443 const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
445
446 let encoded = jwe.encode().expect("encode should succeed");
447 assert_eq!(encoded, EXPECTED);
448
449 let decoded = Jwe::decode(&encoded).expect("decode should succeed");
450 assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
451 assert_eq!(
452 decoded.header.enc,
453 Some(crate::jose::EncryptionAlgorithm::A128GCM)
454 );
455 assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
456 assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
457 assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
458 assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
459 assert_eq!(
460 decoded.ciphertext,
461 Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
462 );
463 assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
464 }
465
466 #[test]
467 fn from_str_success() {
468 let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
470 let jwe = Jwe::from_str(s).expect("should parse valid JWE");
471 assert_eq!(jwe.header.alg, Algorithm::RSA_OAEP_256);
472 assert_eq!(jwe.header.enc, Some(EncryptionAlgorithm::A128GCM));
473 assert_eq!(jwe.header.kid, Some("test-key-id".to_string()));
474 assert_eq!(jwe.header.typ, Type::JWE);
475 }
476
477 #[test]
478 fn from_str_invalid_character() {
479 let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRW!eA.mrze8A.ASNFZw.iavN7w";
482 let err = Jwe::from_str(s).unwrap_err();
483 assert!(matches!(err.kind(), ErrorKind::InvalidData));
484 assert_eq!(
485 err.message(),
486 Some("invalid character in JWE compact serialization")
487 );
488 }
489
490 #[test]
491 fn from_str_too_few_periods() {
492 let s = "a.b.c.d";
494 let err = Jwe::from_str(s).unwrap_err();
495 assert!(matches!(err.kind(), ErrorKind::InvalidData));
496 assert_eq!(
497 err.message(),
498 Some("JWE must have exactly 5 parts separated by periods")
499 );
500 }
501
502 #[test]
503 fn from_str_too_many_periods() {
504 let s = "a.b.c.d.e.f";
506 let err = Jwe::from_str(s).unwrap_err();
507 assert!(matches!(err.kind(), ErrorKind::InvalidData));
508 assert_eq!(
509 err.message(),
510 Some("JWE must have exactly 5 parts separated by periods")
511 );
512 }
513
514 #[test]
515 fn from_str_invalid_header() {
516 let s = "Zm9vYmFy.EjRWeA.mrze8A.ASNFZw.iavN7w";
519 let err = Jwe::from_str(s).unwrap_err();
520 assert!(matches!(err.kind(), ErrorKind::InvalidData));
521 assert_eq!(err.message(), Some("invalid header"));
522 }
523
524 #[test]
525 fn encryption_algorithm_cipher() {
526 let cipher: Cipher = EncryptionAlgorithm::A128GCM
527 .try_into()
528 .expect("try_into should succeed");
529 assert_eq!(cipher.iv_len(), Some(12));
530 assert_eq!(cipher.key_len(), 16);
531
532 let cipher: Cipher = EncryptionAlgorithm::A192GCM
533 .try_into()
534 .expect("try_into should succeed");
535 assert_eq!(cipher.iv_len(), Some(12));
536 assert_eq!(cipher.key_len(), 24);
537
538 let cipher: Cipher = EncryptionAlgorithm::A256GCM
539 .try_into()
540 .expect("try_into should succeed");
541 assert_eq!(cipher.iv_len(), Some(12));
542 assert_eq!(cipher.key_len(), 32);
543 }
544
545 #[tokio::test]
546 async fn encrypt_decrypt_roundtrip() {
547 let kid = "key-name";
548 let alg = Algorithm::RSA_OAEP;
549 let enc = EncryptionAlgorithm::A128GCM;
550 let plaintext = b"Hello, world!";
551
552 let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
554 assert_eq!(key_id, kid);
555 assert_eq!(wrap_alg, &alg);
556 Ok(crate::jose::jwe::WrapKeyResult {
557 kid: "key-name/key-version".into(),
558 cek: Bytes::copy_from_slice(cek),
559 })
560 };
561
562 let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
564 assert_eq!(key_id, "key-name/key-version");
565 assert_eq!(wrap_alg, &alg);
566 Ok(crate::jose::jwe::WrapKeyResult {
567 kid: "key-name/key-version".into(),
568 cek: Bytes::copy_from_slice(cek),
569 })
570 };
571
572 let jwe = Jwe::encryptor()
573 .alg(alg.clone())
574 .enc(enc)
575 .kid(kid)
576 .plaintext(plaintext)
577 .encrypt(wrap_key)
578 .await
579 .expect("encryption should succeed");
580
581 let decrypted = jwe
582 .decrypt(unwrap_key)
583 .await
584 .expect("decryption should succeed");
585 assert_eq!(decrypted, plaintext.as_ref());
586 }
587}