1use crate::{
7 jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
8 Error, ErrorKind, Result, ResultExt as _,
9};
10use aws_lc_rs::{aead, rand};
11use azure_core::{base64, Bytes};
12use azure_security_keyvault_keys::models::KeyOperationResult;
13use std::{marker::PhantomData, str::FromStr};
14
15#[derive(Debug)]
17pub struct Jwe {
18 header: Header,
19 cek: Bytes,
20 iv: Bytes,
21 ciphertext: Bytes,
22 tag: Bytes,
23}
24
25impl Jwe {
26 pub fn encryptor() -> JweEncryptor<Unset, Unset> {
28 JweEncryptor::default()
29 }
30
31 pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
33 where
34 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
35 {
36 if self.header.typ != Type::JWE {
37 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
38 format!("expected JWE, got {}", self.header.typ)
39 }));
40 }
41
42 let key_id = self
44 .header
45 .kid
46 .as_deref()
47 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
48 let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
49
50 let enc = self
51 .header
52 .enc
53 .as_ref()
54 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
55 let alg: &'static aead::Algorithm = enc.try_into()?;
56 let aad = self.header.encode()?;
57
58 let key = aead::LessSafeKey::new(
63 aead::UnboundKey::new(alg, &result.cek)
64 .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid CEK"))?,
65 );
66 let nonce = aead::Nonce::try_assume_unique_for_key(&self.iv)
67 .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid IV"))?;
68 let mut buf = self.ciphertext.to_vec();
70 buf.extend_from_slice(&self.tag);
71 let plaintext = key
72 .open_in_place(nonce, aead::Aad::from(aad.as_bytes()), &mut buf)
73 .map_err(|_| Error::with_message(ErrorKind::Other, "decryption failed"))?;
74 let plaintext = Bytes::copy_from_slice(plaintext);
75
76 Ok(plaintext)
77 }
78
79 pub fn kid(&self) -> Option<&str> {
81 self.header.kid.as_deref()
82 }
83}
84
85impl Encode for Jwe {
86 fn decode(value: &str) -> Result<Self> {
87 value.parse()
88 }
89
90 fn encode(&self) -> Result<String> {
91 Ok([
92 self.header.encode()?,
93 base64::encode_url_safe(&self.cek),
94 base64::encode_url_safe(&self.iv),
95 base64::encode_url_safe(&self.ciphertext),
96 base64::encode_url_safe(&self.tag),
97 ]
98 .join("."))
99 }
100}
101
102impl FromStr for Jwe {
103 type Err = Error;
104 fn from_str(s: &str) -> Result<Self> {
105 const PARTS_ERROR: &str = "JWE must have exactly 5 parts separated by periods";
106
107 fn is_base64url_char(c: char) -> bool {
108 c.is_ascii_alphanumeric() || c == '-' || c == '_'
109 }
110
111 let mut parts = [0usize; 6];
112 let mut current_part_start = 0;
113 for (i, c) in s.char_indices() {
114 if c == '.' {
115 if current_part_start >= 5 {
116 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
117 PARTS_ERROR
118 }));
119 }
120
121 parts[current_part_start + 1] = i + 1;
122 current_part_start += 1;
123 } else if !is_base64url_char(c) {
124 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
125 "invalid character in JWE compact serialization"
126 }));
127 }
128 }
129
130 if current_part_start != 4 {
131 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
132 PARTS_ERROR
133 }));
134 }
135
136 parts[5] = s.len() + 1;
137 let header = &s[parts[0]..parts[1] - 1];
138 let cek = &s[parts[1]..parts[2] - 1];
139 let iv = &s[parts[2]..parts[3] - 1];
140 let ciphertext = &s[parts[3]..parts[4] - 1];
141 let tag = &s[parts[4]..parts[5] - 1];
142
143 let header =
144 Header::decode(header).with_context_fn(ErrorKind::InvalidData, || "invalid header")?;
145 let cek = base64::decode_url_safe(cek)
146 .with_context_fn(ErrorKind::InvalidData, || "invalid cek")?
147 .into();
148 let iv = base64::decode_url_safe(iv)
149 .with_context_fn(ErrorKind::InvalidData, || "invalid iv")?
150 .into();
151 let ciphertext = base64::decode_url_safe(ciphertext)
152 .with_context_fn(ErrorKind::InvalidData, || "invalid ciphertext")?
153 .into();
154 let tag = base64::decode_url_safe(tag)
155 .with_context_fn(ErrorKind::InvalidData, || "invalid tag")?
156 .into();
157
158 Ok(Jwe {
159 header,
160 cek,
161 iv,
162 ciphertext,
163 tag,
164 })
165 }
166}
167
168#[derive(Debug)]
172pub struct JweEncryptor<C, K> {
173 alg: Option<Algorithm>,
174 enc: Option<EncryptionAlgorithm>,
175 kid: Option<String>,
176 cek: Option<Bytes>,
177 iv: Option<Bytes>,
178 plaintext: Option<Bytes>,
179 phantom: PhantomData<(C, K)>,
180}
181
182impl<C, K> JweEncryptor<C, K> {
183 pub fn alg(self, alg: Algorithm) -> Self {
185 Self {
186 alg: Some(alg),
187 ..self
188 }
189 }
190
191 pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
193 Self {
194 enc: Some(enc),
195 ..self
196 }
197 }
198
199 pub fn cek(self, cek: &[u8]) -> Self {
201 Self {
202 cek: Some(Bytes::copy_from_slice(cek)),
203 ..self
204 }
205 }
206
207 pub fn iv(self, iv: &[u8]) -> Self {
209 Self {
210 iv: Some(Bytes::copy_from_slice(iv)),
211 ..self
212 }
213 }
214}
215
216impl<K> JweEncryptor<Unset, K> {
217 pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
219 JweEncryptor::<Set, K> {
220 plaintext: Some(Bytes::copy_from_slice(plaintext)),
221 alg: self.alg,
222 enc: self.enc,
223 kid: self.kid,
224 cek: self.cek,
225 iv: self.iv,
226 phantom: PhantomData,
227 }
228 }
229
230 pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
232 JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
233 }
234}
235
236impl<C> JweEncryptor<C, Unset> {
237 pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
239 JweEncryptor::<C, Set> {
240 kid: Some(kid.into()),
241 alg: self.alg,
242 enc: self.enc,
243 cek: self.cek,
244 iv: self.iv,
245 plaintext: self.plaintext,
246 phantom: PhantomData,
247 }
248 }
249}
250
251impl JweEncryptor<Set, Set> {
252 pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
254 where
255 F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
256 {
257 let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
259 let cipher: &'static aead::Algorithm = enc.try_into()?;
260
261 let cek = match self.cek {
263 Some(v) if v.len() == cipher.key_len() => v,
264 Some(v) => {
265 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
266 format!(
267 "require key size of {} bytes, got {}",
268 cipher.key_len(),
269 v.len()
270 )
271 }));
272 }
273 None => {
274 let mut buf = [0; 32];
276 rand::fill(&mut buf)?;
277 Bytes::copy_from_slice(&buf[0..cipher.key_len()])
278 }
279 };
280
281 let kid = self
282 .kid
283 .as_deref()
284 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
285 let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
286
287 let result = wrap_key(kid, &alg, &cek).await?;
289
290 let header = Header {
291 alg,
292 enc: Some(enc.clone()),
293 kid: Some(result.kid),
294 typ: super::Type::JWE,
295 };
296 let aad = header.encode()?;
297
298 let iv = match self.iv {
300 Some(v) if v.len() == aead::NONCE_LEN => v,
301 Some(v) => {
302 return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
303 format!(
304 "require iv size of {} bytes, got {}",
305 aead::NONCE_LEN,
306 v.len()
307 )
308 }));
309 }
310 None => {
311 let mut buf = [0u8; aead::NONCE_LEN];
312 rand::fill(&mut buf)?;
313 Bytes::copy_from_slice(&buf)
314 }
315 };
316
317 let key = aead::LessSafeKey::new(
322 aead::UnboundKey::new(cipher, &cek)
323 .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid CEK"))?,
324 );
325 let nonce = aead::Nonce::try_assume_unique_for_key(&iv)
326 .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid IV"))?;
327 let plaintext = self.plaintext.expect("expected plaintext");
328 let mut buf = plaintext.to_vec();
330 key.seal_in_place_append_tag(nonce, aead::Aad::from(aad.as_bytes()), &mut buf)
331 .map_err(|_| Error::with_message(ErrorKind::Other, "encryption failed"))?;
332 let tag = buf.split_off(buf.len() - cipher.tag_len());
334 let ciphertext: Bytes = buf.into();
335
336 Ok(Jwe {
337 header,
338 cek: result.cek,
339 iv,
340 ciphertext,
341 tag: tag.into(),
342 })
343 }
344}
345
346impl<C, K> Default for JweEncryptor<C, K> {
347 fn default() -> Self {
348 Self {
349 alg: None,
350 enc: None,
351 kid: None,
352 cek: None,
353 iv: None,
354 plaintext: None,
355 phantom: PhantomData,
356 }
357 }
358}
359
360impl TryFrom<EncryptionAlgorithm> for &'static aead::Algorithm {
361 type Error = Error;
362 fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
363 (&value).try_into()
364 }
365}
366
367impl TryFrom<&EncryptionAlgorithm> for &'static aead::Algorithm {
368 type Error = Error;
369 fn try_from(value: &EncryptionAlgorithm) -> Result<&'static aead::Algorithm> {
370 match value {
371 EncryptionAlgorithm::A128GCM => Ok(&aead::AES_128_GCM),
372 EncryptionAlgorithm::A192GCM => Ok(&aead::AES_192_GCM),
373 EncryptionAlgorithm::A256GCM => Ok(&aead::AES_256_GCM),
374 EncryptionAlgorithm::Other(value) => {
375 Err(Error::with_message_fn(ErrorKind::InvalidData, || {
376 format!("unsupported encryption algorithm {value}")
377 }))
378 }
379 }
380 }
381}
382
383impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
384 type Error = Error;
385 fn try_from(value: &Algorithm) -> Result<Self> {
386 match value {
387 Algorithm::RSA1_5 => Ok(Self::Rsa1_5),
388 Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
389 Algorithm::RSA_OAEP_256 => Ok(Self::RsaOaep256),
390 Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
391 format!("unsupported algorithm {s}")
392 })),
393 }
394 }
395}
396
397#[derive(Debug)]
399pub struct WrapKeyResult {
400 pub kid: String,
402
403 pub cek: Bytes,
405}
406
407impl TryFrom<KeyOperationResult> for WrapKeyResult {
408 type Error = Error;
409 fn try_from(value: KeyOperationResult) -> Result<Self> {
410 Ok(Self {
411 kid: value
412 .kid
413 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
414 cek: value
415 .result
416 .map(Into::into)
417 .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
418 })
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use azure_core::Bytes;
426
427 #[test]
428 fn decode_invalid() {
429 assert!(
430 matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
431 );
432 assert!(
433 matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
434 );
435 }
436
437 #[test]
438 fn encode_decode_roundtrip() {
439 let jwe = Jwe {
440 header: Header {
441 alg: crate::jose::Algorithm::RSA_OAEP_256,
442 enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
443 kid: Some("test-key-id".to_string()),
444 typ: crate::jose::Type::JWE,
445 },
446 cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
447 iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
448 ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
449 tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
450 };
451
452 const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
454
455 let encoded = jwe.encode().expect("encode should succeed");
456 assert_eq!(encoded, EXPECTED);
457
458 let decoded = Jwe::decode(&encoded).expect("decode should succeed");
459 assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
460 assert_eq!(
461 decoded.header.enc,
462 Some(crate::jose::EncryptionAlgorithm::A128GCM)
463 );
464 assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
465 assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
466 assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
467 assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
468 assert_eq!(
469 decoded.ciphertext,
470 Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
471 );
472 assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
473 }
474
475 #[test]
476 fn from_str_success() {
477 let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
479 let jwe = Jwe::from_str(s).expect("should parse valid JWE");
480 assert_eq!(jwe.header.alg, Algorithm::RSA_OAEP_256);
481 assert_eq!(jwe.header.enc, Some(EncryptionAlgorithm::A128GCM));
482 assert_eq!(jwe.header.kid, Some("test-key-id".to_string()));
483 assert_eq!(jwe.header.typ, Type::JWE);
484 }
485
486 #[test]
487 fn from_str_invalid_character() {
488 let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRW!eA.mrze8A.ASNFZw.iavN7w";
491 let err = Jwe::from_str(s).unwrap_err();
492 assert!(matches!(err.kind(), ErrorKind::InvalidData));
493 assert_eq!(
494 err.message(),
495 Some("invalid character in JWE compact serialization")
496 );
497 }
498
499 #[test]
500 fn from_str_too_few_periods() {
501 let s = "a.b.c.d";
503 let err = Jwe::from_str(s).unwrap_err();
504 assert!(matches!(err.kind(), ErrorKind::InvalidData));
505 assert_eq!(
506 err.message(),
507 Some("JWE must have exactly 5 parts separated by periods")
508 );
509 }
510
511 #[test]
512 fn from_str_too_many_periods() {
513 let s = "a.b.c.d.e.f";
515 let err = Jwe::from_str(s).unwrap_err();
516 assert!(matches!(err.kind(), ErrorKind::InvalidData));
517 assert_eq!(
518 err.message(),
519 Some("JWE must have exactly 5 parts separated by periods")
520 );
521 }
522
523 #[test]
524 fn from_str_invalid_header() {
525 let s = "Zm9vYmFy.EjRWeA.mrze8A.ASNFZw.iavN7w";
528 let err = Jwe::from_str(s).unwrap_err();
529 assert!(matches!(err.kind(), ErrorKind::InvalidData));
530 assert_eq!(err.message(), Some("invalid header"));
531 }
532
533 #[test]
534 fn encryption_algorithm_cipher() {
535 let cipher: &'static aead::Algorithm = EncryptionAlgorithm::A128GCM
536 .try_into()
537 .expect("try_into should succeed");
538 assert_eq!(cipher.nonce_len(), 12);
539 assert_eq!(cipher.key_len(), 16);
540
541 let cipher: &'static aead::Algorithm = EncryptionAlgorithm::A192GCM
542 .try_into()
543 .expect("try_into should succeed");
544 assert_eq!(cipher.nonce_len(), 12);
545 assert_eq!(cipher.key_len(), 24);
546
547 let cipher: &'static aead::Algorithm = EncryptionAlgorithm::A256GCM
548 .try_into()
549 .expect("try_into should succeed");
550 assert_eq!(cipher.nonce_len(), 12);
551 assert_eq!(cipher.key_len(), 32);
552 }
553
554 #[tokio::test]
555 async fn encrypt_decrypt_roundtrip() {
556 let kid = "key-name";
557 let alg = Algorithm::RSA_OAEP;
558 let enc = EncryptionAlgorithm::A128GCM;
559 let plaintext = b"Hello, world!";
560
561 let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
563 assert_eq!(key_id, kid);
564 assert_eq!(wrap_alg, &alg);
565 Ok(crate::jose::jwe::WrapKeyResult {
566 kid: "key-name/key-version".into(),
567 cek: Bytes::copy_from_slice(cek),
568 })
569 };
570
571 let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
573 assert_eq!(key_id, "key-name/key-version");
574 assert_eq!(wrap_alg, &alg);
575 Ok(crate::jose::jwe::WrapKeyResult {
576 kid: "key-name/key-version".into(),
577 cek: Bytes::copy_from_slice(cek),
578 })
579 };
580
581 let jwe = Jwe::encryptor()
582 .alg(alg.clone())
583 .enc(enc)
584 .kid(kid)
585 .plaintext(plaintext)
586 .encrypt(wrap_key)
587 .await
588 .expect("encryption should succeed");
589
590 let decrypted = jwe
591 .decrypt(unwrap_key)
592 .await
593 .expect("decryption should succeed");
594 assert_eq!(decrypted, plaintext.as_ref());
595 }
596}