#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::{
error::Error,
threshold::ThresholdConfig,
verify::verify_partial,
};
#[derive(Clone, Debug, PartialEq, Eq)]
enum SlotState {
Empty,
Valid(Vec<u8>),
}
#[derive(Debug)]
pub struct SigningSession {
config: ThresholdConfig,
message: Vec<u8>,
slots: Vec<SlotState>,
valid_count: usize,
}
impl SigningSession {
pub fn new(config: &ThresholdConfig, message: &[u8]) -> Self {
let n = config.total();
Self {
config: config.clone(),
message: message.to_vec(),
slots: vec![SlotState::Empty; n],
valid_count: 0,
}
}
pub fn add_signature(&mut self, signer_index: usize, signature: Vec<u8>) -> Result<(), Error> {
let total = self.config.total();
if signer_index >= total {
return Err(Error::SignerIndexOutOfRange {
index: signer_index,
total,
});
}
if self.slots[signer_index] != SlotState::Empty {
return Err(Error::DuplicateSignature { index: signer_index });
}
let pk = self
.config
.get_public_key(signer_index)
.expect("index is within bounds; already checked above");
let valid = verify_partial(&self.message, &signature, pk.as_bytes(), signer_index)?;
if !valid {
return Err(Error::VerificationFailed { index: signer_index });
}
self.slots[signer_index] = SlotState::Valid(signature);
self.valid_count += 1;
Ok(())
}
pub fn is_complete(&self) -> bool {
self.valid_count >= self.config.required()
}
pub fn verify(&self) -> Result<bool, Error> {
if self.valid_count < self.config.required() {
return Err(Error::ThresholdNotMet {
have: self.valid_count,
need: self.config.required(),
});
}
Ok(true)
}
pub fn valid_signature_count(&self) -> usize {
self.valid_count
}
pub fn required(&self) -> usize {
self.config.required()
}
pub fn progress(&self) -> (usize, usize) {
(self.valid_count, self.config.required())
}
pub fn message(&self) -> &[u8] {
&self.message
}
pub fn config(&self) -> &ThresholdConfig {
&self.config
}
pub fn get_signature(&self, signer_index: usize) -> Option<&[u8]> {
match self.slots.get(signer_index)? {
SlotState::Valid(sig) => Some(sig.as_slice()),
SlotState::Empty => None,
}
}
pub fn signed_indices(&self) -> Vec<usize> {
self.slots
.iter()
.enumerate()
.filter_map(|(i, slot)| {
if *slot != SlotState::Empty {
Some(i)
} else {
None
}
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::keypair::KeyPair;
fn setup(required: usize, total: usize) -> (Vec<KeyPair>, ThresholdConfig) {
let keypairs: Vec<KeyPair> = (0..total).map(|_| KeyPair::generate()).collect();
let pks = keypairs.iter().map(|kp| kp.public_key().clone()).collect();
let config = ThresholdConfig::new(required, pks).unwrap();
(keypairs, config)
}
#[test]
fn session_2of3_complete_roundtrip() {
let (kps, cfg) = setup(2, 3);
let msg = b"payload";
let mut session = SigningSession::new(&cfg, msg);
session.add_signature(0, kps[0].sign(msg)).unwrap();
assert!(!session.is_complete());
session.add_signature(2, kps[2].sign(msg)).unwrap();
assert!(session.is_complete());
assert!(session.verify().unwrap());
}
#[test]
fn session_3of5_any_three_suffice() {
let (kps, cfg) = setup(3, 5);
let msg = b"3-of-5 test";
let mut session = SigningSession::new(&cfg, msg);
session.add_signature(1, kps[1].sign(msg)).unwrap();
session.add_signature(3, kps[3].sign(msg)).unwrap();
session.add_signature(4, kps[4].sign(msg)).unwrap();
assert!(session.is_complete());
assert!(session.verify().unwrap());
}
#[test]
fn verify_before_threshold_met_returns_error() {
let (kps, cfg) = setup(2, 3);
let msg = b"incomplete";
let mut session = SigningSession::new(&cfg, msg);
session.add_signature(0, kps[0].sign(msg)).unwrap();
let err = session.verify().unwrap_err();
assert!(matches!(err, Error::ThresholdNotMet { have: 1, need: 2 }));
}
#[test]
fn duplicate_signature_rejected() {
let (kps, cfg) = setup(2, 3);
let msg = b"dup test";
let mut session = SigningSession::new(&cfg, msg);
session.add_signature(0, kps[0].sign(msg)).unwrap();
let err = session.add_signature(0, kps[0].sign(msg)).unwrap_err();
assert!(matches!(err, Error::DuplicateSignature { index: 0 }));
}
#[test]
fn out_of_range_index_rejected() {
let (kps, cfg) = setup(2, 3);
let msg = b"oob test";
let mut session = SigningSession::new(&cfg, msg);
let err = session.add_signature(99, kps[0].sign(msg)).unwrap_err();
assert!(matches!(err, Error::SignerIndexOutOfRange { index: 99, total: 3 }));
}
#[test]
fn wrong_key_signature_rejected() {
let (kps, cfg) = setup(2, 3);
let attacker = KeyPair::generate();
let msg = b"attack";
let mut session = SigningSession::new(&cfg, msg);
let forged = attacker.sign(msg);
let err = session.add_signature(0, forged).unwrap_err();
assert!(matches!(err, Error::VerificationFailed { index: 0 }));
}
#[test]
fn wrong_message_signature_rejected() {
let (kps, cfg) = setup(2, 3);
let msg = b"correct message";
let mut session = SigningSession::new(&cfg, msg);
let sig_for_wrong_msg = kps[0].sign(b"wrong message");
let err = session.add_signature(0, sig_for_wrong_msg).unwrap_err();
assert!(matches!(err, Error::VerificationFailed { index: 0 }));
}
#[test]
fn progress_reports_correctly() {
let (kps, cfg) = setup(3, 5);
let msg = b"progress test";
let mut session = SigningSession::new(&cfg, msg);
assert_eq!(session.progress(), (0, 3));
session.add_signature(0, kps[0].sign(msg)).unwrap();
assert_eq!(session.progress(), (1, 3));
session.add_signature(1, kps[1].sign(msg)).unwrap();
assert_eq!(session.progress(), (2, 3));
}
#[test]
fn signed_indices_tracks_contributors() {
let (kps, cfg) = setup(2, 4);
let msg = b"indices test";
let mut session = SigningSession::new(&cfg, msg);
session.add_signature(0, kps[0].sign(msg)).unwrap();
session.add_signature(3, kps[3].sign(msg)).unwrap();
let indices = session.signed_indices();
assert_eq!(indices, vec![0, 3]);
}
#[test]
fn get_signature_returns_correct_bytes() {
let (kps, cfg) = setup(2, 3);
let msg = b"get sig test";
let mut session = SigningSession::new(&cfg, msg);
let sig = kps[1].sign(msg);
session.add_signature(1, sig.clone()).unwrap();
assert_eq!(session.get_signature(1), Some(sig.as_slice()));
assert!(session.get_signature(0).is_none());
}
#[test]
fn n_of_n_requires_all_signers() {
let (kps, cfg) = setup(4, 4);
let msg = b"unanimous";
let mut session = SigningSession::new(&cfg, msg);
for i in 0..3 {
session.add_signature(i, kps[i].sign(msg)).unwrap();
assert!(!session.is_complete(), "should not be complete after {i} sigs");
}
session.add_signature(3, kps[3].sign(msg)).unwrap();
assert!(session.is_complete());
assert!(session.verify().unwrap());
}
}