akv_cli/jose/
jwe.rs

1// Copyright 2025 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4//! JSON Web Encryption types.
5
6use crate::{
7    jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
8    Error, ErrorKind, Result, ResultExt as _,
9};
10use azure_core::{base64, Bytes};
11use azure_security_keyvault_keys::models::KeyOperationResult;
12use openssl::{
13    rand,
14    symm::{self, Cipher},
15};
16use std::{marker::PhantomData, str::FromStr};
17
18/// A JSON Web Encryption (JWE) structure.
19#[derive(Debug)]
20pub struct Jwe {
21    header: Header,
22    cek: Bytes,
23    iv: Bytes,
24    ciphertext: Bytes,
25    tag: Bytes,
26}
27
28impl Jwe {
29    /// Gets a JWE encryptor.
30    pub fn encryptor() -> JweEncryptor<Unset, Unset> {
31        JweEncryptor::default()
32    }
33
34    /// Decrypts a JWE.
35    pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
36    where
37        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
38    {
39        if self.header.typ != Type::JWE {
40            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
41                format!("expected JWE, got {}", self.header.typ)
42            }));
43        }
44
45        // Decrypt the CEK.
46        let key_id = self
47            .header
48            .kid
49            .as_deref()
50            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
51        let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
52
53        let enc = self
54            .header
55            .enc
56            .as_ref()
57            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
58        let cipher: Cipher = enc.try_into()?;
59        let aad = self.header.encode()?;
60
61        let plaintext: Bytes = symm::decrypt_aead(
62            cipher,
63            &result.cek,
64            Some(&self.iv),
65            aad.as_bytes(),
66            &self.ciphertext,
67            &self.tag,
68        )?
69        .into();
70
71        Ok(plaintext)
72    }
73
74    /// Gets the key identifier.
75    pub fn kid(&self) -> Option<&str> {
76        self.header.kid.as_deref()
77    }
78}
79
80impl Encode for Jwe {
81    fn decode(value: &str) -> Result<Self> {
82        value.parse()
83    }
84
85    fn encode(&self) -> Result<String> {
86        Ok([
87            self.header.encode()?,
88            base64::encode_url_safe(&self.cek),
89            base64::encode_url_safe(&self.iv),
90            base64::encode_url_safe(&self.ciphertext),
91            base64::encode_url_safe(&self.tag),
92        ]
93        .join("."))
94    }
95}
96
97impl FromStr for Jwe {
98    type Err = Error;
99    fn from_str(s: &str) -> Result<Self> {
100        const PARTS_ERROR: &str = "JWE must have exactly 5 parts separated by periods";
101
102        fn is_base64url_char(c: char) -> bool {
103            c.is_ascii_alphanumeric() || c == '-' || c == '_'
104        }
105
106        let mut parts = [0usize; 6];
107        let mut current_part_start = 0;
108        for (i, c) in s.char_indices() {
109            if c == '.' {
110                if current_part_start >= 5 {
111                    return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
112                        PARTS_ERROR
113                    }));
114                }
115
116                parts[current_part_start + 1] = i + 1;
117                current_part_start += 1;
118            } else if !is_base64url_char(c) {
119                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
120                    "invalid character in JWE compact serialization"
121                }));
122            }
123        }
124
125        if current_part_start != 4 {
126            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
127                PARTS_ERROR
128            }));
129        }
130
131        parts[5] = s.len() + 1;
132        let header = &s[parts[0]..parts[1] - 1];
133        let cek = &s[parts[1]..parts[2] - 1];
134        let iv = &s[parts[2]..parts[3] - 1];
135        let ciphertext = &s[parts[3]..parts[4] - 1];
136        let tag = &s[parts[4]..parts[5] - 1];
137
138        let header =
139            Header::decode(header).with_context_fn(ErrorKind::InvalidData, || "invalid header")?;
140        let cek = base64::decode_url_safe(cek)
141            .with_context_fn(ErrorKind::InvalidData, || "invalid cek")?
142            .into();
143        let iv = base64::decode_url_safe(iv)
144            .with_context_fn(ErrorKind::InvalidData, || "invalid iv")?
145            .into();
146        let ciphertext = base64::decode_url_safe(ciphertext)
147            .with_context_fn(ErrorKind::InvalidData, || "invalid ciphertext")?
148            .into();
149        let tag = base64::decode_url_safe(tag)
150            .with_context_fn(ErrorKind::InvalidData, || "invalid tag")?
151            .into();
152
153        Ok(Jwe {
154            header,
155            cek,
156            iv,
157            ciphertext,
158            tag,
159        })
160    }
161}
162
163/// A JWE encryptor.
164///
165/// Only JWEs with key identifiers are supported, specifically from Key Vault.
166#[derive(Debug)]
167pub struct JweEncryptor<C, K> {
168    alg: Option<Algorithm>,
169    enc: Option<EncryptionAlgorithm>,
170    kid: Option<String>,
171    cek: Option<Bytes>,
172    iv: Option<Bytes>,
173    plaintext: Option<Bytes>,
174    phantom: PhantomData<(C, K)>,
175}
176
177impl<C, K> JweEncryptor<C, K> {
178    /// Sets the JWE [`Algorithm`].
179    pub fn alg(self, alg: Algorithm) -> Self {
180        Self {
181            alg: Some(alg),
182            ..self
183        }
184    }
185
186    /// Sets the JWE [`EncryptionAlgorithm`].
187    pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
188        Self {
189            enc: Some(enc),
190            ..self
191        }
192    }
193
194    /// Sets the JWE content encryption key.
195    pub fn cek(self, cek: &[u8]) -> Self {
196        Self {
197            cek: Some(Bytes::copy_from_slice(cek)),
198            ..self
199        }
200    }
201
202    /// Sets the JWE initialization vector.
203    pub fn iv(self, iv: &[u8]) -> Self {
204        Self {
205            iv: Some(Bytes::copy_from_slice(iv)),
206            ..self
207        }
208    }
209}
210
211impl<K> JweEncryptor<Unset, K> {
212    /// Sets the plaintext data.
213    pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
214        JweEncryptor::<Set, K> {
215            plaintext: Some(Bytes::copy_from_slice(plaintext)),
216            alg: self.alg,
217            enc: self.enc,
218            kid: self.kid,
219            cek: self.cek,
220            iv: self.iv,
221            phantom: PhantomData,
222        }
223    }
224
225    /// Sets the plaintext encoded as a string.
226    pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
227        JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
228    }
229}
230
231impl<C> JweEncryptor<C, Unset> {
232    /// Sets the JWE key identifier.
233    pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
234        JweEncryptor::<C, Set> {
235            kid: Some(kid.into()),
236            alg: self.alg,
237            enc: self.enc,
238            cek: self.cek,
239            iv: self.iv,
240            plaintext: self.plaintext,
241            phantom: PhantomData,
242        }
243    }
244}
245
246impl JweEncryptor<Set, Set> {
247    /// Encrypts the JWE.
248    pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
249    where
250        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
251    {
252        // Determine how big the CEK should be.
253        let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
254        let cipher: Cipher = enc.try_into()?;
255
256        // Validate or generate the CEK.
257        let cek = match self.cek {
258            Some(v) if v.len() == cipher.key_len() => v,
259            Some(v) => {
260                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
261                    format!(
262                        "require key size of {} bytes, got {}",
263                        cipher.key_len(),
264                        v.len()
265                    )
266                }));
267            }
268            None => {
269                // Allocate enough space for largest supported cipher.
270                let mut buf = [0; 32];
271                rand::rand_bytes(&mut buf)?;
272                Bytes::copy_from_slice(&buf[0..cipher.key_len()])
273            }
274        };
275
276        let kid = self
277            .kid
278            .as_deref()
279            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
280        let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
281
282        // Encrypt the CEK so we get the full kid.
283        let result = wrap_key(kid, &alg, &cek).await?;
284
285        let header = Header {
286            alg,
287            enc: Some(enc.clone()),
288            kid: Some(result.kid),
289            typ: super::Type::JWE,
290        };
291        let aad = header.encode()?;
292
293        // Generate the IV.
294        let iv_len = cipher.iv_len().ok_or_else(|| {
295            Error::with_message(
296                ErrorKind::InvalidData,
297                format!("expected iv length for cipher {}", &enc),
298            )
299        })?;
300        let iv = match self.iv {
301            Some(v) if v.len() == iv_len => v,
302            Some(v) => {
303                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
304                    format!("require iv size of {} bytes, got {}", iv_len, v.len())
305                }));
306            }
307            None => {
308                // Allocate enough space for largest supported cipher.
309                let mut buf = [0; 12];
310                rand::rand_bytes(&mut buf)?;
311                Bytes::copy_from_slice(&buf[0..iv_len])
312            }
313        };
314
315        let plaintext = self.plaintext.expect("expected plaintext");
316        let mut tag = [0; 16];
317        let ciphertext: Bytes = symm::encrypt_aead(
318            cipher,
319            &cek,
320            Some(&iv),
321            aad.as_bytes(),
322            &plaintext,
323            &mut tag,
324        )?
325        .into();
326
327        Ok(Jwe {
328            header,
329            cek: result.cek,
330            iv,
331            ciphertext,
332            tag: Bytes::copy_from_slice(&tag),
333        })
334    }
335}
336
337impl<C, K> Default for JweEncryptor<C, K> {
338    fn default() -> Self {
339        Self {
340            alg: None,
341            enc: None,
342            kid: None,
343            cek: None,
344            iv: None,
345            plaintext: None,
346            phantom: PhantomData,
347        }
348    }
349}
350
351impl TryFrom<EncryptionAlgorithm> for Cipher {
352    type Error = Error;
353    fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
354        (&value).try_into()
355    }
356}
357
358impl TryFrom<&EncryptionAlgorithm> for Cipher {
359    type Error = Error;
360    fn try_from(value: &EncryptionAlgorithm) -> Result<Cipher> {
361        match value {
362            EncryptionAlgorithm::A128GCM => Ok(Cipher::aes_128_gcm()),
363            EncryptionAlgorithm::A192GCM => Ok(Cipher::aes_192_gcm()),
364            EncryptionAlgorithm::A256GCM => Ok(Cipher::aes_256_gcm()),
365            EncryptionAlgorithm::Other(value) => {
366                Err(Error::with_message_fn(ErrorKind::InvalidData, || {
367                    format!("unsupported encryption algorithm {value}")
368                }))
369            }
370        }
371    }
372}
373
374impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
375    type Error = Error;
376    fn try_from(value: &Algorithm) -> Result<Self> {
377        match value {
378            Algorithm::RSA1_5 => Ok(Self::Rsa1_5),
379            Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
380            Algorithm::RSA_OAEP_256 => Ok(Self::RsaOaep256),
381            Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
382                format!("unsupported algorithm {s}")
383            })),
384        }
385    }
386}
387
388/// Result for a key wrap operation.
389#[derive(Debug)]
390pub struct WrapKeyResult {
391    /// The key identifier.
392    pub kid: String,
393
394    /// The content encryption key.
395    pub cek: Bytes,
396}
397
398impl TryFrom<KeyOperationResult> for WrapKeyResult {
399    type Error = Error;
400    fn try_from(value: KeyOperationResult) -> Result<Self> {
401        Ok(Self {
402            kid: value
403                .kid
404                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
405            cek: value
406                .result
407                .map(Into::into)
408                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
409        })
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use azure_core::Bytes;
417
418    #[test]
419    fn decode_invalid() {
420        assert!(
421            matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
422        );
423        assert!(
424            matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("JWE must have exactly 5 parts separated by periods"))
425        );
426    }
427
428    #[test]
429    fn encode_decode_roundtrip() {
430        let jwe = Jwe {
431            header: Header {
432                alg: crate::jose::Algorithm::RSA_OAEP_256,
433                enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
434                kid: Some("test-key-id".to_string()),
435                typ: crate::jose::Type::JWE,
436            },
437            cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
438            iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
439            ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
440            tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
441        };
442
443        // cspell:disable-next-line
444        const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
445
446        let encoded = jwe.encode().expect("encode should succeed");
447        assert_eq!(encoded, EXPECTED);
448
449        let decoded = Jwe::decode(&encoded).expect("decode should succeed");
450        assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
451        assert_eq!(
452            decoded.header.enc,
453            Some(crate::jose::EncryptionAlgorithm::A128GCM)
454        );
455        assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
456        assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
457        assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
458        assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
459        assert_eq!(
460            decoded.ciphertext,
461            Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
462        );
463        assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
464    }
465
466    #[test]
467    fn from_str_success() {
468        // cspell:disable-next-line
469        let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
470        let jwe = Jwe::from_str(s).expect("should parse valid JWE");
471        assert_eq!(jwe.header.alg, Algorithm::RSA_OAEP_256);
472        assert_eq!(jwe.header.enc, Some(EncryptionAlgorithm::A128GCM));
473        assert_eq!(jwe.header.kid, Some("test-key-id".to_string()));
474        assert_eq!(jwe.header.typ, Type::JWE);
475    }
476
477    #[test]
478    fn from_str_invalid_character() {
479        // Insert an invalid character ('!') in the cek part
480        // cspell:disable-next-line
481        let s = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRW!eA.mrze8A.ASNFZw.iavN7w";
482        let err = Jwe::from_str(s).unwrap_err();
483        assert!(matches!(err.kind(), ErrorKind::InvalidData));
484        assert_eq!(
485            err.message(),
486            Some("invalid character in JWE compact serialization")
487        );
488    }
489
490    #[test]
491    fn from_str_too_few_periods() {
492        // Only 3 periods (4 parts)
493        let s = "a.b.c.d";
494        let err = Jwe::from_str(s).unwrap_err();
495        assert!(matches!(err.kind(), ErrorKind::InvalidData));
496        assert_eq!(
497            err.message(),
498            Some("JWE must have exactly 5 parts separated by periods")
499        );
500    }
501
502    #[test]
503    fn from_str_too_many_periods() {
504        // 5 periods (6 parts)
505        let s = "a.b.c.d.e.f";
506        let err = Jwe::from_str(s).unwrap_err();
507        assert!(matches!(err.kind(), ErrorKind::InvalidData));
508        assert_eq!(
509            err.message(),
510            Some("JWE must have exactly 5 parts separated by periods")
511        );
512    }
513
514    #[test]
515    fn from_str_invalid_header() {
516        // Valid base64url, but not a valid header
517        // cspell:disable-next-line
518        let s = "Zm9vYmFy.EjRWeA.mrze8A.ASNFZw.iavN7w";
519        let err = Jwe::from_str(s).unwrap_err();
520        assert!(matches!(err.kind(), ErrorKind::InvalidData));
521        assert_eq!(err.message(), Some("invalid header"));
522    }
523
524    #[test]
525    fn encryption_algorithm_cipher() {
526        let cipher: Cipher = EncryptionAlgorithm::A128GCM
527            .try_into()
528            .expect("try_into should succeed");
529        assert_eq!(cipher.iv_len(), Some(12));
530        assert_eq!(cipher.key_len(), 16);
531
532        let cipher: Cipher = EncryptionAlgorithm::A192GCM
533            .try_into()
534            .expect("try_into should succeed");
535        assert_eq!(cipher.iv_len(), Some(12));
536        assert_eq!(cipher.key_len(), 24);
537
538        let cipher: Cipher = EncryptionAlgorithm::A256GCM
539            .try_into()
540            .expect("try_into should succeed");
541        assert_eq!(cipher.iv_len(), Some(12));
542        assert_eq!(cipher.key_len(), 32);
543    }
544
545    #[tokio::test]
546    async fn encrypt_decrypt_roundtrip() {
547        let kid = "key-name";
548        let alg = Algorithm::RSA_OAEP;
549        let enc = EncryptionAlgorithm::A128GCM;
550        let plaintext = b"Hello, world!";
551
552        // wrap_key callback: asserts kid and enc, returns cek as-is
553        let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
554            assert_eq!(key_id, kid);
555            assert_eq!(wrap_alg, &alg);
556            Ok(crate::jose::jwe::WrapKeyResult {
557                kid: "key-name/key-version".into(),
558                cek: Bytes::copy_from_slice(cek),
559            })
560        };
561
562        // unwrap_key callback: asserts kid and enc, returns cek as-is
563        let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
564            assert_eq!(key_id, "key-name/key-version");
565            assert_eq!(wrap_alg, &alg);
566            Ok(crate::jose::jwe::WrapKeyResult {
567                kid: "key-name/key-version".into(),
568                cek: Bytes::copy_from_slice(cek),
569            })
570        };
571
572        let jwe = Jwe::encryptor()
573            .alg(alg.clone())
574            .enc(enc)
575            .kid(kid)
576            .plaintext(plaintext)
577            .encrypt(wrap_key)
578            .await
579            .expect("encryption should succeed");
580
581        let decrypted = jwe
582            .decrypt(unwrap_key)
583            .await
584            .expect("decryption should succeed");
585        assert_eq!(decrypted, plaintext.as_ref());
586    }
587}