Skip to main content

dkls23_core/
protocols.rs

1//! `DKLs23` main protocols and related ones.
2//!
3//! Some structs appearing in most of the protocols are defined here.
4use std::collections::BTreeMap;
5use std::fmt;
6use std::marker::PhantomData;
7
8use zeroize::Zeroize;
9
10use crate::curve::DklsCurve;
11use crate::protocols::derivation::DerivData;
12use crate::utilities::multiplication::{MulReceiver, MulSender};
13use crate::utilities::zero_shares::ZeroShare;
14
15pub mod derivation;
16pub mod dkg;
17pub mod dkg_session;
18#[cfg(feature = "serde")]
19pub mod messages;
20pub mod re_key;
21pub mod refresh;
22pub mod sign_session;
23pub mod signature;
24pub mod signing;
25
26/// Error returned when attempting to construct a `PartyIndex` from `0`.
27#[derive(Debug, Clone)]
28pub struct InvalidPartyIndex;
29
30impl fmt::Display for InvalidPartyIndex {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "party index must be > 0")
33    }
34}
35
36/// Error returned when constructing [`Parameters`] with invalid values.
37#[derive(Debug, Clone)]
38pub struct InvalidParameters;
39
40impl fmt::Display for InvalidParameters {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        write!(f, "parameters must satisfy 1 < threshold <= share_count")
43    }
44}
45
46impl std::error::Error for InvalidParameters {}
47
48impl std::error::Error for InvalidPartyIndex {}
49
50/// Strongly-typed 1-based participant identifier.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroize)]
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[repr(transparent)]
54#[cfg_attr(feature = "serde", serde(try_from = "u8", into = "u8"))]
55pub struct PartyIndex(u8);
56
57impl PartyIndex {
58    pub fn new(value: u8) -> Result<Self, InvalidPartyIndex> {
59        if value == 0 {
60            Err(InvalidPartyIndex)
61        } else {
62            Ok(Self(value))
63        }
64    }
65
66    #[must_use]
67    pub fn as_u8(&self) -> u8 {
68        self.0
69    }
70}
71
72impl TryFrom<u8> for PartyIndex {
73    type Error = InvalidPartyIndex;
74    fn try_from(value: u8) -> Result<Self, Self::Error> {
75        Self::new(value)
76    }
77}
78
79impl From<PartyIndex> for u8 {
80    fn from(pi: PartyIndex) -> Self {
81        pi.0
82    }
83}
84
85impl fmt::Display for PartyIndex {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        write!(f, "{}", self.0)
88    }
89}
90
91/// Contains the values `t` and  `n` from `DKLs23`.
92#[derive(Clone, Debug)]
93#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94pub struct Parameters {
95    pub threshold: u8,   //t
96    pub share_count: u8, //n
97}
98
99impl Parameters {
100    /// Creates validated parameters.
101    ///
102    /// Requires `1 < threshold <= share_count`.
103    pub fn new(threshold: u8, share_count: u8) -> Result<Self, InvalidParameters> {
104        if threshold < 2 || threshold > share_count {
105            return Err(InvalidParameters);
106        }
107        Ok(Self {
108            threshold,
109            share_count,
110        })
111    }
112}
113
114/// Represents a party after key generation ready to sign a message.
115#[derive(Clone, Debug)]
116#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
117#[cfg_attr(
118    feature = "serde",
119    serde(bound(
120        serialize = "C::AffinePoint: serde::Serialize, C::Scalar: serde::Serialize",
121        deserialize = "C::AffinePoint: serde::Deserialize<'de>, C::Scalar: serde::Deserialize<'de>"
122    ))
123)]
124pub struct Party<C: DklsCurve> {
125    pub parameters: Parameters,
126    pub party_index: PartyIndex,
127    pub session_id: Vec<u8>,
128
129    /// Behaves as the secret key share.
130    pub poly_point: C::Scalar,
131    /// Public key.
132    pub pk: C::AffinePoint,
133
134    /// Used for computing shares of zero during signing.
135    pub zero_share: ZeroShare,
136
137    /// Initializations for two-party multiplication.
138    /// The key in the `BTreeMap` represents the other party.
139    pub mul_senders: BTreeMap<PartyIndex, MulSender<C>>,
140    pub mul_receivers: BTreeMap<PartyIndex, MulReceiver<C>>,
141
142    /// Data for BIP-32 derivation.
143    pub derivation_data: DerivData<C>,
144
145    /// Address calculated from the public key.
146    pub address: String,
147}
148
149impl<C: DklsCurve> Zeroize for Party<C> {
150    fn zeroize(&mut self) {
151        // `parameters`, `party_index`, and `pk` are public values — not zeroized.
152        self.session_id.zeroize();
153        self.poly_point.zeroize();
154        self.zero_share.zeroize();
155        // Zeroize each value in the BTreeMaps, then clear the maps.
156        for sender in self.mul_senders.values_mut() {
157            sender.zeroize();
158        }
159        self.mul_senders.clear();
160        for receiver in self.mul_receivers.values_mut() {
161            receiver.zeroize();
162        }
163        self.mul_receivers.clear();
164        self.derivation_data.zeroize();
165        self.address.zeroize();
166    }
167}
168
169impl<C: DklsCurve> Drop for Party<C> {
170    fn drop(&mut self) {
171        self.zeroize();
172    }
173}
174
175/// Aggregates the group public key, per-participant verification shares,
176/// and threshold parameters produced by DKG or re-key.
177#[derive(Debug, Clone)]
178#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
179#[cfg_attr(
180    feature = "serde",
181    serde(bound(
182        serialize = "C::AffinePoint: serde::Serialize, C::Scalar: serde::Serialize",
183        deserialize = "C::AffinePoint: serde::Deserialize<'de>, C::Scalar: serde::Deserialize<'de>"
184    ))
185)]
186pub struct PublicKeyPackage<C: DklsCurve> {
187    verifying_key: C::AffinePoint,
188    verifying_shares: BTreeMap<PartyIndex, C::AffinePoint>,
189    parameters: Parameters,
190    #[cfg_attr(feature = "serde", serde(skip))]
191    _curve: PhantomData<C>,
192}
193
194impl<C: DklsCurve> PublicKeyPackage<C> {
195    #[must_use]
196    pub fn new(
197        verifying_key: C::AffinePoint,
198        verifying_shares: BTreeMap<PartyIndex, C::AffinePoint>,
199        parameters: Parameters,
200    ) -> Self {
201        Self {
202            verifying_key,
203            verifying_shares,
204            parameters,
205            _curve: PhantomData,
206        }
207    }
208
209    #[must_use]
210    pub fn verifying_key(&self) -> &C::AffinePoint {
211        &self.verifying_key
212    }
213
214    #[must_use]
215    pub fn verifying_share(&self, party: PartyIndex) -> Option<&C::AffinePoint> {
216        self.verifying_shares.get(&party)
217    }
218
219    #[must_use]
220    pub fn threshold(&self) -> u8 {
221        self.parameters.threshold
222    }
223
224    #[must_use]
225    pub fn share_count(&self) -> u8 {
226        self.parameters.share_count
227    }
228
229    #[must_use]
230    pub fn verify_share(&self, party: PartyIndex, verification_share: &C::AffinePoint) -> bool {
231        self.verifying_shares
232            .get(&party)
233            .is_some_and(|stored| stored == verification_share)
234    }
235}
236
237/// Classifies the severity and required response for an abort.
238///
239/// The `DKLs23` protocol reuses base OT correlations across signing sessions.
240/// If the COTe consistency check fails, information about this reused state
241/// is leaked. The paper mandates that the offending counterparty must be
242/// **permanently banned** from all future sessions. Failure to do so
243/// enables a key extraction attack over multiple sessions.
244#[derive(Debug, Clone, PartialEq, Eq)]
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246pub enum AbortKind {
247    /// The protocol failed but can be retried safely.
248    /// No long-term state was compromised.
249    Recoverable,
250    /// The identified counterparty cheated in a way that leaks information
251    /// about reusable OT state. This party **MUST** be permanently excluded
252    /// from all future signing and refresh sessions. Continuing to interact
253    /// with this party enables private key extraction.
254    BanCounterparty(PartyIndex),
255}
256
257/// Machine-readable reason for a protocol abort.
258#[derive(Debug, Clone, PartialEq, Eq)]
259#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
260#[non_exhaustive]
261pub enum AbortReason {
262    // --- Input validation (all Recoverable) ---
263    InvalidPartyIndex {
264        index: PartyIndex,
265    },
266    WrongCounterpartyCount {
267        expected: usize,
268        got: usize,
269    },
270    DuplicateCounterparty {
271        index: PartyIndex,
272    },
273    SelfInCounterparties,
274    MissingMulState {
275        counterparty: PartyIndex,
276    },
277
278    // --- Message routing (all Recoverable) ---
279    MisroutedMessage {
280        expected_receiver: PartyIndex,
281        actual_receiver: PartyIndex,
282    },
283    UnexpectedSender {
284        sender: PartyIndex,
285    },
286    DuplicateSender {
287        sender: PartyIndex,
288    },
289    WrongMessageCount {
290        expected: usize,
291        got: usize,
292    },
293    MissingMessageFromParty {
294        party: PartyIndex,
295    },
296
297    // --- Cryptographic verification (severity varies) ---
298    ProofVerificationFailed {
299        counterparty: PartyIndex,
300    },
301    CommitmentMismatch {
302        counterparty: PartyIndex,
303    },
304    PolynomialInconsistency,
305    TrivialInstancePoint {
306        counterparty: PartyIndex,
307    },
308    TrivialPublicKey,
309    TrivialKeyShare,
310    MissingCommittedPoint {
311        party: PartyIndex,
312    },
313
314    // --- OT/Multiplication failures (typically BanCounterparty) ---
315    OtConsistencyCheckFailed {
316        counterparty: PartyIndex,
317    },
318    MultiplicationVerificationFailed {
319        counterparty: PartyIndex,
320        detail: String,
321    },
322    GammaUInconsistency {
323        counterparty: PartyIndex,
324    },
325
326    // --- Signature assembly ---
327    SignatureVerificationFailed,
328    ZeroDenominator,
329    LagrangeCoefficientFailed,
330    InvalidXCoordinateHex,
331
332    // --- Zero-share initialization ---
333    ZeroShareDecommitFailed {
334        counterparty: PartyIndex,
335    },
336
337    // --- Chain code / BIP derivation ---
338    ChainCodeCommitmentFailed {
339        party: PartyIndex,
340    },
341
342    // --- Session state machine ---
343    PhaseCalledOutOfOrder {
344        phase: String,
345    },
346
347    // --- Hex parsing ---
348    InvalidHex {
349        detail: String,
350    },
351}
352
353impl fmt::Display for AbortReason {
354    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355        match self {
356            Self::InvalidPartyIndex { index } => {
357                write!(f, "party index {index} is out of valid range")
358            }
359            Self::WrongCounterpartyCount { expected, got } => {
360                write!(
361                    f,
362                    "wrong counterparty count: expected {expected}, got {got}"
363                )
364            }
365            Self::DuplicateCounterparty { index } => {
366                write!(f, "duplicate counterparty: {index}")
367            }
368            Self::SelfInCounterparties => write!(f, "own index in counterparty list"),
369            Self::MissingMulState { counterparty } => {
370                write!(f, "missing multiplication state for party {counterparty}")
371            }
372            Self::MisroutedMessage {
373                expected_receiver,
374                actual_receiver,
375            } => write!(
376                f,
377                "message addressed to {actual_receiver}, expected {expected_receiver}"
378            ),
379            Self::UnexpectedSender { sender } => write!(f, "unexpected sender: {sender}"),
380            Self::DuplicateSender { sender } => {
381                write!(f, "duplicate message from party {sender}")
382            }
383            Self::WrongMessageCount { expected, got } => {
384                write!(f, "wrong message count: expected {expected}, got {got}")
385            }
386            Self::MissingMessageFromParty { party } => {
387                write!(f, "missing message from party {party}")
388            }
389            Self::ProofVerificationFailed { counterparty } => {
390                write!(f, "proof verification failed for party {counterparty}")
391            }
392            Self::CommitmentMismatch { counterparty } => {
393                write!(f, "commitment mismatch for party {counterparty}")
394            }
395            Self::PolynomialInconsistency => write!(f, "polynomial inconsistency"),
396            Self::TrivialInstancePoint { counterparty } => {
397                write!(f, "trivial instance point from party {counterparty}")
398            }
399            Self::TrivialPublicKey => write!(f, "trivial public key"),
400            Self::TrivialKeyShare => write!(f, "trivial key share"),
401            Self::MissingCommittedPoint { party } => {
402                write!(f, "missing committed point for party {party}")
403            }
404            Self::OtConsistencyCheckFailed { counterparty } => {
405                write!(f, "OT consistency check failed for party {counterparty}")
406            }
407            Self::MultiplicationVerificationFailed {
408                counterparty,
409                detail,
410            } => {
411                write!(
412                    f,
413                    "multiplication verification failed for party {counterparty}: {detail}"
414                )
415            }
416            Self::GammaUInconsistency { counterparty } => {
417                write!(f, "gamma-u inconsistency for party {counterparty}")
418            }
419            Self::SignatureVerificationFailed => write!(f, "signature verification failed"),
420            Self::ZeroDenominator => write!(f, "zero denominator in signature assembly"),
421            Self::LagrangeCoefficientFailed => {
422                write!(f, "failed to compute Lagrange coefficient")
423            }
424            Self::InvalidXCoordinateHex => write!(f, "invalid x-coordinate hex"),
425            Self::ZeroShareDecommitFailed { counterparty } => {
426                write!(f, "zero-share decommitment failed for party {counterparty}")
427            }
428            Self::ChainCodeCommitmentFailed { party } => {
429                write!(f, "chain code commitment failed for party {party}")
430            }
431            Self::PhaseCalledOutOfOrder { phase } => {
432                write!(f, "{phase}")
433            }
434            Self::InvalidHex { detail } => write!(f, "invalid hex: {detail}"),
435        }
436    }
437}
438
439#[derive(Debug, Clone, PartialEq, Eq)]
440#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
441pub struct Abort {
442    /// Index of the party generating the abort message.
443    pub index: PartyIndex,
444    /// Indicates whether the abort requires permanently banning a counterparty.
445    pub kind: AbortKind,
446    /// Machine-readable reason for the abort.
447    pub reason: AbortReason,
448}
449
450impl Abort {
451    /// Creates a recoverable `Abort`.
452    #[must_use]
453    pub fn recoverable(index: PartyIndex, reason: AbortReason) -> Abort {
454        Abort {
455            index,
456            kind: AbortKind::Recoverable,
457            reason,
458        }
459    }
460
461    /// Creates an `Abort` that requires permanently banning a counterparty.
462    ///
463    /// This MUST be used when the COTe consistency check or the multiplication
464    /// protocol's verification step fails. The counterparty identified here
465    /// has either cheated or been compromised, and continuing to sign with
466    /// them leaks information enabling key extraction.
467    #[must_use]
468    pub fn ban(index: PartyIndex, counterparty: PartyIndex, reason: AbortReason) -> Abort {
469        Abort {
470            index,
471            kind: AbortKind::BanCounterparty(counterparty),
472            reason,
473        }
474    }
475
476    /// Human-readable description for logging/debugging.
477    #[must_use]
478    pub fn description(&self) -> String {
479        self.reason.to_string()
480    }
481}
482
483/// Saves the sender and receiver of a message.
484#[derive(Clone, Debug, Zeroize)]
485#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
486pub struct PartiesMessage {
487    pub sender: PartyIndex,
488    pub receiver: PartyIndex,
489}
490
491impl PartiesMessage {
492    /// Swaps the sender with the receiver, returning another instance of `PartiesMessage`.
493    #[must_use]
494    pub fn reverse(&self) -> PartiesMessage {
495        PartiesMessage {
496            sender: self.receiver,
497            receiver: self.sender,
498        }
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use std::collections::BTreeMap;
506
507    #[test]
508    fn party_index_rejects_zero() {
509        assert!(PartyIndex::new(0).is_err());
510        assert!(PartyIndex::try_from(0u8).is_err());
511    }
512
513    #[test]
514    fn party_index_accepts_nonzero() {
515        for i in 1..=u8::MAX {
516            assert!(PartyIndex::new(i).is_ok());
517        }
518    }
519
520    #[test]
521    fn party_index_round_trip() {
522        for i in 1..=u8::MAX {
523            let pi = PartyIndex::new(i).unwrap();
524            assert_eq!(pi.as_u8(), i);
525            assert_eq!(u8::from(pi), i);
526        }
527    }
528
529    #[test]
530    fn party_index_serde_json_transparent() {
531        let pi = PartyIndex::new(5).unwrap();
532        let json = serde_json::to_string(&pi).unwrap();
533        assert_eq!(json, "5");
534
535        let deserialized: PartyIndex = serde_json::from_str(&json).unwrap();
536        assert_eq!(deserialized, pi);
537    }
538
539    #[test]
540    fn party_index_serde_rejects_zero() {
541        let result: Result<PartyIndex, _> = serde_json::from_str("0");
542        assert!(result.is_err());
543    }
544
545    #[test]
546    fn party_index_btreemap_ordering() {
547        let mut map = BTreeMap::new();
548        map.insert(PartyIndex::new(3).unwrap(), "c");
549        map.insert(PartyIndex::new(1).unwrap(), "a");
550        map.insert(PartyIndex::new(2).unwrap(), "b");
551
552        let keys: Vec<u8> = map.keys().map(|k| k.as_u8()).collect();
553        assert_eq!(keys, vec![1, 2, 3]);
554    }
555
556    #[test]
557    fn parameters_new_valid() {
558        assert!(Parameters::new(2, 3).is_ok());
559        assert!(Parameters::new(2, 2).is_ok());
560        assert!(Parameters::new(5, 10).is_ok());
561    }
562
563    #[test]
564    fn parameters_new_rejects_invalid() {
565        // threshold < 2
566        assert!(Parameters::new(0, 3).is_err());
567        assert!(Parameters::new(1, 3).is_err());
568        // threshold > share_count
569        assert!(Parameters::new(4, 3).is_err());
570    }
571}