use crate::field::PrimeField;
use crate::poly::lagrange_eval_unchecked;
use crate::secure::ct_eq_biguint;
use crate::bigint::BigUint;
use crate::csprng::Csprng;
#[derive(Clone, Eq, PartialEq)]
pub struct VectorShare {
pub x: BigUint,
pub y: Vec<BigUint>,
}
impl core::fmt::Debug for VectorShare {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("VectorShare(<elided>)")
}
}
#[must_use]
pub fn split<R: Csprng>(
field: &PrimeField,
rng: &mut R,
secret: &[BigUint],
k: usize,
n: usize,
) -> Vec<VectorShare> {
let m = secret.len();
assert!(m >= 1, "secret must have at least one component");
assert!(k >= 2, "k must be at least 2 (k = 1 would leak the secret)");
assert!(n >= k, "n must be at least k");
assert!(
BigUint::from_u64(n as u64) < *field.modulus(),
"prime modulus must exceed n",
);
let mut u: Vec<Vec<BigUint>> = Vec::with_capacity(k);
u.push(secret.iter().map(|c| field.reduce(c)).collect());
for _ in 1..k {
let block: Vec<BigUint> = (0..m).map(|_| field.random(rng)).collect();
u.push(block);
}
(1..=n)
.map(|i| {
let alpha = BigUint::from_u64(i as u64);
let mut y = vec![BigUint::zero(); m];
for (c, y_c) in y.iter_mut().enumerate() {
let mut acc = BigUint::zero();
let mut pow = BigUint::one();
for u_block in u.iter().take(k) {
let term = field.mul(&u_block[c], &pow);
acc = field.add(&acc, &term);
pow = field.mul(&pow, &alpha);
}
*y_c = acc;
}
VectorShare { x: alpha, y }
})
.collect()
}
#[must_use]
pub fn reconstruct(
field: &PrimeField,
shares: &[VectorShare],
k: usize,
) -> Option<Vec<BigUint>> {
if k == 0 || shares.len() < k {
return None;
}
let m = shares[0].y.len();
if m == 0 {
return None;
}
for s in shares {
if s.x.is_zero() || s.y.len() != m {
return None;
}
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].x == shares[j].x {
return None;
}
}
}
let mut secret: Vec<BigUint> = Vec::with_capacity(m);
for c in 0..m {
let pts: Vec<(BigUint, BigUint)> = shares
.iter()
.take(k)
.map(|s| (s.x.clone(), s.y[c].clone()))
.collect();
secret.push(lagrange_eval_unchecked(field, &pts, &BigUint::zero()));
for s in &shares[k..] {
let pred = lagrange_eval_unchecked(field, &pts, &s.x);
if !ct_eq_biguint(&pred, &s.y[c]) {
return None;
}
}
}
Some(secret)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[33u8; 32])
}
fn small_field() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
#[test]
fn vector_round_trip() {
let f = small_field();
let mut r = rng();
let secret: Vec<BigUint> = (1..=4).map(|i| BigUint::from_u64(0x100 + i)).collect();
let shares = split(&f, &mut r, &secret, 3, 6);
assert_eq!(reconstruct(&f, &shares[..3], 3), Some(secret.clone()));
assert_eq!(reconstruct(&f, &shares[2..5], 3), Some(secret));
}
#[test]
fn scalar_kgh_matches_shamir() {
let f = small_field();
let mut r = rng();
let secret = vec![BigUint::from_u64(0xDECAF)];
let shares = split(&f, &mut r, &secret, 4, 7);
assert_eq!(reconstruct(&f, &shares[..4], 4), Some(secret));
}
#[test]
fn below_threshold_is_refused() {
let f = small_field();
let mut r = rng();
let secret = vec![BigUint::from_u64(1), BigUint::from_u64(2)];
let shares = split(&f, &mut r, &secret, 3, 5);
assert!(reconstruct(&f, &shares[..2], 3).is_none());
}
#[test]
fn extras_must_be_consistent() {
let f = small_field();
let mut r = rng();
let secret = vec![BigUint::from_u64(7), BigUint::from_u64(11)];
let mut shares = split(&f, &mut r, &secret, 3, 5);
shares[4].y[1] = f.add(&shares[4].y[1], &BigUint::from_u64(1));
assert!(reconstruct(&f, &shares, 3).is_none());
}
}