1use std::collections::BTreeMap;
5use std::fmt;
6use std::marker::PhantomData;
7
8use zeroize::Zeroize;
9
10use crate::curve::DklsCurve;
11use crate::protocols::derivation::DerivData;
12use crate::utilities::multiplication::{MulReceiver, MulSender};
13use crate::utilities::zero_shares::ZeroShare;
14
15pub mod derivation;
16pub mod dkg;
17pub mod dkg_session;
18#[cfg(feature = "serde")]
19pub mod messages;
20pub mod re_key;
21pub mod refresh;
22pub mod sign_session;
23pub mod signature;
24pub mod signing;
25
26#[derive(Debug, Clone)]
28pub struct InvalidPartyIndex;
29
30impl fmt::Display for InvalidPartyIndex {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(f, "party index must be > 0")
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct InvalidParameters;
39
40impl fmt::Display for InvalidParameters {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 write!(f, "parameters must satisfy 1 < threshold <= share_count")
43 }
44}
45
46impl std::error::Error for InvalidParameters {}
47
48impl std::error::Error for InvalidPartyIndex {}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Zeroize)]
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[repr(transparent)]
54#[cfg_attr(feature = "serde", serde(try_from = "u8", into = "u8"))]
55pub struct PartyIndex(u8);
56
57impl PartyIndex {
58 pub fn new(value: u8) -> Result<Self, InvalidPartyIndex> {
59 if value == 0 {
60 Err(InvalidPartyIndex)
61 } else {
62 Ok(Self(value))
63 }
64 }
65
66 #[must_use]
67 pub fn as_u8(&self) -> u8 {
68 self.0
69 }
70}
71
72impl TryFrom<u8> for PartyIndex {
73 type Error = InvalidPartyIndex;
74 fn try_from(value: u8) -> Result<Self, Self::Error> {
75 Self::new(value)
76 }
77}
78
79impl From<PartyIndex> for u8 {
80 fn from(pi: PartyIndex) -> Self {
81 pi.0
82 }
83}
84
85impl fmt::Display for PartyIndex {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 write!(f, "{}", self.0)
88 }
89}
90
91#[derive(Clone, Debug)]
93#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94pub struct Parameters {
95 pub threshold: u8, pub share_count: u8, }
98
99impl Parameters {
100 pub fn new(threshold: u8, share_count: u8) -> Result<Self, InvalidParameters> {
104 if threshold < 2 || threshold > share_count {
105 return Err(InvalidParameters);
106 }
107 Ok(Self {
108 threshold,
109 share_count,
110 })
111 }
112}
113
114#[derive(Clone, Debug)]
116#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
117#[cfg_attr(
118 feature = "serde",
119 serde(bound(
120 serialize = "C::AffinePoint: serde::Serialize, C::Scalar: serde::Serialize",
121 deserialize = "C::AffinePoint: serde::Deserialize<'de>, C::Scalar: serde::Deserialize<'de>"
122 ))
123)]
124pub struct Party<C: DklsCurve> {
125 pub parameters: Parameters,
126 pub party_index: PartyIndex,
127 pub session_id: Vec<u8>,
128
129 pub poly_point: C::Scalar,
131 pub pk: C::AffinePoint,
133
134 pub zero_share: ZeroShare,
136
137 pub mul_senders: BTreeMap<PartyIndex, MulSender<C>>,
140 pub mul_receivers: BTreeMap<PartyIndex, MulReceiver<C>>,
141
142 pub derivation_data: DerivData<C>,
144
145 pub address: String,
147}
148
149impl<C: DklsCurve> Zeroize for Party<C> {
150 fn zeroize(&mut self) {
151 self.session_id.zeroize();
153 self.poly_point.zeroize();
154 self.zero_share.zeroize();
155 for sender in self.mul_senders.values_mut() {
157 sender.zeroize();
158 }
159 self.mul_senders.clear();
160 for receiver in self.mul_receivers.values_mut() {
161 receiver.zeroize();
162 }
163 self.mul_receivers.clear();
164 self.derivation_data.zeroize();
165 self.address.zeroize();
166 }
167}
168
169impl<C: DklsCurve> Drop for Party<C> {
170 fn drop(&mut self) {
171 self.zeroize();
172 }
173}
174
175#[derive(Debug, Clone)]
178#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
179#[cfg_attr(
180 feature = "serde",
181 serde(bound(
182 serialize = "C::AffinePoint: serde::Serialize, C::Scalar: serde::Serialize",
183 deserialize = "C::AffinePoint: serde::Deserialize<'de>, C::Scalar: serde::Deserialize<'de>"
184 ))
185)]
186pub struct PublicKeyPackage<C: DklsCurve> {
187 verifying_key: C::AffinePoint,
188 verifying_shares: BTreeMap<PartyIndex, C::AffinePoint>,
189 parameters: Parameters,
190 #[cfg_attr(feature = "serde", serde(skip))]
191 _curve: PhantomData<C>,
192}
193
194impl<C: DklsCurve> PublicKeyPackage<C> {
195 #[must_use]
196 pub fn new(
197 verifying_key: C::AffinePoint,
198 verifying_shares: BTreeMap<PartyIndex, C::AffinePoint>,
199 parameters: Parameters,
200 ) -> Self {
201 Self {
202 verifying_key,
203 verifying_shares,
204 parameters,
205 _curve: PhantomData,
206 }
207 }
208
209 #[must_use]
210 pub fn verifying_key(&self) -> &C::AffinePoint {
211 &self.verifying_key
212 }
213
214 #[must_use]
215 pub fn verifying_share(&self, party: PartyIndex) -> Option<&C::AffinePoint> {
216 self.verifying_shares.get(&party)
217 }
218
219 #[must_use]
220 pub fn threshold(&self) -> u8 {
221 self.parameters.threshold
222 }
223
224 #[must_use]
225 pub fn share_count(&self) -> u8 {
226 self.parameters.share_count
227 }
228
229 #[must_use]
230 pub fn verify_share(&self, party: PartyIndex, verification_share: &C::AffinePoint) -> bool {
231 self.verifying_shares
232 .get(&party)
233 .is_some_and(|stored| stored == verification_share)
234 }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq)]
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246pub enum AbortKind {
247 Recoverable,
250 BanCounterparty(PartyIndex),
255}
256
257#[derive(Debug, Clone, PartialEq, Eq)]
259#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
260#[non_exhaustive]
261pub enum AbortReason {
262 InvalidPartyIndex {
264 index: PartyIndex,
265 },
266 WrongCounterpartyCount {
267 expected: usize,
268 got: usize,
269 },
270 DuplicateCounterparty {
271 index: PartyIndex,
272 },
273 SelfInCounterparties,
274 MissingMulState {
275 counterparty: PartyIndex,
276 },
277
278 MisroutedMessage {
280 expected_receiver: PartyIndex,
281 actual_receiver: PartyIndex,
282 },
283 UnexpectedSender {
284 sender: PartyIndex,
285 },
286 DuplicateSender {
287 sender: PartyIndex,
288 },
289 WrongMessageCount {
290 expected: usize,
291 got: usize,
292 },
293 MissingMessageFromParty {
294 party: PartyIndex,
295 },
296
297 ProofVerificationFailed {
299 counterparty: PartyIndex,
300 },
301 CommitmentMismatch {
302 counterparty: PartyIndex,
303 },
304 PolynomialInconsistency,
305 TrivialInstancePoint {
306 counterparty: PartyIndex,
307 },
308 TrivialPublicKey,
309 TrivialKeyShare,
310 MissingCommittedPoint {
311 party: PartyIndex,
312 },
313
314 OtConsistencyCheckFailed {
316 counterparty: PartyIndex,
317 },
318 MultiplicationVerificationFailed {
319 counterparty: PartyIndex,
320 detail: String,
321 },
322 GammaUInconsistency {
323 counterparty: PartyIndex,
324 },
325
326 SignatureVerificationFailed,
328 ZeroDenominator,
329 LagrangeCoefficientFailed,
330 InvalidXCoordinateHex,
331
332 ZeroShareDecommitFailed {
334 counterparty: PartyIndex,
335 },
336
337 ChainCodeCommitmentFailed {
339 party: PartyIndex,
340 },
341
342 PhaseCalledOutOfOrder {
344 phase: String,
345 },
346
347 InvalidHex {
349 detail: String,
350 },
351}
352
353impl fmt::Display for AbortReason {
354 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 match self {
356 Self::InvalidPartyIndex { index } => {
357 write!(f, "party index {index} is out of valid range")
358 }
359 Self::WrongCounterpartyCount { expected, got } => {
360 write!(
361 f,
362 "wrong counterparty count: expected {expected}, got {got}"
363 )
364 }
365 Self::DuplicateCounterparty { index } => {
366 write!(f, "duplicate counterparty: {index}")
367 }
368 Self::SelfInCounterparties => write!(f, "own index in counterparty list"),
369 Self::MissingMulState { counterparty } => {
370 write!(f, "missing multiplication state for party {counterparty}")
371 }
372 Self::MisroutedMessage {
373 expected_receiver,
374 actual_receiver,
375 } => write!(
376 f,
377 "message addressed to {actual_receiver}, expected {expected_receiver}"
378 ),
379 Self::UnexpectedSender { sender } => write!(f, "unexpected sender: {sender}"),
380 Self::DuplicateSender { sender } => {
381 write!(f, "duplicate message from party {sender}")
382 }
383 Self::WrongMessageCount { expected, got } => {
384 write!(f, "wrong message count: expected {expected}, got {got}")
385 }
386 Self::MissingMessageFromParty { party } => {
387 write!(f, "missing message from party {party}")
388 }
389 Self::ProofVerificationFailed { counterparty } => {
390 write!(f, "proof verification failed for party {counterparty}")
391 }
392 Self::CommitmentMismatch { counterparty } => {
393 write!(f, "commitment mismatch for party {counterparty}")
394 }
395 Self::PolynomialInconsistency => write!(f, "polynomial inconsistency"),
396 Self::TrivialInstancePoint { counterparty } => {
397 write!(f, "trivial instance point from party {counterparty}")
398 }
399 Self::TrivialPublicKey => write!(f, "trivial public key"),
400 Self::TrivialKeyShare => write!(f, "trivial key share"),
401 Self::MissingCommittedPoint { party } => {
402 write!(f, "missing committed point for party {party}")
403 }
404 Self::OtConsistencyCheckFailed { counterparty } => {
405 write!(f, "OT consistency check failed for party {counterparty}")
406 }
407 Self::MultiplicationVerificationFailed {
408 counterparty,
409 detail,
410 } => {
411 write!(
412 f,
413 "multiplication verification failed for party {counterparty}: {detail}"
414 )
415 }
416 Self::GammaUInconsistency { counterparty } => {
417 write!(f, "gamma-u inconsistency for party {counterparty}")
418 }
419 Self::SignatureVerificationFailed => write!(f, "signature verification failed"),
420 Self::ZeroDenominator => write!(f, "zero denominator in signature assembly"),
421 Self::LagrangeCoefficientFailed => {
422 write!(f, "failed to compute Lagrange coefficient")
423 }
424 Self::InvalidXCoordinateHex => write!(f, "invalid x-coordinate hex"),
425 Self::ZeroShareDecommitFailed { counterparty } => {
426 write!(f, "zero-share decommitment failed for party {counterparty}")
427 }
428 Self::ChainCodeCommitmentFailed { party } => {
429 write!(f, "chain code commitment failed for party {party}")
430 }
431 Self::PhaseCalledOutOfOrder { phase } => {
432 write!(f, "{phase}")
433 }
434 Self::InvalidHex { detail } => write!(f, "invalid hex: {detail}"),
435 }
436 }
437}
438
439#[derive(Debug, Clone, PartialEq, Eq)]
440#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
441pub struct Abort {
442 pub index: PartyIndex,
444 pub kind: AbortKind,
446 pub reason: AbortReason,
448}
449
450impl Abort {
451 #[must_use]
453 pub fn recoverable(index: PartyIndex, reason: AbortReason) -> Abort {
454 Abort {
455 index,
456 kind: AbortKind::Recoverable,
457 reason,
458 }
459 }
460
461 #[must_use]
468 pub fn ban(index: PartyIndex, counterparty: PartyIndex, reason: AbortReason) -> Abort {
469 Abort {
470 index,
471 kind: AbortKind::BanCounterparty(counterparty),
472 reason,
473 }
474 }
475
476 #[must_use]
478 pub fn description(&self) -> String {
479 self.reason.to_string()
480 }
481}
482
483#[derive(Clone, Debug, Zeroize)]
485#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
486pub struct PartiesMessage {
487 pub sender: PartyIndex,
488 pub receiver: PartyIndex,
489}
490
491impl PartiesMessage {
492 #[must_use]
494 pub fn reverse(&self) -> PartiesMessage {
495 PartiesMessage {
496 sender: self.receiver,
497 receiver: self.sender,
498 }
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use std::collections::BTreeMap;
506
507 #[test]
508 fn party_index_rejects_zero() {
509 assert!(PartyIndex::new(0).is_err());
510 assert!(PartyIndex::try_from(0u8).is_err());
511 }
512
513 #[test]
514 fn party_index_accepts_nonzero() {
515 for i in 1..=u8::MAX {
516 assert!(PartyIndex::new(i).is_ok());
517 }
518 }
519
520 #[test]
521 fn party_index_round_trip() {
522 for i in 1..=u8::MAX {
523 let pi = PartyIndex::new(i).unwrap();
524 assert_eq!(pi.as_u8(), i);
525 assert_eq!(u8::from(pi), i);
526 }
527 }
528
529 #[test]
530 fn party_index_serde_json_transparent() {
531 let pi = PartyIndex::new(5).unwrap();
532 let json = serde_json::to_string(&pi).unwrap();
533 assert_eq!(json, "5");
534
535 let deserialized: PartyIndex = serde_json::from_str(&json).unwrap();
536 assert_eq!(deserialized, pi);
537 }
538
539 #[test]
540 fn party_index_serde_rejects_zero() {
541 let result: Result<PartyIndex, _> = serde_json::from_str("0");
542 assert!(result.is_err());
543 }
544
545 #[test]
546 fn party_index_btreemap_ordering() {
547 let mut map = BTreeMap::new();
548 map.insert(PartyIndex::new(3).unwrap(), "c");
549 map.insert(PartyIndex::new(1).unwrap(), "a");
550 map.insert(PartyIndex::new(2).unwrap(), "b");
551
552 let keys: Vec<u8> = map.keys().map(|k| k.as_u8()).collect();
553 assert_eq!(keys, vec![1, 2, 3]);
554 }
555
556 #[test]
557 fn parameters_new_valid() {
558 assert!(Parameters::new(2, 3).is_ok());
559 assert!(Parameters::new(2, 2).is_ok());
560 assert!(Parameters::new(5, 10).is_ok());
561 }
562
563 #[test]
564 fn parameters_new_rejects_invalid() {
565 assert!(Parameters::new(0, 3).is_err());
567 assert!(Parameters::new(1, 3).is_err());
568 assert!(Parameters::new(4, 3).is_err());
570 }
571}