1use alloc::vec::Vec;
2
3use digest::Digest;
4use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar};
5use generic_ec_zkp::schnorr_pok;
6use rand_core::{CryptoRng, RngCore};
7use round_based::{
8 rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty,
9 Outgoing, ProtocolMessage, SinkExt,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::progress::Tracer;
14use crate::{
15 errors::IoError,
16 key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate},
17 security_level::SecurityLevel,
18 utils, ExecutionId,
19};
20
21use super::{Bug, KeygenAborted, KeygenError};
22
23macro_rules! prefixed {
24 ($name:tt) => {
25 concat!("dfns.cggmp21.keygen.non_threshold.", $name)
26 };
27}
28
29#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
31#[serde(bound = "")]
32pub enum Msg<E: Curve, L: SecurityLevel, D: Digest> {
33 Round1(MsgRound1<D>),
35 ReliabilityCheck(MsgReliabilityCheck<D>),
37 Round2(MsgRound2<E, L>),
39 Round3(MsgRound3<E>),
41}
42
43#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
45#[serde(bound = "")]
46#[udigest(bound = "")]
47#[udigest(tag = prefixed!("round1"))]
48pub struct MsgRound1<D: Digest> {
49 #[udigest(as_bytes)]
51 pub commitment: digest::Output<D>,
52}
53#[serde_with::serde_as]
55#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
56#[serde(bound = "")]
57#[udigest(bound = "")]
58#[udigest(tag = prefixed!("round2"))]
59pub struct MsgRound2<E: Curve, L: SecurityLevel> {
60 #[serde_as(as = "utils::HexOrBin")]
62 #[udigest(as_bytes)]
63 pub rid: L::Rid,
64 pub X: NonZero<Point<E>>,
66 pub sch_commit: schnorr_pok::Commit<E>,
68 #[cfg(feature = "hd-wallet")]
70 #[serde_as(as = "Option<utils::HexOrBin>")]
71 #[udigest(as = Option<udigest::Bytes>)]
72 pub chain_code: Option<hd_wallet::ChainCode>,
73 #[serde(with = "hex::serde")]
75 #[udigest(as_bytes)]
76 pub decommit: L::Rid,
77}
78#[derive(Clone, Serialize, Deserialize)]
80#[serde(bound = "")]
81pub struct MsgRound3<E: Curve> {
82 pub sch_proof: schnorr_pok::Proof<E>,
84}
85#[derive(Clone, Serialize, Deserialize)]
87#[serde(bound = "")]
88pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
89
90mod unambiguous {
91 use crate::{ExecutionId, SecurityLevel};
92 use generic_ec::Curve;
93
94 #[derive(udigest::Digestable)]
95 #[udigest(tag = prefixed!("hash_commitment"))]
96 #[udigest(bound = "")]
97 pub struct HashCom<'a, E: Curve, L: SecurityLevel> {
98 pub sid: ExecutionId<'a>,
99 pub party_index: u16,
100 pub decommitment: &'a super::MsgRound2<E, L>,
101 }
102
103 #[derive(udigest::Digestable)]
104 #[udigest(tag = prefixed!("schnorr_pok"))]
105 #[udigest(bound = "")]
106 pub struct SchnorrPok<'a> {
107 pub sid: ExecutionId<'a>,
108 pub prover: u16,
109 #[udigest(as_bytes)]
110 pub rid: &'a [u8],
111 }
112
113 #[derive(udigest::Digestable)]
114 #[udigest(tag = prefixed!("echo_round"))]
115 #[udigest(bound = "")]
116 pub struct Echo<'a, D: digest::Digest> {
117 pub sid: ExecutionId<'a>,
118 pub commitment: &'a super::MsgRound1<D>,
119 }
120}
121
122pub async fn run_keygen<E, R, M, L, D>(
123 mut tracer: Option<&mut dyn Tracer>,
124 i: u16,
125 n: u16,
126 reliable_broadcast_enforced: bool,
127 sid: ExecutionId<'_>,
128 rng: &mut R,
129 party: M,
130 #[cfg(feature = "hd-wallet")] hd_enabled: bool,
131) -> Result<CoreKeyShare<E>, KeygenError>
132where
133 E: Curve,
134 L: SecurityLevel,
135 D: Digest + Clone + 'static,
136 R: RngCore + CryptoRng,
137 M: Mpc<ProtocolMessage = Msg<E, L, D>>,
138{
139 tracer.protocol_begins();
140
141 tracer.stage("Setup networking");
142 let MpcParty { delivery, .. } = party.into_party();
143 let (incomings, mut outgoings) = delivery.split();
144
145 let mut rounds = RoundsRouter::<Msg<E, L, D>>::builder();
146 let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
147 let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
148 let round2 = rounds.add_round(RoundInput::<MsgRound2<E, L>>::broadcast(i, n));
149 let round3 = rounds.add_round(RoundInput::<MsgRound3<E>>::broadcast(i, n));
150 let mut rounds = rounds.listen(incomings);
151
152 tracer.round_begins();
154
155 tracer.stage("Sample x_i, rid_i, chain_code");
156 let x_i = NonZero::<SecretScalar<E>>::random(rng);
157 let X_i = Point::generator() * &x_i;
158
159 let mut rid = L::Rid::default();
160 rng.fill_bytes(rid.as_mut());
161
162 #[cfg(feature = "hd-wallet")]
163 let chain_code_local = if hd_enabled {
164 let mut chain_code = hd_wallet::ChainCode::default();
165 rng.fill_bytes(&mut chain_code);
166 Some(chain_code)
167 } else {
168 None
169 };
170
171 tracer.stage("Sample schnorr commitment");
172 let (sch_secret, sch_commit) = schnorr_pok::prover_commits_ephemeral_secret::<E, _>(rng);
173
174 tracer.stage("Commit to public data");
175 let my_decommitment = MsgRound2 {
176 rid,
177 X: X_i,
178 sch_commit,
179 #[cfg(feature = "hd-wallet")]
180 chain_code: chain_code_local,
181 decommit: {
182 let mut nonce = L::Rid::default();
183 rng.fill_bytes(nonce.as_mut());
184 nonce
185 },
186 };
187 let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
188 sid,
189 party_index: i,
190 decommitment: &my_decommitment,
191 });
192 let my_commitment = MsgRound1 {
193 commitment: hash_commit,
194 };
195
196 tracer.send_msg();
197 outgoings
198 .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone())))
199 .await
200 .map_err(IoError::send_message)?;
201 tracer.msg_sent();
202
203 tracer.round_begins();
205
206 tracer.receive_msgs();
207 let commitments = rounds
208 .complete(round1)
209 .await
210 .map_err(IoError::receive_message)?;
211 tracer.msgs_received();
212
213 if reliable_broadcast_enforced {
215 tracer.stage("Hash received msgs (reliability check)");
216 let h_i = udigest::hash_iter::<D>(
217 commitments
218 .iter_including_me(&my_commitment)
219 .map(|commitment| unambiguous::Echo { sid, commitment }),
220 );
221
222 tracer.send_msg();
223 outgoings
224 .send(Outgoing::broadcast(Msg::ReliabilityCheck(
225 MsgReliabilityCheck(h_i.clone()),
226 )))
227 .await
228 .map_err(IoError::send_message)?;
229 tracer.msg_sent();
230
231 tracer.round_begins();
232
233 tracer.receive_msgs();
234 let round1_hashes = rounds
235 .complete(round1_sync)
236 .await
237 .map_err(IoError::receive_message)?;
238 tracer.msgs_received();
239
240 tracer.stage("Assert other parties hashed messages (reliability check)");
241 let parties_have_different_hashes = round1_hashes
242 .into_iter_indexed()
243 .filter(|(_j, _msg_id, hash_j)| hash_j.0 != h_i)
244 .map(|(j, msg_id, _)| (j, msg_id))
245 .collect::<Vec<_>>();
246 if !parties_have_different_hashes.is_empty() {
247 return Err(KeygenAborted::Round1NotReliable(parties_have_different_hashes).into());
248 }
249 }
250
251 tracer.send_msg();
252 outgoings
253 .send(Outgoing::broadcast(Msg::Round2(my_decommitment.clone())))
254 .await
255 .map_err(IoError::send_message)?;
256 tracer.msg_sent();
257
258 tracer.round_begins();
260
261 tracer.receive_msgs();
262 let decommitments = rounds
263 .complete(round2)
264 .await
265 .map_err(IoError::receive_message)?;
266 tracer.msgs_received();
267
268 tracer.stage("Validate decommitments");
269 let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| {
270 let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
271 sid,
272 party_index: j,
273 decommitment: decom,
274 });
275 com.commitment != com_expected
276 });
277 if !blame.is_empty() {
278 return Err(KeygenAborted::InvalidDecommitment(blame).into());
279 }
280
281 #[cfg(feature = "hd-wallet")]
282 let chain_code = if hd_enabled {
283 tracer.stage("Calculate chain_code");
284 let blame = utils::collect_simple_blame(&decommitments, |decom| decom.chain_code.is_none());
285 if !blame.is_empty() {
286 return Err(KeygenAborted::MissingChainCode(blame).into());
287 }
288 Some(decommitments.iter_including_me(&my_decommitment).try_fold(
289 hd_wallet::ChainCode::default(),
290 |acc, decom| {
291 Ok::<_, Bug>(utils::xor_array(
292 acc,
293 decom.chain_code.ok_or(Bug::NoChainCode)?,
294 ))
295 },
296 )?)
297 } else {
298 None
299 };
300
301 tracer.stage("Calculate challege rid");
302 let rid = decommitments
303 .iter_including_me(&my_decommitment)
304 .map(|d| &d.rid)
305 .fold(L::Rid::default(), utils::xor_array);
306 let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
307 sid,
308 prover: i,
309 rid: rid.as_ref(),
310 });
311 let challenge = schnorr_pok::Challenge { nonce: challenge };
312
313 tracer.stage("Prove knowledge of `x_i`");
314 let sch_proof = schnorr_pok::prove(&sch_secret, &challenge, &x_i);
315
316 tracer.send_msg();
317 let my_sch_proof = MsgRound3 { sch_proof };
318 outgoings
319 .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone())))
320 .await
321 .map_err(IoError::send_message)?;
322 tracer.msg_sent();
323
324 tracer.round_begins();
326
327 tracer.receive_msgs();
328 let sch_proofs = rounds
329 .complete(round3)
330 .await
331 .map_err(IoError::receive_message)?;
332 tracer.msgs_received();
333
334 tracer.stage("Validate schnorr proofs");
335 let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| {
336 let challenge = Scalar::from_hash::<D>(&unambiguous::SchnorrPok {
337 sid,
338 prover: j,
339 rid: rid.as_ref(),
340 });
341 let challenge = schnorr_pok::Challenge { nonce: challenge };
342 sch_proof
343 .sch_proof
344 .verify(&decom.sch_commit, &challenge, &decom.X)
345 .is_err()
346 });
347 if !blame.is_empty() {
348 return Err(KeygenAborted::InvalidSchnorrProof(blame).into());
349 }
350
351 tracer.protocol_ends();
352
353 Ok(DirtyCoreKeyShare {
354 i,
355 key_info: DirtyKeyInfo {
356 curve: Default::default(),
357 shared_public_key: NonZero::from_point(
358 decommitments
359 .iter_including_me(&my_decommitment)
360 .map(|d| d.X)
361 .sum(),
362 )
363 .ok_or(Bug::ZeroPk)?,
364 public_shares: decommitments
365 .iter_including_me(&my_decommitment)
366 .map(|d| d.X)
367 .collect(),
368 vss_setup: None,
369 #[cfg(feature = "hd-wallet")]
370 chain_code,
371 },
372 x: x_i,
373 }
374 .validate()
375 .map_err(|e| Bug::InvalidKeyShare(e.into_error()))?)
376}