1use crate::util::scheme_from_ciphertext_length;
2use crate::*;
3use serde::{
4 Deserialize, Deserializer, Serialize, Serializer,
5 de::{MapAccess, Visitor},
6 ser::SerializeStruct,
7};
8
9#[derive(Clone, Debug)]
11pub struct Recipient {
12 pub(crate) capsule: oqs::kem::Ciphertext,
14 pub(crate) wrapped_dek: [u8; 40],
16}
17
18impl std::fmt::Display for Recipient {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(
21 f,
22 "{{ capsule: {}, wrapped_dek: {} }}",
23 hex::encode(self.capsule.as_ref()),
24 hex::encode(self.wrapped_dek)
25 )
26 }
27}
28
29impl Serialize for Recipient {
30 fn serialize<S>(&self, s: S) -> std::result::Result<S::Ok, S::Error>
31 where
32 S: Serializer,
33 {
34 if s.is_human_readable() {
35 let mut state = s.serialize_struct("Recipient", 2)?;
36 state.serialize_field("capsule", &hex::encode(&self.capsule))?;
37 state.serialize_field("wrapped_dek", &hex::encode(self.wrapped_dek))?;
38 state.end()
39 } else {
40 let mut state = s.serialize_struct("Recipient", 2)?;
41 state.serialize_field("capsule", self.capsule.as_ref())?;
42 state.serialize_field("wrapped_dek", &serde_big_array::Array(self.wrapped_dek))?;
43 state.end()
44 }
45 }
46}
47
48impl<'de> Deserialize<'de> for Recipient {
49 fn deserialize<D>(d: D) -> std::result::Result<Self, D::Error>
50 where
51 D: Deserializer<'de>,
52 {
53 fn process_capsule_bytes(capsule_bytes: &[u8]) -> Result<oqs::kem::Ciphertext> {
54 let scheme = scheme_from_ciphertext_length(capsule_bytes.len())?;
55 let kem: oqs::kem::Kem = scheme.into();
56 kem.ciphertext_from_bytes(capsule_bytes)
57 .map(|pk| pk.to_owned())
58 .ok_or(Error::CapsuleConversion)
59 }
60
61 if d.is_human_readable() {
62 struct RecipientVisitor;
63
64 impl<'de> Visitor<'de> for RecipientVisitor {
65 type Value = Recipient;
66
67 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
68 write!(f, "a map representing a Recipient")
69 }
70
71 fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error>
72 where
73 A: MapAccess<'de>,
74 {
75 let mut capsule: Option<String> = None;
76 let mut wrapped_dek: Option<String> = None;
77 while let Some(key) = map.next_key::<String>()? {
78 match key.as_str() {
79 "capsule" => {
80 if capsule.is_some() {
81 return Err(serde::de::Error::duplicate_field("capsule"));
82 }
83 capsule = Some(map.next_value()?);
84 }
85 "wrapped_dek" => {
86 if wrapped_dek.is_some() {
87 return Err(serde::de::Error::duplicate_field("wrapped_dek"));
88 }
89 wrapped_dek = Some(map.next_value()?);
90 }
91 _ => {
92 return Err(serde::de::Error::unknown_field(
93 &key,
94 &["capsule", "wrapped_dek"],
95 ));
96 }
97 }
98 }
99 let capsule =
100 capsule.ok_or_else(|| serde::de::Error::missing_field("capsule"))?;
101 let wrapped_dek = wrapped_dek
102 .ok_or_else(|| serde::de::Error::missing_field("wrapped_dek"))?;
103
104 let capsule_bytes = hex::decode(&capsule).map_err(serde::de::Error::custom)?;
105 let wrapped_dek_bytes =
106 hex::decode(&wrapped_dek).map_err(serde::de::Error::custom)?;
107 if wrapped_dek_bytes.len() != 40 {
108 return Err(serde::de::Error::custom("wrapped_dek must be 40 bytes"));
109 }
110 let mut wrapped_dek_array = [0u8; 40];
111 wrapped_dek_array.copy_from_slice(&wrapped_dek_bytes);
112
113 let capsule =
114 process_capsule_bytes(&capsule_bytes).map_err(serde::de::Error::custom)?;
115
116 Ok(Recipient {
117 capsule,
118 wrapped_dek: wrapped_dek_array,
119 })
120 }
121 }
122 d.deserialize_struct("Recipient", &["capsule", "wrapped_dek"], RecipientVisitor)
123 } else {
124 #[derive(Deserialize)]
125 struct RecipientHelper {
126 capsule: Vec<u8>,
127 #[serde(with = "serde_big_array::BigArray")]
128 wrapped_dek: [u8; 40],
129 }
130 let helper = RecipientHelper::deserialize(d)?;
131
132 Ok(Recipient {
133 capsule: process_capsule_bytes(&helper.capsule)
134 .map_err(serde::de::Error::custom)?,
135 wrapped_dek: helper.wrapped_dek,
136 })
137 }
138 }
139}
140
141impl Recipient {
142 pub(crate) fn new(
143 data_encryption_key: &[u8; 32],
144 recipient_public_key: &oqs::kem::PublicKey,
145 scheme: Scheme,
146 ) -> Result<Self> {
147 let kem: oqs::kem::Kem = scheme.into();
148 let (capsule, shared_secret) = kem.encapsulate(recipient_public_key)?;
149 let kw = scheme.create_kek(shared_secret);
150 let mut wrapped_dek = [0u8; 40];
151 kw.wrap(data_encryption_key, &mut wrapped_dek)?;
152
153 Ok(Recipient {
154 capsule,
155 wrapped_dek,
156 })
157 }
158
159 pub(crate) fn unwrap_dek(
160 &self,
161 recipient_secret_key: &oqs::kem::SecretKey,
162 scheme: Scheme,
163 ) -> Result<[u8; 32]> {
164 let kem: oqs::kem::Kem = scheme.into();
165 let shared_secret = kem.decapsulate(recipient_secret_key, &self.capsule)?;
166 let kw = scheme.create_kek(shared_secret);
167 let mut dek = [0u8; 32];
168 kw.unwrap(&self.wrapped_dek, &mut dek)?;
169 Ok(dek)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use rstest::*;
177
178 #[rstest]
179 #[case::small(Scheme::Small)]
180 #[case::nist(Scheme::Nist)]
181 #[case::secure(Scheme::Secure)]
182 fn serialization_human_readable(#[case] scheme: Scheme) {
183 let (pk, _sk) = scheme.key_pair().unwrap();
184 let dek = [0u8; 32];
185 let recipient = Recipient::new(&dek, &pk, scheme).unwrap();
186 let serialized = serde_json::to_string(&recipient).unwrap();
187 let deserialized: Recipient = serde_json::from_str(&serialized).unwrap();
188 assert_eq!(recipient.capsule.as_ref(), deserialized.capsule.as_ref());
189 assert_eq!(recipient.wrapped_dek, deserialized.wrapped_dek);
190 }
191
192 #[rstest]
193 #[case::small(Scheme::Small)]
194 #[case::nist(Scheme::Nist)]
195 #[case::secure(Scheme::Secure)]
196 fn serialization_binary(#[case] scheme: Scheme) {
197 let (pk, _sk) = scheme.key_pair().unwrap();
198 let dek = [0u8; 32];
199 let recipient = Recipient::new(&dek, &pk, scheme).unwrap();
200 let serialized = postcard::to_stdvec(&recipient).unwrap();
201 let deserialized: Recipient = postcard::from_bytes(&serialized).unwrap();
202 assert_eq!(recipient.capsule.as_ref(), deserialized.capsule.as_ref());
203 assert_eq!(recipient.wrapped_dek, deserialized.wrapped_dek);
204
205 assert_eq!(serialized.len(), scheme.recipient_binary_size());
206 }
207
208 #[rstest]
209 #[case::small(Scheme::Small)]
210 #[case::nist(Scheme::Nist)]
211 #[case::secure(Scheme::Secure)]
212 fn dek_unwrap(#[case] scheme: Scheme) {
213 let (pk, sk) = scheme.key_pair().unwrap();
214 let dek = [1u8; 32];
215 let recipient = Recipient::new(&dek, &pk, scheme).unwrap();
216 let unwrapped_dek = recipient.unwrap_dek(&sk, scheme).unwrap();
217 assert_eq!(dek, unwrapped_dek);
218 }
219
220 #[test]
221 fn incompatibility() {
222 let (pk_small, sk_small) = Scheme::Small.key_pair().unwrap();
223 let (pk_nist, sk_nist) = Scheme::Nist.key_pair().unwrap();
224 let dek = [1u8; 32];
225 let recipient_small = Recipient::new(&dek, &pk_small, Scheme::Small).unwrap();
226 let recipient_nist = Recipient::new(&dek, &pk_nist, Scheme::Nist).unwrap();
227 assert!(recipient_small.unwrap_dek(&sk_nist, Scheme::Nist).is_err());
228 assert!(recipient_nist.unwrap_dek(&sk_small, Scheme::Small).is_err());
229 }
230}