akv_cli/jose/
jwe.rs

1// Copyright 2025 Heath Stewart.
2// Licensed under the MIT License. See LICENSE.txt in the project root for license information.
3
4use crate::{
5    jose::{Algorithm, Encode, EncryptionAlgorithm, Header, Set, Type, Unset},
6    Error, ErrorKind, Result,
7};
8use azure_core::{base64, Bytes};
9use azure_security_keyvault_keys::models::KeyOperationResult;
10use openssl::{
11    rand,
12    symm::{self, Cipher},
13};
14use std::marker::PhantomData;
15
16/// A JSON Web Encryption (JWE) structure.
17#[derive(Debug)]
18pub struct Jwe {
19    header: Header,
20    cek: Bytes,
21    iv: Bytes,
22    ciphertext: Bytes,
23    tag: Bytes,
24}
25
26impl Jwe {
27    pub fn encryptor() -> JweEncryptor<Unset, Unset> {
28        JweEncryptor::default()
29    }
30
31    pub async fn decrypt<F>(self, unwrap_key: F) -> Result<Bytes>
32    where
33        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
34    {
35        if self.header.typ != Type::JWE {
36            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
37                format!("expected JWE, got {}", self.header.typ)
38            }));
39        }
40
41        // Decrypt the CEK.
42        let key_id = self
43            .header
44            .kid
45            .as_deref()
46            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
47        let result = unwrap_key(key_id, &self.header.alg, &self.cek).await?;
48
49        let enc = self
50            .header
51            .enc
52            .as_ref()
53            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected enc"))?;
54        let cipher: Cipher = enc.try_into()?;
55        let aad = self.header.encode()?;
56
57        let plaintext: Bytes = symm::decrypt_aead(
58            cipher,
59            &result.cek,
60            Some(&self.iv),
61            aad.as_bytes(),
62            &self.ciphertext,
63            &self.tag,
64        )?
65        .into();
66
67        Ok(plaintext)
68    }
69
70    pub fn kid(&self) -> Option<&str> {
71        self.header.kid.as_deref()
72    }
73}
74
75impl Encode for Jwe {
76    fn decode(value: &str) -> Result<Self> {
77        let parts: Vec<_> = value.split(".").collect();
78        if parts.len() != 5 {
79            return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
80                format!("invalid compact JWE: expected 5 parts, got {}", parts.len())
81            }));
82        }
83
84        Ok(Self {
85            header: Header::decode(parts[0])?,
86            cek: base64::decode_url_safe(parts[1])?.into(),
87            iv: base64::decode_url_safe(parts[2])?.into(),
88            ciphertext: base64::decode_url_safe(parts[3])?.into(),
89            tag: base64::decode_url_safe(parts[4])?.into(),
90        })
91    }
92
93    fn encode(&self) -> Result<String> {
94        Ok([
95            self.header.encode()?,
96            base64::encode_url_safe(&self.cek),
97            base64::encode_url_safe(&self.iv),
98            base64::encode_url_safe(&self.ciphertext),
99            base64::encode_url_safe(&self.tag),
100        ]
101        .join("."))
102    }
103}
104
105#[derive(Debug)]
106pub struct JweEncryptor<C, K> {
107    alg: Option<Algorithm>,
108    enc: Option<EncryptionAlgorithm>,
109    kid: Option<String>,
110    cek: Option<Bytes>,
111    iv: Option<Bytes>,
112    plaintext: Option<Bytes>,
113    phantom: PhantomData<(C, K)>,
114}
115
116impl<C, K> JweEncryptor<C, K> {
117    pub fn alg(self, alg: Algorithm) -> Self {
118        Self {
119            alg: Some(alg),
120            ..self
121        }
122    }
123
124    pub fn enc(self, enc: EncryptionAlgorithm) -> Self {
125        Self {
126            enc: Some(enc),
127            ..self
128        }
129    }
130
131    pub fn cek(self, cek: &[u8]) -> Self {
132        Self {
133            cek: Some(Bytes::copy_from_slice(cek)),
134            ..self
135        }
136    }
137
138    pub fn iv(self, iv: &[u8]) -> Self {
139        Self {
140            iv: Some(Bytes::copy_from_slice(iv)),
141            ..self
142        }
143    }
144}
145
146impl<K> JweEncryptor<Unset, K> {
147    pub fn plaintext(self, plaintext: &[u8]) -> JweEncryptor<Set, K> {
148        JweEncryptor::<Set, K> {
149            plaintext: Some(Bytes::copy_from_slice(plaintext)),
150            alg: self.alg,
151            enc: self.enc,
152            kid: self.kid,
153            cek: self.cek,
154            iv: self.iv,
155            phantom: PhantomData,
156        }
157    }
158
159    pub fn plaintext_str(self, plaintext: impl AsRef<str>) -> JweEncryptor<Set, K> {
160        JweEncryptor::plaintext(self, plaintext.as_ref().as_bytes())
161    }
162}
163
164impl<C> JweEncryptor<C, Unset> {
165    pub fn kid(self, kid: impl Into<String>) -> JweEncryptor<C, Set> {
166        JweEncryptor::<C, Set> {
167            kid: Some(kid.into()),
168            alg: self.alg,
169            enc: self.enc,
170            cek: self.cek,
171            iv: self.iv,
172            plaintext: self.plaintext,
173            phantom: PhantomData,
174        }
175    }
176}
177
178impl JweEncryptor<Set, Set> {
179    pub async fn encrypt<F>(self, wrap_key: F) -> Result<Jwe>
180    where
181        F: AsyncFn(&str, &Algorithm, &[u8]) -> Result<WrapKeyResult>,
182    {
183        // Determine how big the CEK should be.
184        let enc = &self.enc.unwrap_or(EncryptionAlgorithm::A128GCM);
185        let cipher: Cipher = enc.try_into()?;
186
187        // Validate or generate the CEK.
188        let cek = match self.cek {
189            Some(v) if v.len() == cipher.key_len() => v,
190            Some(v) => {
191                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
192                    format!(
193                        "require key size of {} bytes, got {}",
194                        cipher.key_len(),
195                        v.len()
196                    )
197                }));
198            }
199            None => {
200                // Allocate enough space for largest supported cipher.
201                let mut buf = [0; 32];
202                rand::rand_bytes(&mut buf)?;
203                Bytes::copy_from_slice(&buf[0..cipher.key_len()])
204            }
205        };
206
207        let kid = self
208            .kid
209            .as_deref()
210            .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?;
211        let alg = self.alg.unwrap_or(Algorithm::RSA_OAEP);
212
213        // Encrypt the CEK so we get the full kid.
214        let result = wrap_key(kid, &alg, &cek).await?;
215
216        let header = Header {
217            alg,
218            enc: Some(enc.clone()),
219            kid: Some(result.kid),
220            typ: super::Type::JWE,
221        };
222        let aad = header.encode()?;
223
224        // Generate the IV.
225        let iv_len = cipher.iv_len().ok_or_else(|| {
226            Error::with_message(
227                ErrorKind::InvalidData,
228                format!("expected iv length for cipher {}", &enc),
229            )
230        })?;
231        let iv = match self.iv {
232            Some(v) if v.len() == iv_len => v,
233            Some(v) => {
234                return Err(Error::with_message_fn(ErrorKind::InvalidData, || {
235                    format!("require iv size of {} bytes, got {}", iv_len, v.len())
236                }));
237            }
238            None => {
239                // Allocate enough space for largest supported cipher.
240                let mut buf = [0; 12];
241                rand::rand_bytes(&mut buf)?;
242                Bytes::copy_from_slice(&buf[0..iv_len])
243            }
244        };
245
246        let plaintext = self.plaintext.expect("expected plaintext");
247        let mut tag = [0; 16];
248        let ciphertext: Bytes = symm::encrypt_aead(
249            cipher,
250            &cek,
251            Some(&iv),
252            aad.as_bytes(),
253            &plaintext,
254            &mut tag,
255        )?
256        .into();
257
258        Ok(Jwe {
259            header,
260            cek: result.cek,
261            iv,
262            ciphertext,
263            tag: Bytes::copy_from_slice(&tag),
264        })
265    }
266}
267
268impl<C, K> Default for JweEncryptor<C, K> {
269    fn default() -> Self {
270        Self {
271            alg: None,
272            enc: None,
273            kid: None,
274            cek: None,
275            iv: None,
276            plaintext: None,
277            phantom: PhantomData,
278        }
279    }
280}
281
282impl TryFrom<EncryptionAlgorithm> for Cipher {
283    type Error = Error;
284    fn try_from(value: EncryptionAlgorithm) -> Result<Self> {
285        (&value).try_into()
286    }
287}
288
289impl TryFrom<&EncryptionAlgorithm> for Cipher {
290    type Error = Error;
291    fn try_from(value: &EncryptionAlgorithm) -> Result<Cipher> {
292        match value {
293            EncryptionAlgorithm::A128GCM => Ok(Cipher::aes_128_gcm()),
294            EncryptionAlgorithm::A192GCM => Ok(Cipher::aes_192_gcm()),
295            EncryptionAlgorithm::A256GCM => Ok(Cipher::aes_256_gcm()),
296            EncryptionAlgorithm::Other(value) => {
297                Err(Error::with_message_fn(ErrorKind::InvalidData, || {
298                    format!("unsupported encryption algorithm {value}")
299                }))
300            }
301        }
302    }
303}
304
305impl TryFrom<&Algorithm> for azure_security_keyvault_keys::models::EncryptionAlgorithm {
306    type Error = Error;
307    fn try_from(value: &Algorithm) -> Result<Self> {
308        match value {
309            Algorithm::RSA1_5 => Ok(Self::RSA1_5),
310            Algorithm::RSA_OAEP => Ok(Self::RsaOaep),
311            Algorithm::RSA_OAEP_256 => Ok(Self::RsaOAEP256),
312            Algorithm::Other(s) => Err(Error::with_message_fn(ErrorKind::InvalidData, || {
313                format!("unsupported algorithm {s}")
314            })),
315        }
316    }
317}
318
319#[derive(Debug)]
320pub struct WrapKeyResult {
321    pub kid: String,
322    pub cek: Bytes,
323}
324
325impl TryFrom<KeyOperationResult> for WrapKeyResult {
326    type Error = Error;
327    fn try_from(value: KeyOperationResult) -> Result<Self> {
328        Ok(Self {
329            kid: value
330                .kid
331                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected kid"))?,
332            cek: value
333                .result
334                .map(Into::into)
335                .ok_or_else(|| Error::with_message(ErrorKind::InvalidData, "expected CEK"))?,
336        })
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use azure_core::Bytes;
344
345    #[test]
346    fn decode_invalid() {
347        assert!(
348            matches!(Jwe::decode("1.2.3.4"), Err(err) if err.message() == Some("invalid compact JWE: expected 5 parts, got 4"))
349        );
350        assert!(
351            matches!(Jwe::decode("1.2.3.4.5.6"), Err(err) if err.message() == Some("invalid compact JWE: expected 5 parts, got 6"))
352        );
353    }
354
355    #[test]
356    fn encode_decode_roundtrip() {
357        let jwe = Jwe {
358            header: Header {
359                alg: crate::jose::Algorithm::RSA_OAEP_256,
360                enc: Some(crate::jose::EncryptionAlgorithm::A128GCM),
361                kid: Some("test-key-id".to_string()),
362                typ: crate::jose::Type::JWE,
363            },
364            cek: Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]),
365            iv: Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]),
366            ciphertext: Bytes::from_static(&[0x01, 0x23, 0x45, 0x67]),
367            tag: Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]),
368        };
369
370        // cspell:disable-next-line
371        const EXPECTED: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMTI4R0NNIiwia2lkIjoidGVzdC1rZXktaWQiLCJ0eXAiOiJKV0UifQ.EjRWeA.mrze8A.ASNFZw.iavN7w";
372
373        let encoded = jwe.encode().expect("encode should succeed");
374        assert_eq!(encoded, EXPECTED);
375
376        let decoded = Jwe::decode(&encoded).expect("decode should succeed");
377        assert_eq!(decoded.header.alg, crate::jose::Algorithm::RSA_OAEP_256);
378        assert_eq!(
379            decoded.header.enc,
380            Some(crate::jose::EncryptionAlgorithm::A128GCM)
381        );
382        assert_eq!(decoded.header.kid, Some("test-key-id".to_string()));
383        assert_eq!(decoded.header.typ, crate::jose::Type::JWE);
384        assert_eq!(decoded.cek, Bytes::from_static(&[0x12, 0x34, 0x56, 0x78]));
385        assert_eq!(decoded.iv, Bytes::from_static(&[0x9a, 0xbc, 0xde, 0xf0]));
386        assert_eq!(
387            decoded.ciphertext,
388            Bytes::from_static(&[0x01, 0x23, 0x45, 0x67])
389        );
390        assert_eq!(decoded.tag, Bytes::from_static(&[0x89, 0xab, 0xcd, 0xef]));
391    }
392
393    #[test]
394    fn encryption_algorithm_cipher() {
395        let cipher: Cipher = EncryptionAlgorithm::A128GCM
396            .try_into()
397            .expect("try_into should succeed");
398        assert_eq!(cipher.iv_len(), Some(12));
399        assert_eq!(cipher.key_len(), 16);
400
401        let cipher: Cipher = EncryptionAlgorithm::A192GCM
402            .try_into()
403            .expect("try_into should succeed");
404        assert_eq!(cipher.iv_len(), Some(12));
405        assert_eq!(cipher.key_len(), 24);
406
407        let cipher: Cipher = EncryptionAlgorithm::A256GCM
408            .try_into()
409            .expect("try_into should succeed");
410        assert_eq!(cipher.iv_len(), Some(12));
411        assert_eq!(cipher.key_len(), 32);
412    }
413
414    #[tokio::test]
415    async fn encrypt_decrypt_roundtrip() {
416        let kid = "key-name";
417        let alg = Algorithm::RSA_OAEP;
418        let enc = EncryptionAlgorithm::A128GCM;
419        let plaintext = b"Hello, world!";
420
421        // wrap_key callback: asserts kid and enc, returns cek as-is
422        let wrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
423            assert_eq!(key_id, kid);
424            assert_eq!(wrap_alg, &alg);
425            Ok(crate::jose::jwe::WrapKeyResult {
426                kid: "key-name/key-version".into(),
427                cek: Bytes::copy_from_slice(cek),
428            })
429        };
430
431        // unwrap_key callback: asserts kid and enc, returns cek as-is
432        let unwrap_key = async |key_id: &str, wrap_alg: &Algorithm, cek: &[u8]| {
433            assert_eq!(key_id, "key-name/key-version");
434            assert_eq!(wrap_alg, &alg);
435            Ok(crate::jose::jwe::WrapKeyResult {
436                kid: "key-name/key-version".into(),
437                cek: Bytes::copy_from_slice(cek),
438            })
439        };
440
441        let jwe = Jwe::encryptor()
442            .alg(alg.clone())
443            .enc(enc)
444            .kid(kid)
445            .plaintext(plaintext)
446            .encrypt(wrap_key)
447            .await
448            .expect("encryption should succeed");
449
450        let decrypted = jwe
451            .decrypt(unwrap_key)
452            .await
453            .expect("decryption should succeed");
454        assert_eq!(decrypted, plaintext.as_ref());
455    }
456}