1#[cfg(not(feature = "std"))]
38use alloc::vec::Vec;
39
40use crate::{
41 error::Error,
42 threshold::ThresholdConfig,
43 verify::verify_partial,
44};
45
46#[derive(Clone, Debug, PartialEq, Eq)]
48enum SlotState {
49 Empty,
51 Valid(Vec<u8>),
53}
54
55#[derive(Debug)]
64pub struct SigningSession {
65 config: ThresholdConfig,
66 message: Vec<u8>,
67 slots: Vec<SlotState>,
68 valid_count: usize,
69}
70
71impl SigningSession {
72 pub fn new(config: &ThresholdConfig, message: &[u8]) -> Self {
80 let n = config.total();
81 Self {
82 config: config.clone(),
83 message: message.to_vec(),
84 slots: vec![SlotState::Empty; n],
85 valid_count: 0,
86 }
87 }
88
89 pub fn add_signature(&mut self, signer_index: usize, signature: Vec<u8>) -> Result<(), Error> {
101 let total = self.config.total();
102
103 if signer_index >= total {
104 return Err(Error::SignerIndexOutOfRange {
105 index: signer_index,
106 total,
107 });
108 }
109
110 if self.slots[signer_index] != SlotState::Empty {
111 return Err(Error::DuplicateSignature { index: signer_index });
112 }
113
114 let pk = self
115 .config
116 .get_public_key(signer_index)
117 .expect("index is within bounds; already checked above");
118
119 let valid = verify_partial(&self.message, &signature, pk.as_bytes(), signer_index)?;
120
121 if !valid {
122 return Err(Error::VerificationFailed { index: signer_index });
123 }
124
125 self.slots[signer_index] = SlotState::Valid(signature);
126 self.valid_count += 1;
127 Ok(())
128 }
129
130 pub fn is_complete(&self) -> bool {
132 self.valid_count >= self.config.required()
133 }
134
135 pub fn verify(&self) -> Result<bool, Error> {
145 if self.valid_count < self.config.required() {
146 return Err(Error::ThresholdNotMet {
147 have: self.valid_count,
148 need: self.config.required(),
149 });
150 }
151 Ok(true)
152 }
153
154 pub fn valid_signature_count(&self) -> usize {
156 self.valid_count
157 }
158
159 pub fn required(&self) -> usize {
161 self.config.required()
162 }
163
164 pub fn progress(&self) -> (usize, usize) {
166 (self.valid_count, self.config.required())
167 }
168
169 pub fn message(&self) -> &[u8] {
171 &self.message
172 }
173
174 pub fn config(&self) -> &ThresholdConfig {
176 &self.config
177 }
178
179 pub fn get_signature(&self, signer_index: usize) -> Option<&[u8]> {
183 match self.slots.get(signer_index)? {
184 SlotState::Valid(sig) => Some(sig.as_slice()),
185 SlotState::Empty => None,
186 }
187 }
188
189 pub fn signed_indices(&self) -> Vec<usize> {
191 self.slots
192 .iter()
193 .enumerate()
194 .filter_map(|(i, slot)| {
195 if *slot != SlotState::Empty {
196 Some(i)
197 } else {
198 None
199 }
200 })
201 .collect()
202 }
203}
204
205#[cfg(test)]
210mod tests {
211 use super::*;
212 use crate::keypair::KeyPair;
213
214 fn setup(required: usize, total: usize) -> (Vec<KeyPair>, ThresholdConfig) {
215 let keypairs: Vec<KeyPair> = (0..total).map(|_| KeyPair::generate()).collect();
216 let pks = keypairs.iter().map(|kp| kp.public_key().clone()).collect();
217 let config = ThresholdConfig::new(required, pks).unwrap();
218 (keypairs, config)
219 }
220
221 #[test]
222 fn session_2of3_complete_roundtrip() {
223 let (kps, cfg) = setup(2, 3);
224 let msg = b"payload";
225 let mut session = SigningSession::new(&cfg, msg);
226
227 session.add_signature(0, kps[0].sign(msg)).unwrap();
228 assert!(!session.is_complete());
229
230 session.add_signature(2, kps[2].sign(msg)).unwrap();
231 assert!(session.is_complete());
232 assert!(session.verify().unwrap());
233 }
234
235 #[test]
236 fn session_3of5_any_three_suffice() {
237 let (kps, cfg) = setup(3, 5);
238 let msg = b"3-of-5 test";
239 let mut session = SigningSession::new(&cfg, msg);
240
241 session.add_signature(1, kps[1].sign(msg)).unwrap();
242 session.add_signature(3, kps[3].sign(msg)).unwrap();
243 session.add_signature(4, kps[4].sign(msg)).unwrap();
244
245 assert!(session.is_complete());
246 assert!(session.verify().unwrap());
247 }
248
249 #[test]
250 fn verify_before_threshold_met_returns_error() {
251 let (kps, cfg) = setup(2, 3);
252 let msg = b"incomplete";
253 let mut session = SigningSession::new(&cfg, msg);
254
255 session.add_signature(0, kps[0].sign(msg)).unwrap();
256
257 let err = session.verify().unwrap_err();
258 assert!(matches!(err, Error::ThresholdNotMet { have: 1, need: 2 }));
259 }
260
261 #[test]
262 fn duplicate_signature_rejected() {
263 let (kps, cfg) = setup(2, 3);
264 let msg = b"dup test";
265 let mut session = SigningSession::new(&cfg, msg);
266
267 session.add_signature(0, kps[0].sign(msg)).unwrap();
268 let err = session.add_signature(0, kps[0].sign(msg)).unwrap_err();
269 assert!(matches!(err, Error::DuplicateSignature { index: 0 }));
270 }
271
272 #[test]
273 fn out_of_range_index_rejected() {
274 let (kps, cfg) = setup(2, 3);
275 let msg = b"oob test";
276 let mut session = SigningSession::new(&cfg, msg);
277
278 let err = session.add_signature(99, kps[0].sign(msg)).unwrap_err();
279 assert!(matches!(err, Error::SignerIndexOutOfRange { index: 99, total: 3 }));
280 }
281
282 #[test]
283 fn wrong_key_signature_rejected() {
284 let (kps, cfg) = setup(2, 3);
285 let attacker = KeyPair::generate();
286 let msg = b"attack";
287 let mut session = SigningSession::new(&cfg, msg);
288
289 let forged = attacker.sign(msg);
291 let err = session.add_signature(0, forged).unwrap_err();
292 assert!(matches!(err, Error::VerificationFailed { index: 0 }));
293 }
294
295 #[test]
296 fn wrong_message_signature_rejected() {
297 let (kps, cfg) = setup(2, 3);
298 let msg = b"correct message";
299 let mut session = SigningSession::new(&cfg, msg);
300
301 let sig_for_wrong_msg = kps[0].sign(b"wrong message");
303 let err = session.add_signature(0, sig_for_wrong_msg).unwrap_err();
304 assert!(matches!(err, Error::VerificationFailed { index: 0 }));
305 }
306
307 #[test]
308 fn progress_reports_correctly() {
309 let (kps, cfg) = setup(3, 5);
310 let msg = b"progress test";
311 let mut session = SigningSession::new(&cfg, msg);
312
313 assert_eq!(session.progress(), (0, 3));
314 session.add_signature(0, kps[0].sign(msg)).unwrap();
315 assert_eq!(session.progress(), (1, 3));
316 session.add_signature(1, kps[1].sign(msg)).unwrap();
317 assert_eq!(session.progress(), (2, 3));
318 }
319
320 #[test]
321 fn signed_indices_tracks_contributors() {
322 let (kps, cfg) = setup(2, 4);
323 let msg = b"indices test";
324 let mut session = SigningSession::new(&cfg, msg);
325
326 session.add_signature(0, kps[0].sign(msg)).unwrap();
327 session.add_signature(3, kps[3].sign(msg)).unwrap();
328
329 let indices = session.signed_indices();
330 assert_eq!(indices, vec![0, 3]);
331 }
332
333 #[test]
334 fn get_signature_returns_correct_bytes() {
335 let (kps, cfg) = setup(2, 3);
336 let msg = b"get sig test";
337 let mut session = SigningSession::new(&cfg, msg);
338
339 let sig = kps[1].sign(msg);
340 session.add_signature(1, sig.clone()).unwrap();
341
342 assert_eq!(session.get_signature(1), Some(sig.as_slice()));
343 assert!(session.get_signature(0).is_none());
344 }
345
346 #[test]
347 fn n_of_n_requires_all_signers() {
348 let (kps, cfg) = setup(4, 4);
349 let msg = b"unanimous";
350 let mut session = SigningSession::new(&cfg, msg);
351
352 for i in 0..3 {
353 session.add_signature(i, kps[i].sign(msg)).unwrap();
354 assert!(!session.is_complete(), "should not be complete after {i} sigs");
355 }
356
357 session.add_signature(3, kps[3].sign(msg)).unwrap();
358 assert!(session.is_complete());
359 assert!(session.verify().unwrap());
360 }
361}