1use alloc::vec::Vec;
2
3use digest::Digest;
4use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar};
5use generic_ec_zkp::{polynomial::Polynomial, schnorr_pok};
6use rand_core::{CryptoRng, RngCore};
7use round_based::{
8 rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty,
9 Outgoing, ProtocolMessage, SinkExt,
10};
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14use crate::progress::Tracer;
15use crate::{
16 errors::IoError,
17 key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate, VssSetup},
18 security_level::SecurityLevel,
19 utils, ExecutionId,
20};
21
22use super::{Bug, KeygenAborted, KeygenError};
23
24macro_rules! prefixed {
25 ($name:tt) => {
26 concat!("dfns.cggmp21.keygen.threshold.", $name)
27 };
28}
29
30#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
32#[serde(bound = "")]
33pub enum Msg<E: Curve, L: SecurityLevel, D: Digest> {
34 Round1(MsgRound1<D>),
36 Round2Broad(MsgRound2Broad<E, L>),
38 Round2Uni(MsgRound2Uni<E>),
40 Round3(MsgRound3<E>),
42 ReliabilityCheck(MsgReliabilityCheck<D>),
44}
45
46#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
48#[serde(bound = "")]
49#[udigest(bound = "")]
50#[udigest(tag = prefixed!("round1"))]
51pub struct MsgRound1<D: Digest> {
52 #[udigest(as_bytes)]
54 pub commitment: digest::Output<D>,
55}
56#[serde_as]
58#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
59#[serde(bound = "")]
60#[udigest(bound = "")]
61#[udigest(tag = prefixed!("round2_broad"))]
62pub struct MsgRound2Broad<E: Curve, L: SecurityLevel> {
63 #[serde_as(as = "utils::HexOrBin")]
65 #[udigest(as_bytes)]
66 pub rid: L::Rid,
67 pub F: Polynomial<Point<E>>,
69 pub sch_commit: schnorr_pok::Commit<E>,
71 #[cfg(feature = "hd-wallet")]
73 #[serde_as(as = "Option<utils::HexOrBin>")]
74 #[udigest(as = Option<udigest::Bytes>)]
75 pub chain_code: Option<hd_wallet::ChainCode>,
76 #[serde(with = "hex::serde")]
78 #[udigest(as_bytes)]
79 pub decommit: L::Rid,
80}
81#[derive(Clone, Serialize, Deserialize)]
83#[serde(bound = "")]
84pub struct MsgRound2Uni<E: Curve> {
85 pub sigma: Scalar<E>,
87}
88#[derive(Clone, Serialize, Deserialize)]
90#[serde(bound = "")]
91pub struct MsgRound3<E: Curve> {
92 pub sch_proof: schnorr_pok::Proof<E>,
94}
95#[derive(Clone, Serialize, Deserialize)]
97#[serde(bound = "")]
98pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
99
100mod unambiguous {
101 use generic_ec::{Curve, NonZero, Point};
102
103 use crate::{ExecutionId, SecurityLevel};
104
105 #[derive(udigest::Digestable)]
106 #[udigest(tag = prefixed!("hash_commitment"))]
107 #[udigest(bound = "")]
108 pub struct HashCom<'a, E: Curve, L: SecurityLevel> {
109 pub sid: ExecutionId<'a>,
110 pub party_index: u16,
111 pub decommitment: &'a super::MsgRound2Broad<E, L>,
112 }
113
114 #[derive(udigest::Digestable)]
115 #[udigest(tag = prefixed!("schnorr_pok"))]
116 #[udigest(bound = "")]
117 pub struct SchnorrPok<'a, E: Curve> {
118 pub sid: ExecutionId<'a>,
119 pub prover: u16,
120 #[udigest(as_bytes)]
121 pub rid: &'a [u8],
122 pub y: NonZero<Point<E>>,
123 pub h: Point<E>,
124 }
125
126 #[derive(udigest::Digestable)]
127 #[udigest(tag = prefixed!("echo_round"))]
128 #[udigest(bound = "")]
129 pub struct Echo<'a, D: digest::Digest> {
130 pub sid: ExecutionId<'a>,
131 pub commitment: &'a super::MsgRound1<D>,
132 }
133}
134
135pub async fn run_threshold_keygen<E, R, M, L, D>(
136 mut tracer: Option<&mut dyn Tracer>,
137 i: u16,
138 t: u16,
139 n: u16,
140 reliable_broadcast_enforced: bool,
141 sid: ExecutionId<'_>,
142 rng: &mut R,
143 party: M,
144 #[cfg(feature = "hd-wallet")] hd_enabled: bool,
145) -> Result<CoreKeyShare<E>, KeygenError>
146where
147 E: Curve,
148 L: SecurityLevel,
149 D: Digest + Clone + 'static,
150 R: RngCore + CryptoRng,
151 M: Mpc<ProtocolMessage = Msg<E, L, D>>,
152{
153 tracer.protocol_begins();
154
155 tracer.stage("Setup networking");
156 let MpcParty { delivery, .. } = party.into_party();
157 let (incomings, mut outgoings) = delivery.split();
158
159 let mut rounds = RoundsRouter::<Msg<E, L, D>>::builder();
160 let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
161 let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
162 let round2_broad = rounds.add_round(RoundInput::<MsgRound2Broad<E, L>>::broadcast(i, n));
163 let round2_uni = rounds.add_round(RoundInput::<MsgRound2Uni<E>>::p2p(i, n));
164 let round3 = rounds.add_round(RoundInput::<MsgRound3<E>>::broadcast(i, n));
165 let mut rounds = rounds.listen(incomings);
166
167 tracer.round_begins();
169
170 tracer.stage("Sample rid_i, schnorr commitment, polynomial, chain_code");
171 let mut rid = L::Rid::default();
172 rng.fill_bytes(rid.as_mut());
173
174 let (r, h) = schnorr_pok::prover_commits_ephemeral_secret::<E, _>(rng);
175
176 let f = Polynomial::<SecretScalar<E>>::sample(rng, usize::from(t) - 1);
177 let F = &f * &Point::generator();
178 let sigmas = (0..n)
179 .map(|j| {
180 let x = Scalar::from(j + 1);
181 f.value(&x)
182 })
183 .collect::<Vec<_>>();
184 debug_assert_eq!(sigmas.len(), usize::from(n));
185
186 #[cfg(feature = "hd-wallet")]
187 let chain_code_local = if hd_enabled {
188 let mut chain_code = hd_wallet::ChainCode::default();
189 rng.fill_bytes(&mut chain_code);
190 Some(chain_code)
191 } else {
192 None
193 };
194
195 tracer.stage("Commit to public data");
196 let my_decommitment = MsgRound2Broad {
197 rid,
198 F: F.clone(),
199 sch_commit: h,
200 #[cfg(feature = "hd-wallet")]
201 chain_code: chain_code_local,
202 decommit: {
203 let mut nonce = L::Rid::default();
204 rng.fill_bytes(nonce.as_mut());
205 nonce
206 },
207 };
208 let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
209 sid,
210 party_index: i,
211 decommitment: &my_decommitment,
212 });
213
214 tracer.send_msg();
215 let my_commitment = MsgRound1 {
216 commitment: hash_commit,
217 };
218 outgoings
219 .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone())))
220 .await
221 .map_err(IoError::send_message)?;
222 tracer.msg_sent();
223
224 tracer.round_begins();
226
227 tracer.receive_msgs();
228 let commitments = rounds
229 .complete(round1)
230 .await
231 .map_err(IoError::receive_message)?;
232 tracer.msgs_received();
233
234 if reliable_broadcast_enforced {
236 tracer.stage("Hash received msgs (reliability check)");
237 let h_i = udigest::hash_iter::<D>(
238 commitments
239 .iter_including_me(&my_commitment)
240 .map(|commitment| unambiguous::Echo { sid, commitment }),
241 );
242
243 tracer.send_msg();
244 outgoings
245 .send(Outgoing::broadcast(Msg::ReliabilityCheck(
246 MsgReliabilityCheck(h_i.clone()),
247 )))
248 .await
249 .map_err(IoError::send_message)?;
250 tracer.msg_sent();
251
252 tracer.round_begins();
253
254 tracer.receive_msgs();
255 let hashes = rounds
256 .complete(round1_sync)
257 .await
258 .map_err(IoError::receive_message)?;
259 tracer.msgs_received();
260
261 tracer.stage("Assert other parties hashed messages (reliability check)");
262 let parties_have_different_hashes = hashes
263 .into_iter_indexed()
264 .filter(|(_j, _msg_id, h_j)| h_i != h_j.0)
265 .map(|(j, msg_id, _)| (j, msg_id))
266 .collect::<Vec<_>>();
267 if !parties_have_different_hashes.is_empty() {
268 return Err(KeygenAborted::Round1NotReliable(parties_have_different_hashes).into());
269 }
270 }
271
272 tracer.send_msg();
273 outgoings
274 .send(Outgoing::broadcast(Msg::Round2Broad(
275 my_decommitment.clone(),
276 )))
277 .await
278 .map_err(IoError::send_message)?;
279
280 for j in utils::iter_peers(i, n) {
281 let message = MsgRound2Uni {
282 sigma: sigmas[usize::from(j)],
283 };
284 outgoings
285 .send(Outgoing::p2p(j, Msg::Round2Uni(message)))
286 .await
287 .map_err(IoError::send_message)?;
288 }
289 tracer.msg_sent();
290
291 tracer.round_begins();
293
294 tracer.receive_msgs();
295 let decommitments = rounds
296 .complete(round2_broad)
297 .await
298 .map_err(IoError::receive_message)?;
299 let sigmas_msg = rounds
300 .complete(round2_uni)
301 .await
302 .map_err(IoError::receive_message)?;
303 tracer.msgs_received();
304
305 tracer.stage("Validate decommitments");
306 let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| {
307 let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
308 sid,
309 party_index: j,
310 decommitment: decom,
311 });
312 com.commitment != com_expected
313 });
314 if !blame.is_empty() {
315 return Err(KeygenAborted::InvalidDecommitment(blame).into());
316 }
317
318 tracer.stage("Validate data size");
319 let blame = decommitments
320 .iter_indexed()
321 .filter(|(_, _, d)| d.F.degree() + 1 != usize::from(t))
322 .map(|t| t.0)
323 .collect::<Vec<_>>();
324 if !blame.is_empty() {
325 return Err(KeygenAborted::InvalidDataSize { parties: blame }.into());
326 }
327
328 tracer.stage("Validate Feldmann VSS");
329 let blame = decommitments
330 .iter_indexed()
331 .zip(sigmas_msg.iter())
332 .filter(|((_, _, d), s)| {
333 d.F.value::<_, Point<_>>(&Scalar::from(i + 1)) != Point::generator() * s.sigma
334 })
335 .map(|t| t.0 .0)
336 .collect::<Vec<_>>();
337 if !blame.is_empty() {
338 return Err(KeygenAborted::FeldmanVerificationFailed { parties: blame }.into());
339 }
340
341 tracer.stage("Compute rid");
342 let rid = decommitments
343 .iter_including_me(&my_decommitment)
344 .map(|d| &d.rid)
345 .fold(L::Rid::default(), utils::xor_array);
346 #[cfg(feature = "hd-wallet")]
347 let chain_code = if hd_enabled {
348 tracer.stage("Compute chain_code");
349 let blame = utils::collect_simple_blame(&decommitments, |decom| decom.chain_code.is_none());
350 if !blame.is_empty() {
351 return Err(KeygenAborted::MissingChainCode(blame).into());
352 }
353 Some(decommitments.iter_including_me(&my_decommitment).try_fold(
354 hd_wallet::ChainCode::default(),
355 |acc, decom| {
356 Ok::<_, Bug>(utils::xor_array(
357 acc,
358 decom.chain_code.ok_or(Bug::NoChainCode)?,
359 ))
360 },
361 )?)
362 } else {
363 None
364 };
365 tracer.stage("Compute Ys");
366 let polynomial_sum = decommitments
367 .iter_including_me(&my_decommitment)
368 .map(|d| &d.F)
369 .sum::<Polynomial<_>>();
370 let ys = (0..n)
371 .map(|l| polynomial_sum.value(&Scalar::from(l + 1)))
372 .map(|y_j: Point<E>| NonZero::from_point(y_j).ok_or(Bug::ZeroShare))
373 .collect::<Result<Vec<_>, _>>()?;
374 tracer.stage("Compute sigma");
375 let sigma: Scalar<E> = sigmas_msg.iter().map(|msg| msg.sigma).sum();
376 let mut sigma = sigma + sigmas[usize::from(i)];
377 let sigma = NonZero::from_secret_scalar(SecretScalar::new(&mut sigma)).ok_or(Bug::ZeroShare)?;
378 debug_assert_eq!(Point::generator() * &sigma, ys[usize::from(i)]);
379
380 tracer.stage("Calculate challenge");
381 let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
382 sid,
383 prover: i,
384 rid: rid.as_ref(),
385 y: ys[usize::from(i)],
386 h: my_decommitment.sch_commit.0,
387 });
388 let challenge = schnorr_pok::Challenge { nonce: challenge };
389
390 tracer.stage("Prove knowledge of `sigma_i`");
391 let z = schnorr_pok::prove(&r, &challenge, &sigma);
392
393 tracer.send_msg();
394 let my_sch_proof = MsgRound3 { sch_proof: z };
395 outgoings
396 .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone())))
397 .await
398 .map_err(IoError::send_message)?;
399 tracer.msg_sent();
400
401 tracer.round_begins();
403
404 tracer.receive_msgs();
405 let sch_proofs = rounds
406 .complete(round3)
407 .await
408 .map_err(IoError::receive_message)?;
409 tracer.msgs_received();
410
411 tracer.stage("Validate schnorr proofs");
412 let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| {
413 let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
414 sid,
415 prover: j,
416 rid: rid.as_ref(),
417 y: ys[usize::from(j)],
418 h: decom.sch_commit.0,
419 });
420 let challenge = schnorr_pok::Challenge { nonce: challenge };
421 sch_proof
422 .sch_proof
423 .verify(&decom.sch_commit, &challenge, &ys[usize::from(j)])
424 .is_err()
425 });
426 if !blame.is_empty() {
427 return Err(KeygenAborted::InvalidSchnorrProof(blame).into());
428 }
429
430 tracer.stage("Derive resulting public key and other data");
431 let y: Point<E> = decommitments
432 .iter_including_me(&my_decommitment)
433 .map(|d| d.F.coefs()[0])
434 .sum();
435 let key_shares_indexes = (1..=n)
436 .map(|i| NonZero::from_scalar(Scalar::from(i)))
437 .collect::<Option<Vec<_>>>()
438 .ok_or(Bug::NonZeroScalar)?;
439
440 tracer.protocol_ends();
441
442 Ok(DirtyCoreKeyShare {
443 i,
444 key_info: DirtyKeyInfo {
445 curve: Default::default(),
446 shared_public_key: NonZero::from_point(y).ok_or(Bug::ZeroPk)?,
447 public_shares: ys,
448 vss_setup: Some(VssSetup {
449 min_signers: t,
450 I: key_shares_indexes,
451 }),
452 #[cfg(feature = "hd-wallet")]
453 chain_code,
454 },
455 x: sigma,
456 }
457 .validate()
458 .map_err(|err| Bug::InvalidKeyShare(err.into_error()))?)
459}