Skip to main content

akv_cli/jose/
jwe.rs

1// Copyright 2025 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4//! JSON Web Encryption types.
5
6use crate::{
7    jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
8    Error, ErrorKind, Result, ResultExt as _,
9};
10use aws_lc_rs::{aead, rand};
11use azure_core::{base64, Bytes};
12use azure_security_keyvault_keys::models::KeyOperationResult;
13use std::{marker::PhantomData, str::FromStr};
14
15/// A JSON Web Encryption (JWE) structure.
16#[derive(Debug)]
17pub struct Jwe {
18    header: Header,
19    cek: Bytes,
20    iv: Bytes,
21    ciphertext: Bytes,
22    tag: Bytes,
23}
24
25impl Jwe {
26    /// Gets a JWE encryptor.
27    pub fn encryptor() -> JweEncryptor<Unset, Unset> {
28        JweEncryptor::default()
29    }
30
31    /// Decrypts a JWE.
32    pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
33    where
34        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
35    {
36        if self.header.typ != Type::JWE {
37            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
38                format!("expected JWE, got {}", self.header.typ)
39            }));
40        }
41
42        // Decrypt the CEK.
43        let key_id = self
44            .header
45            .kid
46            .as_deref()
47            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
48        let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
49
50        let enc = self
51            .header
52            .enc
53            .as_ref()
54            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
55        let alg: &'static aead::Algorithm = enc.try_into()?;
56        let aad = self.header.encode()?;
57
58        // `LessSafeKey` is used instead of `OpeningKey` + `NonceSequence` because decryption
59        // must supply an arbitrary nonce stored in the JWE compact form. The nonce-sequence
60        // APIs expect to own nonce tracking and cannot accept an externally supplied nonce
61        // without extra scaffolding.
62        let key = aead::LessSafeKey::new(
63            aead::UnboundKey::new(alg, &result.cek)
64                .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid CEK"))?,
65        );
66        let nonce = aead::Nonce::try_assume_unique_for_key(&self.iv)
67            .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid IV"))?;
68        // aws-lc-rs expects ciphertext || tag concatenated in a single buffer.
69        let mut buf = self.ciphertext.to_vec();
70        buf.extend_from_slice(&self.tag);
71        let plaintext = key
72            .open_in_place(nonce, aead::Aad::from(aad.as_bytes()), &mut buf)
73            .map_err(|_| Error::with_message(ErrorKind::Other, "decryption failed"))?;
74        let plaintext = Bytes::copy_from_slice(plaintext);
75
76        Ok(plaintext)
77    }
78
79    /// Gets the key identifier.
80    pub fn kid(&self) -> Option<&str> {
81        self.header.kid.as_deref()
82    }
83}
84
85impl Encode for Jwe {
86    fn decode(value: &str) -> Result<Self> {
87        value.parse()
88    }
89
90    fn encode(&self) -> Result<String> {
91        Ok([
92            self.header.encode()?,
93            base64::encode_url_safe(&self.cek),
94            base64::encode_url_safe(&self.iv),
95            base64::encode_url_safe(&self.ciphertext),
96            base64::encode_url_safe(&self.tag),
97        ]
98        .join("."))
99    }
100}
101
102impl FromStr for Jwe {
103    type Err = Error;
104    fn from_str(s: &str) -> Result<Self> {
105        const PARTS_ERROR: &str = "JWE must have exactly 5 parts separated by periods";
106
107        fn is_base64url_char(c: char) -> bool {
108            c.is_ascii_alphanumeric() || c == '-' || c == '_'
109        }
110
111        let mut parts = [0usize; 6];
112        let mut current_part_start = 0;
113        for (i, c) in s.char_indices() {
114            if c == '.' {
115                if current_part_start >= 5 {
116                    return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
117                        PARTS_ERROR
118                    }));
119                }
120
121                parts[current_part_start + 1] = i + 1;
122                current_part_start += 1;
123            } else if !is_base64url_char(c) {
124                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
125                    "invalid character in JWE compact serialization"
126                }));
127            }
128        }
129
130        if current_part_start != 4 {
131            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
132                PARTS_ERROR
133            }));
134        }
135
136        parts[5] = s.len() + 1;
137        let header = &s[parts[0]..parts[1] - 1];
138        let cek = &s[parts[1]..parts[2] - 1];
139        let iv = &s[parts[2]..parts[3] - 1];
140        let ciphertext = &s[parts[3]..parts[4] - 1];
141        let tag = &s[parts[4]..parts[5] - 1];
142
143        let header =
144            Header::decode(header).with_context_fn(ErrorKind::InvalidData, || "invalid header")?;
145        let cek = base64::decode_url_safe(cek)
146            .with_context_fn(ErrorKind::InvalidData, || "invalid cek")?
147            .into();
148        let iv = base64::decode_url_safe(iv)
149            .with_context_fn(ErrorKind::InvalidData, || "invalid iv")?
150            .into();
151        let ciphertext = base64::decode_url_safe(ciphertext)
152            .with_context_fn(ErrorKind::InvalidData, || "invalid ciphertext")?
153            .into();
154        let tag = base64::decode_url_safe(tag)
155            .with_context_fn(ErrorKind::InvalidData, || "invalid tag")?
156            .into();
157
158        Ok(Jwe {
159            header,
160            cek,
161            iv,
162            ciphertext,
163            tag,
164        })
165    }
166}
167
168/// A JWE encryptor.
169///
170/// Only JWEs with key identifiers are supported, specifically from Key Vault.
171#[derive(Debug)]
172pub struct JweEncryptor<C, K> {
173    alg: Option<Algorithm>,
174    enc: Option<EncryptionAlgorithm>,
175    kid: Option<String>,
176    cek: Option<Bytes>,
177    iv: Option<Bytes>,
178    plaintext: Option<Bytes>,
179    phantom: PhantomData<(C, K)>,
180}
181
182impl<C, K> JweEncryptor<C, K> {
183    /// Sets the JWE [`Algorithm`].
184    pub fn alg(self, alg: Algorithm) -> Self {
185        Self {
186            alg: Some(alg),
187            ..self
188        }
189    }
190
191    /// Sets the JWE [`EncryptionAlgorithm`].
192    pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
193        Self {
194            enc: Some(enc),
195            ..self
196        }
197    }
198
199    /// Sets the JWE content encryption key.
200    pub fn cek(self, cek: &[u8]) -> Self {
201        Self {
202            cek: Some(Bytes::copy_from_slice(cek)),
203            ..self
204        }
205    }
206
207    /// Sets the JWE initialization vector.
208    pub fn iv(self, iv: &[u8]) -> Self {
209        Self {
210            iv: Some(Bytes::copy_from_slice(iv)),
211            ..self
212        }
213    }
214}
215
216impl<K> JweEncryptor<Unset, K> {
217    /// Sets the plaintext data.
218    pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
219        JweEncryptor::<Set, K> {
220            plaintext: Some(Bytes::copy_from_slice(plaintext)),
221            alg: self.alg,
222            enc: self.enc,
223            kid: self.kid,
224            cek: self.cek,
225            iv: self.iv,
226            phantom: PhantomData,
227        }
228    }
229
230    /// Sets the plaintext encoded as a string.
231    pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
232        JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
233    }
234}
235
236impl<C> JweEncryptor<C, Unset> {
237    /// Sets the JWE key identifier.
238    pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
239        JweEncryptor::<C, Set> {
240            kid: Some(kid.into()),
241            alg: self.alg,
242            enc: self.enc,
243            cek: self.cek,
244            iv: self.iv,
245            plaintext: self.plaintext,
246            phantom: PhantomData,
247        }
248    }
249}
250
251impl JweEncryptor<Set, Set> {
252    /// Encrypts the JWE.
253    pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
254    where
255        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
256    {
257        // Determine how big the CEK should be.
258        let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
259        let cipher: &'static aead::Algorithm = enc.try_into()?;
260
261        // Validate or generate the CEK.
262        let cek = match self.cek {
263            Some(v) if v.len() == cipher.key_len() => v,
264            Some(v) => {
265                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
266                    format!(
267                        "require key size of {} bytes, got {}",
268                        cipher.key_len(),
269                        v.len()
270                    )
271                }));
272            }
273            None => {
274                // Allocate enough space for largest supported cipher.
275                let mut buf = [0; 32];
276                rand::fill(&mut buf)?;
277                Bytes::copy_from_slice(&buf[0..cipher.key_len()])
278            }
279        };
280
281        let kid = self
282            .kid
283            .as_deref()
284            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
285        let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
286
287        // Encrypt the CEK so we get the full kid.
288        let result = wrap_key(kid, &alg, &cek).await?;
289
290        let header = Header {
291            alg,
292            enc: Some(enc.clone()),
293            kid: Some(result.kid),
294            typ: super::Type::JWE,
295        };
296        let aad = header.encode()?;
297
298        // All AES-GCM modes in aws-lc-rs use a fixed 96-bit (12-byte) nonce (aead::NONCE_LEN).
299        let iv = match self.iv {
300            Some(v) if v.len() == aead::NONCE_LEN => v,
301            Some(v) => {
302                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
303                    format!(
304                        "require iv size of {} bytes, got {}",
305                        aead::NONCE_LEN,
306                        v.len()
307                    )
308                }));
309            }
310            None => {
311                let mut buf = [0u8; aead::NONCE_LEN];
312                rand::fill(&mut buf)?;
313                Bytes::copy_from_slice(&buf)
314            }
315        };
316
317        // `LessSafeKey` is used instead of `SealingKey` + `NonceSequence` because the IV is
318        // stored in the JWE compact form and must be reproduced verbatim for decryption. The
319        // nonce-sequence APIs own nonce generation and do not expose a way to supply an
320        // arbitrary nonce without extra scaffolding.
321        let key = aead::LessSafeKey::new(
322            aead::UnboundKey::new(cipher, &cek)
323                .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid CEK"))?,
324        );
325        let nonce = aead::Nonce::try_assume_unique_for_key(&iv)
326            .map_err(|_| Error::with_message(ErrorKind::InvalidData, "invalid IV"))?;
327        let plaintext = self.plaintext.expect("expected plaintext");
328        // aws-lc-rs appends the authentication tag to the ciphertext buffer in-place.
329        let mut buf = plaintext.to_vec();
330        key.seal_in_place_append_tag(nonce, aead::Aad::from(aad.as_bytes()), &mut buf)
331            .map_err(|_| Error::with_message(ErrorKind::Other, "encryption failed"))?;
332        // Split the appended tag (last tag_len bytes) from the ciphertext.
333        let tag = buf.split_off(buf.len() - cipher.tag_len());
334        let ciphertext: Bytes = buf.into();
335
336        Ok(Jwe {
337            header,
338            cek: result.cek,
339            iv,
340            ciphertext,
341            tag: tag.into(),
342        })
343    }
344}
345
346impl<C, K> Default for JweEncryptor<C, K> {
347    fn default() -> Self {
348        Self {
349            alg: None,
350            enc: None,
351            kid: None,
352            cek: None,
353            iv: None,
354            plaintext: None,
355            phantom: PhantomData,
356        }
357    }
358}
359
360impl TryFrom<EncryptionAlgorithm> for &'static aead::Algorithm {
361    type Error = Error;
362    fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
363        (&value).try_into()
364    }
365}
366
367impl TryFrom<&EncryptionAlgorithm> for &'static aead::Algorithm {
368    type Error = Error;
369    fn try_from(value: &EncryptionAlgorithm) -> Result<&'static aead::Algorithm> {
370        match value {
371            EncryptionAlgorithm::A128GCM => Ok(&aead::AES_128_GCM),
372            EncryptionAlgorithm::A192GCM => Ok(&aead::AES_192_GCM),
373            EncryptionAlgorithm::A256GCM => Ok(&aead::AES_256_GCM),
374            EncryptionAlgorithm::Other(value) => {
375                Err(Error::with_message_fn(ErrorKind::InvalidData, || {
376                    format!("unsupported encryption algorithm {value}")
377                }))
378            }
379        }
380    }
381}
382
383impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
384    type Error = Error;
385    fn try_from(value: &Algorithm) -> Result<Self> {
386        match value {
387            Algorithm::RSA1_5 => Ok(Self::Rsa1_5),
388            Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
389            Algorithm::RSA_OAEP_256 => Ok(Self::RsaOaep256),
390            Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
391                format!("unsupported algorithm {s}")
392            })),
393        }
394    }
395}
396
397/// Result for a key wrap operation.
398#[derive(Debug)]
399pub struct WrapKeyResult {
400    /// The key identifier.
401    pub kid: String,
402
403    /// The content encryption key.
404    pub cek: Bytes,
405}
406
407impl TryFrom<KeyOperationResult> for WrapKeyResult {
408    type Error = Error;
409    fn try_from(value: KeyOperationResult) -> Result<Self> {
410        Ok(Self {
411            kid: value
412                .kid
413                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
414            cek: value
415                .result
416                .map(Into::into)
417                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
418        })
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use azure_core::Bytes;
426
427    #[test]
428    fn decode_invalid() {
429        assert!(
430            matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
431        );
432        assert!(
433            matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
434        );
435    }
436
437    #[test]
438    fn encode_decode_roundtrip() {
439        let jwe = Jwe {
440            header: Header {
441                alg: crate::jose::Algorithm::RSA_OAEP_256,
442                enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
443                kid: Some("test-key-id".to_string()),
444                typ: crate::jose::Type::JWE,
445            },
446            cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
447            iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
448            ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
449            tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
450        };
451
452        // cspell:disable-next-line
453        const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
454
455        let encoded = jwe.encode().expect("encode should succeed");
456        assert_eq!(encoded, EXPECTED);
457
458        let decoded = Jwe::decode(&encoded).expect("decode should succeed");
459        assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
460        assert_eq!(
461            decoded.header.enc,
462            Some(crate::jose::EncryptionAlgorithm::A128GCM)
463        );
464        assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
465        assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
466        assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
467        assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
468        assert_eq!(
469            decoded.ciphertext,
470            Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
471        );
472        assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
473    }
474
475    #[test]
476    fn from_str_success() {
477        // cspell:disable-next-line
478        let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
479        let jwe = Jwe::from_str(s).expect("should parse valid JWE");
480        assert_eq!(jwe.header.alg, Algorithm::RSA_OAEP_256);
481        assert_eq!(jwe.header.enc, Some(EncryptionAlgorithm::A128GCM));
482        assert_eq!(jwe.header.kid, Some("test-key-id".to_string()));
483        assert_eq!(jwe.header.typ, Type::JWE);
484    }
485
486    #[test]
487    fn from_str_invalid_character() {
488        // Insert an invalid character ('!') in the cek part
489        // cspell:disable-next-line
490        let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRW!eA.mrze8A.ASNFZw.iavN7w";
491        let err = Jwe::from_str(s).unwrap_err();
492        assert!(matches!(err.kind(), ErrorKind::InvalidData));
493        assert_eq!(
494            err.message(),
495            Some("invalid character in JWE compact serialization")
496        );
497    }
498
499    #[test]
500    fn from_str_too_few_periods() {
501        // Only 3 periods (4 parts)
502        let s = "a.b.c.d";
503        let err = Jwe::from_str(s).unwrap_err();
504        assert!(matches!(err.kind(), ErrorKind::InvalidData));
505        assert_eq!(
506            err.message(),
507            Some("JWE must have exactly 5 parts separated by periods")
508        );
509    }
510
511    #[test]
512    fn from_str_too_many_periods() {
513        // 5 periods (6 parts)
514        let s = "a.b.c.d.e.f";
515        let err = Jwe::from_str(s).unwrap_err();
516        assert!(matches!(err.kind(), ErrorKind::InvalidData));
517        assert_eq!(
518            err.message(),
519            Some("JWE must have exactly 5 parts separated by periods")
520        );
521    }
522
523    #[test]
524    fn from_str_invalid_header() {
525        // Valid base64url, but not a valid header
526        // cspell:disable-next-line
527        let s = "Zm9vYmFy.EjRWeA.mrze8A.ASNFZw.iavN7w";
528        let err = Jwe::from_str(s).unwrap_err();
529        assert!(matches!(err.kind(), ErrorKind::InvalidData));
530        assert_eq!(err.message(), Some("invalid header"));
531    }
532
533    #[test]
534    fn encryption_algorithm_cipher() {
535        let cipher: &'static aead::Algorithm = EncryptionAlgorithm::A128GCM
536            .try_into()
537            .expect("try_into should succeed");
538        assert_eq!(cipher.nonce_len(), 12);
539        assert_eq!(cipher.key_len(), 16);
540
541        let cipher: &'static aead::Algorithm = EncryptionAlgorithm::A192GCM
542            .try_into()
543            .expect("try_into should succeed");
544        assert_eq!(cipher.nonce_len(), 12);
545        assert_eq!(cipher.key_len(), 24);
546
547        let cipher: &'static aead::Algorithm = EncryptionAlgorithm::A256GCM
548            .try_into()
549            .expect("try_into should succeed");
550        assert_eq!(cipher.nonce_len(), 12);
551        assert_eq!(cipher.key_len(), 32);
552    }
553
554    #[tokio::test]
555    async fn encrypt_decrypt_roundtrip() {
556        let kid = "key-name";
557        let alg = Algorithm::RSA_OAEP;
558        let enc = EncryptionAlgorithm::A128GCM;
559        let plaintext = b"Hello, world!";
560
561        // wrap_key callback: asserts kid and enc, returns cek as-is
562        let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
563            assert_eq!(key_id, kid);
564            assert_eq!(wrap_alg, &alg);
565            Ok(crate::jose::jwe::WrapKeyResult {
566                kid: "key-name/key-version".into(),
567                cek: Bytes::copy_from_slice(cek),
568            })
569        };
570
571        // unwrap_key callback: asserts kid and enc, returns cek as-is
572        let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
573            assert_eq!(key_id, "key-name/key-version");
574            assert_eq!(wrap_alg, &alg);
575            Ok(crate::jose::jwe::WrapKeyResult {
576                kid: "key-name/key-version".into(),
577                cek: Bytes::copy_from_slice(cek),
578            })
579        };
580
581        let jwe = Jwe::encryptor()
582            .alg(alg.clone())
583            .enc(enc)
584            .kid(kid)
585            .plaintext(plaintext)
586            .encrypt(wrap_key)
587            .await
588            .expect("encryption should succeed");
589
590        let decrypted = jwe
591            .decrypt(unwrap_key)
592            .await
593            .expect("decryption should succeed");
594        assert_eq!(decrypted, plaintext.as_ref());
595    }
596}