use frost_ed25519 as frost;
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ThresholdConfig {
pub member_count: u16,
pub threshold: u16,
pub aggregated_pubkey_hex: String,
pub committee_id: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct SigningRequest {
pub committee_id: String,
pub request_id: String,
pub payload_digest_hex: String,
pub deadline_secs: u64,
pub human_context: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct SigningResponse {
pub committee_id: String,
pub request_id: String,
pub signer_index: u16,
pub partial_sig_hex: String,
pub refused: bool,
}
#[derive(Debug, Clone)]
pub struct KeyMaterial {
pub identifier: frost::Identifier,
pub key_package: frost::keys::KeyPackage,
pub public_key_package: frost::keys::PublicKeyPackage,
}
#[derive(Debug, thiserror::Error)]
pub enum ThresholdError {
#[error("FROST cryptographic error: {0}")]
Frost(String),
#[error("invalid identifier value (must be 1..=u16::MAX): {0}")]
InvalidIdentifier(u16),
#[error("threshold/member-count mismatch: {threshold} > {member_count}")]
ThresholdAboveMemberCount { threshold: u16, member_count: u16 },
#[error("threshold must be at least 1, got {0}")]
ThresholdZero(u16),
#[error("committee id mismatch: request `{request_committee}` vs config `{config_committee}`")]
CommitteeIdMismatch {
request_committee: String,
config_committee: String,
},
#[error("hex decode error: {0}")]
Hex(#[from] hex::FromHexError),
#[error("payload digest must be exactly 32 bytes, got {0}")]
BadDigestLength(usize),
#[error("not enough signers: got {got}, need {needed}")]
BelowThreshold { got: u16, needed: u16 },
}
impl From<frost::Error> for ThresholdError {
fn from(e: frost::Error) -> Self {
ThresholdError::Frost(format!("{e:?}"))
}
}
pub fn run_dkg_in_memory(
member_count: u16,
threshold: u16,
committee_id: impl Into<String>,
) -> Result<(Vec<KeyMaterial>, ThresholdConfig), ThresholdError> {
if threshold == 0 {
return Err(ThresholdError::ThresholdZero(threshold));
}
if threshold > member_count {
return Err(ThresholdError::ThresholdAboveMemberCount {
threshold,
member_count,
});
}
let mut rng = OsRng;
let mut round1_secrets: BTreeMap<frost::Identifier, frost::keys::dkg::round1::SecretPackage> =
BTreeMap::new();
let mut round1_packages_by_id: BTreeMap<frost::Identifier, frost::keys::dkg::round1::Package> =
BTreeMap::new();
for i in 1u16..=member_count {
let id =
frost::Identifier::try_from(i).map_err(|_| ThresholdError::InvalidIdentifier(i))?;
#[allow(clippy::needless_borrows_for_generic_args)]
let (secret, package) = frost::keys::dkg::part1(id, member_count, threshold, &mut rng)?;
round1_secrets.insert(id, secret);
round1_packages_by_id.insert(id, package);
}
let mut round2_secrets: BTreeMap<frost::Identifier, frost::keys::dkg::round2::SecretPackage> =
BTreeMap::new();
let mut round2_inboxes: BTreeMap<
frost::Identifier,
BTreeMap<frost::Identifier, frost::keys::dkg::round2::Package>,
> = BTreeMap::new();
for (id, secret) in round1_secrets {
let received: BTreeMap<_, _> = round1_packages_by_id
.iter()
.filter(|(other_id, _)| **other_id != id)
.map(|(other_id, pkg)| (*other_id, pkg.clone()))
.collect();
let (round2_secret, round2_packages) = frost::keys::dkg::part2(secret, &received)?;
round2_secrets.insert(id, round2_secret);
for (recipient, pkg) in round2_packages {
round2_inboxes.entry(recipient).or_default().insert(id, pkg);
}
}
let mut materials = Vec::with_capacity(member_count as usize);
for (id, round2_secret) in round2_secrets {
let r1_received: BTreeMap<_, _> = round1_packages_by_id
.iter()
.filter(|(other_id, _)| **other_id != id)
.map(|(other_id, pkg)| (*other_id, pkg.clone()))
.collect();
let r2_received = round2_inboxes.remove(&id).unwrap_or_default();
let (key_package, public_key_package) =
frost::keys::dkg::part3(&round2_secret, &r1_received, &r2_received)?;
materials.push(KeyMaterial {
identifier: id,
key_package,
public_key_package,
});
}
let canonical_pubkey = materials[0].public_key_package.verifying_key();
for m in &materials[1..] {
if m.public_key_package.verifying_key() != canonical_pubkey {
return Err(ThresholdError::Frost(
"DKG produced divergent verifying keys across participants".into(),
));
}
}
let aggregated_pubkey_hex = hex::encode(canonical_pubkey.serialize()?);
let config = ThresholdConfig {
member_count,
threshold,
aggregated_pubkey_hex,
committee_id: committee_id.into(),
};
Ok((materials, config))
}
pub fn sign_round_trip(
config: &ThresholdConfig,
signers: &[&KeyMaterial],
payload: &[u8],
) -> Result<frost::Signature, ThresholdError> {
if (signers.len() as u16) < config.threshold {
return Err(ThresholdError::BelowThreshold {
got: signers.len() as u16,
needed: config.threshold,
});
}
let mut rng = OsRng;
let mut nonces: BTreeMap<frost::Identifier, frost::round1::SigningNonces> = BTreeMap::new();
let mut commitments_by_id: BTreeMap<frost::Identifier, frost::round1::SigningCommitments> =
BTreeMap::new();
for s in signers {
let (nonce, commitment) = frost::round1::commit(s.key_package.signing_share(), &mut rng);
nonces.insert(s.identifier, nonce);
commitments_by_id.insert(s.identifier, commitment);
}
let signing_package = frost::SigningPackage::new(commitments_by_id, payload);
let mut signature_shares: BTreeMap<frost::Identifier, frost::round2::SignatureShare> =
BTreeMap::new();
for s in signers {
let nonce = nonces
.get(&s.identifier)
.expect("we just inserted this in round 1");
let share = frost::round2::sign(&signing_package, nonce, &s.key_package)?;
signature_shares.insert(s.identifier, share);
}
let pubkey_package = &signers[0].public_key_package;
let signature = frost::aggregate(&signing_package, &signature_shares, pubkey_package)?;
Ok(signature)
}
pub fn verify_threshold_signature(
aggregated_pubkey_hex: &str,
payload: &[u8],
signature: &frost::Signature,
) -> Result<(), ThresholdError> {
let bytes = hex::decode(aggregated_pubkey_hex)?;
let arr: [u8; 32] = bytes
.try_into()
.map_err(|v: Vec<u8>| ThresholdError::BadDigestLength(v.len()))?;
let vk = frost::VerifyingKey::deserialize(&arr)?;
vk.verify(payload, signature)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dkg_3_of_5_produces_consistent_pubkey() {
let (materials, config) = run_dkg_in_memory(5, 3, "test-committee").unwrap();
assert_eq!(materials.len(), 5);
assert_eq!(config.member_count, 5);
assert_eq!(config.threshold, 3);
let canonical = materials[0].public_key_package.verifying_key();
for m in &materials[1..] {
assert_eq!(m.public_key_package.verifying_key(), canonical);
}
let expected = hex::encode(canonical.serialize().unwrap());
assert_eq!(config.aggregated_pubkey_hex, expected);
}
#[test]
fn dkg_rejects_zero_threshold() {
let err = run_dkg_in_memory(5, 0, "x").unwrap_err();
assert!(matches!(err, ThresholdError::ThresholdZero(0)));
}
#[test]
fn dkg_rejects_threshold_above_member_count() {
let err = run_dkg_in_memory(3, 5, "x").unwrap_err();
assert!(matches!(
err,
ThresholdError::ThresholdAboveMemberCount {
threshold: 5,
member_count: 3
}
));
}
#[test]
fn sign_3_of_5_round_trip_verifies() {
let (materials, config) = run_dkg_in_memory(5, 3, "test").unwrap();
let payload = b"hello, threshold world";
let signers: Vec<&KeyMaterial> = materials.iter().take(3).collect();
let sig = sign_round_trip(&config, &signers, payload).unwrap();
verify_threshold_signature(&config.aggregated_pubkey_hex, payload, &sig).unwrap();
}
#[test]
fn sign_4_of_5_round_trip_verifies() {
let (materials, config) = run_dkg_in_memory(5, 3, "test").unwrap();
let payload = b"4 of 5 sigs";
let signers: Vec<&KeyMaterial> = materials.iter().take(4).collect();
let sig = sign_round_trip(&config, &signers, payload).unwrap();
verify_threshold_signature(&config.aggregated_pubkey_hex, payload, &sig).unwrap();
}
#[test]
fn sign_below_threshold_fails() {
let (materials, config) = run_dkg_in_memory(5, 3, "test").unwrap();
let payload = b"won't reach threshold";
let signers: Vec<&KeyMaterial> = materials.iter().take(2).collect();
let err = sign_round_trip(&config, &signers, payload).unwrap_err();
assert!(matches!(
err,
ThresholdError::BelowThreshold { got: 2, needed: 3 }
));
}
#[test]
fn signature_does_not_verify_under_wrong_key() {
let (m1, _c1) = run_dkg_in_memory(5, 3, "alpha").unwrap();
let (_m2, c2) = run_dkg_in_memory(5, 3, "beta").unwrap();
let payload = b"crossed wires";
let signers: Vec<&KeyMaterial> = m1.iter().take(3).collect();
let sig = sign_round_trip(
&ThresholdConfig {
aggregated_pubkey_hex: hex::encode(
m1[0]
.public_key_package
.verifying_key()
.serialize()
.unwrap(),
),
..ThresholdConfig {
member_count: 5,
threshold: 3,
aggregated_pubkey_hex: String::new(),
committee_id: "alpha".into(),
}
},
&signers,
payload,
)
.unwrap();
let err = verify_threshold_signature(&c2.aggregated_pubkey_hex, payload, &sig).unwrap_err();
assert!(matches!(err, ThresholdError::Frost(_)));
}
#[test]
fn signature_does_not_verify_with_tampered_payload() {
let (materials, config) = run_dkg_in_memory(5, 3, "test").unwrap();
let payload = b"original";
let tampered = b"tampered";
let signers: Vec<&KeyMaterial> = materials.iter().take(3).collect();
let sig = sign_round_trip(&config, &signers, payload).unwrap();
let err =
verify_threshold_signature(&config.aggregated_pubkey_hex, tampered, &sig).unwrap_err();
assert!(matches!(err, ThresholdError::Frost(_)));
}
#[test]
fn config_round_trip_via_json() {
let (_m, c) = run_dkg_in_memory(5, 3, "rt").unwrap();
let s = serde_json::to_string(&c).unwrap();
let back: ThresholdConfig = serde_json::from_str(&s).unwrap();
assert_eq!(c, back);
}
#[test]
fn signing_request_round_trip() {
let r = SigningRequest {
committee_id: "x".into(),
request_id: "req-1".into(),
payload_digest_hex: "ab".repeat(32),
deadline_secs: 1_777_905_000,
human_context: "swap context".into(),
};
let s = serde_json::to_string(&r).unwrap();
let back: SigningRequest = serde_json::from_str(&s).unwrap();
assert_eq!(r, back);
}
#[test]
fn deny_unknown_fields_in_config() {
let bad = r#"{
"member_count": 5,
"threshold": 3,
"aggregated_pubkey_hex": "00",
"committee_id": "x",
"extra": "rejected"
}"#;
let res: Result<ThresholdConfig, _> = serde_json::from_str(bad);
assert!(res.is_err());
}
#[test]
fn different_signer_subsets_both_verify() {
let (materials, config) = run_dkg_in_memory(5, 3, "test").unwrap();
let payload = b"deterministic payload, randomised sigs";
let sig_a = sign_round_trip(
&config,
&materials[0..3].iter().collect::<Vec<_>>(),
payload,
)
.unwrap();
let sig_b = sign_round_trip(
&config,
&materials[2..5].iter().collect::<Vec<_>>(),
payload,
)
.unwrap();
verify_threshold_signature(&config.aggregated_pubkey_hex, payload, &sig_a).unwrap();
verify_threshold_signature(&config.aggregated_pubkey_hex, payload, &sig_b).unwrap();
let a_bytes = sig_a.serialize().unwrap();
let b_bytes = sig_b.serialize().unwrap();
assert_ne!(a_bytes, b_bytes);
}
}