cggmp21_keygen/
threshold.rs

1use alloc::vec::Vec;
2
3use digest::Digest;
4use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar};
5use generic_ec_zkp::{polynomial::Polynomial, schnorr_pok};
6use rand_core::{CryptoRng, RngCore};
7use round_based::{
8    rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty,
9    Outgoing, ProtocolMessage, SinkExt,
10};
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14use crate::progress::Tracer;
15use crate::{
16    errors::IoError,
17    key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate, VssSetup},
18    security_level::SecurityLevel,
19    utils, ExecutionId,
20};
21
22use super::{Bug, KeygenAborted, KeygenError};
23
24macro_rules! prefixed {
25    ($name:tt) => {
26        concat!("dfns.cggmp21.keygen.threshold.", $name)
27    };
28}
29
30/// Message of key generation protocol
31#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
32#[serde(bound = "")]
33pub enum Msg<E: Curve, L: SecurityLevel, D: Digest> {
34    /// Round 1 message
35    Round1(MsgRound1<D>),
36    /// Round 2a message
37    Round2Broad(MsgRound2Broad<E, L>),
38    /// Round 2b message
39    Round2Uni(MsgRound2Uni<E>),
40    /// Round 3 message
41    Round3(MsgRound3<E>),
42    /// Reliability check message (optional additional round)
43    ReliabilityCheck(MsgReliabilityCheck<D>),
44}
45
46/// Message from round 1
47#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
48#[serde(bound = "")]
49#[udigest(bound = "")]
50#[udigest(tag = prefixed!("round1"))]
51pub struct MsgRound1<D: Digest> {
52    /// $V_i$
53    #[udigest(as_bytes)]
54    pub commitment: digest::Output<D>,
55}
56/// Message from round 2 broadcasted to everyone
57#[serde_as]
58#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
59#[serde(bound = "")]
60#[udigest(bound = "")]
61#[udigest(tag = prefixed!("round2_broad"))]
62pub struct MsgRound2Broad<E: Curve, L: SecurityLevel> {
63    /// `rid_i`
64    #[serde_as(as = "utils::HexOrBin")]
65    #[udigest(as_bytes)]
66    pub rid: L::Rid,
67    /// $\vec S_i$
68    pub F: Polynomial<Point<E>>,
69    /// $A_i$
70    pub sch_commit: schnorr_pok::Commit<E>,
71    /// Party contribution to chain code
72    #[cfg(feature = "hd-wallet")]
73    #[serde_as(as = "Option<utils::HexOrBin>")]
74    #[udigest(as = Option<udigest::Bytes>)]
75    pub chain_code: Option<hd_wallet::ChainCode>,
76    /// $u_i$
77    #[serde(with = "hex::serde")]
78    #[udigest(as_bytes)]
79    pub decommit: L::Rid,
80}
81/// Message from round 2 unicasted to each party
82#[derive(Clone, Serialize, Deserialize)]
83#[serde(bound = "")]
84pub struct MsgRound2Uni<E: Curve> {
85    /// $\sigma_{i,j}$
86    pub sigma: Scalar<E>,
87}
88/// Message from round 3
89#[derive(Clone, Serialize, Deserialize)]
90#[serde(bound = "")]
91pub struct MsgRound3<E: Curve> {
92    /// $\psi_i$
93    pub sch_proof: schnorr_pok::Proof<E>,
94}
95/// Message parties exchange to ensure reliability of broadcast channel
96#[derive(Clone, Serialize, Deserialize)]
97#[serde(bound = "")]
98pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
99
100mod unambiguous {
101    use generic_ec::{Curve, NonZero, Point};
102
103    use crate::{ExecutionId, SecurityLevel};
104
105    #[derive(udigest::Digestable)]
106    #[udigest(tag = prefixed!("hash_commitment"))]
107    #[udigest(bound = "")]
108    pub struct HashCom<'a, E: Curve, L: SecurityLevel> {
109        pub sid: ExecutionId<'a>,
110        pub party_index: u16,
111        pub decommitment: &'a super::MsgRound2Broad<E, L>,
112    }
113
114    #[derive(udigest::Digestable)]
115    #[udigest(tag = prefixed!("schnorr_pok"))]
116    #[udigest(bound = "")]
117    pub struct SchnorrPok<'a, E: Curve> {
118        pub sid: ExecutionId<'a>,
119        pub prover: u16,
120        #[udigest(as_bytes)]
121        pub rid: &'a [u8],
122        pub y: NonZero<Point<E>>,
123        pub h: Point<E>,
124    }
125
126    #[derive(udigest::Digestable)]
127    #[udigest(tag = prefixed!("echo_round"))]
128    #[udigest(bound = "")]
129    pub struct Echo<'a, D: digest::Digest> {
130        pub sid: ExecutionId<'a>,
131        pub commitment: &'a super::MsgRound1<D>,
132    }
133}
134
135pub async fn run_threshold_keygen<E, R, M, L, D>(
136    mut tracer: Option<&mut dyn Tracer>,
137    i: u16,
138    t: u16,
139    n: u16,
140    reliable_broadcast_enforced: bool,
141    sid: ExecutionId<'_>,
142    rng: &mut R,
143    party: M,
144    #[cfg(feature = "hd-wallet")] hd_enabled: bool,
145) -> Result<CoreKeyShare<E>, KeygenError>
146where
147    E: Curve,
148    L: SecurityLevel,
149    D: Digest + Clone + 'static,
150    R: RngCore + CryptoRng,
151    M: Mpc<ProtocolMessage = Msg<E, L, D>>,
152{
153    tracer.protocol_begins();
154
155    tracer.stage("Setup networking");
156    let MpcParty { delivery, .. } = party.into_party();
157    let (incomings, mut outgoings) = delivery.split();
158
159    let mut rounds = RoundsRouter::<Msg<E, L, D>>::builder();
160    let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
161    let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
162    let round2_broad = rounds.add_round(RoundInput::<MsgRound2Broad<E, L>>::broadcast(i, n));
163    let round2_uni = rounds.add_round(RoundInput::<MsgRound2Uni<E>>::p2p(i, n));
164    let round3 = rounds.add_round(RoundInput::<MsgRound3<E>>::broadcast(i, n));
165    let mut rounds = rounds.listen(incomings);
166
167    // Round 1
168    tracer.round_begins();
169
170    tracer.stage("Sample rid_i, schnorr commitment, polynomial, chain_code");
171    let mut rid = L::Rid::default();
172    rng.fill_bytes(rid.as_mut());
173
174    let (r, h) = schnorr_pok::prover_commits_ephemeral_secret::<E, _>(rng);
175
176    let f = Polynomial::<SecretScalar<E>>::sample(rng, usize::from(t) - 1);
177    let F = &f * &Point::generator();
178    let sigmas = (0..n)
179        .map(|j| {
180            let x = Scalar::from(j + 1);
181            f.value(&x)
182        })
183        .collect::<Vec<_>>();
184    debug_assert_eq!(sigmas.len(), usize::from(n));
185
186    #[cfg(feature = "hd-wallet")]
187    let chain_code_local = if hd_enabled {
188        let mut chain_code = hd_wallet::ChainCode::default();
189        rng.fill_bytes(&mut chain_code);
190        Some(chain_code)
191    } else {
192        None
193    };
194
195    tracer.stage("Commit to public data");
196    let my_decommitment = MsgRound2Broad {
197        rid,
198        F: F.clone(),
199        sch_commit: h,
200        #[cfg(feature = "hd-wallet")]
201        chain_code: chain_code_local,
202        decommit: {
203            let mut nonce = L::Rid::default();
204            rng.fill_bytes(nonce.as_mut());
205            nonce
206        },
207    };
208    let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
209        sid,
210        party_index: i,
211        decommitment: &my_decommitment,
212    });
213
214    tracer.send_msg();
215    let my_commitment = MsgRound1 {
216        commitment: hash_commit,
217    };
218    outgoings
219        .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone())))
220        .await
221        .map_err(IoError::send_message)?;
222    tracer.msg_sent();
223
224    // Round 2
225    tracer.round_begins();
226
227    tracer.receive_msgs();
228    let commitments = rounds
229        .complete(round1)
230        .await
231        .map_err(IoError::receive_message)?;
232    tracer.msgs_received();
233
234    // Optional reliability check
235    if reliable_broadcast_enforced {
236        tracer.stage("Hash received msgs (reliability check)");
237        let h_i = udigest::hash_iter::<D>(
238            commitments
239                .iter_including_me(&my_commitment)
240                .map(|commitment| unambiguous::Echo { sid, commitment }),
241        );
242
243        tracer.send_msg();
244        outgoings
245            .send(Outgoing::broadcast(Msg::ReliabilityCheck(
246                MsgReliabilityCheck(h_i.clone()),
247            )))
248            .await
249            .map_err(IoError::send_message)?;
250        tracer.msg_sent();
251
252        tracer.round_begins();
253
254        tracer.receive_msgs();
255        let hashes = rounds
256            .complete(round1_sync)
257            .await
258            .map_err(IoError::receive_message)?;
259        tracer.msgs_received();
260
261        tracer.stage("Assert other parties hashed messages (reliability check)");
262        let parties_have_different_hashes = hashes
263            .into_iter_indexed()
264            .filter(|(_j, _msg_id, h_j)| h_i != h_j.0)
265            .map(|(j, msg_id, _)| (j, msg_id))
266            .collect::<Vec<_>>();
267        if !parties_have_different_hashes.is_empty() {
268            return Err(KeygenAborted::Round1NotReliable(parties_have_different_hashes).into());
269        }
270    }
271
272    tracer.send_msg();
273    outgoings
274        .send(Outgoing::broadcast(Msg::Round2Broad(
275            my_decommitment.clone(),
276        )))
277        .await
278        .map_err(IoError::send_message)?;
279
280    for j in utils::iter_peers(i, n) {
281        let message = MsgRound2Uni {
282            sigma: sigmas[usize::from(j)],
283        };
284        outgoings
285            .send(Outgoing::p2p(j, Msg::Round2Uni(message)))
286            .await
287            .map_err(IoError::send_message)?;
288    }
289    tracer.msg_sent();
290
291    // Round 3
292    tracer.round_begins();
293
294    tracer.receive_msgs();
295    let decommitments = rounds
296        .complete(round2_broad)
297        .await
298        .map_err(IoError::receive_message)?;
299    let sigmas_msg = rounds
300        .complete(round2_uni)
301        .await
302        .map_err(IoError::receive_message)?;
303    tracer.msgs_received();
304
305    tracer.stage("Validate decommitments");
306    let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| {
307        let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
308            sid,
309            party_index: j,
310            decommitment: decom,
311        });
312        com.commitment != com_expected
313    });
314    if !blame.is_empty() {
315        return Err(KeygenAborted::InvalidDecommitment(blame).into());
316    }
317
318    tracer.stage("Validate data size");
319    let blame = decommitments
320        .iter_indexed()
321        .filter(|(_, _, d)| d.F.degree() + 1 != usize::from(t))
322        .map(|t| t.0)
323        .collect::<Vec<_>>();
324    if !blame.is_empty() {
325        return Err(KeygenAborted::InvalidDataSize { parties: blame }.into());
326    }
327
328    tracer.stage("Validate Feldmann VSS");
329    let blame = decommitments
330        .iter_indexed()
331        .zip(sigmas_msg.iter())
332        .filter(|((_, _, d), s)| {
333            d.F.value::<_, Point<_>>(&Scalar::from(i + 1)) != Point::generator() * s.sigma
334        })
335        .map(|t| t.0 .0)
336        .collect::<Vec<_>>();
337    if !blame.is_empty() {
338        return Err(KeygenAborted::FeldmanVerificationFailed { parties: blame }.into());
339    }
340
341    tracer.stage("Compute rid");
342    let rid = decommitments
343        .iter_including_me(&my_decommitment)
344        .map(|d| &d.rid)
345        .fold(L::Rid::default(), utils::xor_array);
346    #[cfg(feature = "hd-wallet")]
347    let chain_code = if hd_enabled {
348        tracer.stage("Compute chain_code");
349        let blame = utils::collect_simple_blame(&decommitments, |decom| decom.chain_code.is_none());
350        if !blame.is_empty() {
351            return Err(KeygenAborted::MissingChainCode(blame).into());
352        }
353        Some(decommitments.iter_including_me(&my_decommitment).try_fold(
354            hd_wallet::ChainCode::default(),
355            |acc, decom| {
356                Ok::<_, Bug>(utils::xor_array(
357                    acc,
358                    decom.chain_code.ok_or(Bug::NoChainCode)?,
359                ))
360            },
361        )?)
362    } else {
363        None
364    };
365    tracer.stage("Compute Ys");
366    let polynomial_sum = decommitments
367        .iter_including_me(&my_decommitment)
368        .map(|d| &d.F)
369        .sum::<Polynomial<_>>();
370    let ys = (0..n)
371        .map(|l| polynomial_sum.value(&Scalar::from(l + 1)))
372        .map(|y_j: Point<E>| NonZero::from_point(y_j).ok_or(Bug::ZeroShare))
373        .collect::<Result<Vec<_>, _>>()?;
374    tracer.stage("Compute sigma");
375    let sigma: Scalar<E> = sigmas_msg.iter().map(|msg| msg.sigma).sum();
376    let mut sigma = sigma + sigmas[usize::from(i)];
377    let sigma = NonZero::from_secret_scalar(SecretScalar::new(&mut sigma)).ok_or(Bug::ZeroShare)?;
378    debug_assert_eq!(Point::generator() * &sigma, ys[usize::from(i)]);
379
380    tracer.stage("Calculate challenge");
381    let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
382        sid,
383        prover: i,
384        rid: rid.as_ref(),
385        y: ys[usize::from(i)],
386        h: my_decommitment.sch_commit.0,
387    });
388    let challenge = schnorr_pok::Challenge { nonce: challenge };
389
390    tracer.stage("Prove knowledge of `sigma_i`");
391    let z = schnorr_pok::prove(&r, &challenge, &sigma);
392
393    tracer.send_msg();
394    let my_sch_proof = MsgRound3 { sch_proof: z };
395    outgoings
396        .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone())))
397        .await
398        .map_err(IoError::send_message)?;
399    tracer.msg_sent();
400
401    // Output round
402    tracer.round_begins();
403
404    tracer.receive_msgs();
405    let sch_proofs = rounds
406        .complete(round3)
407        .await
408        .map_err(IoError::receive_message)?;
409    tracer.msgs_received();
410
411    tracer.stage("Validate schnorr proofs");
412    let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| {
413        let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
414            sid,
415            prover: j,
416            rid: rid.as_ref(),
417            y: ys[usize::from(j)],
418            h: decom.sch_commit.0,
419        });
420        let challenge = schnorr_pok::Challenge { nonce: challenge };
421        sch_proof
422            .sch_proof
423            .verify(&decom.sch_commit, &challenge, &ys[usize::from(j)])
424            .is_err()
425    });
426    if !blame.is_empty() {
427        return Err(KeygenAborted::InvalidSchnorrProof(blame).into());
428    }
429
430    tracer.stage("Derive resulting public key and other data");
431    let y: Point<E> = decommitments
432        .iter_including_me(&my_decommitment)
433        .map(|d| d.F.coefs()[0])
434        .sum();
435    let key_shares_indexes = (1..=n)
436        .map(|i| NonZero::from_scalar(Scalar::from(i)))
437        .collect::<Option<Vec<_>>>()
438        .ok_or(Bug::NonZeroScalar)?;
439
440    tracer.protocol_ends();
441
442    Ok(DirtyCoreKeyShare {
443        i,
444        key_info: DirtyKeyInfo {
445            curve: Default::default(),
446            shared_public_key: NonZero::from_point(y).ok_or(Bug::ZeroPk)?,
447            public_shares: ys,
448            vss_setup: Some(VssSetup {
449                min_signers: t,
450                I: key_shares_indexes,
451            }),
452            #[cfg(feature = "hd-wallet")]
453            chain_code,
454        },
455        x: sigma,
456    }
457    .validate()
458    .map_err(|err| Bug::InvalidKeyShare(err.into_error()))?)
459}