1use elliptic_curve::{Field, Group, ScalarPrimitive};
2
3use crate::compat::CSCurve;
4use crate::participants::ParticipantCounter;
5use crate::protocol::internal::{make_protocol, Context, SharedChannel};
6use crate::protocol::{InitializationError, Protocol};
7use crate::triples::{TriplePub, TripleShare};
8use crate::KeygenOutput;
9use crate::{
10 participants::ParticipantList,
11 protocol::{Participant, ProtocolError},
12};
13
14#[derive(Debug, Clone)]
19pub struct PresignOutput<C: CSCurve> {
20 pub big_r: C::AffinePoint,
22 pub k: C::Scalar,
24 pub sigma: C::Scalar,
26}
27
28#[derive(Debug, Clone)]
30pub struct PresignArguments<C: CSCurve> {
31 pub triple0: (TripleShare<C>, TriplePub<C>),
33 pub triple1: (TripleShare<C>, TriplePub<C>),
35 pub keygen_out: KeygenOutput<C>,
37 pub threshold: usize,
39}
40
41async fn do_presign<C: CSCurve>(
42 mut chan: SharedChannel,
43 participants: ParticipantList,
44 me: Participant,
45 args: PresignArguments<C>,
46) -> Result<PresignOutput<C>, ProtocolError> {
47 let big_k: C::ProjectivePoint = args.triple0.1.big_a.into();
49 let big_d = args.triple0.1.big_b;
50 let big_kd = args.triple0.1.big_c;
51
52 let big_x: C::ProjectivePoint = args.keygen_out.public_key.into();
53
54 let big_a: C::ProjectivePoint = args.triple1.1.big_a.into();
55 let big_b: C::ProjectivePoint = args.triple1.1.big_b.into();
56
57 let lambda = participants.lagrange::<C>(me);
58
59 let k_i = args.triple0.0.a;
60 let k_prime_i = lambda * k_i;
61 let kd_i: C::Scalar = lambda * args.triple0.0.c;
62
63 let a_i = args.triple1.0.a;
64 let b_i = args.triple1.0.b;
65 let c_i = args.triple1.0.c;
66 let a_prime_i = lambda * a_i;
67 let b_prime_i = lambda * b_i;
68
69 let x_prime_i = lambda * args.keygen_out.private_share;
70
71 let wait0 = chan.next_waitpoint();
73 {
74 let kd_i: ScalarPrimitive<C> = kd_i.into();
75 chan.send_many(wait0, &kd_i).await;
76 }
77
78 let ka_i: C::Scalar = k_prime_i + a_prime_i;
80 let xb_i: C::Scalar = x_prime_i + b_prime_i;
81
82 let wait1 = chan.next_waitpoint();
84 {
85 let ka_i: ScalarPrimitive<C> = ka_i.into();
86 let xb_i: ScalarPrimitive<C> = xb_i.into();
87 chan.send_many(wait1, &(ka_i, xb_i)).await;
88 }
89
90 let mut kd = kd_i;
92 let mut seen = ParticipantCounter::new(&participants);
93 seen.put(me);
94 while !seen.full() {
95 let (from, kd_j): (_, ScalarPrimitive<C>) = chan.recv(wait0).await?;
96 if !seen.put(from) {
97 continue;
98 }
99 kd += C::Scalar::from(kd_j);
100 }
101
102 if big_kd != (C::ProjectivePoint::generator() * kd).into() {
104 return Err(ProtocolError::AssertionFailed(
105 "received incorrect shares of kd".to_string(),
106 ));
107 }
108
109 let mut ka = ka_i;
111 let mut xb = xb_i;
112 seen.clear();
113 seen.put(me);
114 while !seen.full() {
115 let (from, (ka_j, xb_j)): (_, (ScalarPrimitive<C>, ScalarPrimitive<C>)) =
116 chan.recv(wait1).await?;
117 if !seen.put(from) {
118 continue;
119 }
120 ka += C::Scalar::from(ka_j);
121 xb += C::Scalar::from(xb_j);
122 }
123
124 if (C::ProjectivePoint::generator() * ka != big_k + big_a)
126 || (C::ProjectivePoint::generator() * xb != big_x + big_b)
127 {
128 return Err(ProtocolError::AssertionFailed(
129 "received incorrect shares of additive triple phase.".to_string(),
130 ));
131 }
132
133 let kd_inv: Option<C::Scalar> = kd.invert().into();
135 let kd_inv =
136 kd_inv.ok_or_else(|| ProtocolError::AssertionFailed("failed to invert kd".to_string()))?;
137 let big_r = (C::ProjectivePoint::from(big_d) * kd_inv).into();
138
139 let sigma_i = ka * args.keygen_out.private_share - xb * a_i + c_i;
141
142 Ok(PresignOutput {
143 big_r,
144 k: k_i,
145 sigma: sigma_i,
146 })
147}
148
149pub fn presign<C: CSCurve>(
157 participants: &[Participant],
158 me: Participant,
159 args: PresignArguments<C>,
160) -> Result<impl Protocol<Output = PresignOutput<C>>, InitializationError> {
161 if participants.len() < 2 {
162 return Err(InitializationError::BadParameters(format!(
163 "participant count cannot be < 2, found: {}",
164 participants.len()
165 )));
166 };
167 if args.threshold > participants.len() {
169 return Err(InitializationError::BadParameters(
170 "threshold must be <= participant count".to_string(),
171 ));
172 }
173 if args.threshold != args.triple0.1.threshold || args.threshold != args.triple1.1.threshold {
179 return Err(InitializationError::BadParameters(
180 "New threshold must match the threshold of both triples".to_string(),
181 ));
182 }
183
184 let participants = ParticipantList::new(participants).ok_or_else(|| {
185 InitializationError::BadParameters("participant list cannot contain duplicates".to_string())
186 })?;
187
188 let ctx = Context::new();
189 let fut = do_presign(ctx.shared_channel(), participants, me, args);
190 Ok(make_protocol(ctx, fut))
191}
192
193#[cfg(test)]
194mod test {
195 use super::*;
196 use rand_core::OsRng;
197
198 use crate::{math::Polynomial, protocol::run_protocol, triples};
199
200 use k256::{ProjectivePoint, Secp256k1};
201
202 #[test]
203 fn test_presign() {
204 let participants = vec![
205 Participant::from(0u32),
206 Participant::from(1u32),
207 Participant::from(2u32),
208 Participant::from(3u32),
209 ];
210 let original_threshold = 2;
211 let f = Polynomial::<Secp256k1>::random(&mut OsRng, original_threshold);
212 let big_x = (ProjectivePoint::GENERATOR * f.evaluate_zero()).to_affine();
213 let threshold = 2;
214
215 let (triple0_pub, triple0_shares) =
216 triples::deal(&mut OsRng, &participants, original_threshold);
217 let (triple1_pub, triple1_shares) =
218 triples::deal(&mut OsRng, &participants, original_threshold);
219
220 #[allow(clippy::type_complexity)]
221 let mut protocols: Vec<(
222 Participant,
223 Box<dyn Protocol<Output = PresignOutput<Secp256k1>>>,
224 )> = Vec::with_capacity(participants.len());
225
226 for ((p, triple0), triple1) in participants
227 .iter()
228 .take(3)
229 .zip(triple0_shares.into_iter())
230 .zip(triple1_shares.into_iter())
231 {
232 let protocol = presign(
233 &participants[..3],
234 *p,
235 PresignArguments {
236 triple0: (triple0, triple0_pub.clone()),
237 triple1: (triple1, triple1_pub.clone()),
238 keygen_out: KeygenOutput {
239 private_share: f.evaluate(&p.scalar::<Secp256k1>()),
240 public_key: big_x,
241 },
242 threshold,
243 },
244 );
245 assert!(protocol.is_ok());
246 let protocol = protocol.unwrap();
247 protocols.push((*p, Box::new(protocol)));
248 }
249
250 let result = run_protocol(protocols);
251 assert!(result.is_ok());
252 let result = result.unwrap();
253
254 assert!(result.len() == 3);
255 assert_eq!(result[0].1.big_r, result[1].1.big_r);
256 assert_eq!(result[1].1.big_r, result[2].1.big_r);
257
258 let big_k = result[2].1.big_r;
259
260 let participants = vec![result[0].0, result[1].0];
261 let k_shares = vec![result[0].1.k, result[1].1.k];
262 let sigma_shares = vec![result[0].1.sigma, result[1].1.sigma];
263 let p_list = ParticipantList::new(&participants).unwrap();
264 let k = p_list.lagrange::<Secp256k1>(participants[0]) * k_shares[0]
265 + p_list.lagrange::<Secp256k1>(participants[1]) * k_shares[1];
266 assert_eq!(ProjectivePoint::GENERATOR * k.invert().unwrap(), big_k);
267 let sigma = p_list.lagrange::<Secp256k1>(participants[0]) * sigma_shares[0]
268 + p_list.lagrange::<Secp256k1>(participants[1]) * sigma_shares[1];
269 assert_eq!(sigma, k * f.evaluate_zero());
270 }
271}