1use crate::*;
2use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce, aead::Aead};
3use rand::prelude::*;
4use serde::de::SeqAccess;
5use serde::{
6 Deserialize, Deserializer, Serialize, Serializer,
7 de::{Error as DError, MapAccess, Visitor},
8 ser::SerializeStruct,
9};
10
11#[derive(Clone, Debug)]
13pub struct Envelope {
14 ciphertext: Vec<u8>,
16 recipients: Vec<Recipient>,
18}
19
20impl std::fmt::Display for Envelope {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 write!(
23 f,
24 "Envelope {{ recipients: [{}], ciphertext: {} }}",
25 self.display_recipients(),
26 hex::encode(&self.ciphertext),
27 )
28 }
29}
30
31impl Serialize for Envelope {
32 fn serialize<S>(&self, s: S) -> std::result::Result<S::Ok, S::Error>
33 where
34 S: Serializer,
35 {
36 if s.is_human_readable() {
37 let mut state = s.serialize_struct("Envelope", 2)?;
38 state.serialize_field("recipients", &self.recipients)?;
39 state.serialize_field("ciphertext", &hex::encode(&self.ciphertext))?;
40 state.end()
41 } else {
42 let mut state = s.serialize_struct("Envelope", 2)?;
43 state.serialize_field("recipients", &self.recipients)?;
44 state.serialize_field("ciphertext", &self.ciphertext)?;
45 state.end()
46 }
47 }
48}
49
50impl<'de> Deserialize<'de> for Envelope {
51 fn deserialize<D>(d: D) -> std::result::Result<Self, D::Error>
52 where
53 D: Deserializer<'de>,
54 {
55 if d.is_human_readable() {
56 struct EnvelopeVisitor;
57
58 impl<'de> Visitor<'de> for EnvelopeVisitor {
59 type Value = Envelope;
60
61 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
62 write!(f, "struct Envelope or map")
63 }
64
65 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
66 where
67 A: MapAccess<'de>,
68 {
69 let mut recipients: Option<Vec<Recipient>> = None;
70 let mut ciphertext: Option<String> = None;
71
72 while let Some(key) = map.next_key::<&str>()? {
73 match key {
74 "recipients" => {
75 if recipients.is_some() {
76 return Err(DError::duplicate_field("recipients"));
77 }
78 recipients = Some(map.next_value()?);
79 }
80 "ciphertext" => {
81 if ciphertext.is_some() {
82 return Err(DError::duplicate_field("ciphertext"));
83 }
84 ciphertext = Some(map.next_value()?);
85 }
86 _ => {
87 let _: serde::de::IgnoredAny = map.next_value()?;
88 }
89 }
90 }
91
92 let recipients =
93 recipients.ok_or_else(|| DError::missing_field("recipients"))?;
94 let ciphertext_hex =
95 ciphertext.ok_or_else(|| DError::missing_field("ciphertext"))?;
96 let ciphertext = hex::decode(&ciphertext_hex)
97 .map_err(|_| DError::custom("Invalid hex in ciphertext"))?;
98
99 Ok(Envelope {
100 recipients,
101 ciphertext,
102 })
103 }
104 }
105 d.deserialize_struct("Envelope", &["recipients", "ciphertext"], EnvelopeVisitor)
106 } else {
107 struct EnvelopeVisitor;
108 impl<'de> Visitor<'de> for EnvelopeVisitor {
109 type Value = Envelope;
110
111 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
112 write!(f, "struct Envelope or map")
113 }
114
115 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
116 where
117 A: SeqAccess<'de>,
118 {
119 let recipients = seq
120 .next_element()?
121 .ok_or_else(|| DError::missing_field("recipients"))?;
122 let ciphertext = seq
123 .next_element()?
124 .ok_or_else(|| DError::missing_field("ciphertext"))?;
125
126 Ok(Envelope {
127 recipients,
128 ciphertext,
129 })
130 }
131 }
132 d.deserialize_struct("Envelope", &["recipients", "ciphertext"], EnvelopeVisitor)
133 }
134 }
135}
136
137impl Envelope {
138 pub(crate) fn display_recipients(&self) -> String {
139 let mut s = String::new();
140 for (i, r) in self.recipients.iter().enumerate() {
141 if i > 0 {
142 s.push_str(", ");
143 }
144 s.push_str(&format!("{}", r));
145 }
146 s
147 }
148
149 pub fn new<B: AsRef<[u8]>>(
154 recipients: &[PublicKey],
155 data: B,
156 data_encryption_key: Option<[u8; 32]>,
157 ) -> Result<Self> {
158 if recipients.is_empty() {
159 return Err(Error::NoRecipients);
160 }
161
162 let mut rng = rand::rng();
163 let dek = data_encryption_key.unwrap_or_else(|| rng.random());
164 let mut envelope_recipients = Vec::with_capacity(recipients.len());
165 let mut scheme: Option<Scheme> = None;
166
167 for pk in recipients {
168 match scheme {
169 None => {
170 scheme = Some(scheme_from_public_key_length(pk.as_ref().len())?);
171 }
172 Some(s) => {
173 let pk_scheme = scheme_from_public_key_length(pk.as_ref().len())?;
174 if s != pk_scheme {
175 return Err(Error::SchemeMismatch);
176 }
177 }
178 }
179 let s = scheme.expect("scheme should be set");
180 envelope_recipients.push(Recipient::new(&dek, pk, s)?);
181 }
182
183 Ok(Self {
184 recipients: envelope_recipients,
185 ciphertext: Self::encrypt_data(data, &dek)?,
186 })
187 }
188
189 pub fn recipients(&self) -> &[Recipient] {
191 &self.recipients
192 }
193
194 pub fn ciphertext(&self) -> &[u8] {
196 &self.ciphertext
197 }
198
199 pub fn decrypt_by_recipient_secret_key(
204 &self,
205 recipient_secret_key: &SecretKey,
206 ) -> Result<Vec<u8>> {
207 let scheme = scheme_from_secret_key_length(recipient_secret_key.as_ref().len())?;
208 for recipient in &self.recipients {
209 if let Ok(k) = recipient.unwrap_dek(recipient_secret_key, scheme) {
210 return Self::decrypt_data(&self.ciphertext, &k);
211 }
212 }
213 Err(Error::InvalidDecapsulationKey)
214 }
215
216 pub fn decrypt_by_recipient_index(
221 &self,
222 index: usize,
223 recipient_secret_key: &SecretKey,
224 ) -> Result<Vec<u8>> {
225 if index >= self.recipients.len() {
226 return Err(Error::InvalidDecapsulationKey);
227 }
228 let scheme = scheme_from_secret_key_length(recipient_secret_key.as_ref().len())?;
229 let recipient = &self.recipients[index];
230 let dek = recipient.unwrap_dek(recipient_secret_key, scheme)?;
231 Self::decrypt_data(&self.ciphertext, &dek)
232 }
233
234 fn encrypt_data<B: AsRef<[u8]>>(data: B, dek: &[u8; 32]) -> Result<Vec<u8>> {
235 let mut rng = rand::rng();
236 let nonce: [u8; 12] = rng.random();
237 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dek));
238 let nonce = Nonce::clone_from_slice(&nonce);
239 let mut ciphertext = cipher.encrypt(&nonce, data.as_ref())?;
240 let mut result = Vec::with_capacity(nonce.len() + ciphertext.len());
241 result.extend_from_slice(&nonce);
242 result.append(&mut ciphertext);
243 Ok(result)
244 }
245
246 fn decrypt_data<B: AsRef<[u8]>>(ciphertext: B, dek: &[u8; 32]) -> Result<Vec<u8>> {
247 let ct = ciphertext.as_ref();
248 if ct.len() < 28 {
249 return Err(Error::AesGcm);
250 }
251 let (nonce, ct) = ct.split_at(12);
252 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(dek));
253 let nonce = Nonce::clone_from_slice(nonce);
254 let plaintext = cipher.decrypt(&nonce, ct)?;
255 Ok(plaintext)
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use rstest::*;
263
264 #[rstest]
265 #[case(Scheme::Small, 5)]
266 #[case(Scheme::Nist, 4)]
267 #[case(Scheme::Secure, 3)]
268 fn serialization_human_readable(#[case] scheme: Scheme, #[case] num_recipients: usize) {
269 let mut recipients_pk = Vec::with_capacity(num_recipients);
270 let mut recipients_sk = Vec::with_capacity(num_recipients);
271
272 for _ in 0..num_recipients {
273 let (pk, sk) = scheme.key_pair().unwrap();
274 recipients_pk.push(pk);
275 recipients_sk.push(sk);
276 }
277
278 let data = b"Hello, world!";
279 let envelope = Envelope::new(&recipients_pk, data.as_ref(), None).unwrap();
280 let serialized = serde_json::to_string(&envelope).unwrap();
281 let deserialized: Envelope = serde_json::from_str(&serialized).unwrap();
282 assert_eq!(envelope.ciphertext, deserialized.ciphertext);
283 assert_eq!(envelope.recipients.len(), deserialized.recipients.len());
284 for (r1, r2) in envelope
285 .recipients
286 .iter()
287 .zip(deserialized.recipients.iter())
288 {
289 assert_eq!(r1.capsule.as_ref(), r2.capsule.as_ref());
290 assert_eq!(r1.wrapped_dek, r2.wrapped_dek);
291 }
292 }
293
294 #[rstest]
295 #[case(Scheme::Small, 4)]
296 #[case(Scheme::Nist, 5)]
297 #[case(Scheme::Secure, 3)]
298 fn serialization_binary(#[case] scheme: Scheme, #[case] num_recipients: usize) {
299 let mut recipients_pk = Vec::with_capacity(num_recipients);
300 let mut recipients_sk = Vec::with_capacity(num_recipients);
301 for _ in 0..num_recipients {
302 let (pk, sk) = scheme.key_pair().unwrap();
303 recipients_pk.push(pk);
304 recipients_sk.push(sk);
305 }
306
307 let data = b"Hello, world!";
308 let envelope = Envelope::new(&recipients_pk, data.as_ref(), None).unwrap();
309 let serialized = postcard::to_stdvec(&envelope).unwrap();
310 let deserialized: Envelope = postcard::from_bytes(&serialized).unwrap();
311 assert_eq!(envelope.ciphertext, deserialized.ciphertext);
312 assert_eq!(envelope.recipients.len(), deserialized.recipients.len());
313 for (r1, r2) in envelope
314 .recipients
315 .iter()
316 .zip(deserialized.recipients.iter())
317 {
318 assert_eq!(r1.capsule.as_ref(), r2.capsule.as_ref());
319 assert_eq!(r1.wrapped_dek, r2.wrapped_dek);
320 }
321 }
322
323 #[rstest]
324 #[case(Scheme::Small, 6)]
325 #[case(Scheme::Nist, 4)]
326 #[case(Scheme::Secure, 5)]
327 fn decryption(#[case] scheme: Scheme, #[case] num_recipients: usize) {
328 let mut recipients_pk = Vec::with_capacity(num_recipients);
329 let mut recipients_sk = Vec::with_capacity(num_recipients);
330 for _ in 0..num_recipients {
331 let (pk, sk) = scheme.key_pair().unwrap();
332 recipients_pk.push(pk);
333 recipients_sk.push(sk);
334 }
335
336 let data = b"envelope decryption";
337 let envelope = Envelope::new(&recipients_pk, data.as_ref(), None).unwrap();
338 for sk in &recipients_sk {
339 let decrypted = envelope.decrypt_by_recipient_secret_key(sk).unwrap();
340 assert_eq!(decrypted, data.as_ref());
341 }
342
343 for (i, sk) in recipients_sk.iter().enumerate() {
344 let decrypted = envelope.decrypt_by_recipient_index(i, sk).unwrap();
345 assert_eq!(decrypted, data.as_ref());
346 let decrypt_fail = envelope.decrypt_by_recipient_index((i + 1) % sk.len(), sk);
347 assert!(decrypt_fail.is_err());
348 }
349 }
350}