1use alloc::boxed::Box;
2use alloc::string::String;
3
4use ferveo::api::{CiphertextHeader, FerveoVariant};
5use serde::{Deserialize, Serialize};
6use umbral_pre::serde_bytes;
7
8use crate::access_control::AccessControlPolicy;
9use crate::conditions::Context;
10use crate::session::key::{SessionSharedSecret, SessionStaticKey};
11use crate::session::{decrypt_with_shared_secret, encrypt_with_shared_secret, DecryptionError};
12use crate::versioning::{
13 messagepack_deserialize, messagepack_serialize, ProtocolObject, ProtocolObjectInner,
14};
15
16#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)]
18pub struct ThresholdDecryptionRequest {
19 pub ritual_id: u32,
21 pub ciphertext_header: CiphertextHeader,
23 pub acp: AccessControlPolicy,
25 pub context: Option<Context>,
27 pub variant: FerveoVariant,
29}
30
31impl ThresholdDecryptionRequest {
32 pub fn new(
34 ritual_id: u32,
35 ciphertext_header: &CiphertextHeader,
36 acp: &AccessControlPolicy,
37 context: Option<&Context>,
38 variant: FerveoVariant,
39 ) -> Self {
40 Self {
41 ritual_id,
42 ciphertext_header: ciphertext_header.clone(),
43 acp: acp.clone(),
44 context: context.cloned(),
45 variant,
46 }
47 }
48
49 pub fn encrypt(
51 &self,
52 shared_secret: &SessionSharedSecret,
53 requester_public_key: &SessionStaticKey,
54 ) -> EncryptedThresholdDecryptionRequest {
55 EncryptedThresholdDecryptionRequest::new(self, shared_secret, requester_public_key)
56 }
57}
58
59impl<'a> ProtocolObjectInner<'a> for ThresholdDecryptionRequest {
60 fn version() -> (u16, u16) {
61 (4, 0)
62 }
63
64 fn brand() -> [u8; 4] {
65 *b"ThRq"
66 }
67
68 fn unversioned_to_bytes(&self) -> Box<[u8]> {
69 messagepack_serialize(&self)
70 }
71
72 fn unversioned_from_bytes(minor_version: u16, bytes: &[u8]) -> Option<Result<Self, String>> {
73 if minor_version == 0 {
74 Some(messagepack_deserialize(bytes))
75 } else {
76 None
77 }
78 }
79}
80
81impl<'a> ProtocolObject<'a> for ThresholdDecryptionRequest {}
82
83#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
85pub struct EncryptedThresholdDecryptionRequest {
86 pub ritual_id: u32,
88
89 pub requester_public_key: SessionStaticKey,
91
92 #[serde(with = "serde_bytes::as_base64")]
93 ciphertext: Box<[u8]>,
95}
96
97impl EncryptedThresholdDecryptionRequest {
98 fn new(
99 request: &ThresholdDecryptionRequest,
100 shared_secret: &SessionSharedSecret,
101 requester_public_key: &SessionStaticKey,
102 ) -> Self {
103 let ciphertext = encrypt_with_shared_secret(shared_secret, &request.to_bytes())
104 .expect("encryption failed - out of memory?");
105 Self {
106 ritual_id: request.ritual_id,
107 requester_public_key: *requester_public_key,
108 ciphertext,
109 }
110 }
111
112 pub fn decrypt(
114 &self,
115 shared_secret: &SessionSharedSecret,
116 ) -> Result<ThresholdDecryptionRequest, DecryptionError> {
117 let decryption_request_bytes = decrypt_with_shared_secret(shared_secret, &self.ciphertext)?;
118 let decryption_request = ThresholdDecryptionRequest::from_bytes(&decryption_request_bytes)
119 .map_err(DecryptionError::DeserializationFailed)?;
120 Ok(decryption_request)
121 }
122}
123
124impl<'a> ProtocolObjectInner<'a> for EncryptedThresholdDecryptionRequest {
125 fn version() -> (u16, u16) {
126 (2, 0)
127 }
128
129 fn brand() -> [u8; 4] {
130 *b"ETRq"
131 }
132
133 fn unversioned_to_bytes(&self) -> Box<[u8]> {
134 messagepack_serialize(&self)
135 }
136
137 fn unversioned_from_bytes(minor_version: u16, bytes: &[u8]) -> Option<Result<Self, String>> {
138 if minor_version == 0 {
139 Some(messagepack_deserialize(bytes))
140 } else {
141 None
142 }
143 }
144}
145
146impl<'a> ProtocolObject<'a> for EncryptedThresholdDecryptionRequest {}
147
148#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Clone)]
150pub struct ThresholdDecryptionResponse {
151 pub ritual_id: u32,
153
154 #[serde(with = "serde_bytes::as_base64")]
156 pub decryption_share: Box<[u8]>,
157}
158
159impl ThresholdDecryptionResponse {
160 pub fn new(ritual_id: u32, decryption_share: &[u8]) -> Self {
162 ThresholdDecryptionResponse {
163 ritual_id,
164 decryption_share: decryption_share.to_vec().into(),
165 }
166 }
167
168 pub fn encrypt(
170 &self,
171 shared_secret: &SessionSharedSecret,
172 ) -> EncryptedThresholdDecryptionResponse {
173 EncryptedThresholdDecryptionResponse::new(self, shared_secret)
174 }
175}
176
177impl<'a> ProtocolObjectInner<'a> for ThresholdDecryptionResponse {
178 fn version() -> (u16, u16) {
179 (2, 0)
180 }
181
182 fn brand() -> [u8; 4] {
183 *b"ThRs"
184 }
185
186 fn unversioned_to_bytes(&self) -> Box<[u8]> {
187 messagepack_serialize(&self)
188 }
189
190 fn unversioned_from_bytes(minor_version: u16, bytes: &[u8]) -> Option<Result<Self, String>> {
191 if minor_version == 0 {
192 Some(messagepack_deserialize(bytes))
193 } else {
194 None
195 }
196 }
197}
198
199impl<'a> ProtocolObject<'a> for ThresholdDecryptionResponse {}
200
201#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
203pub struct EncryptedThresholdDecryptionResponse {
204 pub ritual_id: u32,
206
207 #[serde(with = "serde_bytes::as_base64")]
208 ciphertext: Box<[u8]>,
209}
210
211impl EncryptedThresholdDecryptionResponse {
212 fn new(response: &ThresholdDecryptionResponse, shared_secret: &SessionSharedSecret) -> Self {
213 let ciphertext = encrypt_with_shared_secret(shared_secret, &response.to_bytes())
214 .expect("encryption failed - out of memory?");
215 Self {
216 ritual_id: response.ritual_id,
217 ciphertext,
218 }
219 }
220
221 pub fn decrypt(
223 &self,
224 shared_secret: &SessionSharedSecret,
225 ) -> Result<ThresholdDecryptionResponse, DecryptionError> {
226 let decryption_response_bytes =
227 decrypt_with_shared_secret(shared_secret, &self.ciphertext)?;
228 let decryption_response =
229 ThresholdDecryptionResponse::from_bytes(&decryption_response_bytes)
230 .map_err(DecryptionError::DeserializationFailed)?;
231 Ok(decryption_response)
232 }
233}
234
235impl<'a> ProtocolObjectInner<'a> for EncryptedThresholdDecryptionResponse {
236 fn version() -> (u16, u16) {
237 (2, 0)
238 }
239
240 fn brand() -> [u8; 4] {
241 *b"ETRs"
242 }
243
244 fn unversioned_to_bytes(&self) -> Box<[u8]> {
245 messagepack_serialize(&self)
246 }
247
248 fn unversioned_from_bytes(minor_version: u16, bytes: &[u8]) -> Option<Result<Self, String>> {
249 if minor_version == 0 {
250 Some(messagepack_deserialize(bytes))
251 } else {
252 None
253 }
254 }
255}
256
257impl<'a> ProtocolObject<'a> for EncryptedThresholdDecryptionResponse {}
258
259#[cfg(test)]
260mod tests {
261 use crate::access_control::AccessControlPolicy;
262 use crate::conditions::{Conditions, Context};
263 use crate::session::key::SessionStaticSecret;
264 use crate::test_utils::util::random_dkg_pubkey;
265 use crate::versioning::ProtocolObject;
266 use crate::{
267 AuthenticatedData, EncryptedThresholdDecryptionRequest,
268 EncryptedThresholdDecryptionResponse, ThresholdDecryptionRequest,
269 ThresholdDecryptionResponse,
270 };
271 use ferveo::api::{encrypt as ferveo_encrypt, FerveoVariant, SecretBox};
272
273 #[test]
274 fn threshold_decryption_request() {
275 for variant in [FerveoVariant::Simple, FerveoVariant::Precomputed] {
276 let ritual_id = 0;
277
278 let service_secret = SessionStaticSecret::random();
279
280 let requester_secret = SessionStaticSecret::random();
281 let requester_public_key = requester_secret.public_key();
282
283 let dkg_pk = random_dkg_pubkey();
284 let message = "The Tyranny of Merit".as_bytes().to_vec();
285 let aad = "my-add".as_bytes();
286 let ciphertext = ferveo_encrypt(SecretBox::new(message), aad, &dkg_pk).unwrap();
287
288 let auth_data = AuthenticatedData::new(&dkg_pk, &Conditions::new("abcd"));
289
290 let authorization = b"self_authorization";
291 let acp = AccessControlPolicy::new(&auth_data, authorization);
292
293 let ciphertext_header = ciphertext.header().unwrap();
294
295 let request = ThresholdDecryptionRequest::new(
296 ritual_id,
297 &ciphertext_header,
298 &acp,
299 Some(&Context::new("efgh")),
300 variant,
301 );
302
303 let service_public_key = service_secret.public_key();
305 let requester_shared_secret =
306 requester_secret.derive_shared_secret(&service_public_key);
307 let encrypted_request =
308 request.encrypt(&requester_shared_secret, &requester_public_key);
309
310 let encrypted_request_bytes = encrypted_request.to_bytes();
312 let encrypted_request_from_bytes =
313 EncryptedThresholdDecryptionRequest::from_bytes(&encrypted_request_bytes).unwrap();
314
315 assert_eq!(encrypted_request_from_bytes.ritual_id, ritual_id);
316 assert_eq!(
317 encrypted_request_from_bytes.requester_public_key,
318 requester_public_key
319 );
320
321 let service_shared_secret = service_secret
323 .derive_shared_secret(&encrypted_request_from_bytes.requester_public_key);
324 assert_eq!(
325 service_shared_secret.as_bytes(),
326 requester_shared_secret.as_bytes()
327 );
328 let decrypted_request = encrypted_request_from_bytes
329 .decrypt(&service_shared_secret)
330 .unwrap();
331 assert_eq!(decrypted_request, request);
332
333 let random_secret_key = SessionStaticSecret::random();
335 let random_shared_secret =
336 random_secret_key.derive_shared_secret(&requester_public_key);
337 assert!(encrypted_request_from_bytes
338 .decrypt(&random_shared_secret)
339 .is_err());
340 }
341 }
342
343 #[test]
344 fn threshold_decryption_response() {
345 let ritual_id = 5;
346
347 let service_secret = SessionStaticSecret::random();
348 let requester_secret = SessionStaticSecret::random();
349
350 let decryption_share = b"The Tyranny of Merit";
351
352 let response = ThresholdDecryptionResponse::new(ritual_id, decryption_share);
353
354 let requester_public_key = requester_secret.public_key();
356
357 let service_shared_secret = service_secret.derive_shared_secret(&requester_public_key);
358 let encrypted_response = response.encrypt(&service_shared_secret);
359 assert_eq!(encrypted_response.ritual_id, ritual_id);
360
361 let encrypted_response_bytes = encrypted_response.to_bytes();
363 let encrypted_response_from_bytes =
364 EncryptedThresholdDecryptionResponse::from_bytes(&encrypted_response_bytes).unwrap();
365
366 let service_public_key = service_secret.public_key();
368 let requester_shared_secret = requester_secret.derive_shared_secret(&service_public_key);
369 assert_eq!(
370 requester_shared_secret.as_bytes(),
371 service_shared_secret.as_bytes()
372 );
373 let decrypted_response = encrypted_response_from_bytes
374 .decrypt(&requester_shared_secret)
375 .unwrap();
376 assert_eq!(response, decrypted_response);
377 assert_eq!(response.ritual_id, ritual_id);
378 assert_eq!(
379 response.decryption_share,
380 decrypted_response.decryption_share
381 );
382
383 let random_secret_key = SessionStaticSecret::random();
385 let random_shared_secret = random_secret_key.derive_shared_secret(&requester_public_key);
386 assert!(encrypted_response_from_bytes
387 .decrypt(&random_shared_secret)
388 .is_err());
389 }
390}