use std::collections::HashMap;
use curve25519_dalek::{
constants::ED25519_BASEPOINT_POINT,
edwards::{CompressedEdwardsY, EdwardsPoint},
montgomery::MontgomeryPoint,
scalar::Scalar,
};
use zeroize::Zeroizing;
use super::{dleq, types::PartialDecryption};
use crate::crypto::{error::CryptoError, threshold::decrypt_with_precomputed_dh};
pub fn montgomery_to_edwards(montgomery_bytes: &[u8; 32]) -> Result<EdwardsPoint, CryptoError> {
let montgomery = MontgomeryPoint(*montgomery_bytes);
montgomery
.to_edwards(0)
.ok_or_else(|| CryptoError::ThresholdDecrypt("failed to lift Montgomery point to Edwards".into()))
}
pub fn compute_partial_decryption(index: u32, secret_share: &Scalar, enc_edwards: &EdwardsPoint) -> PartialDecryption {
let base_g = ED25519_BASEPOINT_POINT;
let public_key = secret_share * base_g;
let partial = secret_share * enc_edwards;
let proof = dleq::prove(secret_share, &base_g, &public_key, enc_edwards, &partial);
PartialDecryption { index, partial, proof }
}
pub fn combine_partial_decryptions(
partials: &[PartialDecryption],
enc_edwards: &EdwardsPoint,
public_shares: &HashMap<u32, EdwardsPoint>,
threshold: u32,
) -> Result<[u8; 32], CryptoError> {
if (partials.len() as u32) < threshold {
return Err(CryptoError::ThresholdDecrypt(format!(
"insufficient partials: got {}, need {}",
partials.len(),
threshold
)));
}
let base_g = ED25519_BASEPOINT_POINT;
let mut valid_partials: Vec<&PartialDecryption> = Vec::with_capacity(partials.len());
for pd in partials {
let pk = public_shares
.get(&pd.index)
.ok_or_else(|| CryptoError::ThresholdDecrypt(format!("unknown operator index {}", pd.index)))?;
if !dleq::verify(&pd.proof, &base_g, pk, enc_edwards, &pd.partial) {
tracing::warn!(operator_index = pd.index, "DLEQ proof verification failed");
continue;
}
valid_partials.push(pd);
}
let mut seen_indices = std::collections::HashSet::new();
let valid_partials: Vec<_> = valid_partials
.into_iter()
.filter(|pd| seen_indices.insert(pd.index))
.collect();
if (valid_partials.len() as u32) < threshold {
return Err(CryptoError::ThresholdDecrypt(format!(
"insufficient valid partials after DLEQ verification: got {}, need {}",
valid_partials.len(),
threshold
)));
}
let partials_to_use = &valid_partials[..threshold as usize];
let indices: Vec<u32> = partials_to_use.iter().map(|p| p.index).collect();
let mut combined = EdwardsPoint::default();
for pd in partials_to_use {
let lambda = lagrange_coefficient(pd.index, &indices)?;
combined += lambda * pd.partial;
}
Ok(combined.to_montgomery().to_bytes())
}
pub fn threshold_decrypt(
partials: &[PartialDecryption],
enc: &[u8; 32],
pk_r: &[u8; 32],
ciphertext: &[u8],
aad: &[u8],
public_shares: &HashMap<u32, EdwardsPoint>,
threshold: u32,
) -> Result<Zeroizing<Vec<u8>>, CryptoError> {
let enc_edwards = montgomery_to_edwards(enc)?;
if enc_edwards.is_small_order() {
return Err(CryptoError::ThresholdDecrypt(
"enc point is low-order, rejecting to prevent small subgroup attack".into(),
));
}
let dh = combine_partial_decryptions(partials, &enc_edwards, public_shares, threshold)?;
decrypt_with_precomputed_dh(&dh, enc, pk_r, ciphertext, aad)
}
pub(crate) fn lagrange_coefficient(i: u32, participants: &[u32]) -> Result<Scalar, CryptoError> {
let x_i = Scalar::from(i);
let mut numerator = Scalar::ONE;
let mut denominator = Scalar::ONE;
for &j in participants {
if j == i {
continue;
}
let x_j = Scalar::from(j);
numerator *= x_j;
denominator *= x_j - x_i;
}
if denominator == Scalar::ZERO {
return Err(CryptoError::ThresholdDecrypt(
"duplicate participant index in Lagrange interpolation".into(),
));
}
Ok(numerator * denominator.invert())
}
pub fn threshold_context_from_public_shares(
shares: &[(u32, [u8; 32])],
config: super::types::ThresholdConfig,
) -> Result<super::types::ThresholdDecryptionContext, CryptoError> {
if shares.is_empty() {
return Err(CryptoError::ThresholdDecrypt("no public shares provided".into()));
}
if (shares.len() as u32) < config.threshold {
return Err(CryptoError::ThresholdDecrypt(format!(
"expected at least {} public shares, got {}",
config.threshold,
shares.len()
)));
}
let mut public_shares: HashMap<u32, EdwardsPoint> = HashMap::new();
for (index, bytes) in shares {
let compressed = CompressedEdwardsY(*bytes);
let point = compressed.decompress().ok_or_else(|| {
CryptoError::ThresholdDecrypt(format!("invalid compressed EdwardsPoint for share index {index}"))
})?;
if public_shares.insert(*index, point).is_some() {
return Err(CryptoError::ThresholdDecrypt(format!("duplicate share index {index}")));
}
}
let indices: Vec<u32> = public_shares.keys().copied().collect();
let mut mpk = EdwardsPoint::default(); for &i in &indices {
let pk_i = public_shares[&i];
let lambda = lagrange_coefficient(i, &indices)?;
mpk += lambda * pk_i;
}
let mpk_montgomery = mpk.to_montgomery();
let threshold_pk = super::types::ThresholdPublicKey {
edwards: mpk,
hpke_public_key: mpk_montgomery.to_bytes(),
};
Ok(super::types::ThresholdDecryptionContext {
public_key: threshold_pk,
public_shares,
config,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
crypto::hpke,
dkg::{dealer, types::ThresholdConfig},
};
#[allow(clippy::type_complexity)]
fn setup_threshold_encryption(
t: u32,
n: u32,
) -> (
Vec<crate::dkg::types::KeyShare>,
crate::dkg::types::ThresholdPublicKey,
Vec<u8>, // enc
Vec<u8>, // ciphertext
Vec<u8>, // plaintext
Vec<u8>, // aad
) {
let config = ThresholdConfig { threshold: t, total: n };
let (tpk, _commitment, shares) = dealer::generate_shares(config).unwrap();
let pk = crate::crypto::hpke::HpkePublicKey::from_bytes(&tpk.hpke_public_key).unwrap();
let plaintext = b"threshold decryption test payload".to_vec();
let aad = b"newton-privacy-context".to_vec();
let (enc, ct) = hpke::encrypt(&pk, &plaintext, &aad).unwrap();
(shares, tpk, enc, ct, plaintext, aad)
}
#[test]
fn lagrange_coefficients_sum_to_one_at_zero() {
let indices = vec![1u32, 2, 3];
let mut sum = Scalar::ZERO;
for &i in &indices {
sum += lagrange_coefficient(i, &indices).unwrap();
}
assert_eq!(sum, Scalar::ONE);
}
#[test]
fn lagrange_reconstructs_secret() {
let config = ThresholdConfig { threshold: 3, total: 5 };
let (tpk, _commitment, shares) = dealer::generate_shares(config).unwrap();
let selected = vec![&shares[0], &shares[2], &shares[4]];
let indices: Vec<u32> = selected.iter().map(|s| s.index).collect();
let mut reconstructed = Scalar::ZERO;
for share in &selected {
let lambda = lagrange_coefficient(share.index, &indices).unwrap();
reconstructed += lambda * share.secret_share;
}
let expected_mpk = reconstructed * ED25519_BASEPOINT_POINT;
assert_eq!(expected_mpk.compress(), tpk.edwards.compress());
}
#[test]
fn full_threshold_decrypt_3_of_5() {
let (shares, tpk, enc, ct, plaintext, aad) = setup_threshold_encryption(3, 5);
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<PartialDecryption> = [&shares[0], &shares[2], &shares[4]]
.iter()
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
&aad,
&public_shares,
3,
)
.unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn full_threshold_decrypt_2_of_3() {
let (shares, tpk, enc, ct, plaintext, aad) = setup_threshold_encryption(2, 3);
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<PartialDecryption> = [&shares[1], &shares[2]]
.iter()
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
&aad,
&public_shares,
2,
)
.unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn full_threshold_decrypt_all_shares() {
let (shares, tpk, enc, ct, plaintext, aad) = setup_threshold_encryption(3, 5);
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<PartialDecryption> = shares
.iter()
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
&aad,
&public_shares,
3,
)
.unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn threshold_decrypt_1_of_1() {
let (shares, tpk, enc, ct, plaintext, aad) = setup_threshold_encryption(1, 1);
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials = vec![compute_partial_decryption(
shares[0].index,
&shares[0].secret_share,
&enc_edwards,
)];
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
&aad,
&public_shares,
1,
)
.unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn insufficient_partials_fails() {
let (shares, tpk, enc, ct, _plaintext, aad) = setup_threshold_encryption(3, 5);
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<PartialDecryption> = [&shares[0], &shares[1]]
.iter()
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let result = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
&aad,
&public_shares,
3,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("insufficient"));
}
#[test]
fn invalid_dleq_proof_filtered_out() {
let (shares, tpk, enc, ct, plaintext, aad) = setup_threshold_encryption(2, 4);
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let mut partials: Vec<PartialDecryption> = shares
.iter()
.take(3)
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
partials[0].proof.challenge += Scalar::ONE;
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
&aad,
&public_shares,
2,
)
.unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn empty_aad_works() {
let config = ThresholdConfig { threshold: 2, total: 3 };
let (tpk, _, shares) = dealer::generate_shares(config).unwrap();
let pk = crate::crypto::hpke::HpkePublicKey::from_bytes(&tpk.hpke_public_key).unwrap();
let plaintext = b"no aad";
let aad = b"";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).unwrap();
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let partials: Vec<PartialDecryption> = shares[..2]
.iter()
.map(|s| compute_partial_decryption(s.index, &s.secret_share, &enc_edwards))
.collect();
let recovered =
threshold_decrypt(&partials, &enc_bytes, &tpk.hpke_public_key, &ct, aad, &public_shares, 2).unwrap();
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn any_t_subset_produces_same_result() {
let (shares, tpk, enc, ct, plaintext, aad) = setup_threshold_encryption(3, 5);
let public_shares: HashMap<u32, EdwardsPoint> = shares.iter().map(|s| (s.index, s.public_share)).collect();
let enc_bytes: [u8; 32] = enc[..32].try_into().unwrap();
let enc_edwards = montgomery_to_edwards(&enc_bytes).unwrap();
let subsets: Vec<Vec<usize>> = vec![vec![0, 1, 2], vec![0, 2, 4], vec![1, 3, 4], vec![2, 3, 4]];
for subset in &subsets {
let partials: Vec<PartialDecryption> = subset
.iter()
.map(|&i| compute_partial_decryption(shares[i].index, &shares[i].secret_share, &enc_edwards))
.collect();
let recovered = threshold_decrypt(
&partials,
&enc_bytes,
&tpk.hpke_public_key,
&ct,
&aad,
&public_shares,
3,
)
.unwrap();
assert_eq!(
recovered[..],
plaintext[..],
"subset {:?} produced wrong result",
subset
);
}
}
}