1use std::collections::BTreeMap;
2
3use crate::curve::DklsCurve;
4use crate::protocols::signature::EcdsaSignature;
5use crate::protocols::signing::{
6 Broadcast3to4, KeepPhase1to2, KeepPhase2to3, SignData, TransmitPhase1to2, TransmitPhase2to3,
7 UniqueKeep1to2, UniqueKeep2to3,
8};
9use crate::protocols::{Abort, AbortReason, Party, PartyIndex};
10
11pub struct SignSession<'a, C: DklsCurve> {
12 party: &'a Party<C>,
13 data: SignData,
14 phase1_to_2: Option<(UniqueKeep1to2<C>, BTreeMap<PartyIndex, KeepPhase1to2<C>>)>,
15 phase2_to_3: Option<(UniqueKeep2to3<C>, BTreeMap<PartyIndex, KeepPhase2to3<C>>)>,
16 x_coord: Option<String>,
17}
18
19impl<'a, C: DklsCurve> SignSession<'a, C> {
20 pub fn new(
21 party: &'a Party<C>,
22 data: SignData,
23 ) -> Result<(Self, Vec<TransmitPhase1to2>), Abort> {
24 let (unique_kept, kept, transmit) = party.sign_phase1(&data)?;
25 let session = Self {
26 party,
27 data,
28 phase1_to_2: Some((unique_kept, kept)),
29 phase2_to_3: None,
30 x_coord: None,
31 };
32 Ok((session, transmit))
33 }
34
35 pub fn phase2(
36 &mut self,
37 received: &[TransmitPhase1to2],
38 ) -> Result<Vec<TransmitPhase2to3<C>>, Abort> {
39 let (unique_kept, kept) = self.phase1_to_2.take().ok_or_else(|| {
40 Abort::recoverable(
41 self.party.party_index,
42 AbortReason::PhaseCalledOutOfOrder {
43 phase: "phase2 called out of order".into(),
44 },
45 )
46 })?;
47 let (new_unique, new_kept, transmit) =
48 self.party
49 .sign_phase2(&self.data, &unique_kept, &kept, received)?;
50 self.phase2_to_3 = Some((new_unique, new_kept));
51 Ok(transmit)
52 }
53
54 pub fn phase3(&mut self, received: &[TransmitPhase2to3<C>]) -> Result<Broadcast3to4<C>, Abort> {
55 let (unique_kept, kept) = self.phase2_to_3.take().ok_or_else(|| {
56 Abort::recoverable(
57 self.party.party_index,
58 AbortReason::PhaseCalledOutOfOrder {
59 phase: "phase3 called out of order".into(),
60 },
61 )
62 })?;
63 let (x_coord, broadcast) =
64 self.party
65 .sign_phase3(&self.data, &unique_kept, &kept, received)?;
66 self.x_coord = Some(x_coord);
67 Ok(broadcast)
68 }
69
70 pub fn phase4(
71 mut self,
72 received: &[Broadcast3to4<C>],
73 normalize: bool,
74 ) -> Result<EcdsaSignature, Abort> {
75 let x_coord = self.x_coord.take().ok_or_else(|| {
76 Abort::recoverable(
77 self.party.party_index,
78 AbortReason::PhaseCalledOutOfOrder {
79 phase: "phase4 called out of order".into(),
80 },
81 )
82 })?;
83 let (s_hex, recovery_id) = self
84 .party
85 .sign_phase4(&self.data, &x_coord, received, normalize)?;
86
87 let mut r = [0u8; 32];
88 let mut s = [0u8; 32];
89 hex::decode_to_slice(&x_coord, &mut r).map_err(|e| {
90 Abort::recoverable(
91 self.party.party_index,
92 AbortReason::InvalidHex {
93 detail: format!("invalid r hex: {e}"),
94 },
95 )
96 })?;
97 hex::decode_to_slice(&s_hex, &mut s).map_err(|e| {
98 Abort::recoverable(
99 self.party.party_index,
100 AbortReason::InvalidHex {
101 detail: format!("invalid s hex: {e}"),
102 },
103 )
104 })?;
105 Ok(EcdsaSignature { r, s, recovery_id })
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use k256::elliptic_curve::Field;
113 use k256::Scalar;
114 use k256::Secp256k1;
115
116 use crate::protocols::re_key::re_key;
117 use crate::protocols::signing::{verify_ecdsa_signature, SignData};
118 use crate::protocols::Parameters;
119 use crate::utilities::hashes::tagged_hash;
120 use crate::utilities::rng;
121 use rand::RngExt;
122
123 #[test]
124 fn test_sign_session_happy_path() {
125 let threshold = rng::get_rng().random_range(2..=5);
126 let offset = rng::get_rng().random_range(0..=5);
127 let parameters = Parameters {
128 threshold,
129 share_count: threshold + offset,
130 };
131
132 let session_id = rng::get_rng().random::<[u8; 32]>();
133 let secret_key = Scalar::random(&mut rng::get_rng());
134 let (parties, _) = re_key::<Secp256k1>(¶meters, &session_id, &secret_key, None, |_| {
135 String::new()
136 });
137
138 let sign_id = rng::get_rng().random::<[u8; 32]>();
139 let message_to_sign = tagged_hash(b"test-sign", &["Message to sign!".as_bytes()]);
140 let executing_parties: Vec<u8> = Vec::from_iter(1..=parameters.threshold);
141
142 let mut all_data: BTreeMap<u8, SignData> = BTreeMap::new();
144 for party_index in executing_parties.clone() {
145 let counterparties: Vec<PartyIndex> = executing_parties
146 .iter()
147 .filter(|&&i| i != party_index)
148 .map(|&i| PartyIndex::new(i).unwrap())
149 .collect();
150 all_data.insert(
151 party_index,
152 SignData {
153 sign_id: sign_id.to_vec(),
154 counterparties,
155 message_hash: message_to_sign,
156 },
157 );
158 }
159
160 let mut sessions: BTreeMap<u8, SignSession<'_, Secp256k1>> = BTreeMap::new();
162 let mut transmit_1to2: BTreeMap<u8, Vec<TransmitPhase1to2>> = BTreeMap::new();
163 for party_index in executing_parties.clone() {
164 let (session, transmit) = SignSession::new(
165 &parties[(party_index - 1) as usize],
166 all_data.get(&party_index).unwrap().clone(),
167 )
168 .unwrap();
169 sessions.insert(party_index, session);
170 transmit_1to2.insert(party_index, transmit);
171 }
172
173 let mut received_1to2: BTreeMap<u8, Vec<TransmitPhase1to2>> = BTreeMap::new();
175 for &party_index in &executing_parties {
176 let pi = PartyIndex::new(party_index).unwrap();
177 let msgs: Vec<TransmitPhase1to2> = transmit_1to2
178 .values()
179 .flatten()
180 .filter(|m| m.parties.receiver == pi)
181 .cloned()
182 .collect();
183 received_1to2.insert(party_index, msgs);
184 }
185
186 let mut transmit_2to3: BTreeMap<u8, Vec<TransmitPhase2to3<Secp256k1>>> = BTreeMap::new();
188 for party_index in executing_parties.clone() {
189 let transmit = sessions
190 .get_mut(&party_index)
191 .unwrap()
192 .phase2(received_1to2.get(&party_index).unwrap())
193 .unwrap();
194 transmit_2to3.insert(party_index, transmit);
195 }
196
197 let mut received_2to3: BTreeMap<u8, Vec<TransmitPhase2to3<Secp256k1>>> = BTreeMap::new();
199 for &party_index in &executing_parties {
200 let pi = PartyIndex::new(party_index).unwrap();
201 let msgs: Vec<TransmitPhase2to3<Secp256k1>> = transmit_2to3
202 .values()
203 .flatten()
204 .filter(|m| m.parties.receiver == pi)
205 .cloned()
206 .collect();
207 received_2to3.insert(party_index, msgs);
208 }
209
210 let mut broadcasts: Vec<Broadcast3to4<Secp256k1>> =
212 Vec::with_capacity(parameters.threshold as usize);
213 for party_index in executing_parties.clone() {
214 let broadcast = sessions
215 .get_mut(&party_index)
216 .unwrap()
217 .phase3(received_2to3.get(&party_index).unwrap())
218 .unwrap();
219 broadcasts.push(broadcast);
220 }
221
222 let some_index = executing_parties[0];
224 let session = sessions.remove(&some_index).unwrap();
225 let signature = session.phase4(&broadcasts, true).unwrap();
226
227 assert_ne!(signature.r, [0u8; 32]);
229 assert_ne!(signature.s, [0u8; 32]);
230
231 let r_hex = hex::encode(signature.r);
233 let s_hex = hex::encode(signature.s);
234 assert!(verify_ecdsa_signature::<Secp256k1>(
235 &message_to_sign,
236 &parties[0].pk,
237 &r_hex,
238 &s_hex,
239 ));
240 }
241
242 #[test]
243 fn test_sign_session_phase_order_error() {
244 let parameters = Parameters {
245 threshold: 2,
246 share_count: 2,
247 };
248 let session_id = rng::get_rng().random::<[u8; 32]>();
249 let secret_key = Scalar::random(&mut rng::get_rng());
250 let (parties, _) = re_key::<Secp256k1>(¶meters, &session_id, &secret_key, None, |_| {
251 String::new()
252 });
253
254 let data = SignData {
255 sign_id: rng::get_rng().random::<[u8; 32]>().to_vec(),
256 counterparties: vec![PartyIndex::new(2).unwrap()],
257 message_hash: tagged_hash(b"test-sign", &["Message to sign!".as_bytes()]),
258 };
259
260 let (mut session, _) = SignSession::new(&parties[0], data).unwrap();
261
262 let result = session.phase3(&[]);
264 assert!(result.is_err());
265 }
266}