1use std::collections::BTreeMap;
2use std::fmt;
3
4use zeroize::Zeroize;
5
6use crate::curve::DklsCurve;
7use crate::protocols::dkg::{
8 self, BroadcastDerivationPhase2to4, BroadcastDerivationPhase3to4, KeepInitMulPhase3to4,
9 KeepInitZeroSharePhase2to3, KeepInitZeroSharePhase3to4, ProofCommitment, SessionData,
10 TransmitInitMulPhase3to4, TransmitInitZeroSharePhase2to4, TransmitInitZeroSharePhase3to4,
11 UniqueKeepDerivationPhase2to3,
12};
13use crate::protocols::{Abort, AbortReason, Parameters, Party, PartyIndex, PublicKeyPackage};
14
15pub struct DkgSession<C: DklsCurve> {
16 data: SessionData,
17 poly_point: Option<C::Scalar>,
18 proof_commitment: Option<ProofCommitment<C>>,
19 zero_kept_2to3: Option<BTreeMap<PartyIndex, KeepInitZeroSharePhase2to3>>,
20 bip_kept_2to3: Option<UniqueKeepDerivationPhase2to3>,
21 zero_kept_3to4: Option<BTreeMap<PartyIndex, KeepInitZeroSharePhase3to4>>,
22 mul_kept_3to4: Option<BTreeMap<PartyIndex, KeepInitMulPhase3to4<C>>>,
23}
24
25impl<C: DklsCurve> DkgSession<C> {
26 #[must_use]
27 pub fn new(parameters: Parameters, party_index: PartyIndex, session_id: Vec<u8>) -> Self {
28 DkgSession {
29 data: SessionData {
30 parameters,
31 party_index,
32 session_id,
33 },
34 poly_point: None,
35 proof_commitment: None,
36 zero_kept_2to3: None,
37 bip_kept_2to3: None,
38 zero_kept_3to4: None,
39 mul_kept_3to4: None,
40 }
41 }
42
43 #[must_use]
44 pub fn phase1(&self) -> Vec<C::Scalar> {
45 dkg::phase1::<C>(&self.data)
46 }
47
48 pub fn phase2(
49 &mut self,
50 poly_fragments: &[C::Scalar],
51 ) -> Result<
52 (
53 ProofCommitment<C>,
54 Vec<TransmitInitZeroSharePhase2to4>,
55 BroadcastDerivationPhase2to4,
56 ),
57 Abort,
58 > {
59 if self.poly_point.is_some() {
60 return Err(Abort::recoverable(
61 self.data.party_index,
62 AbortReason::PhaseCalledOutOfOrder {
63 phase: "phase2 already called on this session".into(),
64 },
65 ));
66 }
67
68 let (poly_point, proof_commitment, zero_keep, zero_transmit, bip_keep, bip_broadcast) =
69 dkg::phase2::<C>(&self.data, poly_fragments);
70
71 self.poly_point = Some(poly_point);
72 self.proof_commitment = Some(proof_commitment.clone());
73 self.zero_kept_2to3 = Some(zero_keep);
74 self.bip_kept_2to3 = Some(bip_keep);
75
76 Ok((proof_commitment, zero_transmit, bip_broadcast))
77 }
78
79 pub fn phase3(
80 &mut self,
81 ) -> Result<
82 (
83 Vec<TransmitInitZeroSharePhase3to4>,
84 Vec<TransmitInitMulPhase3to4<C>>,
85 BroadcastDerivationPhase3to4,
86 ),
87 Abort,
88 > {
89 let zero_kept = self.zero_kept_2to3.as_ref().ok_or_else(|| {
90 Abort::recoverable(
91 self.data.party_index,
92 AbortReason::PhaseCalledOutOfOrder {
93 phase: "phase3 called before phase2".into(),
94 },
95 )
96 })?;
97 let bip_kept = self.bip_kept_2to3.as_ref().ok_or_else(|| {
98 Abort::recoverable(
99 self.data.party_index,
100 AbortReason::PhaseCalledOutOfOrder {
101 phase: "phase3 called before phase2".into(),
102 },
103 )
104 })?;
105
106 let (zero_keep_3to4, zero_transmit, mul_keep, mul_transmit, bip_broadcast) =
107 dkg::phase3::<C>(&self.data, zero_kept, bip_kept);
108
109 if let Some(ref mut map) = self.zero_kept_2to3 {
110 for v in map.values_mut() {
111 v.seed.zeroize();
112 v.salt.zeroize();
113 }
114 map.clear();
115 }
116 self.zero_kept_2to3 = None;
117
118 if let Some(ref mut bip) = self.bip_kept_2to3 {
119 bip.aux_chain_code.zeroize();
120 bip.cc_salt.zeroize();
121 }
122 self.bip_kept_2to3 = None;
123 self.zero_kept_3to4 = Some(zero_keep_3to4);
124 self.mul_kept_3to4 = Some(mul_keep);
125
126 Ok((zero_transmit, mul_transmit, bip_broadcast))
127 }
128
129 pub fn phase4(
130 self,
131 proofs_commitments: &[ProofCommitment<C>],
132 zero_received_phase2: &[TransmitInitZeroSharePhase2to4],
133 zero_received_phase3: &[TransmitInitZeroSharePhase3to4],
134 mul_received: &[TransmitInitMulPhase3to4<C>],
135 bip_received_phase2: &BTreeMap<PartyIndex, BroadcastDerivationPhase2to4>,
136 bip_received_phase3: &BTreeMap<PartyIndex, BroadcastDerivationPhase3to4>,
137 address_fn: impl Fn(&C::AffinePoint) -> String,
138 ) -> Result<(Party<C>, PublicKeyPackage<C>), Abort> {
139 let poly_point = self.poly_point.as_ref().ok_or_else(|| {
140 Abort::recoverable(
141 self.data.party_index,
142 AbortReason::PhaseCalledOutOfOrder {
143 phase: "phase4 called before phase2".into(),
144 },
145 )
146 })?;
147 let zero_kept = self.zero_kept_3to4.as_ref().ok_or_else(|| {
148 Abort::recoverable(
149 self.data.party_index,
150 AbortReason::PhaseCalledOutOfOrder {
151 phase: "phase4 called before phase3".into(),
152 },
153 )
154 })?;
155 let mul_kept = self.mul_kept_3to4.as_ref().ok_or_else(|| {
156 Abort::recoverable(
157 self.data.party_index,
158 AbortReason::PhaseCalledOutOfOrder {
159 phase: "phase4 called before phase3".into(),
160 },
161 )
162 })?;
163
164 dkg::phase4::<C>(
165 &self.data,
166 poly_point,
167 proofs_commitments,
168 zero_kept,
169 zero_received_phase2,
170 zero_received_phase3,
171 mul_kept,
172 mul_received,
173 bip_received_phase2,
174 bip_received_phase3,
175 address_fn,
176 )
177 }
178}
179
180impl<C: DklsCurve> fmt::Debug for DkgSession<C> {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 let phase = if self.mul_kept_3to4.is_some() {
183 "phase3 complete"
184 } else if self.poly_point.is_some() {
185 "phase2 complete"
186 } else {
187 "initialized"
188 };
189 f.debug_struct("DkgSession")
190 .field("party_index", &self.data.party_index)
191 .field("threshold", &self.data.parameters.threshold)
192 .field("share_count", &self.data.parameters.share_count)
193 .field("state", &phase)
194 .finish()
195 }
196}
197
198impl<C: DklsCurve> Zeroize for DkgSession<C> {
199 fn zeroize(&mut self) {
200 self.data.session_id.zeroize();
201
202 if let Some(ref mut pp) = self.poly_point {
203 pp.zeroize();
204 }
205 self.poly_point = None;
206 self.proof_commitment = None;
207
208 if let Some(ref mut map) = self.zero_kept_2to3 {
209 for v in map.values_mut() {
210 v.seed.zeroize();
211 v.salt.zeroize();
212 }
213 map.clear();
214 }
215 self.zero_kept_2to3 = None;
216
217 if let Some(ref mut bip) = self.bip_kept_2to3 {
218 bip.aux_chain_code.zeroize();
219 bip.cc_salt.zeroize();
220 }
221 self.bip_kept_2to3 = None;
222
223 if let Some(ref mut map) = self.zero_kept_3to4 {
224 for v in map.values_mut() {
225 v.seed.zeroize();
226 }
227 map.clear();
228 }
229 self.zero_kept_3to4 = None;
230
231 if let Some(ref mut map) = self.mul_kept_3to4 {
232 for v in map.values_mut() {
233 v.ot_sender.s.zeroize();
234 v.ot_receiver.seed.zeroize();
235 v.nonce.zeroize();
236 v.vec_r.zeroize();
237 v.correlation.zeroize();
238 }
239 map.clear();
240 }
241 self.mul_kept_3to4 = None;
242 }
243}
244
245impl<C: DklsCurve> Drop for DkgSession<C> {
246 fn drop(&mut self) {
247 self.zeroize();
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::protocols::AbortReason;
255 use crate::utilities::rng;
256 use k256::Secp256k1;
257 use rand::RngExt;
258
259 const SESSION_ID_LEN: usize = 32;
260
261 #[test]
262 fn test_dkg_session_full_flow() {
263 let threshold = rng::get_rng().random_range(2..=5);
264 let offset = rng::get_rng().random_range(0..=5);
265
266 let parameters = Parameters {
267 threshold,
268 share_count: threshold + offset,
269 };
270 let session_id = rng::get_rng().random::<[u8; SESSION_ID_LEN]>();
271
272 let n = parameters.share_count as usize;
273
274 let mut sessions: Vec<DkgSession<Secp256k1>> = (0..parameters.share_count)
276 .map(|i| {
277 DkgSession::new(
278 parameters.clone(),
279 PartyIndex::new(i + 1).unwrap(),
280 session_id.to_vec(),
281 )
282 })
283 .collect();
284
285 let mut dkg_1: Vec<Vec<k256::Scalar>> = Vec::with_capacity(n);
287 for session in &sessions {
288 dkg_1.push(session.phase1());
289 }
290
291 let mut poly_fragments = vec![Vec::<k256::Scalar>::with_capacity(n); n];
293 for row in dkg_1 {
294 for j in 0..parameters.share_count {
295 poly_fragments[j as usize].push(row[j as usize]);
296 }
297 }
298
299 let mut proofs_commitments: Vec<ProofCommitment<Secp256k1>> = Vec::with_capacity(n);
301 let mut zero_transmit_2to4: Vec<Vec<TransmitInitZeroSharePhase2to4>> =
302 Vec::with_capacity(n);
303 let mut bip_broadcast_2to4: BTreeMap<PartyIndex, BroadcastDerivationPhase2to4> =
304 BTreeMap::new();
305
306 for (i, session) in sessions.iter_mut().enumerate() {
307 let (proof_commitment, zero_transmit, bip_broadcast) =
308 session.phase2(&poly_fragments[i]).unwrap();
309
310 proofs_commitments.push(proof_commitment);
311 zero_transmit_2to4.push(zero_transmit);
312 bip_broadcast_2to4.insert(PartyIndex::new(i as u8 + 1).unwrap(), bip_broadcast);
313 }
314
315 let mut zero_received_2to4: Vec<Vec<TransmitInitZeroSharePhase2to4>> =
317 Vec::with_capacity(n);
318 for i in 1..=parameters.share_count {
319 let pi = PartyIndex::new(i).unwrap();
320 let mut row = Vec::with_capacity(n - 1);
321 for party in &zero_transmit_2to4 {
322 for message in party {
323 if message.parties.receiver == pi {
324 row.push(message.clone());
325 }
326 }
327 }
328 zero_received_2to4.push(row);
329 }
330
331 let mut zero_transmit_3to4: Vec<Vec<TransmitInitZeroSharePhase3to4>> =
333 Vec::with_capacity(n);
334 let mut mul_transmit_3to4: Vec<Vec<TransmitInitMulPhase3to4<Secp256k1>>> =
335 Vec::with_capacity(n);
336 let mut bip_broadcast_3to4: BTreeMap<PartyIndex, BroadcastDerivationPhase3to4> =
337 BTreeMap::new();
338
339 for (i, session) in sessions.iter_mut().enumerate() {
340 let (zero_transmit, mul_transmit, bip_broadcast) = session.phase3().unwrap();
341
342 zero_transmit_3to4.push(zero_transmit);
343 mul_transmit_3to4.push(mul_transmit);
344 bip_broadcast_3to4.insert(PartyIndex::new(i as u8 + 1).unwrap(), bip_broadcast);
345 }
346
347 let mut zero_received_3to4: Vec<Vec<TransmitInitZeroSharePhase3to4>> =
349 Vec::with_capacity(n);
350 let mut mul_received_3to4: Vec<Vec<TransmitInitMulPhase3to4<Secp256k1>>> =
351 Vec::with_capacity(n);
352 for i in 1..=parameters.share_count {
353 let pi = PartyIndex::new(i).unwrap();
354 let mut zero_row = Vec::with_capacity(n - 1);
355 for party in &zero_transmit_3to4 {
356 for message in party {
357 if message.parties.receiver == pi {
358 zero_row.push(message.clone());
359 }
360 }
361 }
362 zero_received_3to4.push(zero_row);
363
364 let mut mul_row = Vec::with_capacity(n - 1);
365 for party in &mul_transmit_3to4 {
366 for message in party {
367 if message.parties.receiver == pi {
368 mul_row.push(message.clone());
369 }
370 }
371 }
372 mul_received_3to4.push(mul_row);
373 }
374
375 let mut parties: Vec<Party<Secp256k1>> = Vec::with_capacity(n);
377 for (i, session) in sessions.into_iter().enumerate() {
378 let (party, _pkg) = session
379 .phase4(
380 &proofs_commitments,
381 &zero_received_2to4[i],
382 &zero_received_3to4[i],
383 &mul_received_3to4[i],
384 &bip_broadcast_2to4,
385 &bip_broadcast_3to4,
386 |_| String::new(),
387 )
388 .unwrap_or_else(|abort| {
389 panic!("Party {} aborted: {:?}", abort.index, abort.description())
390 });
391 parties.push(party);
392 }
393
394 let expected_pk = parties[0].pk;
395 let expected_chain_code = parties[0].derivation_data.chain_code;
396 for party in &parties {
397 assert_eq!(expected_pk, party.pk);
398 assert_eq!(expected_chain_code, party.derivation_data.chain_code);
399 }
400 }
401
402 #[test]
403 fn test_dkg_session_phase_ordering() {
404 let parameters = Parameters {
405 threshold: 2,
406 share_count: 2,
407 };
408 let session_id = rng::get_rng().random::<[u8; SESSION_ID_LEN]>();
409 let pi = PartyIndex::new(1).unwrap();
410
411 let mut session = DkgSession::<Secp256k1>::new(parameters.clone(), pi, session_id.to_vec());
413 let result = session.phase3();
414 assert!(result.is_err());
415 assert!(matches!(
416 result.unwrap_err().reason,
417 AbortReason::PhaseCalledOutOfOrder { ref phase } if phase.contains("phase3 called before phase2")
418 ));
419
420 let session = DkgSession::<Secp256k1>::new(parameters.clone(), pi, session_id.to_vec());
422 let result = session.phase4(
423 &[],
424 &[],
425 &[],
426 &[],
427 &BTreeMap::new(),
428 &BTreeMap::new(),
429 |_| String::new(),
430 );
431 assert!(result.is_err());
432 assert!(matches!(
433 result.unwrap_err().reason,
434 AbortReason::PhaseCalledOutOfOrder { ref phase } if phase.contains("phase4 called before phase2")
435 ));
436
437 let mut session = DkgSession::<Secp256k1>::new(parameters, pi, session_id.to_vec());
439 let fragments = session.phase1();
440 session.phase2(&fragments).unwrap();
441 let result = session.phase4(
442 &[],
443 &[],
444 &[],
445 &[],
446 &BTreeMap::new(),
447 &BTreeMap::new(),
448 |_| String::new(),
449 );
450 assert!(result.is_err());
451 assert!(matches!(
452 result.unwrap_err().reason,
453 AbortReason::PhaseCalledOutOfOrder { ref phase } if phase.contains("phase4 called before phase3")
454 ));
455 }
456
457 #[test]
458 fn test_dkg_session_double_phase2() {
459 let parameters = Parameters {
460 threshold: 2,
461 share_count: 2,
462 };
463 let session_id = rng::get_rng().random::<[u8; SESSION_ID_LEN]>();
464 let pi = PartyIndex::new(1).unwrap();
465
466 let mut session = DkgSession::<Secp256k1>::new(parameters.clone(), pi, session_id.to_vec());
467
468 let fragments = session.phase1();
469 session.phase2(&fragments).unwrap();
470
471 let result = session.phase2(&fragments);
472 assert!(result.is_err());
473 assert!(matches!(
474 result.unwrap_err().reason,
475 AbortReason::PhaseCalledOutOfOrder { .. }
476 ));
477 }
478}