use std::collections::{BTreeMap, HashMap};
use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, edwards::EdwardsPoint, scalar::Scalar};
use super::{
combine::lagrange_coefficient,
types::{KeyShare, ThresholdConfig, ThresholdDecryptionContext, ThresholdPublicKey},
};
use crate::crypto::error::CryptoError;
pub use frost_ristretto255::{
keys::{
dkg::{
self as frost_dkg,
round1::{Package as Round1Package, SecretPackage as Round1SecretPackage},
round2::{Package as Round2Package, SecretPackage as Round2SecretPackage},
},
KeyPackage, PublicKeyPackage,
},
Error as FrostError, Identifier,
};
pub fn round1(
identifier: Identifier,
max_signers: u16,
min_signers: u16,
) -> Result<(Round1SecretPackage, Round1Package), FrostError> {
let rng = rand_core::OsRng;
frost_dkg::part1(identifier, max_signers, min_signers, rng)
}
pub fn round2(
secret_package: Round1SecretPackage,
round1_packages: &BTreeMap<Identifier, Round1Package>,
) -> Result<(Round2SecretPackage, BTreeMap<Identifier, Round2Package>), FrostError> {
frost_dkg::part2(secret_package, round1_packages)
}
pub fn round3(
secret_package: &Round2SecretPackage,
round1_packages: &BTreeMap<Identifier, Round1Package>,
round2_packages: &BTreeMap<Identifier, Round2Package>,
) -> Result<(KeyPackage, PublicKeyPackage), FrostError> {
frost_dkg::part3(secret_package, round1_packages, round2_packages)
}
pub fn identifier_from_index(index: u32) -> Result<Identifier, CryptoError> {
let index_u16: u16 = index
.try_into()
.map_err(|_| CryptoError::ThresholdDecrypt(format!("operator index {} exceeds u16 range", index)))?;
Identifier::try_from(index_u16)
.map_err(|e| CryptoError::ThresholdDecrypt(format!("invalid FROST identifier for index {}: {}", index, e)))
}
pub fn index_from_identifier(id: &Identifier) -> Result<u32, CryptoError> {
let bytes = id.serialize();
if bytes.len() < 4 {
return Err(CryptoError::ThresholdDecrypt("FROST identifier too short".into()));
}
let index = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
if index == 0 {
return Err(CryptoError::ThresholdDecrypt(
"FROST identifier maps to zero index".into(),
));
}
debug_assert_eq!(
identifier_from_index(index).ok().as_ref(),
Some(id),
"FROST Identifier roundtrip failed — serialize() format may have changed"
);
Ok(index)
}
fn extract_signing_scalar(key_package: &KeyPackage) -> Result<Scalar, CryptoError> {
let share_bytes = key_package.signing_share().serialize();
let bytes: [u8; 32] = share_bytes
.try_into()
.map_err(|_| CryptoError::ThresholdDecrypt("signing share is not 32 bytes".into()))?;
Option::from(Scalar::from_canonical_bytes(bytes))
.ok_or_else(|| CryptoError::ThresholdDecrypt("signing share bytes are not a canonical scalar".into()))
}
pub fn frost_to_key_share(key_package: &KeyPackage) -> Result<KeyShare, CryptoError> {
let index = index_from_identifier(key_package.identifier())?;
let secret_share = extract_signing_scalar(key_package)?;
let public_share = secret_share * ED25519_BASEPOINT_POINT;
Ok(KeyShare {
index,
secret_share,
public_share,
})
}
pub fn frost_to_threshold_context(
key_shares: &[KeyShare],
config: ThresholdConfig,
) -> Result<ThresholdDecryptionContext, CryptoError> {
if key_shares.is_empty() {
return Err(CryptoError::ThresholdDecrypt("no key shares provided".into()));
}
let public_shares: HashMap<u32, EdwardsPoint> = key_shares.iter().map(|ks| (ks.index, ks.public_share)).collect();
let indices: Vec<u32> = key_shares.iter().map(|ks| ks.index).collect();
let mpk = lagrange_interpolate_at_zero(&indices, &public_shares)?;
let mpk_montgomery = mpk.to_montgomery();
let threshold_pk = ThresholdPublicKey {
edwards: mpk,
hpke_public_key: mpk_montgomery.to_bytes(),
};
Ok(ThresholdDecryptionContext {
public_key: threshold_pk,
public_shares,
config,
})
}
fn lagrange_interpolate_at_zero(
indices: &[u32],
public_shares: &HashMap<u32, EdwardsPoint>,
) -> Result<EdwardsPoint, CryptoError> {
let mut result = EdwardsPoint::default();
for &i in indices {
let pk_i = public_shares
.get(&i)
.ok_or_else(|| CryptoError::ThresholdDecrypt(format!("missing public share for index {}", i)))?;
let lambda = lagrange_coefficient(i, indices)?;
result += lambda * pk_i;
}
Ok(result)
}
#[cfg(test)]
pub fn run_dkg_ceremony(config: ThresholdConfig) -> Result<(Vec<KeyShare>, ThresholdDecryptionContext), CryptoError> {
let n = config.total as u16;
let t = config.threshold as u16;
let identifiers: Vec<Identifier> = (1..=config.total)
.map(identifier_from_index)
.collect::<Result<Vec<_>, _>>()?;
let mut round1_secrets: BTreeMap<Identifier, Round1SecretPackage> = BTreeMap::new();
let mut round1_packages_all: BTreeMap<Identifier, Round1Package> = BTreeMap::new();
for &id in &identifiers {
let (secret, package) =
round1(id, n, t).map_err(|e| CryptoError::ThresholdDecrypt(format!("FROST round1 failed: {}", e)))?;
round1_secrets.insert(id, secret);
round1_packages_all.insert(id, package);
}
let mut round2_secrets: BTreeMap<Identifier, Round2SecretPackage> = BTreeMap::new();
let mut round2_packages_map: BTreeMap<Identifier, BTreeMap<Identifier, Round2Package>> = BTreeMap::new();
for &id in &identifiers {
let received: BTreeMap<Identifier, Round1Package> = round1_packages_all
.iter()
.filter(|(&k, _)| k != id)
.map(|(&k, v)| (k, v.clone()))
.collect();
let round1_secret = round1_secrets
.remove(&id)
.ok_or_else(|| CryptoError::ThresholdDecrypt("missing round1 secret".into()))?;
let (secret, packages) = round2(round1_secret, &received)
.map_err(|e| CryptoError::ThresholdDecrypt(format!("FROST round2 failed: {}", e)))?;
round2_secrets.insert(id, secret);
round2_packages_map.insert(id, packages);
}
let mut key_shares = Vec::with_capacity(config.total as usize);
for &id in &identifiers {
let received_round1: BTreeMap<Identifier, Round1Package> = round1_packages_all
.iter()
.filter(|(&k, _)| k != id)
.map(|(&k, v)| (k, v.clone()))
.collect();
let received_round2: BTreeMap<Identifier, Round2Package> = round2_packages_map
.iter()
.filter(|(&sender, _)| sender != id)
.filter_map(|(&sender, packages)| packages.get(&id).map(|pkg| (sender, pkg.clone())))
.collect();
let (key_package, _public_key_package) = round3(&round2_secrets[&id], &received_round1, &received_round2)
.map_err(|e| CryptoError::ThresholdDecrypt(format!("FROST round3 failed: {}", e)))?;
let key_share = frost_to_key_share(&key_package)?;
key_shares.push(key_share);
}
let context = frost_to_threshold_context(&key_shares, config)?;
Ok((key_shares, context))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dkg::{combine, dealer};
#[test]
fn frost_dkg_produces_valid_key_shares() {
let config = ThresholdConfig { threshold: 2, total: 3 };
let (shares, ctx) = run_dkg_ceremony(config).unwrap();
assert_eq!(shares.len(), 3);
assert_eq!(ctx.public_shares.len(), 3);
assert_eq!(ctx.config.threshold, 2);
assert_eq!(ctx.config.total, 3);
}
#[test]
fn frost_shares_have_sequential_indices() {
let config = ThresholdConfig { threshold: 2, total: 3 };
let (shares, _) = run_dkg_ceremony(config).unwrap();
for (i, share) in shares.iter().enumerate() {
assert_eq!(share.index, (i + 1) as u32);
}
}
#[test]
fn frost_public_shares_match_secret_shares() {
let config = ThresholdConfig { threshold: 2, total: 3 };
let (shares, _) = run_dkg_ceremony(config).unwrap();
for share in &shares {
let expected = share.secret_share * ED25519_BASEPOINT_POINT;
assert_eq!(
share.public_share.compress(),
expected.compress(),
"public share mismatch for index {}",
share.index
);
}
}
#[test]
fn frost_mpk_equals_lagrange_reconstruction() {
let config = ThresholdConfig { threshold: 3, total: 5 };
let (shares, ctx) = run_dkg_ceremony(config).unwrap();
let indices: Vec<u32> = shares[..3].iter().map(|s| s.index).collect();
let mut reconstructed_secret = Scalar::ZERO;
for share in &shares[..3] {
let lambda = lagrange_coefficient(share.index, &indices).unwrap();
reconstructed_secret += lambda * share.secret_share;
}
let expected_mpk = reconstructed_secret * ED25519_BASEPOINT_POINT;
assert_eq!(
ctx.public_key.edwards.compress(),
expected_mpk.compress(),
"MPK from context doesn't match Lagrange reconstruction"
);
}
#[test]
fn frost_any_t_subset_reconstructs_same_secret() {
let config = ThresholdConfig { threshold: 3, total: 5 };
let (shares, _) = run_dkg_ceremony(config).unwrap();
let subsets: Vec<Vec<usize>> = vec![vec![0, 1, 2], vec![0, 2, 4], vec![1, 3, 4], vec![2, 3, 4]];
let mut secrets: Vec<Scalar> = Vec::new();
for subset in &subsets {
let selected: Vec<&KeyShare> = subset.iter().map(|&i| &shares[i]).collect();
let indices: Vec<u32> = selected.iter().map(|s| s.index).collect();
let mut reconstructed = Scalar::ZERO;
for share in &selected {
let lambda = lagrange_coefficient(share.index, &indices).unwrap();
reconstructed += lambda * share.secret_share;
}
secrets.push(reconstructed);
}
for (i, s) in secrets.iter().enumerate().skip(1) {
assert_eq!(secrets[0], *s, "subset {} reconstructed different secret", i);
}
}
#[test]
fn frost_shares_work_with_threshold_decrypt() {
let config = ThresholdConfig { threshold: 2, total: 3 };
let (shares, ctx) = run_dkg_ceremony(config).unwrap();
let pk = crate::crypto::hpke::HpkePublicKey::from_bytes(&ctx.public_key.hpke_public_key).unwrap();
let plaintext = b"FROST DKG threshold decryption test";
let aad = b"newton-frost-test";
let (enc, ciphertext) = crate::crypto::hpke::encrypt(&pk, plaintext, aad).unwrap();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = combine::montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<_> = shares[..2]
.iter()
.map(|s| combine::compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered = combine::threshold_decrypt(
&partials,
&enc_bytes,
&ctx.public_key.hpke_public_key,
&ciphertext,
aad,
&ctx.public_shares,
config.threshold,
)
.unwrap();
assert_eq!(&*recovered, plaintext);
}
#[test]
fn frost_3_of_5_threshold_decrypt() {
let config = ThresholdConfig { threshold: 3, total: 5 };
let (shares, ctx) = run_dkg_ceremony(config).unwrap();
let pk = crate::crypto::hpke::HpkePublicKey::from_bytes(&ctx.public_key.hpke_public_key).unwrap();
let plaintext = b"3-of-5 FROST threshold test";
let aad = b"newton-frost-3of5";
let (enc, ciphertext) = crate::crypto::hpke::encrypt(&pk, plaintext, aad).unwrap();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = combine::montgomery_to_edwards(&enc_bytes).unwrap();
let selected = [&shares[0], &shares[2], &shares[4]];
let partials: Vec<_> = selected
.iter()
.map(|s| combine::compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered = combine::threshold_decrypt(
&partials,
&enc_bytes,
&ctx.public_key.hpke_public_key,
&ciphertext,
aad,
&ctx.public_shares,
config.threshold,
)
.unwrap();
assert_eq!(&*recovered, plaintext);
}
#[test]
fn frost_1_of_1_rejected() {
let config = ThresholdConfig { threshold: 1, total: 1 };
let result = run_dkg_ceremony(config);
assert!(result.is_err(), "FROST should reject 1-of-1 threshold");
}
#[test]
fn frost_identifier_roundtrip() {
for index in [1u32, 2, 3, 100, 255, 1000, 65535] {
let id = identifier_from_index(index).unwrap();
let recovered = index_from_identifier(&id).unwrap();
assert_eq!(index, recovered, "identifier roundtrip failed for index {}", index);
}
}
#[test]
fn frost_identifier_zero_rejected() {
assert!(identifier_from_index(0).is_err());
}
#[test]
fn frost_dkg_equivalent_to_dealer_for_decryption() {
let config = ThresholdConfig { threshold: 2, total: 3 };
let (dealer_tpk, _commitment, dealer_shares) = dealer::generate_shares(config).unwrap();
let (frost_shares, frost_ctx) = run_dkg_ceremony(config).unwrap();
let pk = crate::crypto::hpke::HpkePublicKey::from_bytes(&frost_ctx.public_key.hpke_public_key).unwrap();
let plaintext = b"cross-scheme compatibility test";
let aad = b"newton-compat";
let (enc, ciphertext) = crate::crypto::hpke::encrypt(&pk, plaintext, aad).unwrap();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = combine::montgomery_to_edwards(&enc_bytes).unwrap();
let frost_partials: Vec<_> = frost_shares[..2]
.iter()
.map(|s| combine::compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let frost_recovered = combine::threshold_decrypt(
&frost_partials,
&enc_bytes,
&frost_ctx.public_key.hpke_public_key,
&ciphertext,
aad,
&frost_ctx.public_shares,
config.threshold,
)
.unwrap();
assert_eq!(&*frost_recovered, plaintext);
let dealer_pk = crate::crypto::hpke::HpkePublicKey::from_bytes(&dealer_tpk.hpke_public_key).unwrap();
let (enc2, ct2) = crate::crypto::hpke::encrypt(&dealer_pk, plaintext, aad).unwrap();
let enc_bytes2: [u8; 32] = enc2[..32].try_into().unwrap();
let enc_edwards2 = combine::montgomery_to_edwards(&enc_bytes2).unwrap();
let dealer_public_shares: HashMap<u32, EdwardsPoint> =
dealer_shares.iter().map(|s| (s.index, s.public_share)).collect();
let dealer_partials: Vec<_> = dealer_shares[..2]
.iter()
.map(|s| combine::compute_partial_decryption(s.index, &s.secret_share, &enc_edwards2))
.collect();
let dealer_recovered = combine::threshold_decrypt(
&dealer_partials,
&enc_bytes2,
&dealer_tpk.hpke_public_key,
&ct2,
aad,
&dealer_public_shares,
config.threshold,
)
.unwrap();
assert_eq!(&*dealer_recovered, plaintext);
}
}