use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, edwards::EdwardsPoint, scalar::Scalar};
use rand_core::OsRng;
use super::types::{KeyShare, ThresholdConfig, ThresholdPublicKey, VssCommitment};
use crate::crypto::error::CryptoError;
pub fn generate_shares(
config: ThresholdConfig,
) -> Result<(ThresholdPublicKey, VssCommitment, Vec<KeyShare>), CryptoError> {
validate_config(config)?;
let mut rng = OsRng;
let mut coefficients: Vec<Scalar> = (0..config.threshold).map(|_| Scalar::random(&mut rng)).collect();
let commitments: Vec<EdwardsPoint> = coefficients.iter().map(|a_j| a_j * ED25519_BASEPOINT_POINT).collect();
let mpk = commitments[0];
let mpk_montgomery = mpk.to_montgomery();
let threshold_pk = ThresholdPublicKey {
edwards: mpk,
hpke_public_key: mpk_montgomery.to_bytes(),
};
let vss_commitment = VssCommitment { commitments };
let shares: Vec<KeyShare> = (1..=config.total)
.map(|i| {
let x = Scalar::from(i);
let secret_share = evaluate_polynomial(&coefficients, &x);
let public_share = secret_share * ED25519_BASEPOINT_POINT;
KeyShare {
index: i,
secret_share,
public_share,
}
})
.collect();
coefficients.fill(Scalar::ZERO);
Ok((threshold_pk, vss_commitment, shares))
}
pub fn generate_shares_from_secret(
config: ThresholdConfig,
master_secret: Scalar,
) -> Result<(ThresholdPublicKey, VssCommitment, Vec<KeyShare>), CryptoError> {
validate_config(config)?;
let mut rng = OsRng;
let mut coefficients = Vec::with_capacity(config.threshold as usize);
coefficients.push(master_secret);
for _ in 1..config.threshold {
coefficients.push(Scalar::random(&mut rng));
}
let commitments: Vec<EdwardsPoint> = coefficients.iter().map(|a_j| a_j * ED25519_BASEPOINT_POINT).collect();
let mpk = commitments[0];
let threshold_pk = ThresholdPublicKey {
edwards: mpk,
hpke_public_key: mpk.to_montgomery().to_bytes(),
};
let vss_commitment = VssCommitment { commitments };
let shares: Vec<KeyShare> = (1..=config.total)
.map(|i| {
let x = Scalar::from(i);
let secret_share = evaluate_polynomial(&coefficients, &x);
let public_share = secret_share * ED25519_BASEPOINT_POINT;
KeyShare {
index: i,
secret_share,
public_share,
}
})
.collect();
coefficients.fill(Scalar::ZERO);
Ok((threshold_pk, vss_commitment, shares))
}
pub fn verify_share(share: &KeyShare, commitment: &VssCommitment) -> bool {
let x = Scalar::from(share.index);
let mut x_pow = Scalar::ONE; let mut expected = EdwardsPoint::default();
for c_j in &commitment.commitments {
expected += x_pow * c_j;
x_pow *= x;
}
share.public_share.compress() == expected.compress()
}
pub(crate) fn evaluate_polynomial(coefficients: &[Scalar], x: &Scalar) -> Scalar {
let mut result = Scalar::ZERO;
for coeff in coefficients.iter().rev() {
result = result * x + coeff;
}
result
}
fn validate_config(config: ThresholdConfig) -> Result<(), CryptoError> {
if config.threshold == 0 {
return Err(CryptoError::ThresholdDecrypt("threshold must be at least 1".into()));
}
if config.total == 0 {
return Err(CryptoError::ThresholdDecrypt("total must be at least 1".into()));
}
if config.threshold > config.total {
return Err(CryptoError::ThresholdDecrypt(format!(
"threshold ({}) cannot exceed total ({})",
config.threshold, config.total
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> ThresholdConfig {
ThresholdConfig { threshold: 3, total: 5 }
}
#[test]
fn generate_shares_produces_correct_count() {
let (_, commitment, shares) = generate_shares(test_config()).unwrap();
assert_eq!(shares.len(), 5);
assert_eq!(commitment.commitments.len(), 3);
}
#[test]
fn all_shares_verify_against_commitment() {
let (_, commitment, shares) = generate_shares(test_config()).unwrap();
for share in &shares {
assert!(
verify_share(share, &commitment),
"share {} failed verification",
share.index
);
}
}
#[test]
fn shares_have_sequential_indices() {
let (_, _, shares) = generate_shares(test_config()).unwrap();
for (i, share) in shares.iter().enumerate() {
assert_eq!(share.index, (i + 1) as u32);
}
}
#[test]
fn master_public_key_matches_first_commitment() {
let (tpk, commitment, _) = generate_shares(test_config()).unwrap();
assert_eq!(tpk.edwards.compress(), commitment.master_public_key().compress());
}
#[test]
fn hpke_public_key_is_montgomery_conversion() {
let (tpk, _, _) = generate_shares(test_config()).unwrap();
let expected = tpk.edwards.to_montgomery().to_bytes();
assert_eq!(tpk.hpke_public_key, expected);
}
#[test]
fn tampered_share_fails_verification() {
let (_, commitment, mut shares) = generate_shares(test_config()).unwrap();
shares[0].secret_share += Scalar::ONE;
shares[0].public_share = shares[0].secret_share * ED25519_BASEPOINT_POINT;
assert!(!verify_share(&shares[0], &commitment));
}
#[test]
fn wrong_commitment_fails_verification() {
let (_, _, shares) = generate_shares(test_config()).unwrap();
let (_, wrong_commitment, _) = generate_shares(test_config()).unwrap();
assert!(!verify_share(&shares[0], &wrong_commitment));
}
#[test]
fn generate_from_secret_uses_provided_key() {
let secret = Scalar::from(42u64);
let expected_mpk = secret * ED25519_BASEPOINT_POINT;
let config = test_config();
let (tpk, commitment, shares) = generate_shares_from_secret(config, secret).unwrap();
assert_eq!(tpk.edwards.compress(), expected_mpk.compress());
assert_eq!(commitment.master_public_key().compress(), expected_mpk.compress());
for share in &shares {
assert!(verify_share(share, &commitment));
}
}
#[test]
fn invalid_config_threshold_zero() {
let config = ThresholdConfig { threshold: 0, total: 5 };
assert!(generate_shares(config).is_err());
}
#[test]
fn invalid_config_threshold_exceeds_total() {
let config = ThresholdConfig { threshold: 6, total: 5 };
assert!(generate_shares(config).is_err());
}
#[test]
fn trivial_1_of_1_scheme() {
let config = ThresholdConfig { threshold: 1, total: 1 };
let (_, commitment, shares) = generate_shares(config).unwrap();
assert_eq!(shares.len(), 1);
assert!(verify_share(&shares[0], &commitment));
}
#[test]
fn polynomial_evaluation_horner() {
let coeffs = vec![Scalar::from(3u64), Scalar::from(2u64), Scalar::from(1u64)];
let result = evaluate_polynomial(&coeffs, &Scalar::from(2u64));
assert_eq!(result, Scalar::from(11u64));
}
}