cggmp21_keygen/
non_threshold.rs

1use alloc::vec::Vec;
2
3use digest::Digest;
4use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar};
5use generic_ec_zkp::schnorr_pok;
6use rand_core::{CryptoRng, RngCore};
7use round_based::{
8    rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty,
9    Outgoing, ProtocolMessage, SinkExt,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::progress::Tracer;
14use crate::{
15    errors::IoError,
16    key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate},
17    security_level::SecurityLevel,
18    utils, ExecutionId,
19};
20
21use super::{Bug, KeygenAborted, KeygenError};
22
23macro_rules! prefixed {
24    ($name:tt) => {
25        concat!("dfns.cggmp21.keygen.non_threshold.", $name)
26    };
27}
28
29/// Message of key generation protocol
30#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
31#[serde(bound = "")]
32pub enum Msg<E: Curve, L: SecurityLevel, D: Digest> {
33    /// Round 1 message
34    Round1(MsgRound1<D>),
35    /// Reliability check message (optional additional round)
36    ReliabilityCheck(MsgReliabilityCheck<D>),
37    /// Round 2 message
38    Round2(MsgRound2<E, L>),
39    /// Round 3 message
40    Round3(MsgRound3<E>),
41}
42
43/// Message from round 1
44#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
45#[serde(bound = "")]
46#[udigest(bound = "")]
47#[udigest(tag = prefixed!("round1"))]
48pub struct MsgRound1<D: Digest> {
49    /// $V_i$
50    #[udigest(as_bytes)]
51    pub commitment: digest::Output<D>,
52}
53/// Message from round 2
54#[serde_with::serde_as]
55#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
56#[serde(bound = "")]
57#[udigest(bound = "")]
58#[udigest(tag = prefixed!("round2"))]
59pub struct MsgRound2<E: Curve, L: SecurityLevel> {
60    /// `rid_i`
61    #[serde_as(as = "utils::HexOrBin")]
62    #[udigest(as_bytes)]
63    pub rid: L::Rid,
64    /// $X_i$
65    pub X: NonZero<Point<E>>,
66    /// $A_i$
67    pub sch_commit: schnorr_pok::Commit<E>,
68    /// Party contribution to chain code
69    #[cfg(feature = "hd-wallet")]
70    #[serde_as(as = "Option<utils::HexOrBin>")]
71    #[udigest(as = Option<udigest::Bytes>)]
72    pub chain_code: Option<hd_wallet::ChainCode>,
73    /// $u_i$
74    #[serde(with = "hex::serde")]
75    #[udigest(as_bytes)]
76    pub decommit: L::Rid,
77}
78/// Message from round 3
79#[derive(Clone, Serialize, Deserialize)]
80#[serde(bound = "")]
81pub struct MsgRound3<E: Curve> {
82    /// $\psi_i$
83    pub sch_proof: schnorr_pok::Proof<E>,
84}
85/// Message parties exchange to ensure reliability of broadcast channel
86#[derive(Clone, Serialize, Deserialize)]
87#[serde(bound = "")]
88pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
89
90mod unambiguous {
91    use crate::{ExecutionId, SecurityLevel};
92    use generic_ec::Curve;
93
94    #[derive(udigest::Digestable)]
95    #[udigest(tag = prefixed!("hash_commitment"))]
96    #[udigest(bound = "")]
97    pub struct HashCom<'a, E: Curve, L: SecurityLevel> {
98        pub sid: ExecutionId<'a>,
99        pub party_index: u16,
100        pub decommitment: &'a super::MsgRound2<E, L>,
101    }
102
103    #[derive(udigest::Digestable)]
104    #[udigest(tag = prefixed!("schnorr_pok"))]
105    #[udigest(bound = "")]
106    pub struct SchnorrPok<'a> {
107        pub sid: ExecutionId<'a>,
108        pub prover: u16,
109        #[udigest(as_bytes)]
110        pub rid: &'a [u8],
111    }
112
113    #[derive(udigest::Digestable)]
114    #[udigest(tag = prefixed!("echo_round"))]
115    #[udigest(bound = "")]
116    pub struct Echo<'a, D: digest::Digest> {
117        pub sid: ExecutionId<'a>,
118        pub commitment: &'a super::MsgRound1<D>,
119    }
120}
121
122pub async fn run_keygen<E, R, M, L, D>(
123    mut tracer: Option<&mut dyn Tracer>,
124    i: u16,
125    n: u16,
126    reliable_broadcast_enforced: bool,
127    sid: ExecutionId<'_>,
128    rng: &mut R,
129    party: M,
130    #[cfg(feature = "hd-wallet")] hd_enabled: bool,
131) -> Result<CoreKeyShare<E>, KeygenError>
132where
133    E: Curve,
134    L: SecurityLevel,
135    D: Digest + Clone + 'static,
136    R: RngCore + CryptoRng,
137    M: Mpc<ProtocolMessage = Msg<E, L, D>>,
138{
139    tracer.protocol_begins();
140
141    tracer.stage("Setup networking");
142    let MpcParty { delivery, .. } = party.into_party();
143    let (incomings, mut outgoings) = delivery.split();
144
145    let mut rounds = RoundsRouter::<Msg<E, L, D>>::builder();
146    let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
147    let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
148    let round2 = rounds.add_round(RoundInput::<MsgRound2<E, L>>::broadcast(i, n));
149    let round3 = rounds.add_round(RoundInput::<MsgRound3<E>>::broadcast(i, n));
150    let mut rounds = rounds.listen(incomings);
151
152    // Round 1
153    tracer.round_begins();
154
155    tracer.stage("Sample x_i, rid_i, chain_code");
156    let x_i = NonZero::<SecretScalar<E>>::random(rng);
157    let X_i = Point::generator() * &x_i;
158
159    let mut rid = L::Rid::default();
160    rng.fill_bytes(rid.as_mut());
161
162    #[cfg(feature = "hd-wallet")]
163    let chain_code_local = if hd_enabled {
164        let mut chain_code = hd_wallet::ChainCode::default();
165        rng.fill_bytes(&mut chain_code);
166        Some(chain_code)
167    } else {
168        None
169    };
170
171    tracer.stage("Sample schnorr commitment");
172    let (sch_secret, sch_commit) = schnorr_pok::prover_commits_ephemeral_secret::<E, _>(rng);
173
174    tracer.stage("Commit to public data");
175    let my_decommitment = MsgRound2 {
176        rid,
177        X: X_i,
178        sch_commit,
179        #[cfg(feature = "hd-wallet")]
180        chain_code: chain_code_local,
181        decommit: {
182            let mut nonce = L::Rid::default();
183            rng.fill_bytes(nonce.as_mut());
184            nonce
185        },
186    };
187    let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
188        sid,
189        party_index: i,
190        decommitment: &my_decommitment,
191    });
192    let my_commitment = MsgRound1 {
193        commitment: hash_commit,
194    };
195
196    tracer.send_msg();
197    outgoings
198        .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone())))
199        .await
200        .map_err(IoError::send_message)?;
201    tracer.msg_sent();
202
203    // Round 2
204    tracer.round_begins();
205
206    tracer.receive_msgs();
207    let commitments = rounds
208        .complete(round1)
209        .await
210        .map_err(IoError::receive_message)?;
211    tracer.msgs_received();
212
213    // Optional reliability check
214    if reliable_broadcast_enforced {
215        tracer.stage("Hash received msgs (reliability check)");
216        let h_i = udigest::hash_iter::<D>(
217            commitments
218                .iter_including_me(&my_commitment)
219                .map(|commitment| unambiguous::Echo { sid, commitment }),
220        );
221
222        tracer.send_msg();
223        outgoings
224            .send(Outgoing::broadcast(Msg::ReliabilityCheck(
225                MsgReliabilityCheck(h_i.clone()),
226            )))
227            .await
228            .map_err(IoError::send_message)?;
229        tracer.msg_sent();
230
231        tracer.round_begins();
232
233        tracer.receive_msgs();
234        let round1_hashes = rounds
235            .complete(round1_sync)
236            .await
237            .map_err(IoError::receive_message)?;
238        tracer.msgs_received();
239
240        tracer.stage("Assert other parties hashed messages (reliability check)");
241        let parties_have_different_hashes = round1_hashes
242            .into_iter_indexed()
243            .filter(|(_j, _msg_id, hash_j)| hash_j.0 != h_i)
244            .map(|(j, msg_id, _)| (j, msg_id))
245            .collect::<Vec<_>>();
246        if !parties_have_different_hashes.is_empty() {
247            return Err(KeygenAborted::Round1NotReliable(parties_have_different_hashes).into());
248        }
249    }
250
251    tracer.send_msg();
252    outgoings
253        .send(Outgoing::broadcast(Msg::Round2(my_decommitment.clone())))
254        .await
255        .map_err(IoError::send_message)?;
256    tracer.msg_sent();
257
258    // Round 3
259    tracer.round_begins();
260
261    tracer.receive_msgs();
262    let decommitments = rounds
263        .complete(round2)
264        .await
265        .map_err(IoError::receive_message)?;
266    tracer.msgs_received();
267
268    tracer.stage("Validate decommitments");
269    let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| {
270        let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
271            sid,
272            party_index: j,
273            decommitment: decom,
274        });
275        com.commitment != com_expected
276    });
277    if !blame.is_empty() {
278        return Err(KeygenAborted::InvalidDecommitment(blame).into());
279    }
280
281    #[cfg(feature = "hd-wallet")]
282    let chain_code = if hd_enabled {
283        tracer.stage("Calculate chain_code");
284        let blame = utils::collect_simple_blame(&decommitments, |decom| decom.chain_code.is_none());
285        if !blame.is_empty() {
286            return Err(KeygenAborted::MissingChainCode(blame).into());
287        }
288        Some(decommitments.iter_including_me(&my_decommitment).try_fold(
289            hd_wallet::ChainCode::default(),
290            |acc, decom| {
291                Ok::<_, Bug>(utils::xor_array(
292                    acc,
293                    decom.chain_code.ok_or(Bug::NoChainCode)?,
294                ))
295            },
296        )?)
297    } else {
298        None
299    };
300
301    tracer.stage("Calculate challege rid");
302    let rid = decommitments
303        .iter_including_me(&my_decommitment)
304        .map(|d| &d.rid)
305        .fold(L::Rid::default(), utils::xor_array);
306    let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
307        sid,
308        prover: i,
309        rid: rid.as_ref(),
310    });
311    let challenge = schnorr_pok::Challenge { nonce: challenge };
312
313    tracer.stage("Prove knowledge of `x_i`");
314    let sch_proof = schnorr_pok::prove(&sch_secret, &challenge, &x_i);
315
316    tracer.send_msg();
317    let my_sch_proof = MsgRound3 { sch_proof };
318    outgoings
319        .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone())))
320        .await
321        .map_err(IoError::send_message)?;
322    tracer.msg_sent();
323
324    // Round 4
325    tracer.round_begins();
326
327    tracer.receive_msgs();
328    let sch_proofs = rounds
329        .complete(round3)
330        .await
331        .map_err(IoError::receive_message)?;
332    tracer.msgs_received();
333
334    tracer.stage("Validate schnorr proofs");
335    let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| {
336        let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
337            sid,
338            prover: j,
339            rid: rid.as_ref(),
340        });
341        let challenge = schnorr_pok::Challenge { nonce: challenge };
342        sch_proof
343            .sch_proof
344            .verify(&decom.sch_commit, &challenge, &decom.X)
345            .is_err()
346    });
347    if !blame.is_empty() {
348        return Err(KeygenAborted::InvalidSchnorrProof(blame).into());
349    }
350
351    tracer.protocol_ends();
352
353    Ok(DirtyCoreKeyShare {
354        i,
355        key_info: DirtyKeyInfo {
356            curve: Default::default(),
357            shared_public_key: NonZero::from_point(
358                decommitments
359                    .iter_including_me(&my_decommitment)
360                    .map(|d| d.X)
361                    .sum(),
362            )
363            .ok_or(Bug::ZeroPk)?,
364            public_shares: decommitments
365                .iter_including_me(&my_decommitment)
366                .map(|d| d.X)
367                .collect(),
368            vss_setup: None,
369            #[cfg(feature = "hd-wallet")]
370            chain_code,
371        },
372        x: x_i,
373    }
374    .validate()
375    .map_err(|e| Bug::InvalidKeyShare(e.into_error()))?)
376}