Skip to main content

dkls23_core/protocols/
sign_session.rs

1use std::collections::BTreeMap;
2
3use crate::curve::DklsCurve;
4use crate::protocols::signature::EcdsaSignature;
5use crate::protocols::signing::{
6    Broadcast3to4, KeepPhase1to2, KeepPhase2to3, SignData, TransmitPhase1to2, TransmitPhase2to3,
7    UniqueKeep1to2, UniqueKeep2to3,
8};
9use crate::protocols::{Abort, AbortReason, Party, PartyIndex};
10
11pub struct SignSession<'a, C: DklsCurve> {
12    party: &'a Party<C>,
13    data: SignData,
14    phase1_to_2: Option<(UniqueKeep1to2<C>, BTreeMap<PartyIndex, KeepPhase1to2<C>>)>,
15    phase2_to_3: Option<(UniqueKeep2to3<C>, BTreeMap<PartyIndex, KeepPhase2to3<C>>)>,
16    x_coord: Option<String>,
17}
18
19impl<'a, C: DklsCurve> SignSession<'a, C> {
20    pub fn new(
21        party: &'a Party<C>,
22        data: SignData,
23    ) -> Result<(Self, Vec<TransmitPhase1to2>), Abort> {
24        let (unique_kept, kept, transmit) = party.sign_phase1(&data)?;
25        let session = Self {
26            party,
27            data,
28            phase1_to_2: Some((unique_kept, kept)),
29            phase2_to_3: None,
30            x_coord: None,
31        };
32        Ok((session, transmit))
33    }
34
35    pub fn phase2(
36        &mut self,
37        received: &[TransmitPhase1to2],
38    ) -> Result<Vec<TransmitPhase2to3<C>>, Abort> {
39        let (unique_kept, kept) = self.phase1_to_2.take().ok_or_else(|| {
40            Abort::recoverable(
41                self.party.party_index,
42                AbortReason::PhaseCalledOutOfOrder {
43                    phase: "phase2 called out of order".into(),
44                },
45            )
46        })?;
47        let (new_unique, new_kept, transmit) =
48            self.party
49                .sign_phase2(&self.data, &unique_kept, &kept, received)?;
50        self.phase2_to_3 = Some((new_unique, new_kept));
51        Ok(transmit)
52    }
53
54    pub fn phase3(&mut self, received: &[TransmitPhase2to3<C>]) -> Result<Broadcast3to4<C>, Abort> {
55        let (unique_kept, kept) = self.phase2_to_3.take().ok_or_else(|| {
56            Abort::recoverable(
57                self.party.party_index,
58                AbortReason::PhaseCalledOutOfOrder {
59                    phase: "phase3 called out of order".into(),
60                },
61            )
62        })?;
63        let (x_coord, broadcast) =
64            self.party
65                .sign_phase3(&self.data, &unique_kept, &kept, received)?;
66        self.x_coord = Some(x_coord);
67        Ok(broadcast)
68    }
69
70    pub fn phase4(
71        mut self,
72        received: &[Broadcast3to4<C>],
73        normalize: bool,
74    ) -> Result<EcdsaSignature, Abort> {
75        let x_coord = self.x_coord.take().ok_or_else(|| {
76            Abort::recoverable(
77                self.party.party_index,
78                AbortReason::PhaseCalledOutOfOrder {
79                    phase: "phase4 called out of order".into(),
80                },
81            )
82        })?;
83        let (s_hex, recovery_id) = self
84            .party
85            .sign_phase4(&self.data, &x_coord, received, normalize)?;
86
87        let mut r = [0u8; 32];
88        let mut s = [0u8; 32];
89        hex::decode_to_slice(&x_coord, &mut r).map_err(|e| {
90            Abort::recoverable(
91                self.party.party_index,
92                AbortReason::InvalidHex {
93                    detail: format!("invalid r hex: {e}"),
94                },
95            )
96        })?;
97        hex::decode_to_slice(&s_hex, &mut s).map_err(|e| {
98            Abort::recoverable(
99                self.party.party_index,
100                AbortReason::InvalidHex {
101                    detail: format!("invalid s hex: {e}"),
102                },
103            )
104        })?;
105        Ok(EcdsaSignature { r, s, recovery_id })
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use k256::elliptic_curve::Field;
113    use k256::Scalar;
114    use k256::Secp256k1;
115
116    use crate::protocols::re_key::re_key;
117    use crate::protocols::signing::{verify_ecdsa_signature, SignData};
118    use crate::protocols::Parameters;
119    use crate::utilities::hashes::tagged_hash;
120    use crate::utilities::rng;
121    use rand::RngExt;
122
123    #[test]
124    fn test_sign_session_happy_path() {
125        let threshold = rng::get_rng().random_range(2..=5);
126        let offset = rng::get_rng().random_range(0..=5);
127        let parameters = Parameters {
128            threshold,
129            share_count: threshold + offset,
130        };
131
132        let session_id = rng::get_rng().random::<[u8; 32]>();
133        let secret_key = Scalar::random(&mut rng::get_rng());
134        let (parties, _) = re_key::<Secp256k1>(&parameters, &session_id, &secret_key, None, |_| {
135            String::new()
136        });
137
138        let sign_id = rng::get_rng().random::<[u8; 32]>();
139        let message_to_sign = tagged_hash(b"test-sign", &["Message to sign!".as_bytes()]);
140        let executing_parties: Vec<u8> = Vec::from_iter(1..=parameters.threshold);
141
142        // Build SignData per party.
143        let mut all_data: BTreeMap<u8, SignData> = BTreeMap::new();
144        for party_index in executing_parties.clone() {
145            let counterparties: Vec<PartyIndex> = executing_parties
146                .iter()
147                .filter(|&&i| i != party_index)
148                .map(|&i| PartyIndex::new(i).unwrap())
149                .collect();
150            all_data.insert(
151                party_index,
152                SignData {
153                    sign_id: sign_id.to_vec(),
154                    counterparties,
155                    message_hash: message_to_sign,
156                },
157            );
158        }
159
160        // Phase 1 — create sessions.
161        let mut sessions: BTreeMap<u8, SignSession<'_, Secp256k1>> = BTreeMap::new();
162        let mut transmit_1to2: BTreeMap<u8, Vec<TransmitPhase1to2>> = BTreeMap::new();
163        for party_index in executing_parties.clone() {
164            let (session, transmit) = SignSession::new(
165                &parties[(party_index - 1) as usize],
166                all_data.get(&party_index).unwrap().clone(),
167            )
168            .unwrap();
169            sessions.insert(party_index, session);
170            transmit_1to2.insert(party_index, transmit);
171        }
172
173        // Route round 1 messages.
174        let mut received_1to2: BTreeMap<u8, Vec<TransmitPhase1to2>> = BTreeMap::new();
175        for &party_index in &executing_parties {
176            let pi = PartyIndex::new(party_index).unwrap();
177            let msgs: Vec<TransmitPhase1to2> = transmit_1to2
178                .values()
179                .flatten()
180                .filter(|m| m.parties.receiver == pi)
181                .cloned()
182                .collect();
183            received_1to2.insert(party_index, msgs);
184        }
185
186        // Phase 2.
187        let mut transmit_2to3: BTreeMap<u8, Vec<TransmitPhase2to3<Secp256k1>>> = BTreeMap::new();
188        for party_index in executing_parties.clone() {
189            let transmit = sessions
190                .get_mut(&party_index)
191                .unwrap()
192                .phase2(received_1to2.get(&party_index).unwrap())
193                .unwrap();
194            transmit_2to3.insert(party_index, transmit);
195        }
196
197        // Route round 2 messages.
198        let mut received_2to3: BTreeMap<u8, Vec<TransmitPhase2to3<Secp256k1>>> = BTreeMap::new();
199        for &party_index in &executing_parties {
200            let pi = PartyIndex::new(party_index).unwrap();
201            let msgs: Vec<TransmitPhase2to3<Secp256k1>> = transmit_2to3
202                .values()
203                .flatten()
204                .filter(|m| m.parties.receiver == pi)
205                .cloned()
206                .collect();
207            received_2to3.insert(party_index, msgs);
208        }
209
210        // Phase 3.
211        let mut broadcasts: Vec<Broadcast3to4<Secp256k1>> =
212            Vec::with_capacity(parameters.threshold as usize);
213        for party_index in executing_parties.clone() {
214            let broadcast = sessions
215                .get_mut(&party_index)
216                .unwrap()
217                .phase3(received_2to3.get(&party_index).unwrap())
218                .unwrap();
219            broadcasts.push(broadcast);
220        }
221
222        // Phase 4 — consume sessions.
223        let some_index = executing_parties[0];
224        let session = sessions.remove(&some_index).unwrap();
225        let signature = session.phase4(&broadcasts, true).unwrap();
226
227        // Verify the EcdsaSignature fields are populated.
228        assert_ne!(signature.r, [0u8; 32]);
229        assert_ne!(signature.s, [0u8; 32]);
230
231        // Cross-check with verify_ecdsa_signature.
232        let r_hex = hex::encode(signature.r);
233        let s_hex = hex::encode(signature.s);
234        assert!(verify_ecdsa_signature::<Secp256k1>(
235            &message_to_sign,
236            &parties[0].pk,
237            &r_hex,
238            &s_hex,
239        ));
240    }
241
242    #[test]
243    fn test_sign_session_phase_order_error() {
244        let parameters = Parameters {
245            threshold: 2,
246            share_count: 2,
247        };
248        let session_id = rng::get_rng().random::<[u8; 32]>();
249        let secret_key = Scalar::random(&mut rng::get_rng());
250        let (parties, _) = re_key::<Secp256k1>(&parameters, &session_id, &secret_key, None, |_| {
251            String::new()
252        });
253
254        let data = SignData {
255            sign_id: rng::get_rng().random::<[u8; 32]>().to_vec(),
256            counterparties: vec![PartyIndex::new(2).unwrap()],
257            message_hash: tagged_hash(b"test-sign", &["Message to sign!".as_bytes()]),
258        };
259
260        let (mut session, _) = SignSession::new(&parties[0], data).unwrap();
261
262        // Skip phase2, call phase3 directly — should fail.
263        let result = session.phase3(&[]);
264        assert!(result.is_err());
265    }
266}