Skip to main content

dkls23_core/protocols/
dkg_session.rs

1use std::collections::BTreeMap;
2use std::fmt;
3
4use zeroize::Zeroize;
5
6use crate::curve::DklsCurve;
7use crate::protocols::dkg::{
8    self, BroadcastDerivationPhase2to4, BroadcastDerivationPhase3to4, KeepInitMulPhase3to4,
9    KeepInitZeroSharePhase2to3, KeepInitZeroSharePhase3to4, ProofCommitment, SessionData,
10    TransmitInitMulPhase3to4, TransmitInitZeroSharePhase2to4, TransmitInitZeroSharePhase3to4,
11    UniqueKeepDerivationPhase2to3,
12};
13use crate::protocols::{Abort, AbortReason, Parameters, Party, PartyIndex, PublicKeyPackage};
14
15pub struct DkgSession<C: DklsCurve> {
16    data: SessionData,
17    poly_point: Option<C::Scalar>,
18    proof_commitment: Option<ProofCommitment<C>>,
19    zero_kept_2to3: Option<BTreeMap<PartyIndex, KeepInitZeroSharePhase2to3>>,
20    bip_kept_2to3: Option<UniqueKeepDerivationPhase2to3>,
21    zero_kept_3to4: Option<BTreeMap<PartyIndex, KeepInitZeroSharePhase3to4>>,
22    mul_kept_3to4: Option<BTreeMap<PartyIndex, KeepInitMulPhase3to4<C>>>,
23}
24
25impl<C: DklsCurve> DkgSession<C> {
26    #[must_use]
27    pub fn new(parameters: Parameters, party_index: PartyIndex, session_id: Vec<u8>) -> Self {
28        DkgSession {
29            data: SessionData {
30                parameters,
31                party_index,
32                session_id,
33            },
34            poly_point: None,
35            proof_commitment: None,
36            zero_kept_2to3: None,
37            bip_kept_2to3: None,
38            zero_kept_3to4: None,
39            mul_kept_3to4: None,
40        }
41    }
42
43    #[must_use]
44    pub fn phase1(&self) -> Vec<C::Scalar> {
45        dkg::phase1::<C>(&self.data)
46    }
47
48    pub fn phase2(
49        &mut self,
50        poly_fragments: &[C::Scalar],
51    ) -> Result<
52        (
53            ProofCommitment<C>,
54            Vec<TransmitInitZeroSharePhase2to4>,
55            BroadcastDerivationPhase2to4,
56        ),
57        Abort,
58    > {
59        if self.poly_point.is_some() {
60            return Err(Abort::recoverable(
61                self.data.party_index,
62                AbortReason::PhaseCalledOutOfOrder {
63                    phase: "phase2 already called on this session".into(),
64                },
65            ));
66        }
67
68        let (poly_point, proof_commitment, zero_keep, zero_transmit, bip_keep, bip_broadcast) =
69            dkg::phase2::<C>(&self.data, poly_fragments);
70
71        self.poly_point = Some(poly_point);
72        self.proof_commitment = Some(proof_commitment.clone());
73        self.zero_kept_2to3 = Some(zero_keep);
74        self.bip_kept_2to3 = Some(bip_keep);
75
76        Ok((proof_commitment, zero_transmit, bip_broadcast))
77    }
78
79    pub fn phase3(
80        &mut self,
81    ) -> Result<
82        (
83            Vec<TransmitInitZeroSharePhase3to4>,
84            Vec<TransmitInitMulPhase3to4<C>>,
85            BroadcastDerivationPhase3to4,
86        ),
87        Abort,
88    > {
89        let zero_kept = self.zero_kept_2to3.as_ref().ok_or_else(|| {
90            Abort::recoverable(
91                self.data.party_index,
92                AbortReason::PhaseCalledOutOfOrder {
93                    phase: "phase3 called before phase2".into(),
94                },
95            )
96        })?;
97        let bip_kept = self.bip_kept_2to3.as_ref().ok_or_else(|| {
98            Abort::recoverable(
99                self.data.party_index,
100                AbortReason::PhaseCalledOutOfOrder {
101                    phase: "phase3 called before phase2".into(),
102                },
103            )
104        })?;
105
106        let (zero_keep_3to4, zero_transmit, mul_keep, mul_transmit, bip_broadcast) =
107            dkg::phase3::<C>(&self.data, zero_kept, bip_kept);
108
109        if let Some(ref mut map) = self.zero_kept_2to3 {
110            for v in map.values_mut() {
111                v.seed.zeroize();
112                v.salt.zeroize();
113            }
114            map.clear();
115        }
116        self.zero_kept_2to3 = None;
117
118        if let Some(ref mut bip) = self.bip_kept_2to3 {
119            bip.aux_chain_code.zeroize();
120            bip.cc_salt.zeroize();
121        }
122        self.bip_kept_2to3 = None;
123        self.zero_kept_3to4 = Some(zero_keep_3to4);
124        self.mul_kept_3to4 = Some(mul_keep);
125
126        Ok((zero_transmit, mul_transmit, bip_broadcast))
127    }
128
129    pub fn phase4(
130        self,
131        proofs_commitments: &[ProofCommitment<C>],
132        zero_received_phase2: &[TransmitInitZeroSharePhase2to4],
133        zero_received_phase3: &[TransmitInitZeroSharePhase3to4],
134        mul_received: &[TransmitInitMulPhase3to4<C>],
135        bip_received_phase2: &BTreeMap<PartyIndex, BroadcastDerivationPhase2to4>,
136        bip_received_phase3: &BTreeMap<PartyIndex, BroadcastDerivationPhase3to4>,
137        address_fn: impl Fn(&C::AffinePoint) -> String,
138    ) -> Result<(Party<C>, PublicKeyPackage<C>), Abort> {
139        let poly_point = self.poly_point.as_ref().ok_or_else(|| {
140            Abort::recoverable(
141                self.data.party_index,
142                AbortReason::PhaseCalledOutOfOrder {
143                    phase: "phase4 called before phase2".into(),
144                },
145            )
146        })?;
147        let zero_kept = self.zero_kept_3to4.as_ref().ok_or_else(|| {
148            Abort::recoverable(
149                self.data.party_index,
150                AbortReason::PhaseCalledOutOfOrder {
151                    phase: "phase4 called before phase3".into(),
152                },
153            )
154        })?;
155        let mul_kept = self.mul_kept_3to4.as_ref().ok_or_else(|| {
156            Abort::recoverable(
157                self.data.party_index,
158                AbortReason::PhaseCalledOutOfOrder {
159                    phase: "phase4 called before phase3".into(),
160                },
161            )
162        })?;
163
164        dkg::phase4::<C>(
165            &self.data,
166            poly_point,
167            proofs_commitments,
168            zero_kept,
169            zero_received_phase2,
170            zero_received_phase3,
171            mul_kept,
172            mul_received,
173            bip_received_phase2,
174            bip_received_phase3,
175            address_fn,
176        )
177    }
178}
179
180impl<C: DklsCurve> fmt::Debug for DkgSession<C> {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        let phase = if self.mul_kept_3to4.is_some() {
183            "phase3 complete"
184        } else if self.poly_point.is_some() {
185            "phase2 complete"
186        } else {
187            "initialized"
188        };
189        f.debug_struct("DkgSession")
190            .field("party_index", &self.data.party_index)
191            .field("threshold", &self.data.parameters.threshold)
192            .field("share_count", &self.data.parameters.share_count)
193            .field("state", &phase)
194            .finish()
195    }
196}
197
198impl<C: DklsCurve> Zeroize for DkgSession<C> {
199    fn zeroize(&mut self) {
200        self.data.session_id.zeroize();
201
202        if let Some(ref mut pp) = self.poly_point {
203            pp.zeroize();
204        }
205        self.poly_point = None;
206        self.proof_commitment = None;
207
208        if let Some(ref mut map) = self.zero_kept_2to3 {
209            for v in map.values_mut() {
210                v.seed.zeroize();
211                v.salt.zeroize();
212            }
213            map.clear();
214        }
215        self.zero_kept_2to3 = None;
216
217        if let Some(ref mut bip) = self.bip_kept_2to3 {
218            bip.aux_chain_code.zeroize();
219            bip.cc_salt.zeroize();
220        }
221        self.bip_kept_2to3 = None;
222
223        if let Some(ref mut map) = self.zero_kept_3to4 {
224            for v in map.values_mut() {
225                v.seed.zeroize();
226            }
227            map.clear();
228        }
229        self.zero_kept_3to4 = None;
230
231        if let Some(ref mut map) = self.mul_kept_3to4 {
232            for v in map.values_mut() {
233                v.ot_sender.s.zeroize();
234                v.ot_receiver.seed.zeroize();
235                v.nonce.zeroize();
236                v.vec_r.zeroize();
237                v.correlation.zeroize();
238            }
239            map.clear();
240        }
241        self.mul_kept_3to4 = None;
242    }
243}
244
245impl<C: DklsCurve> Drop for DkgSession<C> {
246    fn drop(&mut self) {
247        self.zeroize();
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::protocols::AbortReason;
255    use crate::utilities::rng;
256    use k256::Secp256k1;
257    use rand::RngExt;
258
259    const SESSION_ID_LEN: usize = 32;
260
261    #[test]
262    fn test_dkg_session_full_flow() {
263        let threshold = rng::get_rng().random_range(2..=5);
264        let offset = rng::get_rng().random_range(0..=5);
265
266        let parameters = Parameters {
267            threshold,
268            share_count: threshold + offset,
269        };
270        let session_id = rng::get_rng().random::<[u8; SESSION_ID_LEN]>();
271
272        let n = parameters.share_count as usize;
273
274        // Create sessions for each party.
275        let mut sessions: Vec<DkgSession<Secp256k1>> = (0..parameters.share_count)
276            .map(|i| {
277                DkgSession::new(
278                    parameters.clone(),
279                    PartyIndex::new(i + 1).unwrap(),
280                    session_id.to_vec(),
281                )
282            })
283            .collect();
284
285        // Phase 1
286        let mut dkg_1: Vec<Vec<k256::Scalar>> = Vec::with_capacity(n);
287        for session in &sessions {
288            dkg_1.push(session.phase1());
289        }
290
291        // Communication round 1: transpose poly fragments.
292        let mut poly_fragments = vec![Vec::<k256::Scalar>::with_capacity(n); n];
293        for row in dkg_1 {
294            for j in 0..parameters.share_count {
295                poly_fragments[j as usize].push(row[j as usize]);
296            }
297        }
298
299        // Phase 2
300        let mut proofs_commitments: Vec<ProofCommitment<Secp256k1>> = Vec::with_capacity(n);
301        let mut zero_transmit_2to4: Vec<Vec<TransmitInitZeroSharePhase2to4>> =
302            Vec::with_capacity(n);
303        let mut bip_broadcast_2to4: BTreeMap<PartyIndex, BroadcastDerivationPhase2to4> =
304            BTreeMap::new();
305
306        for (i, session) in sessions.iter_mut().enumerate() {
307            let (proof_commitment, zero_transmit, bip_broadcast) =
308                session.phase2(&poly_fragments[i]).unwrap();
309
310            proofs_commitments.push(proof_commitment);
311            zero_transmit_2to4.push(zero_transmit);
312            bip_broadcast_2to4.insert(PartyIndex::new(i as u8 + 1).unwrap(), bip_broadcast);
313        }
314
315        // Communication round 2: route zero-share messages.
316        let mut zero_received_2to4: Vec<Vec<TransmitInitZeroSharePhase2to4>> =
317            Vec::with_capacity(n);
318        for i in 1..=parameters.share_count {
319            let pi = PartyIndex::new(i).unwrap();
320            let mut row = Vec::with_capacity(n - 1);
321            for party in &zero_transmit_2to4 {
322                for message in party {
323                    if message.parties.receiver == pi {
324                        row.push(message.clone());
325                    }
326                }
327            }
328            zero_received_2to4.push(row);
329        }
330
331        // Phase 3
332        let mut zero_transmit_3to4: Vec<Vec<TransmitInitZeroSharePhase3to4>> =
333            Vec::with_capacity(n);
334        let mut mul_transmit_3to4: Vec<Vec<TransmitInitMulPhase3to4<Secp256k1>>> =
335            Vec::with_capacity(n);
336        let mut bip_broadcast_3to4: BTreeMap<PartyIndex, BroadcastDerivationPhase3to4> =
337            BTreeMap::new();
338
339        for (i, session) in sessions.iter_mut().enumerate() {
340            let (zero_transmit, mul_transmit, bip_broadcast) = session.phase3().unwrap();
341
342            zero_transmit_3to4.push(zero_transmit);
343            mul_transmit_3to4.push(mul_transmit);
344            bip_broadcast_3to4.insert(PartyIndex::new(i as u8 + 1).unwrap(), bip_broadcast);
345        }
346
347        // Communication round 3: route zero-share and mul messages.
348        let mut zero_received_3to4: Vec<Vec<TransmitInitZeroSharePhase3to4>> =
349            Vec::with_capacity(n);
350        let mut mul_received_3to4: Vec<Vec<TransmitInitMulPhase3to4<Secp256k1>>> =
351            Vec::with_capacity(n);
352        for i in 1..=parameters.share_count {
353            let pi = PartyIndex::new(i).unwrap();
354            let mut zero_row = Vec::with_capacity(n - 1);
355            for party in &zero_transmit_3to4 {
356                for message in party {
357                    if message.parties.receiver == pi {
358                        zero_row.push(message.clone());
359                    }
360                }
361            }
362            zero_received_3to4.push(zero_row);
363
364            let mut mul_row = Vec::with_capacity(n - 1);
365            for party in &mul_transmit_3to4 {
366                for message in party {
367                    if message.parties.receiver == pi {
368                        mul_row.push(message.clone());
369                    }
370                }
371            }
372            mul_received_3to4.push(mul_row);
373        }
374
375        // Phase 4
376        let mut parties: Vec<Party<Secp256k1>> = Vec::with_capacity(n);
377        for (i, session) in sessions.into_iter().enumerate() {
378            let (party, _pkg) = session
379                .phase4(
380                    &proofs_commitments,
381                    &zero_received_2to4[i],
382                    &zero_received_3to4[i],
383                    &mul_received_3to4[i],
384                    &bip_broadcast_2to4,
385                    &bip_broadcast_3to4,
386                    |_| String::new(),
387                )
388                .unwrap_or_else(|abort| {
389                    panic!("Party {} aborted: {:?}", abort.index, abort.description())
390                });
391            parties.push(party);
392        }
393
394        let expected_pk = parties[0].pk;
395        let expected_chain_code = parties[0].derivation_data.chain_code;
396        for party in &parties {
397            assert_eq!(expected_pk, party.pk);
398            assert_eq!(expected_chain_code, party.derivation_data.chain_code);
399        }
400    }
401
402    #[test]
403    fn test_dkg_session_phase_ordering() {
404        let parameters = Parameters {
405            threshold: 2,
406            share_count: 2,
407        };
408        let session_id = rng::get_rng().random::<[u8; SESSION_ID_LEN]>();
409        let pi = PartyIndex::new(1).unwrap();
410
411        // phase3 before phase2
412        let mut session = DkgSession::<Secp256k1>::new(parameters.clone(), pi, session_id.to_vec());
413        let result = session.phase3();
414        assert!(result.is_err());
415        assert!(matches!(
416            result.unwrap_err().reason,
417            AbortReason::PhaseCalledOutOfOrder { ref phase } if phase.contains("phase3 called before phase2")
418        ));
419
420        // phase4 before phase2
421        let session = DkgSession::<Secp256k1>::new(parameters.clone(), pi, session_id.to_vec());
422        let result = session.phase4(
423            &[],
424            &[],
425            &[],
426            &[],
427            &BTreeMap::new(),
428            &BTreeMap::new(),
429            |_| String::new(),
430        );
431        assert!(result.is_err());
432        assert!(matches!(
433            result.unwrap_err().reason,
434            AbortReason::PhaseCalledOutOfOrder { ref phase } if phase.contains("phase4 called before phase2")
435        ));
436
437        // phase4 after phase2 but before phase3
438        let mut session = DkgSession::<Secp256k1>::new(parameters, pi, session_id.to_vec());
439        let fragments = session.phase1();
440        session.phase2(&fragments).unwrap();
441        let result = session.phase4(
442            &[],
443            &[],
444            &[],
445            &[],
446            &BTreeMap::new(),
447            &BTreeMap::new(),
448            |_| String::new(),
449        );
450        assert!(result.is_err());
451        assert!(matches!(
452            result.unwrap_err().reason,
453            AbortReason::PhaseCalledOutOfOrder { ref phase } if phase.contains("phase4 called before phase3")
454        ));
455    }
456
457    #[test]
458    fn test_dkg_session_double_phase2() {
459        let parameters = Parameters {
460            threshold: 2,
461            share_count: 2,
462        };
463        let session_id = rng::get_rng().random::<[u8; SESSION_ID_LEN]>();
464        let pi = PartyIndex::new(1).unwrap();
465
466        let mut session = DkgSession::<Secp256k1>::new(parameters.clone(), pi, session_id.to_vec());
467
468        let fragments = session.phase1();
469        session.phase2(&fragments).unwrap();
470
471        let result = session.phase2(&fragments);
472        assert!(result.is_err());
473        assert!(matches!(
474            result.unwrap_err().reason,
475            AbortReason::PhaseCalledOutOfOrder { .. }
476        ));
477    }
478}