1use std::{collections::HashMap, fmt, io};
2
3use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
4use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
5use ferveo_common::serialization;
6pub use ferveo_tdec::{
7 api::{
8 prepare_combine_simple, share_combine_precomputed,
9 share_combine_simple, DecryptionSharePrecomputed, Fr, G1Affine,
10 G1Prepared, G2Affine, SecretBox, E,
11 },
12 DomainPoint,
13};
14use generic_array::{
15 typenum::{Unsigned, U48},
16 GenericArray,
17};
18use rand::{thread_rng, RngCore};
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20use serde_with::serde_as;
21
22#[cfg(feature = "bindings-python")]
23use crate::bindings_python;
24#[cfg(feature = "bindings-wasm")]
25use crate::bindings_wasm;
26pub use crate::EthereumAddress;
27use crate::{
28 do_verify_aggregation, Error, PubliclyVerifiableSS, Result,
29 UpdateTranscript,
30};
31
32pub type ValidatorPublicKey = ferveo_common::PublicKey<E>;
33pub type ValidatorKeypair = ferveo_common::Keypair<E>;
34pub type Validator = crate::Validator<E>;
35pub type Transcript = PubliclyVerifiableSS<E>;
36pub type RefreshTranscript = UpdateTranscript<E>; pub type ValidatorMessage = (Validator, Transcript);
38
39pub fn to_bytes<T: CanonicalSerialize>(item: &T) -> Result<Vec<u8>> {
43 let mut writer = Vec::new();
44 item.serialize_compressed(&mut writer)?;
45 Ok(writer)
46}
47
48pub fn from_bytes<T: CanonicalDeserialize>(bytes: &[u8]) -> Result<T> {
49 let mut reader = io::Cursor::new(bytes);
50 let item = T::deserialize_compressed(&mut reader)?;
51 Ok(item)
52}
53
54pub fn encrypt(
55 message: SecretBox<Vec<u8>>,
56 aad: &[u8],
57 public_key: &DkgPublicKey,
58) -> Result<Ciphertext> {
59 let mut rng = thread_rng();
60 let ciphertext =
61 ferveo_tdec::api::encrypt(message, aad, &public_key.0, &mut rng)?;
62 Ok(Ciphertext(ciphertext))
63}
64
65pub fn decrypt_with_shared_secret(
66 ciphertext: &Ciphertext,
67 aad: &[u8],
68 shared_secret: &SharedSecret,
69) -> Result<Vec<u8>> {
70 ferveo_tdec::api::decrypt_with_shared_secret(
71 &ciphertext.0,
72 aad,
73 &shared_secret.0,
74 )
75 .map_err(Error::from)
76}
77
78#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Eq)]
79pub struct Ciphertext(ferveo_tdec::api::Ciphertext);
80
81impl Ciphertext {
82 pub fn header(&self) -> Result<CiphertextHeader> {
83 Ok(CiphertextHeader(self.0.header()?))
84 }
85
86 pub fn payload(&self) -> Vec<u8> {
87 self.0.payload()
88 }
89}
90
91#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
92pub struct CiphertextHeader(ferveo_tdec::api::CiphertextHeader);
93
94#[derive(
96 PartialEq, Eq, Debug, Serialize, Deserialize, Copy, Clone, PartialOrd,
97)]
98pub enum FerveoVariant {
99 Simple,
101 Precomputed,
103}
104
105impl fmt::Display for FerveoVariant {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 write!(f, "{}", self.as_str())
108 }
109}
110
111impl FerveoVariant {
112 pub fn as_str(&self) -> &'static str {
113 match self {
114 FerveoVariant::Simple => "FerveoVariant::Simple",
115 FerveoVariant::Precomputed => "FerveoVariant::Precomputed",
116 }
117 }
118
119 pub fn from_string(s: &str) -> Result<Self> {
120 match s {
121 "FerveoVariant::Simple" => Ok(FerveoVariant::Simple),
122 "FerveoVariant::Precomputed" => Ok(FerveoVariant::Precomputed),
123 _ => Err(Error::InvalidVariant(s.to_string())),
124 }
125 }
126}
127
128#[cfg(feature = "bindings-python")]
129impl From<bindings_python::FerveoVariant> for FerveoVariant {
130 fn from(variant: bindings_python::FerveoVariant) -> Self {
131 variant.0
132 }
133}
134
135#[cfg(feature = "bindings-wasm")]
136impl From<bindings_wasm::FerveoVariant> for FerveoVariant {
137 fn from(variant: bindings_wasm::FerveoVariant) -> Self {
138 variant.0
139 }
140}
141
142#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
143pub struct DkgPublicKey(
144 #[serde(bound(
145 serialize = "ferveo_tdec::DkgPublicKey<E>: Serialize",
146 deserialize = "ferveo_tdec::DkgPublicKey<E>: DeserializeOwned"
147 ))]
148 pub(crate) ferveo_tdec::DkgPublicKey<E>,
149);
150
151impl DkgPublicKey {
153 pub fn to_bytes(&self) -> Result<GenericArray<u8, U48>> {
154 let as_bytes = to_bytes(&self.0 .0)?;
155 Ok(GenericArray::<u8, U48>::from_slice(&as_bytes).to_owned())
156 }
157
158 pub fn from_bytes(bytes: &[u8]) -> Result<DkgPublicKey> {
159 let bytes =
160 GenericArray::<u8, U48>::from_exact_iter(bytes.iter().cloned())
161 .ok_or_else(|| {
162 Error::InvalidByteLength(
163 Self::serialized_size(),
164 bytes.len(),
165 )
166 })?;
167 let pk: G1Affine = from_bytes(&bytes)?;
168 Ok(DkgPublicKey(ferveo_tdec::DkgPublicKey(pk)))
169 }
170
171 pub fn serialized_size() -> usize {
172 U48::to_usize()
173 }
174}
175
176#[serde_as]
178#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
179pub struct FieldPoint(#[serde_as(as = "serialization::SerdeAs")] pub Fr);
180
181impl FieldPoint {
182 pub fn to_bytes(&self) -> Result<Vec<u8>> {
183 to_bytes(&self.0)
184 }
185
186 pub fn from_bytes(bytes: &[u8]) -> Result<FieldPoint> {
187 from_bytes(bytes).map(FieldPoint)
188 }
189}
190
191#[derive(Clone)]
192pub struct Dkg(crate::PubliclyVerifiableDkg<E>);
193
194impl Dkg {
195 pub fn new(
196 tau: u32,
197 shares_num: u32,
198 security_threshold: u32,
199 validators: &[Validator],
200 me: &Validator,
201 ) -> Result<Self> {
202 let dkg_params =
203 crate::DkgParams::new(tau, security_threshold, shares_num)?;
204 let dkg = crate::PubliclyVerifiableDkg::<E>::new(
205 validators,
206 &dkg_params,
207 me,
208 )?;
209 Ok(Self(dkg))
210 }
211
212 pub fn generate_transcript<R: RngCore>(
213 &mut self,
214 rng: &mut R,
215 ) -> Result<Transcript> {
216 self.0.generate_transcript(rng)
217 }
218
219 pub fn aggregate_transcripts(
220 &self,
221 messages: &[ValidatorMessage],
222 ) -> Result<AggregatedTranscript> {
223 self.0
224 .aggregate_transcripts(messages)
225 .map(AggregatedTranscript)
226 }
227
228 pub fn generate_refresh_transcript<R: RngCore>(
229 &self,
230 rng: &mut R,
231 ) -> Result<RefreshTranscript> {
232 self.0.generate_refresh_transcript(rng)
233 }
234
235 pub fn generate_handover_transcript<R: RngCore>(
236 &self,
237 aggregate: &AggregatedTranscript,
238 handover_slot_index: u32,
239 incoming_validator_keypair: &ferveo_common::Keypair<E>,
240 rng: &mut R,
241 ) -> Result<HandoverTranscript> {
242 self.0
243 .generate_handover_transcript(
244 &aggregate.0,
245 handover_slot_index,
246 incoming_validator_keypair,
247 rng,
248 )
249 .map(HandoverTranscript)
250 }
251
252 pub fn me(&self) -> &Validator {
253 &self.0.me
254 }
255
256 pub fn domain_points(&self) -> Vec<DomainPoint<E>> {
257 self.0.domain_points()
258 }
259}
260
261#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
262pub struct AggregatedTranscript(crate::AggregatedTranscript<E>);
263
264#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
265pub struct HandoverTranscript(crate::HandoverTranscript<E>);
266
267impl AggregatedTranscript {
268 pub fn new(messages: &[ValidatorMessage]) -> Result<Self> {
269 let transcripts: Vec<_> = messages
270 .iter()
271 .map(|(_, transcript)| transcript.clone())
272 .collect();
273 let aggregated_transcript =
274 crate::AggregatedTranscript::<E>::from_transcripts(&transcripts)?;
275 Ok(AggregatedTranscript(aggregated_transcript))
276 }
277
278 pub fn verify(
279 &self,
280 validators_num: u32,
281 messages: &[ValidatorMessage],
282 ) -> Result<bool> {
283 if validators_num < messages.len() as u32 {
284 return Err(Error::InvalidAggregateVerificationParameters(
285 validators_num,
286 messages.len() as u32,
287 ));
288 }
289
290 let domain =
291 GeneralEvaluationDomain::<Fr>::new(validators_num as usize)
292 .expect("Unable to construct an evaluation domain");
293 let is_valid_optimistic = self.0.aggregate.verify_optimistic();
294 if !is_valid_optimistic {
295 return Err(Error::InvalidTranscriptAggregate);
296 }
297
298 let validators: Vec<_> = messages
299 .iter()
300 .map(|(validator, _)| validator)
301 .cloned()
302 .collect();
303 let pvss_list = messages
304 .iter()
305 .map(|(_validator, transcript)| transcript)
306 .cloned()
307 .collect::<Vec<_>>();
308 do_verify_aggregation(
310 &self.0.aggregate.coeffs,
311 &self.0.aggregate.shares,
312 &validators,
313 &domain,
314 &pvss_list,
315 )
316 }
317
318 pub fn create_decryption_share_precomputed(
319 &self,
320 dkg: &Dkg,
321 ciphertext_header: &CiphertextHeader,
322 aad: &[u8],
323 validator_keypair: &ValidatorKeypair,
324 selected_validators: &[Validator],
325 ) -> Result<DecryptionSharePrecomputed> {
326 let selected_domain_points = selected_validators
327 .iter()
328 .filter_map(|v| {
329 dkg.0
330 .get_domain_point(v.share_index)
331 .ok()
332 .map(|domain_point| (v.share_index, domain_point))
333 })
334 .collect::<HashMap<u32, ferveo_tdec::DomainPoint<E>>>();
335 self.0.aggregate.create_decryption_share_precomputed(
336 &ciphertext_header.0,
337 aad,
338 validator_keypair,
339 dkg.0.me.share_index,
340 &selected_domain_points,
341 )
342 }
343
344 pub fn create_decryption_share_simple(
345 &self,
346 dkg: &Dkg,
347 ciphertext_header: &CiphertextHeader,
348 aad: &[u8],
349 validator_keypair: &ValidatorKeypair,
350 ) -> Result<DecryptionShareSimple> {
351 let share = self.0.aggregate.create_decryption_share_simple(
352 &ciphertext_header.0,
353 aad,
354 validator_keypair,
355 dkg.0.me.share_index,
356 )?;
357 let domain_point = dkg.0.get_domain_point(dkg.0.me.share_index)?;
358 Ok(DecryptionShareSimple {
359 share,
360 domain_point,
361 })
362 }
363
364 pub fn public_key(&self) -> DkgPublicKey {
365 DkgPublicKey(self.0.public_key)
366 }
367
368 pub fn refresh(
369 &self,
370 update_transcripts: &HashMap<u32, RefreshTranscript>,
371 validator_keys_map: &HashMap<u32, ValidatorPublicKey>,
372 ) -> Result<Self> {
373 let updated_aggregate = self
375 .0
376 .aggregate
377 .refresh(update_transcripts, validator_keys_map)
378 .unwrap();
379 let eeww =
380 crate::AggregatedTranscript::<E>::from_aggregate(updated_aggregate)
381 .unwrap();
382 Ok(AggregatedTranscript(eeww))
383 }
384
385 pub fn finalize_handover(
386 &self,
387 handover_transcript: &HandoverTranscript,
388 validator_keypair: &ValidatorKeypair,
389 ) -> Result<Self> {
390 let new_aggregate = self
391 .0
392 .aggregate
393 .finalize_handover(&handover_transcript.0, validator_keypair)
394 .unwrap();
395 let eeww =
397 crate::AggregatedTranscript::<E>::from_aggregate(new_aggregate)
398 .unwrap();
399 Ok(AggregatedTranscript(eeww))
400 }
401}
402
403#[serde_as]
404#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
405pub struct DecryptionShareSimple {
406 share: ferveo_tdec::api::DecryptionShareSimple,
407 #[serde_as(as = "serialization::SerdeAs")]
408 domain_point: DomainPoint<E>,
409}
410
411pub fn combine_shares_simple(shares: &[DecryptionShareSimple]) -> SharedSecret {
412 let domain_points: Vec<_> = shares.iter().map(|s| s.domain_point).collect();
413 let lagrange_coefficients = prepare_combine_simple::<E>(&domain_points);
414
415 let shares: Vec<_> = shares.iter().cloned().map(|s| s.share).collect();
416 let shared_secret =
417 share_combine_simple(&shares, &lagrange_coefficients[..]);
418 SharedSecret(shared_secret)
419}
420
421#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
422pub struct SharedSecret(pub ferveo_tdec::api::SharedSecret<E>);
423
424#[cfg(test)]
425mod test_ferveo_api {
426
427 use ark_std::{iterable::Iterable, UniformRand};
428 use ferveo_tdec::SecretBox;
429 use itertools::{izip, Itertools};
430 use rand::{
431 prelude::{SliceRandom, StdRng},
432 SeedableRng,
433 };
434 use test_case::test_case;
435
436 use crate::{
437 api::*,
438 test_common::{gen_address, gen_keypairs, AAD, MSG, TAU},
439 };
440
441 type TestInputs =
442 (Vec<ValidatorMessage>, Vec<Validator>, Vec<ValidatorKeypair>);
443
444 fn make_test_inputs(
446 rng: &mut StdRng,
447 tau: u32,
448 security_threshold: u32,
449 shares_num: u32,
450 validators_num: u32,
451 ) -> TestInputs {
452 let validator_keypairs = gen_keypairs(validators_num);
453 let validators = validator_keypairs
454 .iter()
455 .enumerate()
456 .map(|(i, keypair)| Validator {
457 address: gen_address(i),
458 public_key: keypair.public_key(),
459 share_index: i as u32,
460 })
461 .collect::<Vec<_>>();
462
463 let mut messages: Vec<_> = validators
466 .iter()
467 .map(|sender| {
468 let dkg = Dkg::new(
469 tau,
470 shares_num,
471 security_threshold,
472 &validators,
473 sender,
474 )
475 .unwrap();
476 (sender.clone(), dkg.0.generate_transcript(rng).unwrap())
477 })
478 .collect();
479 messages.shuffle(rng);
480 (messages, validators, validator_keypairs)
481 }
482
483 fn random_dkg_public_key() -> DkgPublicKey {
484 let mut rng = thread_rng();
485 let g1 = G1Affine::rand(&mut rng);
486 DkgPublicKey(ferveo_tdec::DkgPublicKey(g1))
487 }
488
489 #[test]
490 fn test_dkg_pk_serialization() {
491 let dkg_pk = random_dkg_public_key();
492 let serialized = dkg_pk.to_bytes().unwrap();
493 let deserialized = DkgPublicKey::from_bytes(&serialized).unwrap();
494 assert_eq!(serialized.len(), 48_usize);
495 assert_eq!(dkg_pk, deserialized);
496 }
497
498 #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
499 #[test_case(4, 4; "N is a power of 2, t=N")]
500 #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
501 #[test_case(30, 30; "N is not a power of 2, t=N")]
502 fn test_server_api_tdec_precomputed(
503 shares_num: u32,
504 security_threshold: u32,
505 ) {
506 let validators_num = shares_num; let rng = &mut StdRng::seed_from_u64(0);
508 let (messages, validators, validator_keypairs) = make_test_inputs(
509 rng,
510 TAU,
511 security_threshold,
512 shares_num,
513 validators_num,
514 );
515 let messages = &messages[..shares_num as usize];
517
518 let me = validators[0].clone();
520 let dkg =
521 Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
522 .unwrap();
523 let local_aggregate = dkg.aggregate_transcripts(messages).unwrap();
524 assert!(local_aggregate.verify(validators_num, messages).unwrap());
525
526 let dkg_public_key = local_aggregate.public_key();
528
529 let ciphertext =
531 encrypt(SecretBox::new(MSG.to_vec()), AAD, &dkg_public_key)
532 .unwrap();
533
534 let selected_validators: Vec<_> = validators
537 .choose_multiple(rng, security_threshold as usize)
538 .cloned()
539 .collect();
540
541 let mut decryption_shares = selected_validators
543 .iter()
544 .map(|validator| {
545 let validator_keypair = validator_keypairs
546 .iter()
547 .find(|kp| kp.public_key() == validator.public_key)
548 .unwrap();
549 let dkg = Dkg::new(
551 TAU,
552 shares_num,
553 security_threshold,
554 &validators,
555 validator,
556 )
557 .unwrap();
558 let server_aggregate =
559 dkg.aggregate_transcripts(messages).unwrap();
560 assert!(server_aggregate
561 .verify(validators_num, messages)
562 .unwrap());
563
564 server_aggregate
566 .create_decryption_share_precomputed(
567 &dkg,
568 &ciphertext.header().unwrap(),
569 AAD,
570 validator_keypair,
571 &selected_validators,
572 )
573 .unwrap()
574 })
575 .take(security_threshold as usize)
577 .collect::<Vec<DecryptionSharePrecomputed>>();
578 decryption_shares.shuffle(rng);
579
580 let shared_secret = share_combine_precomputed(&decryption_shares);
583 let plaintext = decrypt_with_shared_secret(
584 &ciphertext,
585 AAD,
586 &SharedSecret(shared_secret),
587 )
588 .unwrap();
589 assert_eq!(plaintext, MSG);
590
591 let decryption_shares = decryption_shares
594 .iter()
595 .take(security_threshold as usize - 1)
596 .cloned()
597 .collect::<Vec<_>>();
598 let shared_secret = share_combine_precomputed(&decryption_shares);
599 let result = decrypt_with_shared_secret(
600 &ciphertext,
601 AAD,
602 &SharedSecret(shared_secret),
603 );
604 assert!(result.is_err());
605 }
606
607 #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
608 #[test_case(4, 4; "N is a power of 2, t=N")]
609 #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
610 #[test_case(30, 30; "N is not a power of 2, t=N")]
611 fn test_server_api_tdec_simple(shares_num: u32, security_threshold: u32) {
612 let rng = &mut StdRng::seed_from_u64(0);
613 let validators_num: u32 = shares_num; let (messages, validators, validator_keypairs) = make_test_inputs(
615 rng,
616 TAU,
617 security_threshold,
618 shares_num,
619 validators_num,
620 );
621 let messages = &messages[..shares_num as usize];
623
624 let local_aggregate = AggregatedTranscript::new(messages).unwrap();
627 assert!(local_aggregate.verify(validators_num, messages).unwrap());
628
629 let public_key = local_aggregate.public_key();
631
632 let ciphertext =
634 encrypt(SecretBox::new(MSG.to_vec()), AAD, &public_key).unwrap();
635
636 let mut decryption_shares: Vec<_> =
638 izip!(&validators, &validator_keypairs)
639 .map(|(validator, validator_keypair)| {
640 let dkg = Dkg::new(
642 TAU,
643 shares_num,
644 security_threshold,
645 &validators,
646 validator,
647 )
648 .unwrap();
649 let server_aggregate =
650 dkg.aggregate_transcripts(messages).unwrap();
651 assert!(server_aggregate
652 .verify(validators_num, messages)
653 .unwrap());
654 server_aggregate
655 .create_decryption_share_simple(
656 &dkg,
657 &ciphertext.header().unwrap(),
658 AAD,
659 validator_keypair,
660 )
661 .unwrap()
662 })
663 .take(security_threshold as usize)
665 .collect();
666 decryption_shares.shuffle(rng);
667
668 let decryption_shares =
671 decryption_shares[..security_threshold as usize].to_vec();
672
673 let shared_secret = combine_shares_simple(&decryption_shares);
674 let plaintext =
675 decrypt_with_shared_secret(&ciphertext, AAD, &shared_secret)
676 .unwrap();
677 assert_eq!(plaintext, MSG);
678
679 let decryption_shares =
682 decryption_shares[..security_threshold as usize - 1].to_vec();
683
684 let shared_secret = combine_shares_simple(&decryption_shares);
685 let result =
686 decrypt_with_shared_secret(&ciphertext, AAD, &shared_secret);
687 assert!(result.is_err());
688 }
689
690 #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
694 #[test_case(4, 4; "N is a power of 2, t=N")]
695 #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
696 #[test_case(30, 30; "N is not a power of 2, t=N")]
697 fn server_side_local_verification(
698 shares_num: u32,
699 security_threshold: u32,
700 ) {
701 let rng = &mut StdRng::seed_from_u64(0);
702 let validators_num: u32 = shares_num; let (messages, validators, _) = make_test_inputs(
704 rng,
705 TAU,
706 security_threshold,
707 shares_num,
708 validators_num,
709 );
710 let messages = &messages[..shares_num as usize];
712
713 let me = validators[0].clone();
716 let dkg =
717 Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
718 .unwrap();
719 let good_aggregate = dkg.aggregate_transcripts(messages).unwrap();
720 assert!(good_aggregate.verify(validators_num, messages).is_ok());
721
722 assert!(matches!(
729 good_aggregate.verify(messages.len() as u32 - 1, messages),
730 Err(Error::InvalidAggregateVerificationParameters(_, _))
731 ));
732
733 let dkg =
735 Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
736 .unwrap();
737 assert!(matches!(
738 dkg.aggregate_transcripts(&[]),
739 Err(Error::NoTranscriptsToAggregate)
740 ));
741
742 let dkg =
744 Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
745 .unwrap();
746 let not_enough_messages = &messages[..security_threshold as usize - 1];
747 assert!(not_enough_messages.len() < security_threshold as usize);
748 let insufficient_aggregate =
749 dkg.aggregate_transcripts(not_enough_messages).unwrap();
750 assert!(matches!(
751 insufficient_aggregate.verify(validators_num, messages),
752 Err(Error::InvalidTranscriptAggregate)
753 ));
754
755 let messages_with_duplicated_transcript = [
757 (
758 validators[security_threshold as usize - 1].clone(),
759 messages[security_threshold as usize - 1].1.clone(),
760 ),
761 (
762 validators[security_threshold as usize - 1].clone(),
763 messages[security_threshold as usize - 2].1.clone(),
764 ),
765 ];
766 assert!(dkg
767 .aggregate_transcripts(&messages_with_duplicated_transcript)
768 .is_err());
769
770 let messages_with_duplicated_transcript = [
771 (
772 validators[security_threshold as usize - 1].clone(),
773 messages[security_threshold as usize - 1].1.clone(),
774 ),
775 (
776 validators[security_threshold as usize - 2].clone(),
777 messages[security_threshold as usize - 1].1.clone(),
778 ),
779 ];
780 assert!(dkg
781 .aggregate_transcripts(&messages_with_duplicated_transcript)
782 .is_err());
783
784 let mut dkg =
787 Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
788 .unwrap();
789 let bad_message = (
790 messages[security_threshold as usize - 1].0.clone(),
792 dkg.generate_transcript(rng).unwrap(),
793 );
794 let mixed_messages = [
795 &messages[..(security_threshold - 1) as usize],
796 &[bad_message],
797 ]
798 .concat();
799 assert_eq!(mixed_messages.len(), security_threshold as usize);
800 let bad_aggregate = dkg.aggregate_transcripts(&mixed_messages).unwrap();
801 assert!(matches!(
802 bad_aggregate.verify(validators_num, messages),
803 Err(Error::InvalidTranscriptAggregate)
804 ));
805 }
806
807 #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
808 #[test_case(4, 4; "N is a power of 2, t=N")]
809 #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
810 #[test_case(30, 30; "N is not a power of 2, t=N")]
811 fn client_side_local_verification(
812 shares_num: u32,
813 security_threshold: u32,
814 ) {
815 let rng = &mut StdRng::seed_from_u64(0);
816 let validators_num: u32 = shares_num; let (messages, _, _) = make_test_inputs(
818 rng,
819 TAU,
820 security_threshold,
821 shares_num,
822 validators_num,
823 );
824
825 let messages = &messages[..shares_num as usize];
827
828 let good_aggregate = AggregatedTranscript::new(messages).unwrap();
830
831 let result = good_aggregate.verify(validators_num, messages);
836 assert!(result.is_ok());
837 assert!(result.unwrap());
838
839 assert!(matches!(
843 good_aggregate.verify(messages.len() as u32 - 1, messages),
844 Err(Error::InvalidAggregateVerificationParameters(_, _))
845 ));
846
847 assert!(matches!(
849 AggregatedTranscript::new(&[]),
850 Err(Error::NoTranscriptsToAggregate)
851 ));
852
853 let not_enough_messages = &messages[..security_threshold as usize - 1];
855 assert!(not_enough_messages.len() < security_threshold as usize);
856 let insufficient_aggregate =
857 AggregatedTranscript::new(not_enough_messages).unwrap();
858 let _result = insufficient_aggregate.verify(validators_num, messages);
859 assert!(matches!(
860 insufficient_aggregate.verify(validators_num, messages),
861 Err(Error::InvalidTranscriptAggregate)
862 ));
863
864 let (bad_messages, _, _) = make_test_inputs(
867 rng,
868 TAU,
869 security_threshold,
870 shares_num,
871 validators_num,
872 );
873 let mixed_messages = [&messages[..2], &bad_messages[..1]].concat();
874 let bad_aggregate = AggregatedTranscript::new(&mixed_messages).unwrap();
875 assert!(matches!(
876 bad_aggregate.verify(validators_num, messages),
877 Err(Error::InvalidTranscriptAggregate)
878 ));
879 }
880
881 fn make_share_update_test_inputs(
883 shares_num: u32,
884 validators_num: u32,
885 rng: &mut StdRng,
886 security_threshold: u32,
887 ) -> (
888 Vec<ValidatorMessage>,
889 Vec<Validator>,
890 Vec<ValidatorKeypair>,
891 Vec<Dkg>,
892 CiphertextHeader,
893 SharedSecret,
894 ) {
895 let (messages, validators, validator_keypairs) = make_test_inputs(
896 rng,
897 TAU,
898 security_threshold,
899 shares_num,
900 validators_num,
901 );
902 let dkgs = validators
903 .iter()
904 .map(|validator| {
905 Dkg::new(
906 TAU,
907 shares_num,
908 security_threshold,
909 &validators,
910 validator,
911 )
912 .unwrap()
913 })
914 .collect::<Vec<_>>();
915
916 let dkg = dkgs[0].clone();
918 let server_aggregate =
919 dkg.aggregate_transcripts(messages.as_slice()).unwrap();
920 assert!(server_aggregate
921 .verify(validators_num, messages.as_slice())
922 .unwrap());
923
924 let public_key = server_aggregate.public_key();
926 let ciphertext =
927 encrypt(SecretBox::new(MSG.to_vec()), AAD, &public_key).unwrap();
928 let ciphertext_header = ciphertext.header().unwrap();
929 let transcripts = messages
930 .iter()
931 .map(|(_, transcript)| transcript)
932 .cloned()
933 .collect::<Vec<_>>();
934 let (_, _, old_shared_secret) =
935 crate::test_dkg_full::create_shared_secret_simple_tdec(
936 &dkg.0,
937 AAD,
938 &ciphertext_header.0,
939 validator_keypairs.as_slice(),
940 &transcripts,
941 );
942 (
943 messages,
944 validators,
945 validator_keypairs,
946 dkgs,
947 ciphertext_header,
948 SharedSecret(old_shared_secret),
949 )
950 }
951
952 #[ignore = "Re-introduce recovery tests - #193"]
955 #[test_case(4, 4, true; "number of shares (validators) is a power of 2")]
956 #[test_case(7, 7, true; "number of shares (validators) is not a power of 2")]
957 #[test_case(4, 6, true; "number of validators greater than the number of shares")]
958 #[test_case(4, 6, false; "recovery at a specific point")]
959 fn test_dkg_simple_tdec_share_recovery(
960 shares_num: u32,
961 validators_num: u32,
962 _recover_at_random_point: bool,
963 ) {
964 let rng = &mut StdRng::seed_from_u64(0);
965 let security_threshold = shares_num / 2 + 1;
966 let (
967 mut messages,
968 mut validators,
969 mut validator_keypairs,
970 mut dkgs,
971 ciphertext_header,
972 old_shared_secret,
973 ) = make_share_update_test_inputs(
974 shares_num,
975 validators_num,
976 rng,
977 security_threshold,
978 );
979
980 let aggregated_transcript = dkgs[0]
985 .clone()
986 .aggregate_transcripts(messages.as_slice())
987 .unwrap();
988 assert!(aggregated_transcript
989 .verify(validators_num, messages.as_slice())
990 .unwrap());
991
992 let mut domain_points = dkgs[0].0.domain_point_map();
994 let _removed_domain_point = domain_points
995 .remove(&validators.last().unwrap().share_index)
996 .unwrap();
997
998 messages.pop().unwrap();
1001 dkgs.pop();
1002 validator_keypairs.pop().unwrap();
1003 let _removed_validator = validators.pop().unwrap();
1004
1005 let mut decryption_shares: Vec<DecryptionShareSimple> =
1075 validator_keypairs
1076 .iter()
1077 .zip_eq(dkgs.iter())
1078 .map(|(validator_keypair, validator_dkg)| {
1079 aggregated_transcript
1080 .create_decryption_share_simple(
1081 validator_dkg,
1082 &ciphertext_header,
1083 AAD,
1084 validator_keypair,
1085 )
1086 .unwrap()
1087 })
1088 .collect();
1089 decryption_shares.shuffle(rng);
1090
1091 let domain_points = domain_points
1126 .values()
1127 .take(security_threshold as usize)
1128 .cloned()
1129 .collect::<Vec<_>>();
1130 let decryption_shares =
1131 &decryption_shares[..security_threshold as usize];
1132 assert_eq!(domain_points.len(), security_threshold as usize);
1133 assert_eq!(decryption_shares.len(), security_threshold as usize);
1134
1135 let new_shared_secret = combine_shares_simple(decryption_shares);
1136 assert_ne!(
1137 old_shared_secret, new_shared_secret,
1138 "Shared secret reconstruction failed"
1139 );
1140 }
1141
1142 #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
1143 #[test_case(4, 4; "N is a power of 2, t=N")]
1144 #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
1145 #[test_case(30, 30; "N is not a power of 2, t=N")]
1146 fn test_dkg_api_simple_tdec_share_refresh(
1147 shares_num: u32,
1148 security_threshold: u32,
1149 ) {
1150 let rng = &mut StdRng::seed_from_u64(0);
1151 let validators_num: u32 = shares_num; let (
1153 messages,
1154 _validators,
1155 validator_keypairs,
1156 dkgs,
1157 ciphertext_header,
1158 old_shared_secret,
1159 ) = make_share_update_test_inputs(
1160 shares_num,
1161 validators_num,
1162 rng,
1163 security_threshold,
1164 );
1165
1166 let mut update_transcripts: HashMap<u32, RefreshTranscript> =
1169 HashMap::new();
1170 let mut validator_map: HashMap<u32, _> = HashMap::new();
1171
1172 for dkg in &dkgs {
1173 for validator in dkg.0.validators.values() {
1174 update_transcripts.insert(
1175 validator.share_index,
1176 dkg.generate_refresh_transcript(rng).unwrap(),
1177 );
1178 validator_map.insert(
1179 validator.share_index,
1180 validator_keypairs
1181 .get(validator.share_index as usize)
1182 .unwrap()
1183 .public_key(),
1184 );
1185 }
1186 }
1187
1188 let refreshed_aggregates: Vec<AggregatedTranscript> = dkgs
1195 .iter()
1196 .map(|validator_dkg| {
1197 let aggregate = validator_dkg
1199 .clone()
1200 .aggregate_transcripts(messages.as_slice())
1201 .unwrap();
1202 assert!(aggregate
1203 .verify(validators_num, messages.as_slice())
1204 .unwrap());
1205
1206 aggregate
1209 .refresh(&update_transcripts, &validator_map)
1210 .unwrap()
1211 })
1212 .collect();
1213
1214 let mut decryption_shares: Vec<DecryptionShareSimple> =
1218 validator_keypairs
1219 .iter()
1220 .zip_eq(dkgs.iter())
1221 .map(|(validator_keypair, validator_dkg)| {
1222 let validator_index =
1223 validator_dkg.me().share_index as usize;
1224
1225 let aggregate =
1226 refreshed_aggregates.get(validator_index).unwrap();
1227
1228 aggregate
1229 .create_decryption_share_simple(
1230 validator_dkg,
1231 &ciphertext_header,
1232 AAD,
1233 validator_keypair,
1234 )
1235 .unwrap()
1236 })
1237 .take(security_threshold as usize)
1239 .collect();
1240 decryption_shares.shuffle(rng);
1241
1242 let decryption_shares =
1243 &decryption_shares[..security_threshold as usize];
1244 assert_eq!(decryption_shares.len(), security_threshold as usize);
1245
1246 let new_shared_secret = combine_shares_simple(decryption_shares);
1247 assert_eq!(
1248 old_shared_secret, new_shared_secret,
1249 "Shared secret reconstruction failed"
1250 );
1251 }
1252}