pq_envelope/
envelope.rs

1use crate::*;
2use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
3use rand::prelude::*;
4use serde::de::SeqAccess;
5use serde::{
6    Deserialize, Deserializer, Serialize, Serializer,
7    de::{Error as DError, MapAccess, Visitor},
8    ser::SerializeStruct,
9};
10
11/// The envelope structure that holds the encrypted data along with the necessary metadata.
12#[derive(Clone, Debug)]
13pub struct Envelope {
14    /// The encrypted data
15    ciphertext: Vec<u8>,
16    /// The recipient-specific KEM ciphertext
17    recipients: Vec<Recipient>,
18}
19
20impl std::fmt::Display for Envelope {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        write!(
23            f,
24            "Envelope {{ recipients: [{}], ciphertext: {} }}",
25            self.display_recipients(),
26            hex::encode(&self.ciphertext),
27        )
28    }
29}
30
31impl Serialize for Envelope {
32    fn serialize<S>(&self, s: S) -> std::result::Result<S::Ok, S::Error>
33    where
34        S: Serializer,
35    {
36        if s.is_human_readable() {
37            let mut state = s.serialize_struct("Envelope", 2)?;
38            state.serialize_field("recipients", &self.recipients)?;
39            state.serialize_field("ciphertext", &hex::encode(&self.ciphertext))?;
40            state.end()
41        } else {
42            let mut state = s.serialize_struct("Envelope", 2)?;
43            state.serialize_field("recipients", &self.recipients)?;
44            state.serialize_field("ciphertext", &self.ciphertext)?;
45            state.end()
46        }
47    }
48}
49
50impl<'de> Deserialize<'de> for Envelope {
51    fn deserialize<D>(d: D) -> std::result::Result<Self, D::Error>
52    where
53        D: Deserializer<'de>,
54    {
55        if d.is_human_readable() {
56            struct EnvelopeVisitor;
57
58            impl<'de> Visitor<'de> for EnvelopeVisitor {
59                type Value = Envelope;
60
61                fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
62                    write!(f, "struct Envelope or map")
63                }
64
65                fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
66                where
67                    A: MapAccess<'de>,
68                {
69                    let mut recipients: Option<Vec<Recipient>> = None;
70                    let mut ciphertext: Option<String> = None;
71
72                    while let Some(key) = map.next_key::<&str>()? {
73                        match key {
74                            "recipients" => {
75                                if recipients.is_some() {
76                                    return Err(DError::duplicate_field("recipients"));
77                                }
78                                recipients = Some(map.next_value()?);
79                            }
80                            "ciphertext" => {
81                                if ciphertext.is_some() {
82                                    return Err(DError::duplicate_field("ciphertext"));
83                                }
84                                ciphertext = Some(map.next_value()?);
85                            }
86                            _ => {
87                                let _: serde::de::IgnoredAny = map.next_value()?;
88                            }
89                        }
90                    }
91
92                    let recipients =
93                        recipients.ok_or_else(|| DError::missing_field("recipients"))?;
94                    let ciphertext_hex =
95                        ciphertext.ok_or_else(|| DError::missing_field("ciphertext"))?;
96                    let ciphertext = hex::decode(&ciphertext_hex)
97                        .map_err(|_| DError::custom("Invalid hex in ciphertext"))?;
98
99                    Ok(Envelope {
100                        recipients,
101                        ciphertext,
102                    })
103                }
104            }
105            d.deserialize_struct("Envelope", &["recipients", "ciphertext"], EnvelopeVisitor)
106        } else {
107            struct EnvelopeVisitor;
108            impl<'de> Visitor<'de> for EnvelopeVisitor {
109                type Value = Envelope;
110
111                fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
112                    write!(f, "struct Envelope or map")
113                }
114
115                fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
116                where
117                    A: SeqAccess<'de>,
118                {
119                    let recipients = seq
120                        .next_element()?
121                        .ok_or_else(|| DError::missing_field("recipients"))?;
122                    let ciphertext = seq
123                        .next_element()?
124                        .ok_or_else(|| DError::missing_field("ciphertext"))?;
125
126                    Ok(Envelope {
127                        recipients,
128                        ciphertext,
129                    })
130                }
131            }
132            d.deserialize_struct("Envelope", &["recipients", "ciphertext"], EnvelopeVisitor)
133        }
134    }
135}
136
137impl Envelope {
138    pub(crate) fn display_recipients(&self) -> String {
139        let mut s = String::new();
140        for (i, r) in self.recipients.iter().enumerate() {
141            if i > 0 {
142                s.push_str(", ");
143            }
144            s.push_str(&format!("{}", r));
145        }
146        s
147    }
148
149    /// Create a new envelope for the given recipients with the specified data.
150    ///
151    /// Optionally, an already existing data encryption key can be provided.
152    /// However, if it is not provided, a new one will be created.
153    pub fn new<B: AsRef<[u8]>>(
154        recipients: &[PublicKey],
155        data: B,
156        data_encryption_key: Option<[u8; 32]>,
157    ) -> Result<Self> {
158        if recipients.is_empty() {
159            return Err(Error::NoRecipients);
160        }
161
162        let mut rng = rand::rng();
163        let dek = data_encryption_key.unwrap_or_else(|| rng.random());
164        let mut envelope_recipients = Vec::with_capacity(recipients.len());
165        let mut scheme: Option<Scheme> = None;
166
167        for pk in recipients {
168            match scheme {
169                None => {
170                    scheme = Some(scheme_from_public_key_length(pk.as_ref().len())?);
171                }
172                Some(s) => {
173                    let pk_scheme = scheme_from_public_key_length(pk.as_ref().len())?;
174                    if s != pk_scheme {
175                        return Err(Error::SchemeMismatch);
176                    }
177                }
178            }
179            let s = scheme.expect("scheme should be set");
180            envelope_recipients.push(Recipient::new(&dek, pk, s)?);
181        }
182
183        Ok(Self {
184            recipients: envelope_recipients,
185            ciphertext: Self::encrypt_data(data, &dek)?,
186        })
187    }
188
189    /// Return the list of recipients that can decrypt the data
190    pub fn recipients(&self) -> &[Recipient] {
191        &self.recipients
192    }
193
194    /// Return the encrypted data
195    pub fn ciphertext(&self) -> &[u8] {
196        &self.ciphertext
197    }
198
199    /// Decrypt the envelope using the recipient's secret key
200    ///
201    /// This method will attempt to decrypt the envelope using each recipient's capsule
202    /// until one succeeds. If none succeed, an error is returned.
203    pub fn decrypt_by_recipient_secret_key(
204        &self,
205        recipient_secret_key: &SecretKey,
206    ) -> Result<Vec<u8>> {
207        let scheme = scheme_from_secret_key_length(recipient_secret_key.as_ref().len())?;
208        for recipient in &self.recipients {
209            if let Ok(k) = recipient.unwrap_dek(recipient_secret_key, scheme) {
210                return Self::decrypt_data(&self.ciphertext, &k);
211            }
212        }
213        Err(Error::InvalidDecapsulationKey)
214    }
215
216    /// Decrypt the envelope using the recipient's index and secret key
217    ///
218    /// This method will attempt to decrypt the envelope using the recipient at the specified index.
219    /// If the index is out of bounds or the decryption fails, an error is returned
220    pub fn decrypt_by_recipient_index(
221        &self,
222        index: usize,
223        recipient_secret_key: &SecretKey,
224    ) -> Result<Vec<u8>> {
225        if index >= self.recipients.len() {
226            return Err(Error::InvalidDecapsulationKey);
227        }
228        let scheme = scheme_from_secret_key_length(recipient_secret_key.as_ref().len())?;
229        let recipient = &self.recipients[index];
230        let dek = recipient.unwrap_dek(recipient_secret_key, scheme)?;
231        Self::decrypt_data(&self.ciphertext, &dek)
232    }
233
234    fn encrypt_data<B: AsRef<[u8]>>(data: B, dek: &[u8; 32]) -> Result<Vec<u8>> {
235        let mut rng = rand::rng();
236        let nonce: [u8; 12] = rng.random();
237        let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dek));
238        let nonce = Nonce::clone_from_slice(&nonce);
239        let mut ciphertext = cipher.encrypt(&nonce, data.as_ref())?;
240        let mut result = Vec::with_capacity(nonce.len() + ciphertext.len());
241        result.extend_from_slice(&nonce);
242        result.append(&mut ciphertext);
243        Ok(result)
244    }
245
246    fn decrypt_data<B: AsRef<[u8]>>(ciphertext: B, dek: &[u8; 32]) -> Result<Vec<u8>> {
247        let ct = ciphertext.as_ref();
248        if ct.len() < 28 {
249            return Err(Error::AesGcm);
250        }
251        let (nonce, ct) = ct.split_at(12);
252        let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dek));
253        let nonce = Nonce::clone_from_slice(nonce);
254        let plaintext = cipher.decrypt(&nonce, ct)?;
255        Ok(plaintext)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use rstest::*;
263
264    #[rstest]
265    #[case(Scheme::Small, 5)]
266    #[case(Scheme::Nist, 4)]
267    #[case(Scheme::Secure, 3)]
268    fn serialization_human_readable(#[case] scheme: Scheme, #[case] num_recipients: usize) {
269        let mut recipients_pk = Vec::with_capacity(num_recipients);
270        let mut recipients_sk = Vec::with_capacity(num_recipients);
271
272        for _ in 0..num_recipients {
273            let (pk, sk) = scheme.key_pair().unwrap();
274            recipients_pk.push(pk);
275            recipients_sk.push(sk);
276        }
277
278        let data = b"Hello, world!";
279        let envelope = Envelope::new(&recipients_pk, data.as_ref(), None).unwrap();
280        let serialized = serde_json::to_string(&envelope).unwrap();
281        let deserialized: Envelope = serde_json::from_str(&serialized).unwrap();
282        assert_eq!(envelope.ciphertext, deserialized.ciphertext);
283        assert_eq!(envelope.recipients.len(), deserialized.recipients.len());
284        for (r1, r2) in envelope
285            .recipients
286            .iter()
287            .zip(deserialized.recipients.iter())
288        {
289            assert_eq!(r1.capsule.as_ref(), r2.capsule.as_ref());
290            assert_eq!(r1.wrapped_dek, r2.wrapped_dek);
291        }
292    }
293
294    #[rstest]
295    #[case(Scheme::Small, 4)]
296    #[case(Scheme::Nist, 5)]
297    #[case(Scheme::Secure, 3)]
298    fn serialization_binary(#[case] scheme: Scheme, #[case] num_recipients: usize) {
299        let mut recipients_pk = Vec::with_capacity(num_recipients);
300        let mut recipients_sk = Vec::with_capacity(num_recipients);
301        for _ in 0..num_recipients {
302            let (pk, sk) = scheme.key_pair().unwrap();
303            recipients_pk.push(pk);
304            recipients_sk.push(sk);
305        }
306
307        let data = b"Hello, world!";
308        let envelope = Envelope::new(&recipients_pk, data.as_ref(), None).unwrap();
309        let serialized = postcard::to_stdvec(&envelope).unwrap();
310        let deserialized: Envelope = postcard::from_bytes(&serialized).unwrap();
311        assert_eq!(envelope.ciphertext, deserialized.ciphertext);
312        assert_eq!(envelope.recipients.len(), deserialized.recipients.len());
313        for (r1, r2) in envelope
314            .recipients
315            .iter()
316            .zip(deserialized.recipients.iter())
317        {
318            assert_eq!(r1.capsule.as_ref(), r2.capsule.as_ref());
319            assert_eq!(r1.wrapped_dek, r2.wrapped_dek);
320        }
321    }
322
323    #[rstest]
324    #[case(Scheme::Small, 6)]
325    #[case(Scheme::Nist, 4)]
326    #[case(Scheme::Secure, 5)]
327    fn decryption(#[case] scheme: Scheme, #[case] num_recipients: usize) {
328        let mut recipients_pk = Vec::with_capacity(num_recipients);
329        let mut recipients_sk = Vec::with_capacity(num_recipients);
330        for _ in 0..num_recipients {
331            let (pk, sk) = scheme.key_pair().unwrap();
332            recipients_pk.push(pk);
333            recipients_sk.push(sk);
334        }
335
336        let data = b"envelope decryption";
337        let envelope = Envelope::new(&recipients_pk, data.as_ref(), None).unwrap();
338        for sk in &recipients_sk {
339            let decrypted = envelope.decrypt_by_recipient_secret_key(sk).unwrap();
340            assert_eq!(decrypted, data.as_ref());
341        }
342
343        for (i, sk) in recipients_sk.iter().enumerate() {
344            let decrypted = envelope.decrypt_by_recipient_index(i, sk).unwrap();
345            assert_eq!(decrypted, data.as_ref());
346            let decrypt_fail = envelope.decrypt_by_recipient_index((i + 1) % sk.len(), sk);
347            assert!(decrypt_fail.is_err());
348        }
349    }
350}