use crate::field::PrimeField;
use crate::poly::lagrange_eval;
use crate::secure::ct_eq_biguint;
use crate::shamir::Share;
use crate::bigint::BigUint;
use crate::csprng::Csprng;
#[must_use]
pub fn split<R: Csprng>(
field: &PrimeField,
rng: &mut R,
secret: &[BigUint],
k: usize,
n: usize,
) -> Vec<Share> {
let l = secret.len();
assert!(l >= 1, "secret must be non-empty");
assert!(l <= k, "L = secret.len() must be ≤ k");
assert!(k >= 2, "k must be at least 2");
assert!(n >= k, "n must be at least k (otherwise the share set is unreconstructable)");
assert!(
BigUint::from_u64((k + n) as u64) < *field.modulus(),
"prime modulus must exceed k + n",
);
let mut anchors: Vec<(BigUint, BigUint)> = Vec::with_capacity(k);
for (j, s) in secret.iter().enumerate() {
anchors.push((BigUint::from_u64((j + 1) as u64), field.reduce(s)));
}
for j in l..k {
anchors.push((BigUint::from_u64((j + 1) as u64), field.random(rng)));
}
(1..=n)
.map(|i| {
let x = BigUint::from_u64((k + i) as u64);
let y = lagrange_eval(field, &anchors, &x).expect("distinct anchors");
Share { x, y }
})
.collect()
}
#[must_use]
pub fn reconstruct(
field: &PrimeField,
shares: &[Share],
k: usize,
l: usize,
) -> Option<Vec<BigUint>> {
if shares.is_empty() || k < 2 || l == 0 || l > k || shares.len() < k {
return None;
}
let k_big = BigUint::from_u64(k as u64);
for s in shares {
let xr = field.reduce(&s.x);
if xr.is_zero() || xr <= k_big {
return None;
}
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].x == shares[j].x {
return None;
}
}
}
let pts: Vec<(BigUint, BigUint)> = shares
.iter()
.take(k)
.map(|s| (s.x.clone(), s.y.clone()))
.collect();
for s in &shares[k..] {
let pred = lagrange_eval(field, &pts, &s.x)?;
if !ct_eq_biguint(&pred, &s.y) {
return None;
}
}
let mut out = Vec::with_capacity(l);
for j in 1..=l {
let xj = BigUint::from_u64(j as u64);
out.push(lagrange_eval(field, &pts, &xj)?);
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn small() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[0xCDu8; 32])
}
#[test]
fn round_trip_l_equals_k() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=4).map(|i| BigUint::from_u64(0x100 + i)).collect();
let shares = split(&f, &mut r, &secret, 4, 7);
assert_eq!(shares.len(), 7);
assert_eq!(reconstruct(&f, &shares[..4], 4, 4), Some(secret.clone()));
assert_eq!(reconstruct(&f, &shares[3..], 4, 4), Some(secret));
}
#[test]
fn round_trip_l_less_than_k() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=2).map(|i| BigUint::from_u64(0x300 + i)).collect();
let shares = split(&f, &mut r, &secret, 5, 8);
assert_eq!(shares.len(), 8);
assert_eq!(reconstruct(&f, &shares[..5], 5, 2), Some(secret.clone()));
assert_eq!(reconstruct(&f, &shares[2..7], 5, 2), Some(secret));
}
#[test]
fn round_trip_l_equals_one_matches_shamir() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(0xF00D);
let shares = split(&f, &mut r, std::slice::from_ref(&secret), 3, 5);
let got = reconstruct(&f, &shares[..3], 3, 1).unwrap();
assert_eq!(got, vec![secret]);
}
#[test]
fn extras_validated_and_tampering_rejected() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=3).map(|i| BigUint::from_u64(0x500 + i)).collect();
let mut shares = split(&f, &mut r, &secret, 4, 7);
assert_eq!(reconstruct(&f, &shares, 4, 3), Some(secret.clone()));
shares[5].y = f.add(&shares[5].y, &BigUint::from_u64(1));
assert!(reconstruct(&f, &shares, 4, 3).is_none());
}
#[test]
fn below_threshold_returns_none() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=2).map(|i| BigUint::from_u64(0x600 + i)).collect();
let shares = split(&f, &mut r, &secret, 4, 6);
assert!(reconstruct(&f, &shares[..3], 4, 2).is_none());
}
#[test]
#[should_panic(expected = "L = secret.len() must be ≤ k")]
fn split_rejects_l_greater_than_k() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=4).map(BigUint::from_u64).collect();
let _ = split(&f, &mut r, &secret, 3, 5);
}
#[test]
fn share_payload_is_one_field_element() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=6).map(BigUint::from_u64).collect();
let shares = split(&f, &mut r, &secret, 8, 10);
for s in &shares {
let _ = s.y.clone();
}
assert_eq!(shares.len(), 10);
assert_eq!(reconstruct(&f, &shares[..8], 8, 6), Some(secret));
}
#[test]
#[should_panic(expected = "n must be at least k")]
fn split_rejects_n_below_k() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=2).map(BigUint::from_u64).collect();
let _ = split(&f, &mut r, &secret, 8, 5);
}
#[test]
fn k_minus_l_shares_does_not_yield_secret() {
let f = small();
let mut r = rng();
let secret: Vec<BigUint> = (1..=2).map(|i| BigUint::from_u64(0x1000 + i)).collect();
let shares = split(&f, &mut r, &secret, 5, 8);
assert!(reconstruct(&f, &shares[..3], 5, 2).is_none());
assert!(reconstruct(&f, &shares[..4], 5, 2).is_none());
}
}