Skip to main content

ferveo_nucypher/
api.rs

1use std::{collections::HashMap, fmt, io};
2
3use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
4use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
5use ferveo_common::serialization;
6pub use ferveo_tdec::{
7    api::{
8        prepare_combine_simple, share_combine_precomputed,
9        share_combine_simple, DecryptionSharePrecomputed, Fr, G1Affine,
10        G1Prepared, G2Affine, SecretBox, E,
11    },
12    DomainPoint,
13};
14use generic_array::{
15    typenum::{Unsigned, U48},
16    GenericArray,
17};
18use rand::{thread_rng, RngCore};
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20use serde_with::serde_as;
21
22#[cfg(feature = "bindings-python")]
23use crate::bindings_python;
24#[cfg(feature = "bindings-wasm")]
25use crate::bindings_wasm;
26pub use crate::EthereumAddress;
27use crate::{
28    do_verify_aggregation, Error, PubliclyVerifiableSS, Result,
29    UpdateTranscript,
30};
31
32pub type ValidatorPublicKey = ferveo_common::PublicKey<E>;
33pub type ValidatorKeypair = ferveo_common::Keypair<E>;
34pub type Validator = crate::Validator<E>;
35pub type Transcript = PubliclyVerifiableSS<E>;
36pub type RefreshTranscript = UpdateTranscript<E>; // TODO: Consider renaming to UpdateTranscript when dealing with #193
37pub type ValidatorMessage = (Validator, Transcript);
38
39// Normally, we would use a custom trait for this, but we can't because
40// the `arkworks` will not let us create a blanket implementation for G1Affine
41// and `Fr` types. So instead, we're using this shared utility function:
42pub fn to_bytes<T: CanonicalSerialize>(item: &T) -> Result<Vec<u8>> {
43    let mut writer = Vec::new();
44    item.serialize_compressed(&mut writer)?;
45    Ok(writer)
46}
47
48pub fn from_bytes<T: CanonicalDeserialize>(bytes: &[u8]) -> Result<T> {
49    let mut reader = io::Cursor::new(bytes);
50    let item = T::deserialize_compressed(&mut reader)?;
51    Ok(item)
52}
53
54pub fn encrypt(
55    message: SecretBox<Vec<u8>>,
56    aad: &[u8],
57    public_key: &DkgPublicKey,
58) -> Result<Ciphertext> {
59    let mut rng = thread_rng();
60    let ciphertext =
61        ferveo_tdec::api::encrypt(message, aad, &public_key.0, &mut rng)?;
62    Ok(Ciphertext(ciphertext))
63}
64
65pub fn decrypt_with_shared_secret(
66    ciphertext: &Ciphertext,
67    aad: &[u8],
68    shared_secret: &SharedSecret,
69) -> Result<Vec<u8>> {
70    ferveo_tdec::api::decrypt_with_shared_secret(
71        &ciphertext.0,
72        aad,
73        &shared_secret.0,
74    )
75    .map_err(Error::from)
76}
77
78#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Eq)]
79pub struct Ciphertext(ferveo_tdec::api::Ciphertext);
80
81impl Ciphertext {
82    pub fn header(&self) -> Result<CiphertextHeader> {
83        Ok(CiphertextHeader(self.0.header()?))
84    }
85
86    pub fn payload(&self) -> Vec<u8> {
87        self.0.payload()
88    }
89}
90
91#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
92pub struct CiphertextHeader(ferveo_tdec::api::CiphertextHeader);
93
94/// The ferveo variant to use for the decryption share derivation.
95#[derive(
96    PartialEq, Eq, Debug, Serialize, Deserialize, Copy, Clone, PartialOrd,
97)]
98pub enum FerveoVariant {
99    /// The simple variant requires m of n shares to decrypt
100    Simple,
101    /// The precomputed variant requires n of n shares to decrypt
102    Precomputed,
103}
104
105impl fmt::Display for FerveoVariant {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        write!(f, "{}", self.as_str())
108    }
109}
110
111impl FerveoVariant {
112    pub fn as_str(&self) -> &'static str {
113        match self {
114            FerveoVariant::Simple => "FerveoVariant::Simple",
115            FerveoVariant::Precomputed => "FerveoVariant::Precomputed",
116        }
117    }
118
119    pub fn from_string(s: &str) -> Result<Self> {
120        match s {
121            "FerveoVariant::Simple" => Ok(FerveoVariant::Simple),
122            "FerveoVariant::Precomputed" => Ok(FerveoVariant::Precomputed),
123            _ => Err(Error::InvalidVariant(s.to_string())),
124        }
125    }
126}
127
128#[cfg(feature = "bindings-python")]
129impl From<bindings_python::FerveoVariant> for FerveoVariant {
130    fn from(variant: bindings_python::FerveoVariant) -> Self {
131        variant.0
132    }
133}
134
135#[cfg(feature = "bindings-wasm")]
136impl From<bindings_wasm::FerveoVariant> for FerveoVariant {
137    fn from(variant: bindings_wasm::FerveoVariant) -> Self {
138        variant.0
139    }
140}
141
142#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
143pub struct DkgPublicKey(
144    #[serde(bound(
145        serialize = "ferveo_tdec::DkgPublicKey<E>: Serialize",
146        deserialize = "ferveo_tdec::DkgPublicKey<E>: DeserializeOwned"
147    ))]
148    pub(crate) ferveo_tdec::DkgPublicKey<E>,
149);
150
151// TODO: Consider moving these implementation details to ferveo_tdec::DkgPublicKey - #197
152impl DkgPublicKey {
153    pub fn to_bytes(&self) -> Result<GenericArray<u8, U48>> {
154        let as_bytes = to_bytes(&self.0 .0)?;
155        Ok(GenericArray::<u8, U48>::from_slice(&as_bytes).to_owned())
156    }
157
158    pub fn from_bytes(bytes: &[u8]) -> Result<DkgPublicKey> {
159        let bytes =
160            GenericArray::<u8, U48>::from_exact_iter(bytes.iter().cloned())
161                .ok_or_else(|| {
162                    Error::InvalidByteLength(
163                        Self::serialized_size(),
164                        bytes.len(),
165                    )
166                })?;
167        let pk: G1Affine = from_bytes(&bytes)?;
168        Ok(DkgPublicKey(ferveo_tdec::DkgPublicKey(pk)))
169    }
170
171    pub fn serialized_size() -> usize {
172        U48::to_usize()
173    }
174}
175
176// TODO: Consider if FieldPoint should be removed - #197
177#[serde_as]
178#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
179pub struct FieldPoint(#[serde_as(as = "serialization::SerdeAs")] pub Fr);
180
181impl FieldPoint {
182    pub fn to_bytes(&self) -> Result<Vec<u8>> {
183        to_bytes(&self.0)
184    }
185
186    pub fn from_bytes(bytes: &[u8]) -> Result<FieldPoint> {
187        from_bytes(bytes).map(FieldPoint)
188    }
189}
190
191#[derive(Clone)]
192pub struct Dkg(crate::PubliclyVerifiableDkg<E>);
193
194impl Dkg {
195    pub fn new(
196        tau: u32,
197        shares_num: u32,
198        security_threshold: u32,
199        validators: &[Validator],
200        me: &Validator,
201    ) -> Result<Self> {
202        let dkg_params =
203            crate::DkgParams::new(tau, security_threshold, shares_num)?;
204        let dkg = crate::PubliclyVerifiableDkg::<E>::new(
205            validators,
206            &dkg_params,
207            me,
208        )?;
209        Ok(Self(dkg))
210    }
211
212    pub fn generate_transcript<R: RngCore>(
213        &mut self,
214        rng: &mut R,
215    ) -> Result<Transcript> {
216        self.0.generate_transcript(rng)
217    }
218
219    pub fn aggregate_transcripts(
220        &self,
221        messages: &[ValidatorMessage],
222    ) -> Result<AggregatedTranscript> {
223        self.0
224            .aggregate_transcripts(messages)
225            .map(AggregatedTranscript)
226    }
227
228    pub fn generate_refresh_transcript<R: RngCore>(
229        &self,
230        rng: &mut R,
231    ) -> Result<RefreshTranscript> {
232        self.0.generate_refresh_transcript(rng)
233    }
234
235    pub fn generate_handover_transcript<R: RngCore>(
236        &self,
237        aggregate: &AggregatedTranscript,
238        handover_slot_index: u32,
239        incoming_validator_keypair: &ferveo_common::Keypair<E>,
240        rng: &mut R,
241    ) -> Result<HandoverTranscript> {
242        self.0
243            .generate_handover_transcript(
244                &aggregate.0,
245                handover_slot_index,
246                incoming_validator_keypair,
247                rng,
248            )
249            .map(HandoverTranscript)
250    }
251
252    pub fn me(&self) -> &Validator {
253        &self.0.me
254    }
255
256    pub fn domain_points(&self) -> Vec<DomainPoint<E>> {
257        self.0.domain_points()
258    }
259}
260
261#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
262pub struct AggregatedTranscript(crate::AggregatedTranscript<E>);
263
264#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
265pub struct HandoverTranscript(crate::HandoverTranscript<E>);
266
267impl AggregatedTranscript {
268    pub fn new(messages: &[ValidatorMessage]) -> Result<Self> {
269        let transcripts: Vec<_> = messages
270            .iter()
271            .map(|(_, transcript)| transcript.clone())
272            .collect();
273        let aggregated_transcript =
274            crate::AggregatedTranscript::<E>::from_transcripts(&transcripts)?;
275        Ok(AggregatedTranscript(aggregated_transcript))
276    }
277
278    pub fn verify(
279        &self,
280        validators_num: u32,
281        messages: &[ValidatorMessage],
282    ) -> Result<bool> {
283        if validators_num < messages.len() as u32 {
284            return Err(Error::InvalidAggregateVerificationParameters(
285                validators_num,
286                messages.len() as u32,
287            ));
288        }
289
290        let domain =
291            GeneralEvaluationDomain::<Fr>::new(validators_num as usize)
292                .expect("Unable to construct an evaluation domain");
293        let is_valid_optimistic = self.0.aggregate.verify_optimistic();
294        if !is_valid_optimistic {
295            return Err(Error::InvalidTranscriptAggregate);
296        }
297
298        let validators: Vec<_> = messages
299            .iter()
300            .map(|(validator, _)| validator)
301            .cloned()
302            .collect();
303        let pvss_list = messages
304            .iter()
305            .map(|(_validator, transcript)| transcript)
306            .cloned()
307            .collect::<Vec<_>>();
308        // This check also includes `verify_full`. See impl. for details.
309        do_verify_aggregation(
310            &self.0.aggregate.coeffs,
311            &self.0.aggregate.shares,
312            &validators,
313            &domain,
314            &pvss_list,
315        )
316    }
317
318    pub fn create_decryption_share_precomputed(
319        &self,
320        dkg: &Dkg,
321        ciphertext_header: &CiphertextHeader,
322        aad: &[u8],
323        validator_keypair: &ValidatorKeypair,
324        selected_validators: &[Validator],
325    ) -> Result<DecryptionSharePrecomputed> {
326        let selected_domain_points = selected_validators
327            .iter()
328            .filter_map(|v| {
329                dkg.0
330                    .get_domain_point(v.share_index)
331                    .ok()
332                    .map(|domain_point| (v.share_index, domain_point))
333            })
334            .collect::<HashMap<u32, ferveo_tdec::DomainPoint<E>>>();
335        self.0.aggregate.create_decryption_share_precomputed(
336            &ciphertext_header.0,
337            aad,
338            validator_keypair,
339            dkg.0.me.share_index,
340            &selected_domain_points,
341        )
342    }
343
344    pub fn create_decryption_share_simple(
345        &self,
346        dkg: &Dkg,
347        ciphertext_header: &CiphertextHeader,
348        aad: &[u8],
349        validator_keypair: &ValidatorKeypair,
350    ) -> Result<DecryptionShareSimple> {
351        let share = self.0.aggregate.create_decryption_share_simple(
352            &ciphertext_header.0,
353            aad,
354            validator_keypair,
355            dkg.0.me.share_index,
356        )?;
357        let domain_point = dkg.0.get_domain_point(dkg.0.me.share_index)?;
358        Ok(DecryptionShareSimple {
359            share,
360            domain_point,
361        })
362    }
363
364    pub fn public_key(&self) -> DkgPublicKey {
365        DkgPublicKey(self.0.public_key)
366    }
367
368    pub fn refresh(
369        &self,
370        update_transcripts: &HashMap<u32, RefreshTranscript>,
371        validator_keys_map: &HashMap<u32, ValidatorPublicKey>,
372    ) -> Result<Self> {
373        // TODO: Aggregates structs should be refactored, this is a bit of a mess - #162
374        let updated_aggregate = self
375            .0
376            .aggregate
377            .refresh(update_transcripts, validator_keys_map)
378            .unwrap();
379        let eeww =
380            crate::AggregatedTranscript::<E>::from_aggregate(updated_aggregate)
381                .unwrap();
382        Ok(AggregatedTranscript(eeww))
383    }
384
385    pub fn finalize_handover(
386        &self,
387        handover_transcript: &HandoverTranscript,
388        validator_keypair: &ValidatorKeypair,
389    ) -> Result<Self> {
390        let new_aggregate = self
391            .0
392            .aggregate
393            .finalize_handover(&handover_transcript.0, validator_keypair)
394            .unwrap();
395        // TODO: Aggregates structs should be refactored, this is a bit of a mess - #162
396        let eeww =
397            crate::AggregatedTranscript::<E>::from_aggregate(new_aggregate)
398                .unwrap();
399        Ok(AggregatedTranscript(eeww))
400    }
401}
402
403#[serde_as]
404#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
405pub struct DecryptionShareSimple {
406    share: ferveo_tdec::api::DecryptionShareSimple,
407    #[serde_as(as = "serialization::SerdeAs")]
408    domain_point: DomainPoint<E>,
409}
410
411pub fn combine_shares_simple(shares: &[DecryptionShareSimple]) -> SharedSecret {
412    let domain_points: Vec<_> = shares.iter().map(|s| s.domain_point).collect();
413    let lagrange_coefficients = prepare_combine_simple::<E>(&domain_points);
414
415    let shares: Vec<_> = shares.iter().cloned().map(|s| s.share).collect();
416    let shared_secret =
417        share_combine_simple(&shares, &lagrange_coefficients[..]);
418    SharedSecret(shared_secret)
419}
420
421#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
422pub struct SharedSecret(pub ferveo_tdec::api::SharedSecret<E>);
423
424#[cfg(test)]
425mod test_ferveo_api {
426
427    use ark_std::{iterable::Iterable, UniformRand};
428    use ferveo_tdec::SecretBox;
429    use itertools::{izip, Itertools};
430    use rand::{
431        prelude::{SliceRandom, StdRng},
432        SeedableRng,
433    };
434    use test_case::test_case;
435
436    use crate::{
437        api::*,
438        test_common::{gen_address, gen_keypairs, AAD, MSG, TAU},
439    };
440
441    type TestInputs =
442        (Vec<ValidatorMessage>, Vec<Validator>, Vec<ValidatorKeypair>);
443
444    // TODO: validators_num - #197
445    fn make_test_inputs(
446        rng: &mut StdRng,
447        tau: u32,
448        security_threshold: u32,
449        shares_num: u32,
450        validators_num: u32,
451    ) -> TestInputs {
452        let validator_keypairs = gen_keypairs(validators_num);
453        let validators = validator_keypairs
454            .iter()
455            .enumerate()
456            .map(|(i, keypair)| Validator {
457                address: gen_address(i),
458                public_key: keypair.public_key(),
459                share_index: i as u32,
460            })
461            .collect::<Vec<_>>();
462
463        // Each validator holds their own DKG instance and generates a transcript every
464        // validator, including themselves
465        let mut messages: Vec<_> = validators
466            .iter()
467            .map(|sender| {
468                let dkg = Dkg::new(
469                    tau,
470                    shares_num,
471                    security_threshold,
472                    &validators,
473                    sender,
474                )
475                .unwrap();
476                (sender.clone(), dkg.0.generate_transcript(rng).unwrap())
477            })
478            .collect();
479        messages.shuffle(rng);
480        (messages, validators, validator_keypairs)
481    }
482
483    fn random_dkg_public_key() -> DkgPublicKey {
484        let mut rng = thread_rng();
485        let g1 = G1Affine::rand(&mut rng);
486        DkgPublicKey(ferveo_tdec::DkgPublicKey(g1))
487    }
488
489    #[test]
490    fn test_dkg_pk_serialization() {
491        let dkg_pk = random_dkg_public_key();
492        let serialized = dkg_pk.to_bytes().unwrap();
493        let deserialized = DkgPublicKey::from_bytes(&serialized).unwrap();
494        assert_eq!(serialized.len(), 48_usize);
495        assert_eq!(dkg_pk, deserialized);
496    }
497
498    #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
499    #[test_case(4, 4; "N is a power of 2, t=N")]
500    #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
501    #[test_case(30, 30; "N is not a power of 2, t=N")]
502    fn test_server_api_tdec_precomputed(
503        shares_num: u32,
504        security_threshold: u32,
505    ) {
506        let validators_num = shares_num; // TODO: #197
507        let rng = &mut StdRng::seed_from_u64(0);
508        let (messages, validators, validator_keypairs) = make_test_inputs(
509            rng,
510            TAU,
511            security_threshold,
512            shares_num,
513            validators_num,
514        );
515        // We only need `shares_num` transcripts to aggregate
516        let messages = &messages[..shares_num as usize];
517
518        // Every validator can aggregate the transcripts
519        let me = validators[0].clone();
520        let dkg =
521            Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
522                .unwrap();
523        let local_aggregate = dkg.aggregate_transcripts(messages).unwrap();
524        assert!(local_aggregate.verify(validators_num, messages).unwrap());
525
526        // At this point, any given validator should be able to provide a DKG public key
527        let dkg_public_key = local_aggregate.public_key();
528
529        // In the meantime, the client creates a ciphertext and decryption request
530        let ciphertext =
531            encrypt(SecretBox::new(MSG.to_vec()), AAD, &dkg_public_key)
532                .unwrap();
533
534        // In precomputed variant, client selects a specific subset of validators to create
535        // decryption shares
536        let selected_validators: Vec<_> = validators
537            .choose_multiple(rng, security_threshold as usize)
538            .cloned()
539            .collect();
540
541        // Having aggregated the transcripts, the validators can now create decryption shares
542        let mut decryption_shares = selected_validators
543            .iter()
544            .map(|validator| {
545                let validator_keypair = validator_keypairs
546                    .iter()
547                    .find(|kp| kp.public_key() == validator.public_key)
548                    .unwrap();
549                // Each validator holds their own instance of DKG and creates their own aggregate
550                let dkg = Dkg::new(
551                    TAU,
552                    shares_num,
553                    security_threshold,
554                    &validators,
555                    validator,
556                )
557                .unwrap();
558                let server_aggregate =
559                    dkg.aggregate_transcripts(messages).unwrap();
560                assert!(server_aggregate
561                    .verify(validators_num, messages)
562                    .unwrap());
563
564                // And then each validator creates their own decryption share
565                server_aggregate
566                    .create_decryption_share_precomputed(
567                        &dkg,
568                        &ciphertext.header().unwrap(),
569                        AAD,
570                        validator_keypair,
571                        &selected_validators,
572                    )
573                    .unwrap()
574            })
575            // We only need `security_threshold` shares to be able to decrypt
576            .take(security_threshold as usize)
577            .collect::<Vec<DecryptionSharePrecomputed>>();
578        decryption_shares.shuffle(rng);
579
580        // Now, the decryption share can be used to decrypt the ciphertext
581        // This part is part of the client API
582        let shared_secret = share_combine_precomputed(&decryption_shares);
583        let plaintext = decrypt_with_shared_secret(
584            &ciphertext,
585            AAD,
586            &SharedSecret(shared_secret),
587        )
588        .unwrap();
589        assert_eq!(plaintext, MSG);
590
591        // We need `security_threshold` shares to be able to decrypt
592        // So if we remove one share, we should not be able to decrypt
593        let decryption_shares = decryption_shares
594            .iter()
595            .take(security_threshold as usize - 1)
596            .cloned()
597            .collect::<Vec<_>>();
598        let shared_secret = share_combine_precomputed(&decryption_shares);
599        let result = decrypt_with_shared_secret(
600            &ciphertext,
601            AAD,
602            &SharedSecret(shared_secret),
603        );
604        assert!(result.is_err());
605    }
606
607    #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
608    #[test_case(4, 4; "N is a power of 2, t=N")]
609    #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
610    #[test_case(30, 30; "N is not a power of 2, t=N")]
611    fn test_server_api_tdec_simple(shares_num: u32, security_threshold: u32) {
612        let rng = &mut StdRng::seed_from_u64(0);
613        let validators_num: u32 = shares_num; // TODO: #197
614        let (messages, validators, validator_keypairs) = make_test_inputs(
615            rng,
616            TAU,
617            security_threshold,
618            shares_num,
619            validators_num,
620        );
621        // We only need `shares_num` transcripts to aggregate
622        let messages = &messages[..shares_num as usize];
623
624        // Now that every validator holds a dkg instance and a transcript for every other validator,
625        // every validator can aggregate the transcripts
626        let local_aggregate = AggregatedTranscript::new(messages).unwrap();
627        assert!(local_aggregate.verify(validators_num, messages).unwrap());
628
629        // At this point, any given validator should be able to provide a DKG public key
630        let public_key = local_aggregate.public_key();
631
632        // In the meantime, the client creates a ciphertext and decryption request
633        let ciphertext =
634            encrypt(SecretBox::new(MSG.to_vec()), AAD, &public_key).unwrap();
635
636        // Having aggregated the transcripts, the validators can now create decryption shares
637        let mut decryption_shares: Vec<_> =
638            izip!(&validators, &validator_keypairs)
639                .map(|(validator, validator_keypair)| {
640                    // Each validator holds their own instance of DKG and creates their own aggregate
641                    let dkg = Dkg::new(
642                        TAU,
643                        shares_num,
644                        security_threshold,
645                        &validators,
646                        validator,
647                    )
648                    .unwrap();
649                    let server_aggregate =
650                        dkg.aggregate_transcripts(messages).unwrap();
651                    assert!(server_aggregate
652                        .verify(validators_num, messages)
653                        .unwrap());
654                    server_aggregate
655                        .create_decryption_share_simple(
656                            &dkg,
657                            &ciphertext.header().unwrap(),
658                            AAD,
659                            validator_keypair,
660                        )
661                        .unwrap()
662                })
663                // We only need `security_threshold` shares to be able to decrypt
664                .take(security_threshold as usize)
665                .collect();
666        decryption_shares.shuffle(rng);
667
668        // Now, the decryption share can be used to decrypt the ciphertext
669        // This part is part of the client API
670        let decryption_shares =
671            decryption_shares[..security_threshold as usize].to_vec();
672
673        let shared_secret = combine_shares_simple(&decryption_shares);
674        let plaintext =
675            decrypt_with_shared_secret(&ciphertext, AAD, &shared_secret)
676                .unwrap();
677        assert_eq!(plaintext, MSG);
678
679        // We need `security_threshold` shares to be able to decrypt
680        // So if we remove one share, we should not be able to decrypt
681        let decryption_shares =
682            decryption_shares[..security_threshold as usize - 1].to_vec();
683
684        let shared_secret = combine_shares_simple(&decryption_shares);
685        let result =
686            decrypt_with_shared_secret(&ciphertext, AAD, &shared_secret);
687        assert!(result.is_err());
688    }
689
690    /// Note that the server and client code are using the same underlying
691    /// implementation for aggregation and aggregate verification.
692    /// Here, we focus on testing user-facing APIs for server and client users.
693    #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
694    #[test_case(4, 4; "N is a power of 2, t=N")]
695    #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
696    #[test_case(30, 30; "N is not a power of 2, t=N")]
697    fn server_side_local_verification(
698        shares_num: u32,
699        security_threshold: u32,
700    ) {
701        let rng = &mut StdRng::seed_from_u64(0);
702        let validators_num: u32 = shares_num; // TODO: #197
703        let (messages, validators, _) = make_test_inputs(
704            rng,
705            TAU,
706            security_threshold,
707            shares_num,
708            validators_num,
709        );
710        // We only need `shares_num` transcripts to aggregate
711        let messages = &messages[..shares_num as usize];
712
713        // Now that every validator holds a dkg instance and a transcript for every other validator,
714        // every validator can aggregate the transcripts
715        let me = validators[0].clone();
716        let dkg =
717            Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
718                .unwrap();
719        let good_aggregate = dkg.aggregate_transcripts(messages).unwrap();
720        assert!(good_aggregate.verify(validators_num, messages).is_ok());
721
722        // Test negative cases
723
724        // Notice that the dkg instance is mutable, so we need to get a fresh one
725        // for every test case
726
727        // Should fail if the number of validators is less than the number of messages
728        assert!(matches!(
729            good_aggregate.verify(messages.len() as u32 - 1, messages),
730            Err(Error::InvalidAggregateVerificationParameters(_, _))
731        ));
732
733        // Should fail if no transcripts are provided
734        let dkg =
735            Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
736                .unwrap();
737        assert!(matches!(
738            dkg.aggregate_transcripts(&[]),
739            Err(Error::NoTranscriptsToAggregate)
740        ));
741
742        // Not enough transcripts
743        let dkg =
744            Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
745                .unwrap();
746        let not_enough_messages = &messages[..security_threshold as usize - 1];
747        assert!(not_enough_messages.len() < security_threshold as usize);
748        let insufficient_aggregate =
749            dkg.aggregate_transcripts(not_enough_messages).unwrap();
750        assert!(matches!(
751            insufficient_aggregate.verify(validators_num, messages),
752            Err(Error::InvalidTranscriptAggregate)
753        ));
754
755        // Duplicated transcripts
756        let messages_with_duplicated_transcript = [
757            (
758                validators[security_threshold as usize - 1].clone(),
759                messages[security_threshold as usize - 1].1.clone(),
760            ),
761            (
762                validators[security_threshold as usize - 1].clone(),
763                messages[security_threshold as usize - 2].1.clone(),
764            ),
765        ];
766        assert!(dkg
767            .aggregate_transcripts(&messages_with_duplicated_transcript)
768            .is_err());
769
770        let messages_with_duplicated_transcript = [
771            (
772                validators[security_threshold as usize - 1].clone(),
773                messages[security_threshold as usize - 1].1.clone(),
774            ),
775            (
776                validators[security_threshold as usize - 2].clone(),
777                messages[security_threshold as usize - 1].1.clone(),
778            ),
779        ];
780        assert!(dkg
781            .aggregate_transcripts(&messages_with_duplicated_transcript)
782            .is_err());
783
784        // Unexpected transcripts in the aggregate or transcripts from a different ritual
785        // Using same DKG parameters, but different DKG instances and validators
786        let mut dkg =
787            Dkg::new(TAU, shares_num, security_threshold, &validators, &me)
788                .unwrap();
789        let bad_message = (
790            // Reusing a good validator, but giving them a bad transcript
791            messages[security_threshold as usize - 1].0.clone(),
792            dkg.generate_transcript(rng).unwrap(),
793        );
794        let mixed_messages = [
795            &messages[..(security_threshold - 1) as usize],
796            &[bad_message],
797        ]
798        .concat();
799        assert_eq!(mixed_messages.len(), security_threshold as usize);
800        let bad_aggregate = dkg.aggregate_transcripts(&mixed_messages).unwrap();
801        assert!(matches!(
802            bad_aggregate.verify(validators_num, messages),
803            Err(Error::InvalidTranscriptAggregate)
804        ));
805    }
806
807    #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
808    #[test_case(4, 4; "N is a power of 2, t=N")]
809    #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
810    #[test_case(30, 30; "N is not a power of 2, t=N")]
811    fn client_side_local_verification(
812        shares_num: u32,
813        security_threshold: u32,
814    ) {
815        let rng = &mut StdRng::seed_from_u64(0);
816        let validators_num: u32 = shares_num; // TODO: #197
817        let (messages, _, _) = make_test_inputs(
818            rng,
819            TAU,
820            security_threshold,
821            shares_num,
822            validators_num,
823        );
824
825        // We only need `shares_num` transcripts to aggregate
826        let messages = &messages[..shares_num as usize];
827
828        // Create an aggregated transcript on the client side
829        let good_aggregate = AggregatedTranscript::new(messages).unwrap();
830
831        // We are separating the verification from the aggregation since the client may fetch
832        // the aggregate from a side-channel or decide to persist it and verify it later
833
834        // Now, the client can verify the aggregated transcript
835        let result = good_aggregate.verify(validators_num, messages);
836        assert!(result.is_ok());
837        assert!(result.unwrap());
838
839        // Test negative cases
840
841        // Should fail if the number of validators is less than the number of messages
842        assert!(matches!(
843            good_aggregate.verify(messages.len() as u32 - 1, messages),
844            Err(Error::InvalidAggregateVerificationParameters(_, _))
845        ));
846
847        // Should fail if no transcripts are provided
848        assert!(matches!(
849            AggregatedTranscript::new(&[]),
850            Err(Error::NoTranscriptsToAggregate)
851        ));
852
853        // Not enough transcripts
854        let not_enough_messages = &messages[..security_threshold as usize - 1];
855        assert!(not_enough_messages.len() < security_threshold as usize);
856        let insufficient_aggregate =
857            AggregatedTranscript::new(not_enough_messages).unwrap();
858        let _result = insufficient_aggregate.verify(validators_num, messages);
859        assert!(matches!(
860            insufficient_aggregate.verify(validators_num, messages),
861            Err(Error::InvalidTranscriptAggregate)
862        ));
863
864        // Unexpected transcripts in the aggregate or transcripts from a different ritual
865        // Using same DKG parameters, but different DKG instances and validators
866        let (bad_messages, _, _) = make_test_inputs(
867            rng,
868            TAU,
869            security_threshold,
870            shares_num,
871            validators_num,
872        );
873        let mixed_messages = [&messages[..2], &bad_messages[..1]].concat();
874        let bad_aggregate = AggregatedTranscript::new(&mixed_messages).unwrap();
875        assert!(matches!(
876            bad_aggregate.verify(validators_num, messages),
877            Err(Error::InvalidTranscriptAggregate)
878        ));
879    }
880
881    // TODO: validators_num #197
882    fn make_share_update_test_inputs(
883        shares_num: u32,
884        validators_num: u32,
885        rng: &mut StdRng,
886        security_threshold: u32,
887    ) -> (
888        Vec<ValidatorMessage>,
889        Vec<Validator>,
890        Vec<ValidatorKeypair>,
891        Vec<Dkg>,
892        CiphertextHeader,
893        SharedSecret,
894    ) {
895        let (messages, validators, validator_keypairs) = make_test_inputs(
896            rng,
897            TAU,
898            security_threshold,
899            shares_num,
900            validators_num,
901        );
902        let dkgs = validators
903            .iter()
904            .map(|validator| {
905                Dkg::new(
906                    TAU,
907                    shares_num,
908                    security_threshold,
909                    &validators,
910                    validator,
911                )
912                .unwrap()
913            })
914            .collect::<Vec<_>>();
915
916        // Creating a copy to avoiding accidentally changing DKG state
917        let dkg = dkgs[0].clone();
918        let server_aggregate =
919            dkg.aggregate_transcripts(messages.as_slice()).unwrap();
920        assert!(server_aggregate
921            .verify(validators_num, messages.as_slice())
922            .unwrap());
923
924        // Create an initial shared secret for testing purposes
925        let public_key = server_aggregate.public_key();
926        let ciphertext =
927            encrypt(SecretBox::new(MSG.to_vec()), AAD, &public_key).unwrap();
928        let ciphertext_header = ciphertext.header().unwrap();
929        let transcripts = messages
930            .iter()
931            .map(|(_, transcript)| transcript)
932            .cloned()
933            .collect::<Vec<_>>();
934        let (_, _, old_shared_secret) =
935            crate::test_dkg_full::create_shared_secret_simple_tdec(
936                &dkg.0,
937                AAD,
938                &ciphertext_header.0,
939                validator_keypairs.as_slice(),
940                &transcripts,
941            );
942        (
943            messages,
944            validators,
945            validator_keypairs,
946            dkgs,
947            ciphertext_header,
948            SharedSecret(old_shared_secret),
949        )
950    }
951
952    // FIXME: This test is currently broken, and adjusted to allow compilation
953    // Also, see test cases in other tests that include threshold as a parameter
954    #[ignore = "Re-introduce recovery tests - #193"]
955    #[test_case(4, 4, true; "number of shares (validators) is a power of 2")]
956    #[test_case(7, 7, true; "number of shares (validators) is not a power of 2")]
957    #[test_case(4, 6, true; "number of validators greater than the number of shares")]
958    #[test_case(4, 6, false; "recovery at a specific point")]
959    fn test_dkg_simple_tdec_share_recovery(
960        shares_num: u32,
961        validators_num: u32,
962        _recover_at_random_point: bool,
963    ) {
964        let rng = &mut StdRng::seed_from_u64(0);
965        let security_threshold = shares_num / 2 + 1;
966        let (
967            mut messages,
968            mut validators,
969            mut validator_keypairs,
970            mut dkgs,
971            ciphertext_header,
972            old_shared_secret,
973        ) = make_share_update_test_inputs(
974            shares_num,
975            validators_num,
976            rng,
977            security_threshold,
978        );
979
980        // We assume that all participants have the same aggregate, and that participants created
981        // their own aggregates before the off-boarding of the validator
982        // If we didn't create this aggregate here, we risk having a "dangling validator message"
983        // later when we off-board the validator
984        let aggregated_transcript = dkgs[0]
985            .clone()
986            .aggregate_transcripts(messages.as_slice())
987            .unwrap();
988        assert!(aggregated_transcript
989            .verify(validators_num, messages.as_slice())
990            .unwrap());
991
992        // We need to save this domain point to be user in the recovery testing scenario
993        let mut domain_points = dkgs[0].0.domain_point_map();
994        let _removed_domain_point = domain_points
995            .remove(&validators.last().unwrap().share_index)
996            .unwrap();
997
998        // Remove one participant from the contexts and all nested structure
999        // to simulate off-boarding a validator
1000        messages.pop().unwrap();
1001        dkgs.pop();
1002        validator_keypairs.pop().unwrap();
1003        let _removed_validator = validators.pop().unwrap();
1004
1005        // Now, we're going to recover a new share at a random point or at a specific point
1006        // and check that the shared secret is still the same.
1007        // let _x_r = if recover_at_random_point {
1008        //     // Onboarding a validator with a completely new private key share
1009        //     DomainPoint<E>::rand(rng)
1010        // } else {
1011        //     // Onboarding a validator with a private key share recovered from the removed validator
1012        //     removed_domain_point
1013        // };
1014
1015        // Each participant prepares an update for each other participant
1016        // let share_updates = dkgs
1017        //     .iter()
1018        //     .map(|validator_dkg| {
1019        //         let share_update =
1020        //             ShareRecoveryUpdate::create_recovery_updates(
1021        //                 validator_dkg,
1022        //                 &x_r,
1023        //             )
1024        //             .unwrap();
1025        //         (validator_dkg.me().address.clone(), share_update)
1026        //     })
1027        //     .collect::<HashMap<_, _>>();
1028
1029        // Participants share updates and update their shares
1030
1031        // Now, every participant separately:
1032        // let updated_shares: HashMap<u32, _> = dkgs
1033        //     .iter()
1034        //     .map(|validator_dkg| {
1035        //         // Current participant receives updates from other participants
1036        //         let updates_for_participant: Vec<_> = share_updates
1037        //             .values()
1038        //             .map(|updates| {
1039        //                 updates.get(&validator_dkg.me().share_index).unwrap()
1040        //             })
1041        //             .cloned()
1042        //             .collect();
1043
1044        //         // Each validator uses their decryption key to update their share
1045        //         let validator_keypair = validator_keypairs
1046        //             .get(validator_dkg.me().share_index as usize)
1047        //             .unwrap();
1048
1049        //         // And creates updated private key shares
1050        //         let updated_key_share = aggregated_transcript
1051        //             .get_private_key_share(
1052        //                 validator_keypair,
1053        //                 validator_dkg.me().share_index,
1054        //             )
1055        //             .unwrap()
1056        //             .create_updated_private_key_share_for_recovery(
1057        //                 &updates_for_participant,
1058        //             )
1059        //             .unwrap();
1060        //         (validator_dkg.me().share_index, updated_key_share)
1061        //     })
1062        //     .collect();
1063
1064        // Now, we have to combine new share fragments into a new share
1065        // let recovered_key_share =
1066        // PrivateKeyShare::recover_share_from_updated_private_shares(
1067        //     &x_r,
1068        //     &domain_points,
1069        //     &updated_shares,
1070        // )
1071        // .unwrap();
1072
1073        // Get decryption shares from remaining participants
1074        let mut decryption_shares: Vec<DecryptionShareSimple> =
1075            validator_keypairs
1076                .iter()
1077                .zip_eq(dkgs.iter())
1078                .map(|(validator_keypair, validator_dkg)| {
1079                    aggregated_transcript
1080                        .create_decryption_share_simple(
1081                            validator_dkg,
1082                            &ciphertext_header,
1083                            AAD,
1084                            validator_keypair,
1085                        )
1086                        .unwrap()
1087                })
1088                .collect();
1089        decryption_shares.shuffle(rng);
1090
1091        // In order to test the recovery, we need to create a new decryption share from the recovered
1092        // private key share. To do that, we need a new validator
1093
1094        // Let's create and onboard a new validator
1095        // TODO: Add test scenarios for onboarding and offboarding validators
1096        // let new_validator_keypair = Keypair::random();
1097        // Normally, we would get these from the Coordinator:
1098        // let new_validator_share_index = removed_validator.share_index;
1099        // let new_validator = Validator {
1100        //     address: gen_address(new_validator_share_index as usize),
1101        //     public_key: new_validator_keypair.public_key(),
1102        //     share_index: new_validator_share_index,
1103        // };
1104        // validators.push(new_validator.clone());
1105        // let new_validator_dkg = Dkg::new(
1106        //     TAU,
1107        //     shares_num,
1108        //     security_threshold,
1109        //     &validators,
1110        //     &new_validator,
1111        // )
1112        // .unwrap();
1113
1114        // let new_decryption_share = recovered_key_share
1115        //     .create_decryption_share_simple(
1116        //         &new_validator_dkg,
1117        //         &ciphertext_header,
1118        //         &new_validator_keypair,
1119        //         AAD,
1120        //     )
1121        //     .unwrap();
1122        // decryption_shares.push(new_decryption_share);
1123        // domain_points.insert(new_validator_share_index, x_r);
1124
1125        let domain_points = domain_points
1126            .values()
1127            .take(security_threshold as usize)
1128            .cloned()
1129            .collect::<Vec<_>>();
1130        let decryption_shares =
1131            &decryption_shares[..security_threshold as usize];
1132        assert_eq!(domain_points.len(), security_threshold as usize);
1133        assert_eq!(decryption_shares.len(), security_threshold as usize);
1134
1135        let new_shared_secret = combine_shares_simple(decryption_shares);
1136        assert_ne!(
1137            old_shared_secret, new_shared_secret,
1138            "Shared secret reconstruction failed"
1139        );
1140    }
1141
1142    #[test_case(4, 3; "N is a power of 2, t is 1 + 50%")]
1143    #[test_case(4, 4; "N is a power of 2, t=N")]
1144    #[test_case(30, 16; "N is not a power of 2, t is 1 + 50%")]
1145    #[test_case(30, 30; "N is not a power of 2, t=N")]
1146    fn test_dkg_api_simple_tdec_share_refresh(
1147        shares_num: u32,
1148        security_threshold: u32,
1149    ) {
1150        let rng = &mut StdRng::seed_from_u64(0);
1151        let validators_num: u32 = shares_num; // TODO: #197
1152        let (
1153            messages,
1154            _validators,
1155            validator_keypairs,
1156            dkgs,
1157            ciphertext_header,
1158            old_shared_secret,
1159        ) = make_share_update_test_inputs(
1160            shares_num,
1161            validators_num,
1162            rng,
1163            security_threshold,
1164        );
1165
1166        // When the share refresh protocol is necessary, each participant
1167        // prepares an UpdateTranscript, containing updates for each other.
1168        let mut update_transcripts: HashMap<u32, RefreshTranscript> =
1169            HashMap::new();
1170        let mut validator_map: HashMap<u32, _> = HashMap::new();
1171
1172        for dkg in &dkgs {
1173            for validator in dkg.0.validators.values() {
1174                update_transcripts.insert(
1175                    validator.share_index,
1176                    dkg.generate_refresh_transcript(rng).unwrap(),
1177                );
1178                validator_map.insert(
1179                    validator.share_index,
1180                    validator_keypairs
1181                        .get(validator.share_index as usize)
1182                        .unwrap()
1183                        .public_key(),
1184                );
1185            }
1186        }
1187
1188        // Participants distribute UpdateTranscripts and update their shares
1189        // accordingly. The result is a new AggregatedTranscript.
1190        // In this test, all participants will obtain the same AggregatedTranscript,
1191        // but we're anyway computing it independently for each participant.
1192
1193        // So, every participant separately:
1194        let refreshed_aggregates: Vec<AggregatedTranscript> = dkgs
1195            .iter()
1196            .map(|validator_dkg| {
1197                // Obtain the original aggregate (in the real world, this would be already available)
1198                let aggregate = validator_dkg
1199                    .clone()
1200                    .aggregate_transcripts(messages.as_slice())
1201                    .unwrap();
1202                assert!(aggregate
1203                    .verify(validators_num, messages.as_slice())
1204                    .unwrap());
1205
1206                // Each participant updates their own DKG aggregate
1207                // using the UpdateTranscripts of all participants
1208                aggregate
1209                    .refresh(&update_transcripts, &validator_map)
1210                    .unwrap()
1211            })
1212            .collect();
1213
1214        // TODO: test that refreshed aggregates are all the same
1215
1216        // Participants create decryption shares
1217        let mut decryption_shares: Vec<DecryptionShareSimple> =
1218            validator_keypairs
1219                .iter()
1220                .zip_eq(dkgs.iter())
1221                .map(|(validator_keypair, validator_dkg)| {
1222                    let validator_index =
1223                        validator_dkg.me().share_index as usize;
1224
1225                    let aggregate =
1226                        refreshed_aggregates.get(validator_index).unwrap();
1227
1228                    aggregate
1229                        .create_decryption_share_simple(
1230                            validator_dkg,
1231                            &ciphertext_header,
1232                            AAD,
1233                            validator_keypair,
1234                        )
1235                        .unwrap()
1236                })
1237                // We only need `security_threshold` shares to be able to decrypt
1238                .take(security_threshold as usize)
1239                .collect();
1240        decryption_shares.shuffle(rng);
1241
1242        let decryption_shares =
1243            &decryption_shares[..security_threshold as usize];
1244        assert_eq!(decryption_shares.len(), security_threshold as usize);
1245
1246        let new_shared_secret = combine_shares_simple(decryption_shares);
1247        assert_eq!(
1248            old_shared_secret, new_shared_secret,
1249            "Shared secret reconstruction failed"
1250        );
1251    }
1252}