akv_cli/jose/
jwe.rs

1// Copyright 2025 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4use crate::{
5    jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
6    Error, ErrorKind, Result, ResultExt as _,
7};
8use azure_core::{base64, Bytes};
9use azure_security_keyvault_keys::models::KeyOperationResult;
10use openssl::{
11    rand,
12    symm::{self, Cipher},
13};
14use std::{marker::PhantomData, str::FromStr};
15
16/// A JSON Web Encryption (JWE) structure.
17#[derive(Debug)]
18pub struct Jwe {
19    header: Header,
20    cek: Bytes,
21    iv: Bytes,
22    ciphertext: Bytes,
23    tag: Bytes,
24}
25
26impl Jwe {
27    pub fn encryptor() -> JweEncryptor<Unset, Unset> {
28        JweEncryptor::default()
29    }
30
31    pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
32    where
33        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
34    {
35        if self.header.typ != Type::JWE {
36            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
37                format!("expected JWE, got {}", self.header.typ)
38            }));
39        }
40
41        // Decrypt the CEK.
42        let key_id = self
43            .header
44            .kid
45            .as_deref()
46            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
47        let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
48
49        let enc = self
50            .header
51            .enc
52            .as_ref()
53            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
54        let cipher: Cipher = enc.try_into()?;
55        let aad = self.header.encode()?;
56
57        let plaintext: Bytes = symm::decrypt_aead(
58            cipher,
59            &result.cek,
60            Some(&self.iv),
61            aad.as_bytes(),
62            &self.ciphertext,
63            &self.tag,
64        )?
65        .into();
66
67        Ok(plaintext)
68    }
69
70    pub fn kid(&self) -> Option<&str> {
71        self.header.kid.as_deref()
72    }
73}
74
75impl Encode for Jwe {
76    fn decode(value: &str) -> Result<Self> {
77        value.parse()
78    }
79
80    fn encode(&self) -> Result<String> {
81        Ok([
82            self.header.encode()?,
83            base64::encode_url_safe(&self.cek),
84            base64::encode_url_safe(&self.iv),
85            base64::encode_url_safe(&self.ciphertext),
86            base64::encode_url_safe(&self.tag),
87        ]
88        .join("."))
89    }
90}
91
92impl FromStr for Jwe {
93    type Err = Error;
94    fn from_str(s: &str) -> Result<Self> {
95        const PARTS_ERROR: &str = "JWE must have exactly 5 parts separated by periods";
96
97        fn is_base64url_char(c: char) -> bool {
98            c.is_ascii_alphanumeric() || c == '-' || c == '_'
99        }
100
101        let mut parts = [0usize; 6];
102        let mut current_part_start = 0;
103        for (i, c) in s.char_indices() {
104            if c == '.' {
105                if current_part_start >= 5 {
106                    return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
107                        PARTS_ERROR
108                    }));
109                }
110
111                parts[current_part_start + 1] = i + 1;
112                current_part_start += 1;
113            } else if !is_base64url_char(c) {
114                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
115                    "invalid character in JWE compact serialization"
116                }));
117            }
118        }
119
120        if current_part_start != 4 {
121            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
122                PARTS_ERROR
123            }));
124        }
125
126        parts[5] = s.len() + 1;
127        let header = &s[parts[0]..parts[1] - 1];
128        let cek = &s[parts[1]..parts[2] - 1];
129        let iv = &s[parts[2]..parts[3] - 1];
130        let ciphertext = &s[parts[3]..parts[4] - 1];
131        let tag = &s[parts[4]..parts[5] - 1];
132
133        let header =
134            Header::decode(header).with_context_fn(ErrorKind::InvalidData, || "invalid header")?;
135        let cek = base64::decode_url_safe(cek)
136            .with_context_fn(ErrorKind::InvalidData, || "invalid cek")?
137            .into();
138        let iv = base64::decode_url_safe(iv)
139            .with_context_fn(ErrorKind::InvalidData, || "invalid iv")?
140            .into();
141        let ciphertext = base64::decode_url_safe(ciphertext)
142            .with_context_fn(ErrorKind::InvalidData, || "invalid ciphertext")?
143            .into();
144        let tag = base64::decode_url_safe(tag)
145            .with_context_fn(ErrorKind::InvalidData, || "invalid tag")?
146            .into();
147
148        Ok(Jwe {
149            header,
150            cek,
151            iv,
152            ciphertext,
153            tag,
154        })
155    }
156}
157
158#[derive(Debug)]
159pub struct JweEncryptor<C, K> {
160    alg: Option<Algorithm>,
161    enc: Option<EncryptionAlgorithm>,
162    kid: Option<String>,
163    cek: Option<Bytes>,
164    iv: Option<Bytes>,
165    plaintext: Option<Bytes>,
166    phantom: PhantomData<(C, K)>,
167}
168
169impl<C, K> JweEncryptor<C, K> {
170    pub fn alg(self, alg: Algorithm) -> Self {
171        Self {
172            alg: Some(alg),
173            ..self
174        }
175    }
176
177    pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
178        Self {
179            enc: Some(enc),
180            ..self
181        }
182    }
183
184    pub fn cek(self, cek: &[u8]) -> Self {
185        Self {
186            cek: Some(Bytes::copy_from_slice(cek)),
187            ..self
188        }
189    }
190
191    pub fn iv(self, iv: &[u8]) -> Self {
192        Self {
193            iv: Some(Bytes::copy_from_slice(iv)),
194            ..self
195        }
196    }
197}
198
199impl<K> JweEncryptor<Unset, K> {
200    pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
201        JweEncryptor::<Set, K> {
202            plaintext: Some(Bytes::copy_from_slice(plaintext)),
203            alg: self.alg,
204            enc: self.enc,
205            kid: self.kid,
206            cek: self.cek,
207            iv: self.iv,
208            phantom: PhantomData,
209        }
210    }
211
212    pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
213        JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
214    }
215}
216
217impl<C> JweEncryptor<C, Unset> {
218    pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
219        JweEncryptor::<C, Set> {
220            kid: Some(kid.into()),
221            alg: self.alg,
222            enc: self.enc,
223            cek: self.cek,
224            iv: self.iv,
225            plaintext: self.plaintext,
226            phantom: PhantomData,
227        }
228    }
229}
230
231impl JweEncryptor<Set, Set> {
232    pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
233    where
234        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
235    {
236        // Determine how big the CEK should be.
237        let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
238        let cipher: Cipher = enc.try_into()?;
239
240        // Validate or generate the CEK.
241        let cek = match self.cek {
242            Some(v) if v.len() == cipher.key_len() => v,
243            Some(v) => {
244                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
245                    format!(
246                        "require key size of {} bytes, got {}",
247                        cipher.key_len(),
248                        v.len()
249                    )
250                }));
251            }
252            None => {
253                // Allocate enough space for largest supported cipher.
254                let mut buf = [0; 32];
255                rand::rand_bytes(&mut buf)?;
256                Bytes::copy_from_slice(&buf[0..cipher.key_len()])
257            }
258        };
259
260        let kid = self
261            .kid
262            .as_deref()
263            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
264        let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
265
266        // Encrypt the CEK so we get the full kid.
267        let result = wrap_key(kid, &alg, &cek).await?;
268
269        let header = Header {
270            alg,
271            enc: Some(enc.clone()),
272            kid: Some(result.kid),
273            typ: super::Type::JWE,
274        };
275        let aad = header.encode()?;
276
277        // Generate the IV.
278        let iv_len = cipher.iv_len().ok_or_else(|| {
279            Error::with_message(
280                ErrorKind::InvalidData,
281                format!("expected iv length for cipher {}", &enc),
282            )
283        })?;
284        let iv = match self.iv {
285            Some(v) if v.len() == iv_len => v,
286            Some(v) => {
287                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
288                    format!("require iv size of {} bytes, got {}", iv_len, v.len())
289                }));
290            }
291            None => {
292                // Allocate enough space for largest supported cipher.
293                let mut buf = [0; 12];
294                rand::rand_bytes(&mut buf)?;
295                Bytes::copy_from_slice(&buf[0..iv_len])
296            }
297        };
298
299        let plaintext = self.plaintext.expect("expected plaintext");
300        let mut tag = [0; 16];
301        let ciphertext: Bytes = symm::encrypt_aead(
302            cipher,
303            &cek,
304            Some(&iv),
305            aad.as_bytes(),
306            &plaintext,
307            &mut tag,
308        )?
309        .into();
310
311        Ok(Jwe {
312            header,
313            cek: result.cek,
314            iv,
315            ciphertext,
316            tag: Bytes::copy_from_slice(&tag),
317        })
318    }
319}
320
321impl<C, K> Default for JweEncryptor<C, K> {
322    fn default() -> Self {
323        Self {
324            alg: None,
325            enc: None,
326            kid: None,
327            cek: None,
328            iv: None,
329            plaintext: None,
330            phantom: PhantomData,
331        }
332    }
333}
334
335impl TryFrom<EncryptionAlgorithm> for Cipher {
336    type Error = Error;
337    fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
338        (&value).try_into()
339    }
340}
341
342impl TryFrom<&EncryptionAlgorithm> for Cipher {
343    type Error = Error;
344    fn try_from(value: &EncryptionAlgorithm) -> Result<Cipher> {
345        match value {
346            EncryptionAlgorithm::A128GCM => Ok(Cipher::aes_128_gcm()),
347            EncryptionAlgorithm::A192GCM => Ok(Cipher::aes_192_gcm()),
348            EncryptionAlgorithm::A256GCM => Ok(Cipher::aes_256_gcm()),
349            EncryptionAlgorithm::Other(value) => {
350                Err(Error::with_message_fn(ErrorKind::InvalidData, || {
351                    format!("unsupported encryption algorithm {value}")
352                }))
353            }
354        }
355    }
356}
357
358impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
359    type Error = Error;
360    fn try_from(value: &Algorithm) -> Result<Self> {
361        match value {
362            Algorithm::RSA1_5 => Ok(Self::RSA1_5),
363            Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
364            Algorithm::RSA_OAEP_256 => Ok(Self::RsaOAEP256),
365            Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
366                format!("unsupported algorithm {s}")
367            })),
368        }
369    }
370}
371
372#[derive(Debug)]
373pub struct WrapKeyResult {
374    pub kid: String,
375    pub cek: Bytes,
376}
377
378impl TryFrom<KeyOperationResult> for WrapKeyResult {
379    type Error = Error;
380    fn try_from(value: KeyOperationResult) -> Result<Self> {
381        Ok(Self {
382            kid: value
383                .kid
384                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
385            cek: value
386                .result
387                .map(Into::into)
388                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
389        })
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use azure_core::Bytes;
397
398    #[test]
399    fn decode_invalid() {
400        assert!(
401            matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
402        );
403        assert!(
404            matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
405        );
406    }
407
408    #[test]
409    fn encode_decode_roundtrip() {
410        let jwe = Jwe {
411            header: Header {
412                alg: crate::jose::Algorithm::RSA_OAEP_256,
413                enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
414                kid: Some("test-key-id".to_string()),
415                typ: crate::jose::Type::JWE,
416            },
417            cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
418            iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
419            ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
420            tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
421        };
422
423        // cspell:disable-next-line
424        const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
425
426        let encoded = jwe.encode().expect("encode should succeed");
427        assert_eq!(encoded, EXPECTED);
428
429        let decoded = Jwe::decode(&encoded).expect("decode should succeed");
430        assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
431        assert_eq!(
432            decoded.header.enc,
433            Some(crate::jose::EncryptionAlgorithm::A128GCM)
434        );
435        assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
436        assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
437        assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
438        assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
439        assert_eq!(
440            decoded.ciphertext,
441            Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
442        );
443        assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
444    }
445
446    #[test]
447    fn from_str_success() {
448        // cspell:disable-next-line
449        let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
450        let jwe = Jwe::from_str(s).expect("should parse valid JWE");
451        assert_eq!(jwe.header.alg, Algorithm::RSA_OAEP_256);
452        assert_eq!(jwe.header.enc, Some(EncryptionAlgorithm::A128GCM));
453        assert_eq!(jwe.header.kid, Some("test-key-id".to_string()));
454        assert_eq!(jwe.header.typ, Type::JWE);
455    }
456
457    #[test]
458    fn from_str_invalid_character() {
459        // Insert an invalid character ('!') in the cek part
460        // cspell:disable-next-line
461        let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRW!eA.mrze8A.ASNFZw.iavN7w";
462        let err = Jwe::from_str(s).unwrap_err();
463        assert!(matches!(err.kind(), ErrorKind::InvalidData));
464        assert_eq!(
465            err.message(),
466            Some("invalid character in JWE compact serialization")
467        );
468    }
469
470    #[test]
471    fn from_str_too_few_periods() {
472        // Only 3 periods (4 parts)
473        let s = "a.b.c.d";
474        let err = Jwe::from_str(s).unwrap_err();
475        assert!(matches!(err.kind(), ErrorKind::InvalidData));
476        assert_eq!(
477            err.message(),
478            Some("JWE must have exactly 5 parts separated by periods")
479        );
480    }
481
482    #[test]
483    fn from_str_too_many_periods() {
484        // 5 periods (6 parts)
485        let s = "a.b.c.d.e.f";
486        let err = Jwe::from_str(s).unwrap_err();
487        assert!(matches!(err.kind(), ErrorKind::InvalidData));
488        assert_eq!(
489            err.message(),
490            Some("JWE must have exactly 5 parts separated by periods")
491        );
492    }
493
494    #[test]
495    fn from_str_invalid_header() {
496        // Valid base64url, but not a valid header
497        // cspell:disable-next-line
498        let s = "Zm9vYmFy.EjRWeA.mrze8A.ASNFZw.iavN7w";
499        let err = Jwe::from_str(s).unwrap_err();
500        assert!(matches!(err.kind(), ErrorKind::InvalidData));
501        assert_eq!(err.message(), Some("invalid header"));
502    }
503
504    #[test]
505    fn encryption_algorithm_cipher() {
506        let cipher: Cipher = EncryptionAlgorithm::A128GCM
507            .try_into()
508            .expect("try_into should succeed");
509        assert_eq!(cipher.iv_len(), Some(12));
510        assert_eq!(cipher.key_len(), 16);
511
512        let cipher: Cipher = EncryptionAlgorithm::A192GCM
513            .try_into()
514            .expect("try_into should succeed");
515        assert_eq!(cipher.iv_len(), Some(12));
516        assert_eq!(cipher.key_len(), 24);
517
518        let cipher: Cipher = EncryptionAlgorithm::A256GCM
519            .try_into()
520            .expect("try_into should succeed");
521        assert_eq!(cipher.iv_len(), Some(12));
522        assert_eq!(cipher.key_len(), 32);
523    }
524
525    #[tokio::test]
526    async fn encrypt_decrypt_roundtrip() {
527        let kid = "key-name";
528        let alg = Algorithm::RSA_OAEP;
529        let enc = EncryptionAlgorithm::A128GCM;
530        let plaintext = b"Hello, world!";
531
532        // wrap_key callback: asserts kid and enc, returns cek as-is
533        let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
534            assert_eq!(key_id, kid);
535            assert_eq!(wrap_alg, &alg);
536            Ok(crate::jose::jwe::WrapKeyResult {
537                kid: "key-name/key-version".into(),
538                cek: Bytes::copy_from_slice(cek),
539            })
540        };
541
542        // unwrap_key callback: asserts kid and enc, returns cek as-is
543        let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
544            assert_eq!(key_id, "key-name/key-version");
545            assert_eq!(wrap_alg, &alg);
546            Ok(crate::jose::jwe::WrapKeyResult {
547                kid: "key-name/key-version".into(),
548                cek: Bytes::copy_from_slice(cek),
549            })
550        };
551
552        let jwe = Jwe::encryptor()
553            .alg(alg.clone())
554            .enc(enc)
555            .kid(kid)
556            .plaintext(plaintext)
557            .encrypt(wrap_key)
558            .await
559            .expect("encryption should succeed");
560
561        let decrypted = jwe
562            .decrypt(unwrap_key)
563            .await
564            .expect("decryption should succeed");
565        assert_eq!(decrypted, plaintext.as_ref());
566    }
567}