1use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
55use curve25519_dalek::ristretto::RistrettoPoint;
56use curve25519_dalek::scalar::Scalar;
57use curve25519_dalek::traits::Identity;
58use rand::Rng;
59use serde::{Deserialize, Serialize};
60use thiserror::Error;
61
62fn random_scalar() -> Scalar {
64 let mut rng = rand::thread_rng();
65 let mut bytes = [0u8; 32];
66 rng.fill(&mut bytes);
67 Scalar::from_bytes_mod_order(bytes)
68}
69
70#[allow(dead_code)]
72fn random_point() -> RistrettoPoint {
73 RISTRETTO_BASEPOINT_POINT * random_scalar()
74}
75
76#[derive(Error, Debug)]
78pub enum DkgError {
79 #[error("Invalid threshold: must have 1 <= threshold <= total_parties")]
80 InvalidThreshold,
81 #[error("Invalid participant ID")]
82 InvalidParticipantId,
83 #[error("Invalid share")]
84 InvalidShare,
85 #[error("Share verification failed")]
86 ShareVerificationFailed,
87 #[error("Insufficient shares for reconstruction")]
88 InsufficientShares,
89 #[error("Duplicate participant ID")]
90 DuplicateParticipant,
91}
92
93pub type DkgResult<T> = Result<T, DkgError>;
94
95#[derive(Clone, Debug)]
97pub struct DkgParams {
98 pub total_parties: usize,
100 pub threshold: usize,
102 g: RistrettoPoint,
104}
105
106impl DkgParams {
107 pub fn new(total_parties: usize, threshold: usize) -> Self {
125 assert!(threshold > 0 && threshold <= total_parties);
126
127 let g = RISTRETTO_BASEPOINT_POINT;
129
130 Self {
131 total_parties,
132 threshold,
133 g,
134 }
135 }
136}
137
138pub struct DkgParticipant {
140 id: usize,
142 params: DkgParams,
144 coefficients: Vec<Scalar>,
146 commitments: Vec<RistrettoPoint>,
148 received_shares: Vec<Option<Scalar>>,
150}
151
152#[derive(Clone, Debug, Serialize, Deserialize)]
154pub struct DkgCommitments {
155 pub participant_id: usize,
157 pub commitments: Vec<Vec<u8>>, }
160
161#[derive(Clone, Debug, Serialize, Deserialize)]
163pub struct DkgShare {
164 value: Vec<u8>, }
167
168impl DkgParticipant {
169 pub fn new(params: &DkgParams, id: usize) -> Self {
176 assert!(id < params.total_parties);
177
178 let coefficients: Vec<Scalar> = (0..params.threshold).map(|_| random_scalar()).collect();
181
182 let commitments: Vec<RistrettoPoint> =
184 coefficients.iter().map(|coeff| params.g * coeff).collect();
185
186 let received_shares = vec![None; params.total_parties];
187
188 Self {
189 id,
190 params: params.clone(),
191 coefficients,
192 commitments,
193 received_shares,
194 }
195 }
196
197 pub fn get_commitments(&self) -> DkgCommitments {
199 DkgCommitments {
200 participant_id: self.id,
201 commitments: self
202 .commitments
203 .iter()
204 .map(|c| c.compress().as_bytes().to_vec())
205 .collect(),
206 }
207 }
208
209 pub fn generate_share(&self, target_id: usize) -> DkgResult<DkgShare> {
215 if target_id >= self.params.total_parties {
216 return Err(DkgError::InvalidParticipantId);
217 }
218
219 let x = Scalar::from((target_id + 1) as u64);
222 let share_value = evaluate_polynomial(&self.coefficients, x);
223
224 Ok(DkgShare {
225 value: share_value.to_bytes().to_vec(),
226 })
227 }
228
229 pub fn receive_share(
237 &mut self,
238 from_id: usize,
239 share: DkgShare,
240 commitments: &DkgCommitments,
241 ) -> DkgResult<()> {
242 if from_id >= self.params.total_parties {
243 return Err(DkgError::InvalidParticipantId);
244 }
245
246 if commitments.participant_id != from_id {
247 return Err(DkgError::InvalidShare);
248 }
249
250 if self.received_shares[from_id].is_some() {
251 return Err(DkgError::DuplicateParticipant);
252 }
253
254 if share.value.len() != 32 {
256 return Err(DkgError::InvalidShare);
257 }
258 let mut share_bytes = [0u8; 32];
259 share_bytes.copy_from_slice(&share.value);
260 let share_scalar = Scalar::from_bytes_mod_order(share_bytes);
261
262 let x = Scalar::from((self.id + 1) as u64);
265
266 let mut expected = RistrettoPoint::identity();
267 let mut x_power = Scalar::ONE;
268
269 for commitment_bytes in &commitments.commitments {
270 if commitment_bytes.len() != 32 {
271 return Err(DkgError::InvalidShare);
272 }
273
274 let mut bytes = [0u8; 32];
275 bytes.copy_from_slice(commitment_bytes);
276
277 let commitment = curve25519_dalek::ristretto::CompressedRistretto(bytes)
278 .decompress()
279 .ok_or(DkgError::InvalidShare)?;
280
281 expected += commitment * x_power;
282 x_power *= x;
283 }
284
285 let actual = self.params.g * share_scalar;
286
287 if actual != expected {
288 return Err(DkgError::ShareVerificationFailed);
289 }
290
291 self.received_shares[from_id] = Some(share_scalar);
293
294 Ok(())
295 }
296
297 pub fn get_secret_share(&self) -> DkgResult<Scalar> {
301 let own_share = self.generate_share(self.id)?;
303 let mut own_bytes = [0u8; 32];
304 own_bytes.copy_from_slice(&own_share.value);
305 let mut total = Scalar::from_bytes_mod_order(own_bytes);
306
307 for share in self.received_shares.iter().flatten() {
309 total += share;
310 }
311
312 Ok(total)
313 }
314}
315
316pub fn aggregate_public_key(all_commitments: &[DkgCommitments]) -> RistrettoPoint {
340 let mut joint_pk = RistrettoPoint::identity();
342
343 for commitments in all_commitments {
344 if !commitments.commitments.is_empty() {
345 let mut bytes = [0u8; 32];
346 bytes.copy_from_slice(&commitments.commitments[0]);
347
348 if let Some(point) =
349 curve25519_dalek::ristretto::CompressedRistretto(bytes).decompress()
350 {
351 joint_pk += point;
352 }
353 }
354 }
355
356 joint_pk
357}
358
359fn evaluate_polynomial(coefficients: &[Scalar], x: Scalar) -> Scalar {
361 let mut result = Scalar::ZERO;
362 let mut x_power = Scalar::ONE;
363
364 for coeff in coefficients {
365 result += coeff * x_power;
366 x_power *= x;
367 }
368
369 result
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_dkg_basic() {
378 let params = DkgParams::new(3, 2);
379
380 let mut participants: Vec<_> = (0..3).map(|i| DkgParticipant::new(¶ms, i)).collect();
381
382 let commitments: Vec<_> = participants.iter().map(|p| p.get_commitments()).collect();
384
385 for i in 0..3 {
387 for j in 0..3 {
388 if i != j {
389 let share = participants[j].generate_share(i).unwrap();
390 participants[i]
391 .receive_share(j, share, &commitments[j])
392 .unwrap();
393 }
394 }
395 }
396
397 let shares: Vec<_> = participants
399 .iter()
400 .map(|p| p.get_secret_share().unwrap())
401 .collect();
402
403 assert_eq!(shares.len(), 3);
404 }
405
406 #[test]
407 fn test_dkg_aggregate_public_key() {
408 let params = DkgParams::new(5, 3);
409
410 let participants: Vec<_> = (0..5).map(|i| DkgParticipant::new(¶ms, i)).collect();
411
412 let commitments: Vec<_> = participants.iter().map(|p| p.get_commitments()).collect();
413
414 let public_key = aggregate_public_key(&commitments);
415
416 assert_ne!(public_key, RistrettoPoint::identity());
418 }
419
420 #[test]
421 fn test_dkg_invalid_threshold() {
422 let params = DkgParams::new(3, 2);
423 let participant = DkgParticipant::new(¶ms, 0);
424
425 assert!(participant.generate_share(10).is_err());
427 }
428
429 #[test]
430 fn test_dkg_share_verification() {
431 let params = DkgParams::new(3, 2);
432 let mut participant0 = DkgParticipant::new(¶ms, 0);
433 let participant1 = DkgParticipant::new(¶ms, 1);
434
435 let commitments1 = participant1.get_commitments();
436 let share = participant1.generate_share(0).unwrap();
437
438 assert!(
440 participant0
441 .receive_share(1, share.clone(), &commitments1)
442 .is_ok()
443 );
444
445 assert!(participant0.receive_share(1, share, &commitments1).is_err());
447 }
448
449 #[test]
450 fn test_dkg_different_thresholds() {
451 for (total, threshold) in [(3, 2), (5, 3), (7, 4)] {
452 let params = DkgParams::new(total, threshold);
453
454 let mut participants: Vec<_> = (0..total)
455 .map(|i| DkgParticipant::new(¶ms, i))
456 .collect();
457
458 let commitments: Vec<_> = participants.iter().map(|p| p.get_commitments()).collect();
459
460 for i in 0..total {
462 for j in 0..total {
463 if i != j {
464 let share = participants[j].generate_share(i).unwrap();
465 participants[i]
466 .receive_share(j, share, &commitments[j])
467 .unwrap();
468 }
469 }
470 }
471
472 for p in &participants {
474 assert!(p.get_secret_share().is_ok());
475 }
476
477 let pk = aggregate_public_key(&commitments);
479 assert_ne!(pk, RistrettoPoint::identity());
480 }
481 }
482
483 #[test]
484 fn test_evaluate_polynomial() {
485 let coefficients = vec![
486 Scalar::from(1u64), Scalar::from(2u64), Scalar::from(3u64), ];
490
491 let x = Scalar::from(2u64);
493 let result = evaluate_polynomial(&coefficients, x);
494 let expected = Scalar::from(17u64);
495
496 assert_eq!(result, expected);
497 }
498
499 #[test]
500 fn test_dkg_partial_shares() {
501 let params = DkgParams::new(5, 3);
502 let mut participants: Vec<_> = (0..5).map(|i| DkgParticipant::new(¶ms, i)).collect();
503
504 let commitments: Vec<_> = participants.iter().map(|p| p.get_commitments()).collect();
505
506 for i in 0..5 {
508 for j in 0..2 {
509 if i != j {
510 let share = participants[j].generate_share(i).unwrap();
511 participants[i]
512 .receive_share(j, share, &commitments[j])
513 .unwrap();
514 }
515 }
516 }
517
518 let result = participants[0].get_secret_share();
521 assert!(result.is_ok());
522
523 let partial_share = result.unwrap();
525 assert_ne!(partial_share, Scalar::ZERO);
526 }
527
528 #[test]
529 fn test_dkg_serialization() {
530 let params = DkgParams::new(3, 2);
531 let participant = DkgParticipant::new(¶ms, 0);
532
533 let commitments = participant.get_commitments();
534
535 let serialized = crate::codec::encode(&commitments).unwrap();
537
538 let deserialized: DkgCommitments = crate::codec::decode(&serialized).unwrap();
540
541 assert_eq!(
542 commitments.commitments.len(),
543 deserialized.commitments.len()
544 );
545 for (orig, deser) in commitments
546 .commitments
547 .iter()
548 .zip(deserialized.commitments.iter())
549 {
550 assert_eq!(orig, deser);
551 }
552 }
553
554 #[test]
555 fn test_dkg_invalid_share_verification() {
556 let params = DkgParams::new(3, 2);
557 let mut participant0 = DkgParticipant::new(¶ms, 0);
558 let participant1 = DkgParticipant::new(¶ms, 1);
559
560 let commitments1 = participant1.get_commitments();
561
562 let valid_share = participant1.generate_share(0).unwrap();
564
565 let mut corrupted_value = valid_share.value.clone();
567 corrupted_value[0] = corrupted_value[0].wrapping_add(1); let invalid_share = DkgShare {
569 value: corrupted_value,
570 };
571
572 let result = participant0.receive_share(1, invalid_share, &commitments1);
574 assert!(result.is_err());
575 assert!(matches!(result, Err(DkgError::ShareVerificationFailed)));
576 }
577
578 #[test]
579 fn test_dkg_commitment_consistency() {
580 let params = DkgParams::new(3, 2);
581
582 let participants: Vec<_> = (0..3).map(|i| DkgParticipant::new(¶ms, i)).collect();
584
585 let commitments: Vec<_> = participants.iter().map(|p| p.get_commitments()).collect();
587
588 for commitment in &commitments {
590 assert_eq!(commitment.commitments.len(), params.threshold);
591 }
592
593 let identity_bytes = RistrettoPoint::identity().compress().to_bytes();
595 for commitment in &commitments {
596 for point_bytes in &commitment.commitments {
597 assert_ne!(point_bytes.as_slice(), identity_bytes.as_slice());
598 }
599 }
600
601 let pk1 = aggregate_public_key(&commitments);
603 let pk2 = aggregate_public_key(&commitments);
604 assert_eq!(pk1, pk2);
605 }
606}