1use blake3::Hasher;
52use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
53use curve25519_dalek::ristretto::RistrettoPoint;
54use curve25519_dalek::scalar::Scalar;
55use rand::RngCore;
56use serde::{Deserialize, Serialize};
57use thiserror::Error;
58
59fn eval_polynomial(coefficients: &[Scalar], x: Scalar) -> Scalar {
61 let mut result = Scalar::ZERO;
62 let mut x_power = Scalar::ONE;
63
64 for coeff in coefficients {
65 result += coeff * x_power;
66 x_power *= x;
67 }
68
69 result
70}
71
72pub fn generate_threshold_keys(
74 threshold: u32,
75 total: u32,
76) -> ThresholdEcdsaResult<ThresholdKeyShares> {
77 if threshold == 0 || threshold > total {
78 return Err(ThresholdEcdsaError::InvalidThreshold(format!(
79 "threshold={}, total={}",
80 threshold, total
81 )));
82 }
83
84 let secret = random_scalar();
86 let mut coefficients = vec![secret];
87 for _ in 1..threshold {
88 coefficients.push(random_scalar());
89 }
90
91 let group_pubkey = RISTRETTO_BASEPOINT_POINT * secret;
93
94 let mut shares = Vec::new();
96 for id in 1..=total {
97 let x = Scalar::from(id as u64);
98 let share_value = eval_polynomial(&coefficients, x);
99 let public_key = RISTRETTO_BASEPOINT_POINT * share_value;
100
101 shares.push((
102 id,
103 SecretShare {
104 signer_id: id,
105 share: share_value,
106 },
107 PublicShare {
108 signer_id: id,
109 public_key,
110 },
111 ));
112 }
113
114 Ok((group_pubkey, shares))
115}
116
117#[derive(Debug, Error)]
118pub enum ThresholdEcdsaError {
119 #[error("Invalid threshold: {0}")]
120 InvalidThreshold(String),
121 #[error("Invalid signer ID")]
122 InvalidSignerId,
123 #[error("Insufficient signers")]
124 InsufficientSigners,
125 #[error("Invalid public key")]
126 InvalidPublicKey,
127 #[error("Invalid signature")]
128 InvalidSignature,
129 #[error("Mismatched lengths")]
130 MismatchedLengths,
131 #[error("Serialization error: {0}")]
132 Serialization(String),
133}
134
135pub type ThresholdEcdsaResult<T> = Result<T, ThresholdEcdsaError>;
136
137fn random_scalar() -> Scalar {
139 let mut bytes = [0u8; 32];
140 rand::thread_rng().fill_bytes(&mut bytes);
141 Scalar::from_bytes_mod_order(bytes)
142}
143
144#[derive(Clone, Serialize, Deserialize)]
146pub struct SecretShare {
147 signer_id: u32,
148 share: Scalar,
149}
150
151#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
153pub struct PublicShare {
154 signer_id: u32,
155 public_key: RistrettoPoint,
156}
157
158pub type ThresholdKeyShares = (RistrettoPoint, Vec<(u32, SecretShare, PublicShare)>);
160
161#[derive(Clone)]
163pub struct ThresholdEcdsaSigner {
164 signer_id: u32,
165 threshold: u32,
166 #[allow(dead_code)]
167 total: u32,
168 secret_share: SecretShare,
169 public_share: PublicShare,
170}
171
172#[derive(Clone)]
174pub struct NonceShare {
175 #[allow(dead_code)]
176 signer_id: u32,
177 secret: Scalar,
178 public: PublicNonceShare,
179}
180
181#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
183pub struct PublicNonceShare {
184 signer_id: u32,
185 nonce_point: RistrettoPoint,
186}
187
188#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
190pub struct ThresholdPartialSignature {
191 signer_id: u32,
192 s_share: Scalar,
193}
194
195#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
197pub struct ThresholdEcdsaSignature {
198 r: RistrettoPoint,
199 s: Scalar,
200}
201
202impl ThresholdEcdsaSigner {
203 pub fn new(signer_id: u32, threshold: u32, total: u32) -> Self {
210 if threshold == 0 || threshold > total {
211 panic!("Invalid threshold: {} (total: {})", threshold, total);
212 }
213 if signer_id == 0 || signer_id > total {
214 panic!("Invalid signer ID: {} (total: {})", signer_id, total);
215 }
216
217 let secret = random_scalar();
219 let public_key = RISTRETTO_BASEPOINT_POINT * secret;
220
221 Self {
222 signer_id,
223 threshold,
224 total,
225 secret_share: SecretShare {
226 signer_id,
227 share: secret,
228 },
229 public_share: PublicShare {
230 signer_id,
231 public_key,
232 },
233 }
234 }
235
236 pub fn from_share(
238 threshold: u32,
239 total: u32,
240 secret_share: SecretShare,
241 public_share: PublicShare,
242 ) -> Self {
243 Self {
244 signer_id: secret_share.signer_id,
245 threshold,
246 total,
247 secret_share,
248 public_share,
249 }
250 }
251
252 pub fn public_share(&self) -> PublicShare {
254 self.public_share
255 }
256
257 pub fn generate_nonce_share(&self) -> NonceShare {
259 let secret = random_scalar();
260 let nonce_point = RISTRETTO_BASEPOINT_POINT * secret;
261
262 NonceShare {
263 signer_id: self.signer_id,
264 secret,
265 public: PublicNonceShare {
266 signer_id: self.signer_id,
267 nonce_point,
268 },
269 }
270 }
271
272 pub fn partial_sign(
280 &self,
281 message: &[u8],
282 nonce: &NonceShare,
283 nonce_shares: &[PublicNonceShare],
284 signer_ids: &[u32],
285 ) -> ThresholdEcdsaResult<ThresholdPartialSignature> {
286 if nonce_shares.len() < self.threshold as usize {
287 return Err(ThresholdEcdsaError::InsufficientSigners);
288 }
289
290 if nonce_shares.len() != signer_ids.len() {
291 return Err(ThresholdEcdsaError::MismatchedLengths);
292 }
293
294 let mut r = RistrettoPoint::default();
296 for nonce_share in nonce_shares {
297 r += nonce_share.nonce_point;
298 }
299
300 let lambda = compute_lagrange_coefficient(self.signer_id, signer_ids)?;
302
303 let challenge = compute_challenge(&r, message);
305
306 let s_share = nonce.secret + lambda * challenge * self.secret_share.share;
308
309 Ok(ThresholdPartialSignature {
310 signer_id: self.signer_id,
311 s_share,
312 })
313 }
314}
315
316impl NonceShare {
317 pub fn public(&self) -> PublicNonceShare {
319 self.public
320 }
321}
322
323fn compute_lagrange_coefficient(
325 signer_id: u32,
326 signer_ids: &[u32],
327) -> ThresholdEcdsaResult<Scalar> {
328 if !signer_ids.contains(&signer_id) {
329 return Err(ThresholdEcdsaError::InvalidSignerId);
330 }
331
332 let mut numerator = Scalar::ONE;
333 let mut denominator = Scalar::ONE;
334
335 let id_scalar = Scalar::from(signer_id as u64);
336
337 for &other_id in signer_ids {
338 if other_id != signer_id {
339 let other_scalar = Scalar::from(other_id as u64);
340 numerator *= other_scalar;
341 denominator *= other_scalar - id_scalar;
342 }
343 }
344
345 let denom_inv = denominator.invert();
347
348 Ok(numerator * denom_inv)
349}
350
351pub fn aggregate_threshold_public_key(
353 shares: &[PublicShare],
354) -> ThresholdEcdsaResult<RistrettoPoint> {
355 if shares.is_empty() {
356 return Err(ThresholdEcdsaError::InsufficientSigners);
357 }
358
359 let mut aggregated = RistrettoPoint::default();
360 for share in shares {
361 aggregated += share.public_key;
362 }
363
364 Ok(aggregated)
365}
366
367pub fn aggregate_threshold_signatures(
369 partials: &[ThresholdPartialSignature],
370 nonce_shares: &[PublicNonceShare],
371) -> ThresholdEcdsaResult<ThresholdEcdsaSignature> {
372 if partials.is_empty() {
373 return Err(ThresholdEcdsaError::InsufficientSigners);
374 }
375
376 let mut r = RistrettoPoint::default();
378 for nonce_share in nonce_shares {
379 r += nonce_share.nonce_point;
380 }
381
382 let mut s = Scalar::ZERO;
384 for partial in partials {
385 s += partial.s_share;
386 }
387
388 Ok(ThresholdEcdsaSignature { r, s })
389}
390
391fn compute_challenge(r: &RistrettoPoint, message: &[u8]) -> Scalar {
393 let mut hasher = Hasher::new();
394 hasher.update(&r.compress().to_bytes());
395 hasher.update(message);
396
397 let hash = hasher.finalize();
398 Scalar::from_bytes_mod_order(*hash.as_bytes())
399}
400
401pub fn verify_threshold_ecdsa(
403 public_key: &RistrettoPoint,
404 message: &[u8],
405 signature: &ThresholdEcdsaSignature,
406) -> bool {
407 let challenge = compute_challenge(&signature.r, message);
409
410 let lhs = RISTRETTO_BASEPOINT_POINT * signature.s;
412 let rhs = signature.r + challenge * public_key;
413
414 lhs == rhs
415}
416
417impl ThresholdEcdsaSignature {
418 pub fn to_bytes(&self) -> [u8; 64] {
420 let mut bytes = [0u8; 64];
421 bytes[..32].copy_from_slice(&self.r.compress().to_bytes());
422 bytes[32..].copy_from_slice(&self.s.to_bytes());
423 bytes
424 }
425
426 pub fn from_bytes(bytes: &[u8; 64]) -> ThresholdEcdsaResult<Self> {
428 let r = curve25519_dalek::ristretto::CompressedRistretto(bytes[..32].try_into().unwrap())
429 .decompress()
430 .ok_or(ThresholdEcdsaError::InvalidSignature)?;
431 let s = Scalar::from_bytes_mod_order(bytes[32..].try_into().unwrap());
432
433 Ok(Self { r, s })
434 }
435}
436
437impl PublicShare {
438 pub fn to_bytes(&self) -> [u8; 36] {
440 let mut bytes = [0u8; 36];
441 bytes[..4].copy_from_slice(&self.signer_id.to_le_bytes());
442 bytes[4..].copy_from_slice(&self.public_key.compress().to_bytes());
443 bytes
444 }
445
446 pub fn from_bytes(bytes: &[u8; 36]) -> ThresholdEcdsaResult<Self> {
448 let signer_id = u32::from_le_bytes(bytes[..4].try_into().unwrap());
449 let public_key =
450 curve25519_dalek::ristretto::CompressedRistretto(bytes[4..].try_into().unwrap())
451 .decompress()
452 .ok_or(ThresholdEcdsaError::InvalidPublicKey)?;
453
454 Ok(Self {
455 signer_id,
456 public_key,
457 })
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_threshold_ecdsa_2_of_3() {
467 let threshold = 2;
468 let total = 3;
469
470 let (group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
472
473 let signer1 = ThresholdEcdsaSigner::from_share(
475 threshold,
476 total,
477 key_shares[0].1.clone(),
478 key_shares[0].2,
479 );
480 let signer2 = ThresholdEcdsaSigner::from_share(
481 threshold,
482 total,
483 key_shares[1].1.clone(),
484 key_shares[1].2,
485 );
486 let _signer3 = ThresholdEcdsaSigner::from_share(
487 threshold,
488 total,
489 key_shares[2].1.clone(),
490 key_shares[2].2,
491 );
492
493 let message = b"Test message";
494
495 let nonce1 = signer1.generate_nonce_share();
497 let nonce2 = signer2.generate_nonce_share();
498
499 let nonce_shares = vec![nonce1.public(), nonce2.public()];
500 let signer_ids = vec![1, 2];
501
502 let partial1 = signer1
503 .partial_sign(message, &nonce1, &nonce_shares, &signer_ids)
504 .unwrap();
505 let partial2 = signer2
506 .partial_sign(message, &nonce2, &nonce_shares, &signer_ids)
507 .unwrap();
508
509 let signature =
510 aggregate_threshold_signatures(&[partial1, partial2], &nonce_shares).unwrap();
511
512 assert!(verify_threshold_ecdsa(&group_pubkey, message, &signature));
513 }
514
515 #[test]
516 fn test_threshold_ecdsa_different_signers() {
517 let threshold = 2;
518 let total = 3;
519
520 let (group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
522
523 let signer1 = ThresholdEcdsaSigner::from_share(
525 threshold,
526 total,
527 key_shares[0].1.clone(),
528 key_shares[0].2,
529 );
530 let _signer2 = ThresholdEcdsaSigner::from_share(
531 threshold,
532 total,
533 key_shares[1].1.clone(),
534 key_shares[1].2,
535 );
536 let signer3 = ThresholdEcdsaSigner::from_share(
537 threshold,
538 total,
539 key_shares[2].1.clone(),
540 key_shares[2].2,
541 );
542
543 let message = b"Test message";
544
545 let nonce1 = signer1.generate_nonce_share();
547 let nonce3 = signer3.generate_nonce_share();
548
549 let nonce_shares = vec![nonce1.public(), nonce3.public()];
550 let signer_ids = vec![1, 3];
551
552 let partial1 = signer1
553 .partial_sign(message, &nonce1, &nonce_shares, &signer_ids)
554 .unwrap();
555 let partial3 = signer3
556 .partial_sign(message, &nonce3, &nonce_shares, &signer_ids)
557 .unwrap();
558
559 let signature =
560 aggregate_threshold_signatures(&[partial1, partial3], &nonce_shares).unwrap();
561
562 assert!(verify_threshold_ecdsa(&group_pubkey, message, &signature));
563 }
564
565 #[test]
566 fn test_threshold_ecdsa_3_of_5() {
567 let threshold = 3;
568 let total = 5;
569
570 let (group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
572
573 let signers: Vec<ThresholdEcdsaSigner> = key_shares
575 .iter()
576 .map(|(_, secret_share, signer_id)| {
577 ThresholdEcdsaSigner::from_share(threshold, total, secret_share.clone(), *signer_id)
578 })
579 .collect();
580
581 let message = b"3-of-5 threshold test";
582
583 let nonces: Vec<NonceShare> = vec![
585 signers[0].generate_nonce_share(),
586 signers[2].generate_nonce_share(),
587 signers[4].generate_nonce_share(),
588 ];
589
590 let nonce_shares: Vec<PublicNonceShare> = nonces.iter().map(|n| n.public()).collect();
591 let signer_ids = vec![1, 3, 5];
592
593 let partials: Vec<ThresholdPartialSignature> = vec![
594 signers[0]
595 .partial_sign(message, &nonces[0], &nonce_shares, &signer_ids)
596 .unwrap(),
597 signers[2]
598 .partial_sign(message, &nonces[1], &nonce_shares, &signer_ids)
599 .unwrap(),
600 signers[4]
601 .partial_sign(message, &nonces[2], &nonce_shares, &signer_ids)
602 .unwrap(),
603 ];
604
605 let signature = aggregate_threshold_signatures(&partials, &nonce_shares).unwrap();
606
607 assert!(verify_threshold_ecdsa(&group_pubkey, message, &signature));
608 }
609
610 #[test]
611 fn test_insufficient_signers() {
612 let threshold = 3;
613 let total = 5;
614
615 let (_group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
617
618 let signer1 = ThresholdEcdsaSigner::from_share(
619 threshold,
620 total,
621 key_shares[0].1.clone(),
622 key_shares[0].2,
623 );
624 let signer2 = ThresholdEcdsaSigner::from_share(
625 threshold,
626 total,
627 key_shares[1].1.clone(),
628 key_shares[1].2,
629 );
630
631 let message = b"Test message";
632
633 let nonce1 = signer1.generate_nonce_share();
634 let nonce2 = signer2.generate_nonce_share();
635
636 let nonce_shares = vec![nonce1.public(), nonce2.public()];
637 let signer_ids = vec![1, 2];
638
639 let result = signer1.partial_sign(message, &nonce1, &nonce_shares, &signer_ids);
641 assert!(result.is_err());
642 }
643
644 #[test]
645 fn test_wrong_message() {
646 let threshold = 2;
647 let total = 3;
648
649 let (group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
651
652 let signer1 = ThresholdEcdsaSigner::from_share(
653 threshold,
654 total,
655 key_shares[0].1.clone(),
656 key_shares[0].2,
657 );
658 let signer2 = ThresholdEcdsaSigner::from_share(
659 threshold,
660 total,
661 key_shares[1].1.clone(),
662 key_shares[1].2,
663 );
664
665 let message = b"Original message";
666
667 let nonce1 = signer1.generate_nonce_share();
668 let nonce2 = signer2.generate_nonce_share();
669
670 let nonce_shares = vec![nonce1.public(), nonce2.public()];
671 let signer_ids = vec![1, 2];
672
673 let partial1 = signer1
674 .partial_sign(message, &nonce1, &nonce_shares, &signer_ids)
675 .unwrap();
676 let partial2 = signer2
677 .partial_sign(message, &nonce2, &nonce_shares, &signer_ids)
678 .unwrap();
679
680 let signature =
681 aggregate_threshold_signatures(&[partial1, partial2], &nonce_shares).unwrap();
682
683 assert!(!verify_threshold_ecdsa(
685 &group_pubkey,
686 b"Wrong message",
687 &signature
688 ));
689 }
690
691 #[test]
692 fn test_signature_serialization() {
693 let threshold = 2;
694 let total = 3;
695
696 let (group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
698
699 let signer1 = ThresholdEcdsaSigner::from_share(
700 threshold,
701 total,
702 key_shares[0].1.clone(),
703 key_shares[0].2,
704 );
705 let signer2 = ThresholdEcdsaSigner::from_share(
706 threshold,
707 total,
708 key_shares[1].1.clone(),
709 key_shares[1].2,
710 );
711
712 let message = b"Serialization test";
713
714 let nonce1 = signer1.generate_nonce_share();
715 let nonce2 = signer2.generate_nonce_share();
716
717 let nonce_shares = vec![nonce1.public(), nonce2.public()];
718 let signer_ids = vec![1, 2];
719
720 let partial1 = signer1
721 .partial_sign(message, &nonce1, &nonce_shares, &signer_ids)
722 .unwrap();
723 let partial2 = signer2
724 .partial_sign(message, &nonce2, &nonce_shares, &signer_ids)
725 .unwrap();
726
727 let signature =
728 aggregate_threshold_signatures(&[partial1, partial2], &nonce_shares).unwrap();
729
730 let bytes = signature.to_bytes();
731 let recovered = ThresholdEcdsaSignature::from_bytes(&bytes).unwrap();
732
733 assert!(verify_threshold_ecdsa(&group_pubkey, message, &recovered));
734 }
735
736 #[test]
737 fn test_public_share_serialization() {
738 let threshold = 2;
739 let total = 3;
740
741 let (_group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
743
744 let signer = ThresholdEcdsaSigner::from_share(
745 threshold,
746 total,
747 key_shares[0].1.clone(),
748 key_shares[0].2,
749 );
750 let pub_share = signer.public_share();
751
752 let bytes = pub_share.to_bytes();
753 let recovered = PublicShare::from_bytes(&bytes).unwrap();
754
755 assert_eq!(pub_share.signer_id, recovered.signer_id);
756 assert_eq!(pub_share.public_key, recovered.public_key);
757 }
758
759 #[test]
760 fn test_all_signers_participate() {
761 let threshold = 3;
762 let total = 3;
763
764 let (group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
766
767 let signers: Vec<ThresholdEcdsaSigner> = key_shares
769 .iter()
770 .map(|(_, secret_share, signer_id)| {
771 ThresholdEcdsaSigner::from_share(threshold, total, secret_share.clone(), *signer_id)
772 })
773 .collect();
774
775 let message = b"All signers participate";
776
777 let nonces: Vec<NonceShare> = signers.iter().map(|s| s.generate_nonce_share()).collect();
778 let nonce_shares: Vec<PublicNonceShare> = nonces.iter().map(|n| n.public()).collect();
779 let signer_ids = vec![1, 2, 3];
780
781 let partials: Vec<ThresholdPartialSignature> = signers
782 .iter()
783 .zip(nonces.iter())
784 .map(|(signer, nonce)| {
785 signer
786 .partial_sign(message, nonce, &nonce_shares, &signer_ids)
787 .unwrap()
788 })
789 .collect();
790
791 let signature = aggregate_threshold_signatures(&partials, &nonce_shares).unwrap();
792
793 assert!(verify_threshold_ecdsa(&group_pubkey, message, &signature));
794 }
795
796 #[test]
797 fn test_lagrange_coefficient() {
798 let signer_ids = vec![1, 2, 3];
800
801 let lambda1 = compute_lagrange_coefficient(1, &signer_ids).unwrap();
802 let lambda2 = compute_lagrange_coefficient(2, &signer_ids).unwrap();
803 let lambda3 = compute_lagrange_coefficient(3, &signer_ids).unwrap();
804
805 assert_ne!(lambda1, Scalar::ZERO);
807 assert_ne!(lambda2, Scalar::ZERO);
808 assert_ne!(lambda3, Scalar::ZERO);
809 }
810
811 #[test]
812 fn test_multiple_signatures_same_key() {
813 let threshold = 2;
814 let total = 3;
815
816 let (group_pubkey, key_shares) = generate_threshold_keys(threshold, total).unwrap();
818
819 let signer1 = ThresholdEcdsaSigner::from_share(
820 threshold,
821 total,
822 key_shares[0].1.clone(),
823 key_shares[0].2,
824 );
825 let signer2 = ThresholdEcdsaSigner::from_share(
826 threshold,
827 total,
828 key_shares[1].1.clone(),
829 key_shares[1].2,
830 );
831
832 let message1 = b"First message";
834 let message2 = b"Second message";
835
836 let nonce1a = signer1.generate_nonce_share();
838 let nonce2a = signer2.generate_nonce_share();
839 let nonce_shares_a = vec![nonce1a.public(), nonce2a.public()];
840 let signer_ids = vec![1, 2];
841
842 let partial1a = signer1
843 .partial_sign(message1, &nonce1a, &nonce_shares_a, &signer_ids)
844 .unwrap();
845 let partial2a = signer2
846 .partial_sign(message1, &nonce2a, &nonce_shares_a, &signer_ids)
847 .unwrap();
848 let sig1 =
849 aggregate_threshold_signatures(&[partial1a, partial2a], &nonce_shares_a).unwrap();
850
851 let nonce1b = signer1.generate_nonce_share();
853 let nonce2b = signer2.generate_nonce_share();
854 let nonce_shares_b = vec![nonce1b.public(), nonce2b.public()];
855
856 let partial1b = signer1
857 .partial_sign(message2, &nonce1b, &nonce_shares_b, &signer_ids)
858 .unwrap();
859 let partial2b = signer2
860 .partial_sign(message2, &nonce2b, &nonce_shares_b, &signer_ids)
861 .unwrap();
862 let sig2 =
863 aggregate_threshold_signatures(&[partial1b, partial2b], &nonce_shares_b).unwrap();
864
865 assert!(verify_threshold_ecdsa(&group_pubkey, message1, &sig1));
867 assert!(verify_threshold_ecdsa(&group_pubkey, message2, &sig2));
868
869 assert!(!verify_threshold_ecdsa(&group_pubkey, message1, &sig2));
871 assert!(!verify_threshold_ecdsa(&group_pubkey, message2, &sig1));
872 }
873}