use crate::field::PrimeField;
use crate::poly::{horner, lagrange_eval};
use crate::bigint::BigUint;
use crate::csprng::Csprng;
use crate::secure::{ct_eq_biguint, Zeroizing};
#[derive(Clone, Eq, PartialEq)]
pub struct Share {
pub x: BigUint,
pub y: BigUint,
}
impl core::fmt::Debug for Share {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("Share(<elided>)")
}
}
#[must_use]
pub fn split<R: Csprng>(
field: &PrimeField,
rng: &mut R,
secret: &BigUint,
k: usize,
n: usize,
) -> Vec<Share> {
assert!(k >= 2, "k must be at least 2 (k = 1 would leak the secret)");
assert!(n >= k, "n must be at least k");
assert!(
BigUint::from_u64(n as u64) < *field.modulus(),
"prime modulus must exceed n",
);
let mut coeffs = Zeroizing::new(Vec::<BigUint>::with_capacity(k));
coeffs.push(field.reduce(secret));
for _ in 1..k {
coeffs.push(field.random(rng));
}
(1..=n)
.map(|i| {
let x = BigUint::from_u64(i as u64);
let y = horner(field, &coeffs, &x);
Share { x, y }
})
.collect()
}
#[must_use]
pub fn reconstruct(field: &PrimeField, shares: &[Share], k: usize) -> Option<BigUint> {
if k == 0 || shares.len() < k {
return None;
}
for s in shares {
if s.x.is_zero() {
return None;
}
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].x == shares[j].x {
return None;
}
}
}
let pts: Vec<(BigUint, BigUint)> = shares
.iter()
.take(k)
.map(|s| (s.x.clone(), s.y.clone()))
.collect();
let secret = lagrange_eval(field, &pts, &BigUint::zero())?;
for s in &shares[k..] {
let pred = lagrange_eval(field, &pts, &s.x)?;
if !ct_eq_biguint(&pred, &s.y) {
return None;
}
}
Some(secret)
}
#[must_use]
pub fn split_multi<R: Csprng>(
field: &PrimeField,
rng: &mut R,
secrets: &[BigUint],
k: usize,
n: usize,
) -> Vec<Share> {
let l = secrets.len();
assert!(k >= 2, "k must be at least 2 (k = 1 would leak the secret)");
assert!(l >= 1 && l <= k, "need 1 ≤ ℓ ≤ k secrets");
assert!(n >= k, "n must be at least k");
assert!(
BigUint::from_u64(n as u64) < *field.modulus(),
"prime modulus must exceed n",
);
let mut coeffs = Zeroizing::new(Vec::<BigUint>::with_capacity(k));
for s in secrets {
coeffs.push(field.reduce(s));
}
for _ in l..k {
coeffs.push(field.random(rng));
}
(1..=n)
.map(|i| {
let x = BigUint::from_u64(i as u64);
let y = horner(field, &coeffs, &x);
Share { x, y }
})
.collect()
}
#[must_use]
#[allow(clippy::needless_range_loop)] pub fn reconstruct_multi(
field: &PrimeField,
shares: &[Share],
k: usize,
ell: usize,
) -> Option<Vec<BigUint>> {
if ell == 0 || ell > k || shares.len() < k {
return None;
}
for s in shares {
if s.x.is_zero() {
return None;
}
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].x == shares[j].x {
return None;
}
}
}
let used = &shares[..k];
let mut mat: Vec<Vec<BigUint>> = Vec::with_capacity(k);
for s in used {
let mut row = Vec::with_capacity(k + 1);
let mut x_pow = BigUint::one();
for _ in 0..k {
row.push(x_pow.clone());
x_pow = field.mul(&x_pow, &s.x);
}
row.push(s.y.clone());
mat.push(row);
}
for col in 0..k {
let mut pivot_row = None;
for r in col..k {
if !mat[r][col].is_zero() {
pivot_row = Some(r);
break;
}
}
let pr = pivot_row?;
if pr != col {
mat.swap(pr, col);
}
let inv = field.inv(&mat[col][col])?;
for c in col..=k {
mat[col][c] = field.mul(&mat[col][c], &inv);
}
for r in 0..k {
if r == col {
continue;
}
if mat[r][col].is_zero() {
continue;
}
let factor = mat[r][col].clone();
for c in col..=k {
let term = field.mul(&factor, &mat[col][c]);
mat[r][c] = field.sub(&mat[r][c], &term);
}
}
}
let coeffs: Vec<BigUint> = (0..k).map(|i| mat[i][k].clone()).collect();
for s in &shares[k..] {
let pred = crate::poly::horner(field, &coeffs, &s.x);
if !ct_eq_biguint(&pred, &s.y) {
return None;
}
}
Some(coeffs.into_iter().take(ell).collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[42u8; 32])
}
fn small_field() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
#[test]
fn basic_round_trip() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0xC0FFEE);
for &(k, n) in &[(2usize, 3usize), (3, 5), (5, 9), (2, 7)] {
let shares = split(&f, &mut r, &secret, k, n);
assert_eq!(shares.len(), n);
assert_eq!(reconstruct(&f, &shares[..k], k), Some(secret.clone()));
assert_eq!(reconstruct(&f, &shares[1..1 + k], k), Some(secret.clone()));
}
}
#[test]
fn k_minus_one_does_not_yield_secret() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(987_654_321);
let shares = split(&f, &mut r, &secret, 4, 7);
assert!(reconstruct(&f, &shares[..3], 4).is_none());
let partial = reconstruct(&f, &shares[..3], 3);
assert_ne!(partial, Some(secret));
}
#[test]
fn duplicate_x_is_rejected() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(7);
let mut shares = split(&f, &mut r, &secret, 2, 3);
shares[1].x = shares[0].x.clone();
assert!(reconstruct(&f, &shares, 2).is_none());
}
#[test]
fn multi_secret_recovers_all_secrets() {
let f = small_field();
let mut r = rng();
let secrets: Vec<BigUint> = (1..=3).map(|i| BigUint::from_u64(1000 + i)).collect();
let shares = split_multi(&f, &mut r, &secrets, 4, 6);
let recovered = reconstruct_multi(&f, &shares, 4, 3).expect("decode");
assert_eq!(recovered, secrets);
}
#[test]
fn multi_secret_threshold_holds() {
let f = small_field();
let mut r = rng();
let secrets: Vec<BigUint> = (1..=3).map(|i| BigUint::from_u64(2000 + i)).collect();
let shares = split_multi(&f, &mut r, &secrets, 4, 6);
assert!(reconstruct_multi(&f, &shares[..3], 4, 3).is_none());
}
#[test]
fn multi_secret_rejects_inconsistent_extra_share() {
let f = small_field();
let mut r = rng();
let secrets: Vec<BigUint> = (1..=2).map(|i| BigUint::from_u64(50 + i)).collect();
let mut shares = split_multi(&f, &mut r, &secrets, 3, 6);
shares[5].y = f.add(&shares[5].y, &BigUint::from_u64(1));
assert!(reconstruct_multi(&f, &shares, 3, 2).is_none());
}
#[test]
fn reconstruct_below_threshold_returns_none() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0xCAFE);
let shares = split(&f, &mut r, &secret, 4, 7);
assert!(reconstruct(&f, &shares[..3], 4).is_none());
}
#[test]
fn reconstruct_rejects_inconsistent_extra_share() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0xBEEF);
let mut shares = split(&f, &mut r, &secret, 3, 6);
shares[5].y = f.add(&shares[5].y, &BigUint::from_u64(1));
assert!(reconstruct(&f, &shares, 3).is_none());
}
#[test]
fn multi_secret_with_l_equals_one_matches_shamir() {
let f = small_field();
let secret = BigUint::from_u64(0xDEAD_BEEF);
let mut r1 = rng();
let mut r2 = rng();
let shares = split_multi(&f, &mut r1, std::slice::from_ref(&secret), 3, 5);
let _ = split(&f, &mut r2, &secret, 3, 5); assert_eq!(
reconstruct_multi(&f, &shares, 3, 1).map(|v| v[0].clone()),
Some(secret)
);
}
}