use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, edwards::EdwardsPoint, scalar::Scalar};
use rand_core::OsRng;
use super::{combine, dealer};
use crate::crypto::error::CryptoError;
pub fn generate_refresh_polynomial(threshold: u32) -> Vec<Scalar> {
let mut rng = OsRng;
let mut coefficients = Vec::with_capacity(threshold as usize);
coefficients.push(Scalar::ZERO);
for _ in 1..threshold {
coefficients.push(Scalar::random(&mut rng));
}
coefficients
}
pub fn generate_reshare_polynomial(share: &Scalar, new_threshold: u32) -> Vec<Scalar> {
let mut rng = OsRng;
let mut coefficients = Vec::with_capacity(new_threshold as usize);
coefficients.push(*share);
for _ in 1..new_threshold {
coefficients.push(Scalar::random(&mut rng));
}
coefficients
}
pub fn polynomial_commitments(coefficients: &[Scalar]) -> Vec<EdwardsPoint> {
coefficients.iter().map(|c| c * ED25519_BASEPOINT_POINT).collect()
}
pub fn evaluate_at(coefficients: &[Scalar], index: u32) -> Scalar {
let x = Scalar::from(index);
dealer::evaluate_polynomial(coefficients, &x)
}
pub fn accumulate_refresh(current_share: &Scalar, deltas: &[Scalar]) -> Scalar {
let delta_sum: Scalar = deltas.iter().sum();
current_share + delta_sum
}
pub fn combine_reshare(sub_shares: &[(u32, Scalar)], old_participants: &[u32]) -> Result<Scalar, CryptoError> {
let mut new_share = Scalar::ZERO;
for &(old_index, sub_share_value) in sub_shares {
let lambda = combine::lagrange_coefficient(old_index, old_participants)?;
new_share += lambda * sub_share_value;
}
Ok(new_share)
}
pub fn verify_refresh_share(share_value: &Scalar, recipient_index: u32, commitments: &[EdwardsPoint]) -> bool {
let x = Scalar::from(recipient_index);
let mut x_pow = Scalar::ONE;
let mut expected = EdwardsPoint::default();
for c_j in commitments {
expected += x_pow * c_j;
x_pow *= x;
}
(share_value * ED25519_BASEPOINT_POINT).compress() == expected.compress()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
crypto::hpke,
dkg::{
combine::{compute_partial_decryption, montgomery_to_edwards, threshold_decrypt},
dealer,
types::{KeyShare, ThresholdConfig},
},
};
use curve25519_dalek::{constants::ED25519_BASEPOINT_POINT, edwards::EdwardsPoint, scalar::Scalar};
use std::collections::HashMap;
#[test]
fn refresh_polynomial_has_zero_constant() {
for threshold in 1..=5 {
let poly = generate_refresh_polynomial(threshold);
assert_eq!(poly.len(), threshold as usize);
assert_eq!(
poly[0],
Scalar::ZERO,
"constant term must be zero for threshold={threshold}"
);
}
}
#[test]
fn refresh_preserves_secret_2_of_3() {
let config = ThresholdConfig { threshold: 2, total: 3 };
let (tpk, _commitment, shares) = dealer::generate_shares(config).unwrap();
let original_mpk = tpk.edwards;
let pk = hpke::HpkePublicKey::from_bytes(&tpk.hpke_public_key).unwrap();
let plaintext = b"PSS refresh test payload";
let aad = b"refresh-context";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).unwrap();
let mut refresh_polys: Vec<Vec<Scalar>> = Vec::new();
for _ in 0..config.total {
refresh_polys.push(generate_refresh_polynomial(config.threshold));
}
let mut new_shares: Vec<KeyShare> = Vec::new();
for recipient in &shares {
let mut deltas = Vec::new();
for poly in &refresh_polys {
deltas.push(evaluate_at(poly, recipient.index));
}
let new_secret = accumulate_refresh(&recipient.secret_share, &deltas);
let new_public = new_secret * ED25519_BASEPOINT_POINT;
new_shares.push(KeyShare {
index: recipient.index,
secret_share: new_secret,
public_share: new_public,
});
}
let indices: Vec<u32> = new_shares.iter().map(|s| s.index).collect();
let mut reconstructed_mpk = EdwardsPoint::default();
for share in &new_shares {
let lambda = combine::lagrange_coefficient(share.index, &indices).unwrap();
reconstructed_mpk += lambda * share.public_share;
}
assert_eq!(
reconstructed_mpk.compress(),
original_mpk.compress(),
"MPK must be unchanged after refresh"
);
let public_shares: HashMap<u32, EdwardsPoint> = new_shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<_> = [&new_shares[0], &new_shares[2]]
.iter()
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
aad,
&public_shares,
config.threshold,
)
.unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn refresh_share_verification() {
let poly = generate_refresh_polynomial(3);
let commitments = polynomial_commitments(&poly);
let share_val = evaluate_at(&poly, 2);
assert!(
verify_refresh_share(&share_val, 2, &commitments),
"valid share must verify"
);
let tampered = share_val + Scalar::ONE;
assert!(
!verify_refresh_share(&tampered, 2, &commitments),
"tampered share must fail verification"
);
assert!(
!verify_refresh_share(&share_val, 3, &commitments),
"wrong index must fail verification"
);
}
#[test]
fn reshare_to_larger_set() {
let old_config = ThresholdConfig { threshold: 2, total: 2 };
let (tpk, _commitment, old_shares) = dealer::generate_shares(old_config).unwrap();
let original_mpk = tpk.edwards;
let pk = hpke::HpkePublicKey::from_bytes(&tpk.hpke_public_key).unwrap();
let plaintext = b"reshare expansion test";
let aad = b"reshare-context";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).unwrap();
let new_threshold = 2u32;
let new_total = 3u32;
let old_participants: Vec<u32> = old_shares.iter().map(|s| s.index).collect();
let reshare_polys: Vec<Vec<Scalar>> = old_shares
.iter()
.map(|s| generate_reshare_polynomial(&s.secret_share, new_threshold))
.collect();
let mut new_shares: Vec<KeyShare> = Vec::new();
for new_index in 1..=new_total {
let sub_shares: Vec<(u32, Scalar)> = old_shares
.iter()
.zip(reshare_polys.iter())
.map(|(old_share, poly)| {
let sub_share = evaluate_at(poly, new_index);
(old_share.index, sub_share)
})
.collect();
let new_secret = combine_reshare(&sub_shares, &old_participants).unwrap();
let new_public = new_secret * ED25519_BASEPOINT_POINT;
new_shares.push(KeyShare {
index: new_index,
secret_share: new_secret,
public_share: new_public,
});
}
let new_indices: Vec<u32> = new_shares.iter().map(|s| s.index).collect();
let mut reconstructed_mpk = EdwardsPoint::default();
for share in &new_shares {
let lambda = combine::lagrange_coefficient(share.index, &new_indices).unwrap();
reconstructed_mpk += lambda * share.public_share;
}
assert_eq!(
reconstructed_mpk.compress(),
original_mpk.compress(),
"MPK must be unchanged after resharing to larger set"
);
let public_shares: HashMap<u32, EdwardsPoint> = new_shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let subsets: Vec<Vec<usize>> = vec![vec![0, 1], vec![0, 2], vec![1, 2]];
for subset in &subsets {
let partials: Vec<_> = subset
.iter()
.map(|&i| compute_partial_decryption(new_shares[i].index, &new_shares[i].secret_share, &enc_edwards))
.collect();
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
aad,
&public_shares,
new_threshold,
)
.unwrap();
assert_eq!(
recovered[..],
plaintext[..],
"decryption failed for subset {:?}",
subset
);
}
}
#[test]
fn reshare_to_smaller_set() {
let old_config = ThresholdConfig { threshold: 2, total: 3 };
let (tpk, _commitment, old_shares) = dealer::generate_shares(old_config).unwrap();
let original_mpk = tpk.edwards;
let pk = hpke::HpkePublicKey::from_bytes(&tpk.hpke_public_key).unwrap();
let plaintext = b"reshare shrink test";
let aad = b"shrink-context";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).unwrap();
let new_threshold = 2u32;
let new_total = 2u32;
let participating_old: Vec<&KeyShare> = old_shares.iter().take(old_config.threshold as usize).collect();
let old_participants: Vec<u32> = participating_old.iter().map(|s| s.index).collect();
let reshare_polys: Vec<Vec<Scalar>> = participating_old
.iter()
.map(|s| generate_reshare_polynomial(&s.secret_share, new_threshold))
.collect();
let mut new_shares: Vec<KeyShare> = Vec::new();
for new_index in 1..=new_total {
let sub_shares: Vec<(u32, Scalar)> = participating_old
.iter()
.zip(reshare_polys.iter())
.map(|(old_share, poly)| {
let sub_share = evaluate_at(poly, new_index);
(old_share.index, sub_share)
})
.collect();
let new_secret = combine_reshare(&sub_shares, &old_participants).unwrap();
let new_public = new_secret * ED25519_BASEPOINT_POINT;
new_shares.push(KeyShare {
index: new_index,
secret_share: new_secret,
public_share: new_public,
});
}
let new_indices: Vec<u32> = new_shares.iter().map(|s| s.index).collect();
let mut reconstructed_mpk = EdwardsPoint::default();
for share in &new_shares {
let lambda = combine::lagrange_coefficient(share.index, &new_indices).unwrap();
reconstructed_mpk += lambda * share.public_share;
}
assert_eq!(
reconstructed_mpk.compress(),
original_mpk.compress(),
"MPK must be unchanged after resharing to smaller set"
);
let public_shares: HashMap<u32, EdwardsPoint> = new_shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<_> = new_shares
.iter()
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
aad,
&public_shares,
new_threshold,
)
.unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
}