use crate::bigint::{BigUint, MontgomeryCtx};
use crate::csprng::Csprng;
use crate::field::PrimeField;
use crate::poly::{horner, lagrange_eval};
use crate::primes::{is_probable_prime, random_below};
use crate::secure::Zeroizing;
#[derive(Clone, Debug)]
pub struct DlogGroup {
p: BigUint,
q: BigUint,
g: BigUint,
mont: MontgomeryCtx,
}
impl DlogGroup {
#[must_use]
pub fn new(p: BigUint, q: BigUint, g: BigUint) -> Option<Self> {
if p < BigUint::from_u64(3) {
return None;
}
if !p.is_odd() {
return None;
}
if !is_probable_prime(&p) {
return None;
}
if q <= BigUint::one() || q >= p {
return None;
}
if !is_probable_prime(&q) {
return None;
}
let p_minus_1 = p.sub_ref(&BigUint::one());
let (_, rem) = p_minus_1.div_rem(&q);
if !rem.is_zero() {
return None;
}
let g = g.modulo(&p);
if g.is_zero() || g == BigUint::one() {
return None;
}
let mont = MontgomeryCtx::new(&p)?;
let one = mont.pow(&g, &q);
if one != BigUint::one() {
return None;
}
Some(Self { p, q, g, mont })
}
#[must_use]
pub fn p(&self) -> &BigUint {
&self.p
}
#[must_use]
pub fn q(&self) -> &BigUint {
&self.q
}
#[must_use]
pub fn g(&self) -> &BigUint {
&self.g
}
#[must_use]
pub fn pow(&self, base: &BigUint, exp: &BigUint) -> BigUint {
self.mont.pow(base, exp)
}
#[must_use]
pub fn mul(&self, a: &BigUint, b: &BigUint) -> BigUint {
BigUint::mod_mul(a, b, &self.p)
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct VssShare {
pub player: usize,
pub value: BigUint,
}
impl core::fmt::Debug for VssShare {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("VssShare(<elided>)")
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct Commitments {
pub c: Vec<BigUint>,
}
impl core::fmt::Debug for Commitments {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("Commitments(<elided>)")
}
}
#[must_use]
pub fn deal<R: Csprng>(
group: &DlogGroup,
rng: &mut R,
secret: &BigUint,
k: usize,
n: usize,
) -> (Vec<VssShare>, Commitments) {
assert!(k >= 2, "k must be at least 2");
assert!(n >= k, "n must be at least k");
assert!(
BigUint::from_u64(n as u64) < *group.q(),
"subgroup order must exceed n",
);
assert!(secret < group.q(), "secret must be < q");
let q_field = PrimeField::new_unchecked(group.q().clone());
let mut coeffs = Zeroizing::new(Vec::<BigUint>::with_capacity(k));
coeffs.push(secret.clone());
for _ in 1..k {
let v = random_below(rng, group.q()).expect("q > 0");
coeffs.push(v);
}
let commitments = Commitments {
c: coeffs.iter().map(|a| group.pow(group.g(), a)).collect(),
};
let shares: Vec<VssShare> = (1..=n)
.map(|j| {
let x = BigUint::from_u64(j as u64);
let y = horner(&q_field, &coeffs, &x);
VssShare {
player: j,
value: y,
}
})
.collect();
(shares, commitments)
}
#[must_use]
pub fn verify_share(group: &DlogGroup, commits: &Commitments, share: &VssShare) -> bool {
if share.player == 0 {
return false;
}
let j_big = BigUint::from_u64(share.player as u64);
if j_big >= *group.q() {
return false;
}
for c_i in &commits.c {
if c_i.is_zero() {
return false;
}
if group.pow(c_i, group.q()) != BigUint::one() {
return false;
}
}
let lhs = group.pow(group.g(), &share.value);
let q = group.q();
let j = BigUint::from_u64(share.player as u64);
let mut rhs = BigUint::one();
let mut pow_j = BigUint::one(); for c_i in &commits.c {
let term = group.pow(c_i, &pow_j);
rhs = group.mul(&rhs, &term);
pow_j = BigUint::mod_mul(&pow_j, &j, q);
}
lhs == rhs
}
#[must_use]
pub fn reconstruct(group: &DlogGroup, shares: &[VssShare], k: usize) -> Option<BigUint> {
if k < 2 || shares.len() < k {
return None;
}
let q_field = PrimeField::new_unchecked(group.q().clone());
for s in shares {
if s.player == 0 {
return None;
}
let j_big = BigUint::from_u64(s.player as u64);
if j_big >= *group.q() {
return None;
}
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].player == shares[j].player {
return None;
}
}
}
let pts: Vec<(BigUint, BigUint)> = shares
.iter()
.take(k)
.map(|s| (BigUint::from_u64(s.player as u64), s.value.clone()))
.collect();
lagrange_eval(&q_field, &pts, &BigUint::zero())
}
#[must_use]
pub fn small_test_group() -> DlogGroup {
DlogGroup::new(
BigUint::from_u64(23),
BigUint::from_u64(11),
BigUint::from_u64(4),
)
.expect("hand-validated toy group")
}
#[must_use]
pub fn rfc5114_modp_2048_256() -> DlogGroup {
const P_HEX: &str = "\
87A8E61DB4B6663CFFBBD19C651959998CEEF608660DD0F2\
5D2CEED4435E3B00E00DF8F1D61957D4FAF7DF4561B2AA30\
16C3D91134096FAA3BF4296D830E9A7C209E0C6497517ABD\
5A8A9D306BCF67ED91F9E6725B4758C022E0B1EF4275BF7B\
6C5BFC11D45F9088B941F54EB1E59BB8BC39A0BF12307F5C\
4FDB70C581B23F76B63ACAE1CAA6B7902D52526735488A0E\
F13C6D9A51BFA4AB3AD8347796524D8EF6A167B5A41825D9\
67E144E5140564251CCACB83E6B486F6B3CA3F7971506026\
C0B857F689962856DED4010ABD0BE621C3A3960A54E710C3\
75F26375D7014103A4B54330C198AF126116D2276E11715F\
693877FAD7EF09CADB094AE91E1A1597";
const Q_HEX: &str =
"8CF83642A709A097B447997640129DA299B1A47D1EB3750BA308B0FE64F5FBD3";
const G_HEX: &str = "\
3FB32C9B73134D0B2E77506660EDBD484CA7B18F21EF205407F4793A1A0BA125\
10DBC15077BE463FFF4FED4AAC0BB555BE3A6C1B0C6B47B1BC3773BF7E8C6F62\
901228F8C28CBB18A55AE31341000A650196F931C77A57F2DDF463E5E9EC144B\
777DE62AAAB8A8628AC376D282D6ED3864E67982428EBC831D14348F6F2F9193\
B5045AF2767164E1DFC967C1FB3F2E55A4BD1BFFE83B9C80D052B985D182EA0A\
DB2A3B7313D3FE14C8484B1E052588B9B7D2BBD2DF016199ECD06E1557CD0915\
B3353BBB64E0EC377FD028370DF92B52C7891428CDC67EB6184B523D1DB246C3\
2F63078490F00EF8D647D148D47954515E2327CFEF98C582664B4C0F6CC41659";
DlogGroup::new(hex_to_biguint(P_HEX), hex_to_biguint(Q_HEX), hex_to_biguint(G_HEX))
.expect("RFC 5114 §2.3 group is well-formed")
}
fn hex_to_biguint(hex: &str) -> BigUint {
let mut bytes = Vec::with_capacity(hex.len() / 2);
let cleaned: Vec<u8> = hex.bytes().filter(|b| !b.is_ascii_whitespace()).collect();
assert!(cleaned.len().is_multiple_of(2), "hex literal must have even length");
for chunk in cleaned.chunks_exact(2) {
let hi = nibble(chunk[0]);
let lo = nibble(chunk[1]);
bytes.push((hi << 4) | lo);
}
BigUint::from_be_bytes(&bytes)
}
fn nibble(b: u8) -> u8 {
match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => b - b'a' + 10,
b'A'..=b'F' => b - b'A' + 10,
_ => panic!("non-hex character in literal: {b:#x}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[0xC9u8; 32])
}
#[test]
fn round_trip_3_of_5() {
let group = small_test_group();
let mut r = rng();
let secret = BigUint::from_u64(7); let (shares, commits) = deal(&group, &mut r, &secret, 3, 5);
assert_eq!(shares.len(), 5);
assert_eq!(commits.c.len(), 3);
for s in &shares {
assert!(verify_share(&group, &commits, s), "player {} verifies", s.player);
}
assert_eq!(reconstruct(&group, &shares[..3], 3), Some(secret.clone()));
assert_eq!(reconstruct(&group, &shares[2..5], 3), Some(secret));
}
#[test]
fn tampered_share_fails_verification() {
let group = small_test_group();
let mut r = rng();
let secret = BigUint::from_u64(3);
let (mut shares, commits) = deal(&group, &mut r, &secret, 3, 5);
let q = BigUint::from_u64(11);
shares[2].value = shares[2].value.add_ref(&BigUint::one()).modulo(&q);
assert!(!verify_share(&group, &commits, &shares[2]));
for s in shares.iter().filter(|s| s.player != 3) {
assert!(verify_share(&group, &commits, s));
}
}
#[test]
fn verify_rejects_oversized_player() {
let group = small_test_group();
let mut r = rng();
let secret = BigUint::from_u64(2);
let (_shares, commits) = deal(&group, &mut r, &secret, 3, 5);
let bad = VssShare {
player: 11,
value: BigUint::zero(),
};
assert!(!verify_share(&group, &commits, &bad));
}
#[test]
fn verify_rejects_non_subgroup_commitment() {
let group = small_test_group();
let mut r = rng();
let secret = BigUint::from_u64(2);
let (shares, mut commits) = deal(&group, &mut r, &secret, 3, 5);
commits.c[0] = BigUint::from_u64(5);
for s in &shares {
assert!(!verify_share(&group, &commits, s));
}
}
#[test]
fn tampered_commitment_breaks_all_shares() {
let group = small_test_group();
let mut r = rng();
let secret = BigUint::from_u64(5);
let (shares, mut commits) = deal(&group, &mut r, &secret, 3, 5);
commits.c[0] = group.mul(&commits.c[0], &BigUint::from_u64(2));
for s in &shares {
assert!(!verify_share(&group, &commits, s));
}
}
#[test]
fn below_threshold_reconstruct_returns_none() {
let group = small_test_group();
let mut r = rng();
let secret = BigUint::from_u64(8);
let (shares, _) = deal(&group, &mut r, &secret, 3, 5);
assert!(reconstruct(&group, &shares[..2], 3).is_none());
}
#[test]
fn duplicate_player_in_reconstruct_returns_none() {
let group = small_test_group();
let mut r = rng();
let secret = BigUint::from_u64(2);
let (shares, _) = deal(&group, &mut r, &secret, 3, 5);
let dup = vec![shares[0].clone(), shares[0].clone(), shares[1].clone()];
assert!(reconstruct(&group, &dup, 3).is_none());
}
#[test]
fn larger_group_round_trip() {
let group = DlogGroup::new(
BigUint::from_u64(167),
BigUint::from_u64(83),
BigUint::from_u64(4),
)
.expect("p=167, q=83 is a valid Schnorr group with g=4");
let mut r = rng();
let secret = BigUint::from_u64(42);
let (shares, commits) = deal(&group, &mut r, &secret, 4, 7);
for s in &shares {
assert!(verify_share(&group, &commits, s));
}
assert_eq!(reconstruct(&group, &shares[..4], 4), Some(secret));
}
#[test]
fn rejects_identity_generator_after_reduction() {
let bad = DlogGroup::new(
BigUint::from_u64(23),
BigUint::from_u64(11),
BigUint::from_u64(24), );
assert!(bad.is_none());
}
#[test]
fn rejects_non_subgroup_generator() {
let bad = DlogGroup::new(
BigUint::from_u64(23),
BigUint::from_u64(11),
BigUint::from_u64(5),
);
assert!(bad.is_none());
}
#[test]
fn rejects_even_modulus() {
let bad = DlogGroup::new(
BigUint::from_u64(22),
BigUint::from_u64(11),
BigUint::from_u64(4),
);
assert!(bad.is_none());
}
#[test]
fn rejects_q_too_large() {
let bad = DlogGroup::new(
BigUint::from_u64(23),
BigUint::from_u64(23),
BigUint::from_u64(4),
);
assert!(bad.is_none());
}
}