1use log::error;
2use rand::RngCore;
3use thiserror::Error;
4
5use async_trait::async_trait;
6use threshold_bls::group::Curve;
7
8use super::{
9 board::BoardPublisher,
10 primitives::{
11 phases::{Phase0, Phase1, Phase2, Phase3},
12 types::{BundledJustification, BundledResponses, BundledShares, DKGOutput},
13 DKGError,
14 },
15};
16
17#[derive(Debug, Error)]
18pub enum DKGNodeError {
20 #[error("Could not publish to board")]
22 PublisherError,
23 #[error("DKG Error: {0}")]
25 DKGError(#[from] DKGError),
26}
27
28#[derive(Clone, Debug)]
30pub enum Phase2Result<C: Curve, P: Phase3<C>> {
31 Output(DKGOutput<C>),
33 GoToPhase3(P),
35}
36
37type NodeResult<T> = std::result::Result<T, DKGNodeError>;
38
39#[async_trait]
41pub trait DKGPhase<C: Curve, B: BoardPublisher<C>, T> {
42 type Next: Send;
44
45 async fn run(self, board: &mut B, arg: T) -> NodeResult<Self::Next>
48 where
49 C: 'async_trait,
50 T: 'async_trait;
51}
52
53#[async_trait]
54impl<C, B, R, F, P> DKGPhase<C, B, F> for P
55where
56 C: Curve,
57 B: BoardPublisher<C> + Send,
58 P: Phase0<C> + Send,
59 P::Next: Send,
60 R: RngCore,
61 F: Fn() -> R + Send,
62{
63 type Next = P::Next;
64
65 async fn run(mut self, board: &mut B, rng: F) -> NodeResult<Self::Next>
66 where
67 C: 'async_trait,
68 F: 'async_trait,
69 {
70 self.set_rpc_endpoint();
71 let (next, shares) = self.encrypt_shares(rng)?;
72 if let Some(sh) = shares {
73 board.publish_shares(sh).await.map_err(|e| {
74 error!("{:?}", e);
75 DKGNodeError::PublisherError
76 })?;
77 }
78
79 Ok(next)
80 }
81}
82
83#[async_trait]
84impl<C, B, P> DKGPhase<C, B, &[BundledShares<C>]> for P
85where
86 C: Curve,
87 B: BoardPublisher<C> + Send,
88 P: Phase1<C> + Send,
89 P::Next: Send,
90{
91 type Next = P::Next;
92
93 async fn run(
94 self,
95 board: &mut B,
96 shares: &'async_trait [BundledShares<C>],
97 ) -> NodeResult<Self::Next>
98 where
99 C: 'async_trait,
100 {
101 match self.process_shares(shares, true) {
102 Ok((next, bundle)) => {
103 if let Some(bundle) = bundle {
104 board.publish_responses(bundle).await.map_err(|e| {
105 error!("{:?}", e);
106 DKGNodeError::PublisherError
107 })?;
108 }
109
110 Ok(next)
111 }
112 Err(e) => Err(DKGNodeError::DKGError(e)),
113 }
114 }
115}
116
117#[async_trait]
118impl<C, B, P> DKGPhase<C, B, &[BundledResponses]> for P
119where
120 C: Curve,
121 B: BoardPublisher<C> + Send,
122 P: Phase2<C> + Send,
123 P::Next: Send,
124{
125 type Next = Phase2Result<C, P::Next>;
126
127 async fn run(
128 self,
129 board: &mut B,
130 responses: &'async_trait [BundledResponses],
131 ) -> NodeResult<Self::Next>
132 where
133 C: 'async_trait,
134 {
135 match self.process_responses(responses) {
136 Ok(output) => Ok(Phase2Result::Output(output)),
137 Err(next) => {
138 match next {
139 Ok((next, justifications)) => {
140 if let Some(justifications) = justifications {
145 board
146 .publish_justifications(justifications)
147 .await
148 .map_err(|e| {
149 error!("{:?}", e);
150 DKGNodeError::PublisherError
151 })?;
152 }
153
154 Ok(Phase2Result::GoToPhase3(next))
155 }
156 Err(e) => Err(DKGNodeError::DKGError(e)),
157 }
158 }
159 }
160 }
161}
162
163#[async_trait]
164impl<C, B, P> DKGPhase<C, B, &[BundledJustification<C>]> for P
165where
166 C: Curve,
167 B: BoardPublisher<C> + Send,
168 P: Phase3<C> + Send,
169{
170 type Next = DKGOutput<C>;
171
172 async fn run(
173 self,
174 _: &mut B,
175 responses: &'async_trait [BundledJustification<C>],
176 ) -> NodeResult<Self::Next>
177 where
178 C: 'async_trait,
179 {
180 Ok(self.process_justifications(responses)?)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use crate::{
187 primitives::{
188 group::{Group, Node},
189 joint_feldman,
190 },
191 test_helpers::InMemoryBoard,
192 };
193 use threshold_bls::{
194 curve::bn254::{self, PairingCurve as BN254},
195 poly::Idx,
196 sig::{BlindThresholdScheme, G1Scheme, G2Scheme, Scheme, SignatureScheme, ThresholdScheme},
197 };
198
199 use super::*;
200
201 fn bad_phase0<C: Curve, R: RngCore, P: Phase0<C>>(phase0: P, rng: fn() -> R) -> P::Next {
204 let (next, _) = phase0.encrypt_shares(rng).unwrap();
205 next
206 }
207
208 #[tokio::test]
209 async fn dkg_sign_e2e() {
210 let (t, n) = (3, 5);
211 dkg_sign_e2e_curve::<bn254::G1Curve, G1Scheme<BN254>>(n, t).await;
212 dkg_sign_e2e_curve::<bn254::G2Curve, G2Scheme<BN254>>(n, t).await;
213 }
214
215 async fn dkg_sign_e2e_curve<C, S>(n: usize, t: usize)
216 where
217 C: Curve,
218 S: Scheme<Public = <C as Curve>::Point, Private = <C as Curve>::Scalar>
220 + BlindThresholdScheme
221 + ThresholdScheme
222 + SignatureScheme,
223 {
224 let msg = rand::random::<[u8; 32]>().to_vec();
225
226 let outputs = run_dkg::<C, S>(n, t).await;
228
229 let (token, blinded_msg) = S::blind_msg(&msg[..], &mut rand::thread_rng());
231
232 let partial_sigs = outputs
234 .iter()
235 .map(|output| S::sign_blind_partial(&output.share, &blinded_msg[..]).unwrap())
236 .collect::<Vec<_>>();
237
238 let blinded_sig = S::aggregate(t, &partial_sigs).unwrap();
240
241 let unblinded_sig = S::unblind_sig(&token, &blinded_sig).unwrap();
243
244 let pubkey = outputs[0].public.public_key();
246
247 S::verify(&pubkey, &msg, &unblinded_sig).unwrap();
249 }
250
251 async fn run_dkg<C, S>(n: usize, t: usize) -> Vec<DKGOutput<C>>
252 where
253 C: Curve,
254 S: Scheme<Public = <C as Curve>::Point, Private = <C as Curve>::Scalar>,
256 {
257 let rng = &mut rand::thread_rng();
258
259 let (mut board, phase0s) = setup::<C, S, _>(n, t, rng);
260
261 let mut phase1s = Vec::new();
263 for phase0 in phase0s {
264 phase1s.push(phase0.run(&mut board, rand::thread_rng).await.unwrap());
265 }
266
267 let shares = board.shares.clone();
269
270 let mut phase2s = Vec::new();
272 for phase1 in phase1s {
273 phase2s.push(phase1.run(&mut board, &shares).await.unwrap());
274 }
275
276 let responses = board.responses.clone();
278
279 let mut results = Vec::new();
280 for phase2 in phase2s {
281 results.push(phase2.run(&mut board, &responses).await.unwrap());
282 }
283
284 let outputs = results
286 .into_iter()
287 .map(|res| match res {
288 Phase2Result::Output(out) => out,
289 Phase2Result::GoToPhase3(_) => unreachable!("should not get here"),
290 })
291 .collect::<Vec<_>>();
292 assert!(is_all_same(outputs.iter().map(|output| &output.public)));
293
294 outputs
295 }
296
297 #[tokio::test]
298 async fn not_enough_validator_shares() {
299 let (t, n) = (6, 10);
300 let bad = t + 1;
301 let honest = n - bad;
302
303 let rng = &mut rand::thread_rng();
304 let (mut board, phase0s) = setup::<bn254::G1Curve, G1Scheme<BN254>, _>(n, t, rng);
305
306 let mut phase1s = Vec::new();
307 for (i, phase0) in phase0s.into_iter().enumerate() {
308 let phase1 = if i < bad {
309 bad_phase0(phase0, rand::thread_rng)
310 } else {
311 phase0.run(&mut board, rand::thread_rng).await.unwrap()
312 };
313 phase1s.push(phase1);
314 }
315
316 let shares = board.shares.clone();
318
319 let mut errs = Vec::new();
321 for phase1 in phase1s {
322 let err = match phase1.run(&mut board, &shares).await.unwrap_err() {
323 DKGNodeError::DKGError(err) => err,
324 _ => panic!("should get dkg error"),
325 };
326 errs.push(err);
327 }
328
329 for err in &errs[..bad] {
332 match err {
333 DKGError::NotEnoughValidShares(got, required, _) => {
334 assert_eq!(*got, honest);
335 assert_eq!(*required, t);
336 }
337 _ => panic!("should not get here"),
338 };
339 }
340
341 for err in &errs[bad..] {
344 match err {
345 DKGError::NotEnoughValidShares(got, required, _) => {
346 assert_eq!(*got, honest - 1);
347 assert_eq!(*required, t);
348 }
349 _ => panic!("should not get here"),
350 };
351 }
352 }
353
354 #[tokio::test]
355 async fn dkg_phase3() {
356 let (t, n) = (5, 8);
357 let bad = 2; let rng = &mut rand::thread_rng();
360 let (mut board, phase0s) = setup::<bn254::G1Curve, G1Scheme<BN254>, _>(n, t, rng);
361
362 let mut phase1s = Vec::new();
363 for (i, phase0) in phase0s.into_iter().enumerate() {
364 let phase1 = if i < bad {
365 bad_phase0(phase0, rand::thread_rng)
366 } else {
367 phase0.run(&mut board, rand::thread_rng).await.unwrap()
368 };
369 phase1s.push(phase1);
370 }
371
372 let shares = board.shares.clone();
374
375 let mut phase2s = Vec::new();
377 for phase1 in phase1s {
378 phase2s.push(phase1.run(&mut board, &shares).await.unwrap());
379 }
380
381 let responses = board.responses.clone();
383
384 let mut results = Vec::new();
385 for phase2 in phase2s {
386 results.push(phase2.run(&mut board, &responses).await.unwrap());
387 }
388
389 let phase3s = results
390 .into_iter()
391 .map(|res| match res {
392 Phase2Result::GoToPhase3(p3) => p3,
393 _ => unreachable!("should not get here"),
394 })
395 .collect::<Vec<_>>();
396
397 let justifications = board.justifs.clone();
398
399 let mut outputs = Vec::new();
400 for phase3 in phase3s {
401 outputs.push(phase3.run(&mut board, &justifications).await.unwrap());
402 }
403
404 assert!(is_all_same(outputs.iter().map(|output| &output.qual)));
406
407 assert!(is_all_same(
409 outputs[bad..].iter().map(|output| &output.public)
410 ));
411
412 let pubkey = &outputs[bad].public;
414 for output in &outputs[..bad] {
415 assert_ne!(&output.public, pubkey);
416 }
417 }
418
419 fn setup<C, S, R: rand::RngCore>(
420 n: usize,
421 t: usize,
422 rng: &mut R,
423 ) -> (InMemoryBoard<C>, Vec<joint_feldman::DKG<C>>)
424 where
425 C: Curve,
426 S: Scheme<Public = <C as Curve>::Point, Private = <C as Curve>::Scalar>,
428 {
429 let keypairs = (0..n).map(|_| S::keypair(rng)).collect::<Vec<_>>();
431
432 let nodes = keypairs
433 .iter()
434 .enumerate()
435 .map(|(i, (_, public))| Node::<C>::new(i as Idx, public.clone()))
436 .collect::<Vec<_>>();
437
438 let group = Group::new(nodes, t).unwrap();
441
442 let phase0s = keypairs
444 .iter()
445 .map(|(private, _)| {
446 joint_feldman::DKG::new(private.clone(), String::from(""), group.clone()).unwrap()
447 })
448 .collect::<Vec<_>>();
449
450 let board = InMemoryBoard::<C>::new();
452
453 (board, phase0s)
454 }
455
456 fn is_all_same<T: PartialEq>(mut arr: impl Iterator<Item = T>) -> bool {
457 let first = arr.next().unwrap();
458 arr.all(|item| item == first)
459 }
460}