pq_envelope/
recipient.rs

1use crate::util::scheme_from_ciphertext_length;
2use crate::*;
3use serde::{
4    Deserialize, Deserializer, Serialize, Serializer,
5    de::{MapAccess, Visitor},
6    ser::SerializeStruct,
7};
8
9/// The recipient structure that holds the necessary metadata for a recipient to decrypt the data.
10#[derive(Clone, Debug)]
11pub struct Recipient {
12    /// The KEM ciphertext
13    pub(crate) capsule: oqs::kem::Ciphertext,
14    /// The wrapped data encryption key (DEK)
15    pub(crate) wrapped_dek: [u8; 40],
16}
17
18impl std::fmt::Display for Recipient {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(
21            f,
22            "{{ capsule: {}, wrapped_dek: {} }}",
23            hex::encode(self.capsule.as_ref()),
24            hex::encode(self.wrapped_dek)
25        )
26    }
27}
28
29impl Serialize for Recipient {
30    fn serialize<S>(&self, s: S) -> std::result::Result<S::Ok, S::Error>
31    where
32        S: Serializer,
33    {
34        if s.is_human_readable() {
35            let mut state = s.serialize_struct("Recipient", 2)?;
36            state.serialize_field("capsule", &hex::encode(&self.capsule))?;
37            state.serialize_field("wrapped_dek", &hex::encode(self.wrapped_dek))?;
38            state.end()
39        } else {
40            let mut state = s.serialize_struct("Recipient", 2)?;
41            state.serialize_field("capsule", self.capsule.as_ref())?;
42            state.serialize_field("wrapped_dek", &serde_big_array::Array(self.wrapped_dek))?;
43            state.end()
44        }
45    }
46}
47
48impl<'de> Deserialize<'de> for Recipient {
49    fn deserialize<D>(d: D) -> std::result::Result<Self, D::Error>
50    where
51        D: Deserializer<'de>,
52    {
53        fn process_capsule_bytes(capsule_bytes: &[u8]) -> Result<oqs::kem::Ciphertext> {
54            let scheme = scheme_from_ciphertext_length(capsule_bytes.len())?;
55            let kem: oqs::kem::Kem = scheme.into();
56            kem.ciphertext_from_bytes(capsule_bytes)
57                .map(|pk| pk.to_owned())
58                .ok_or(Error::CapsuleConversion)
59        }
60
61        if d.is_human_readable() {
62            struct RecipientVisitor;
63
64            impl<'de> Visitor<'de> for RecipientVisitor {
65                type Value = Recipient;
66
67                fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
68                    write!(f, "a map representing a Recipient")
69                }
70
71                fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
72                where
73                    A: MapAccess<'de>,
74                {
75                    let mut capsule: Option<String> = None;
76                    let mut wrapped_dek: Option<String> = None;
77                    while let Some(key) = map.next_key::<String>()? {
78                        match key.as_str() {
79                            "capsule" => {
80                                if capsule.is_some() {
81                                    return Err(serde::de::Error::duplicate_field("capsule"));
82                                }
83                                capsule = Some(map.next_value()?);
84                            }
85                            "wrapped_dek" => {
86                                if wrapped_dek.is_some() {
87                                    return Err(serde::de::Error::duplicate_field("wrapped_dek"));
88                                }
89                                wrapped_dek = Some(map.next_value()?);
90                            }
91                            _ => {
92                                return Err(serde::de::Error::unknown_field(
93                                    &key,
94                                    &["capsule", "wrapped_dek"],
95                                ));
96                            }
97                        }
98                    }
99                    let capsule =
100                        capsule.ok_or_else(|| serde::de::Error::missing_field("capsule"))?;
101                    let wrapped_dek = wrapped_dek
102                        .ok_or_else(|| serde::de::Error::missing_field("wrapped_dek"))?;
103
104                    let capsule_bytes = hex::decode(&capsule).map_err(serde::de::Error::custom)?;
105                    let wrapped_dek_bytes =
106                        hex::decode(&wrapped_dek).map_err(serde::de::Error::custom)?;
107                    if wrapped_dek_bytes.len() != 40 {
108                        return Err(serde::de::Error::custom("wrapped_dek must be 40 bytes"));
109                    }
110                    let mut wrapped_dek_array = [0u8; 40];
111                    wrapped_dek_array.copy_from_slice(&wrapped_dek_bytes);
112
113                    let capsule =
114                        process_capsule_bytes(&capsule_bytes).map_err(serde::de::Error::custom)?;
115
116                    Ok(Recipient {
117                        capsule,
118                        wrapped_dek: wrapped_dek_array,
119                    })
120                }
121            }
122            d.deserialize_struct("Recipient", &["capsule", "wrapped_dek"], RecipientVisitor)
123        } else {
124            #[derive(Deserialize)]
125            struct RecipientHelper {
126                capsule: Vec<u8>,
127                #[serde(with = "serde_big_array::BigArray")]
128                wrapped_dek: [u8; 40],
129            }
130            let helper = RecipientHelper::deserialize(d)?;
131
132            Ok(Recipient {
133                capsule: process_capsule_bytes(&helper.capsule)
134                    .map_err(serde::de::Error::custom)?,
135                wrapped_dek: helper.wrapped_dek,
136            })
137        }
138    }
139}
140
141impl Recipient {
142    pub(crate) fn new(
143        data_encryption_key: &[u8; 32],
144        recipient_public_key: &oqs::kem::PublicKey,
145        scheme: Scheme,
146    ) -> Result<Self> {
147        let kem: oqs::kem::Kem = scheme.into();
148        let (capsule, shared_secret) = kem.encapsulate(recipient_public_key)?;
149        let kw = scheme.create_kek(shared_secret);
150        let mut wrapped_dek = [0u8; 40];
151        kw.wrap(data_encryption_key, &mut wrapped_dek)?;
152
153        Ok(Recipient {
154            capsule,
155            wrapped_dek,
156        })
157    }
158
159    pub(crate) fn unwrap_dek(
160        &self,
161        recipient_secret_key: &oqs::kem::SecretKey,
162        scheme: Scheme,
163    ) -> Result<[u8; 32]> {
164        let kem: oqs::kem::Kem = scheme.into();
165        let shared_secret = kem.decapsulate(recipient_secret_key, &self.capsule)?;
166        let kw = scheme.create_kek(shared_secret);
167        let mut dek = [0u8; 32];
168        kw.unwrap(&self.wrapped_dek, &mut dek)?;
169        Ok(dek)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use rstest::*;
177
178    #[rstest]
179    #[case::small(Scheme::Small)]
180    #[case::nist(Scheme::Nist)]
181    #[case::secure(Scheme::Secure)]
182    fn serialization_human_readable(#[case] scheme: Scheme) {
183        let (pk, _sk) = scheme.key_pair().unwrap();
184        let dek = [0u8; 32];
185        let recipient = Recipient::new(&dek, &pk, scheme).unwrap();
186        let serialized = serde_json::to_string(&recipient).unwrap();
187        let deserialized: Recipient = serde_json::from_str(&serialized).unwrap();
188        assert_eq!(recipient.capsule.as_ref(), deserialized.capsule.as_ref());
189        assert_eq!(recipient.wrapped_dek, deserialized.wrapped_dek);
190    }
191
192    #[rstest]
193    #[case::small(Scheme::Small)]
194    #[case::nist(Scheme::Nist)]
195    #[case::secure(Scheme::Secure)]
196    fn serialization_binary(#[case] scheme: Scheme) {
197        let (pk, _sk) = scheme.key_pair().unwrap();
198        let dek = [0u8; 32];
199        let recipient = Recipient::new(&dek, &pk, scheme).unwrap();
200        let serialized = postcard::to_stdvec(&recipient).unwrap();
201        let deserialized: Recipient = postcard::from_bytes(&serialized).unwrap();
202        assert_eq!(recipient.capsule.as_ref(), deserialized.capsule.as_ref());
203        assert_eq!(recipient.wrapped_dek, deserialized.wrapped_dek);
204
205        assert_eq!(serialized.len(), scheme.recipient_binary_size());
206    }
207
208    #[rstest]
209    #[case::small(Scheme::Small)]
210    #[case::nist(Scheme::Nist)]
211    #[case::secure(Scheme::Secure)]
212    fn dek_unwrap(#[case] scheme: Scheme) {
213        let (pk, sk) = scheme.key_pair().unwrap();
214        let dek = [1u8; 32];
215        let recipient = Recipient::new(&dek, &pk, scheme).unwrap();
216        let unwrapped_dek = recipient.unwrap_dek(&sk, scheme).unwrap();
217        assert_eq!(dek, unwrapped_dek);
218    }
219
220    #[test]
221    fn incompatibility() {
222        let (pk_small, sk_small) = Scheme::Small.key_pair().unwrap();
223        let (pk_nist, sk_nist) = Scheme::Nist.key_pair().unwrap();
224        let dek = [1u8; 32];
225        let recipient_small = Recipient::new(&dek, &pk_small, Scheme::Small).unwrap();
226        let recipient_nist = Recipient::new(&dek, &pk_nist, Scheme::Nist).unwrap();
227        assert!(recipient_small.unwrap_dek(&sk_nist, Scheme::Nist).is_err());
228        assert!(recipient_nist.unwrap_dek(&sk_small, Scheme::Small).is_err());
229    }
230}