use crate::bigint::BigUint;
use crate::csprng::Csprng;
use crate::field::PrimeField;
use crate::poly::{horner, lagrange_eval};
use crate::secure::{ct_eq_biguint, Zeroizing};
use crate::shamir::Share;
#[must_use]
pub fn refresh<R: Csprng>(
field: &PrimeField,
rng: &mut R,
shares: &[Share],
k: usize,
) -> Vec<Share> {
assert!(k >= 2, "k must be at least 2");
assert!(
shares.len() >= k,
"input must have ≥ k shares to remain reconstructable"
);
for s in shares {
assert!(!s.x.is_zero(), "shares must have nonzero x");
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
assert_ne!(
shares[i].x, shares[j].x,
"shares must have distinct x-coordinates"
);
}
}
let n = shares.len();
let mut contributions = Zeroizing::new(Vec::<Vec<BigUint>>::with_capacity(n));
for _ in 0..n {
let mut coeffs = Vec::with_capacity(k);
coeffs.push(BigUint::zero()); for _ in 1..k {
coeffs.push(field.random(rng));
}
contributions.push(coeffs);
}
shares
.iter()
.map(|recipient| {
let mut new_y = recipient.y.clone();
for r_i in contributions.iter() {
let delta = horner(field, r_i, &recipient.x);
new_y = field.add(&new_y, &delta);
}
Share {
x: recipient.x.clone(),
y: new_y,
}
})
.collect()
}
#[must_use]
pub fn recover_share(
field: &PrimeField,
live: &[Share],
k: usize,
x_lost: &BigUint,
) -> Option<Share> {
if k < 2 || live.len() < k {
return None;
}
for s in live {
if s.x == *x_lost {
return None;
}
if s.x.is_zero() {
return None;
}
}
for i in 0..live.len() {
for j in (i + 1)..live.len() {
if live[i].x == live[j].x {
return None;
}
}
}
let pts: Vec<(BigUint, BigUint)> = live
.iter()
.take(k)
.map(|s| (s.x.clone(), s.y.clone()))
.collect();
for s in live.iter().skip(k) {
let pred = lagrange_eval(field, &pts, &s.x)?;
if !ct_eq_biguint(&pred, &s.y) {
return None;
}
}
let y_lost = lagrange_eval(field, &pts, x_lost)?;
Some(Share {
x: x_lost.clone(),
y: y_lost,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
use crate::shamir;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[0x70u8; 32])
}
fn small() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
#[test]
fn refresh_preserves_secret() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(0xC0FFEE);
let shares = shamir::split(&f, &mut r, &secret, 3, 5);
let fresh = refresh(&f, &mut r, &shares, 3);
assert_eq!(fresh.len(), 5);
for (a, b) in shares.iter().zip(fresh.iter()) {
assert_eq!(a.x, b.x);
}
let recovered = shamir::reconstruct(&f, &fresh[..3], 3).unwrap();
assert_eq!(recovered, secret);
}
#[test]
fn refresh_actually_changes_share_values() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(0xBAD);
let shares = shamir::split(&f, &mut r, &secret, 3, 5);
let fresh = refresh(&f, &mut r, &shares, 3);
let any_changed = shares.iter().zip(fresh.iter()).any(|(a, b)| a.y != b.y);
assert!(any_changed, "refresh must change at least one y value");
}
#[test]
fn many_refreshes_preserve_secret() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(42);
let mut shares = shamir::split(&f, &mut r, &secret, 4, 7);
for _ in 0..10 {
shares = refresh(&f, &mut r, &shares, 4);
}
let recovered = shamir::reconstruct(&f, &shares[..4], 4).unwrap();
assert_eq!(recovered, secret);
}
#[test]
fn old_shares_do_not_combine_with_new() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(99);
let shares = shamir::split(&f, &mut r, &secret, 4, 7);
let fresh = refresh(&f, &mut r, &shares, 4);
let mixed: Vec<Share> = vec![
shares[0].clone(),
shares[1].clone(),
fresh[2].clone(),
fresh[3].clone(),
];
let bad = shamir::reconstruct(&f, &mixed, 4);
assert_ne!(bad, Some(secret));
}
#[test]
fn recover_lost_share_round_trip() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(0xCAFE);
let shares = shamir::split(&f, &mut r, &secret, 3, 5);
let live: Vec<Share> = vec![shares[0].clone(), shares[1].clone(), shares[3].clone()];
let recovered = recover_share(&f, &live, 3, &shares[2].x).unwrap();
assert_eq!(recovered, shares[2]);
}
#[test]
fn recover_share_on_present_x_returns_none() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(7);
let shares = shamir::split(&f, &mut r, &secret, 3, 5);
let attempt = recover_share(&f, &shares[..3], 3, &shares[0].x);
assert!(attempt.is_none());
}
#[test]
fn recover_share_below_threshold_returns_none() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(11);
let shares = shamir::split(&f, &mut r, &secret, 4, 6);
let live: Vec<Share> = shares[..3].to_vec();
let attempt = recover_share(&f, &live, 4, &BigUint::from_u64(99));
assert!(attempt.is_none());
}
#[test]
#[should_panic(expected = "k must be at least 2")]
fn refresh_rejects_k_one() {
let f = small();
let mut r = rng();
let dummy = vec![
Share {
x: BigUint::one(),
y: BigUint::one(),
},
];
let _ = refresh(&f, &mut r, &dummy, 1);
}
#[test]
#[should_panic(expected = "input must have ≥ k shares")]
fn refresh_rejects_too_few_shares() {
let f = small();
let mut r = rng();
let too_few = vec![Share {
x: BigUint::one(),
y: BigUint::one(),
}];
let _ = refresh(&f, &mut r, &too_few, 3);
}
#[test]
fn recover_share_validates_extras() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(0x55);
let shares = shamir::split(&f, &mut r, &secret, 3, 6);
let mut live: Vec<Share> = shares[..5].to_vec();
live[3].y = f.add(&live[3].y, &BigUint::from_u64(1));
let attempt = recover_share(&f, &live, 3, &shares[5].x);
assert!(attempt.is_none(), "tampered extra must yield None");
}
#[test]
fn refresh_then_recover_lost_share_pipeline() {
let f = small();
let mut r = rng();
let secret = BigUint::from_u64(0xDECAF);
let mut shares = shamir::split(&f, &mut r, &secret, 3, 5);
shares = refresh(&f, &mut r, &shares, 3);
let lost_x = shares[2].x.clone();
let live: Vec<Share> = vec![shares[0].clone(), shares[1].clone(), shares[3].clone()];
let recovered = recover_share(&f, &live, 3, &lost_x).unwrap();
let mut full: Vec<Share> = vec![
shares[0].clone(),
shares[1].clone(),
recovered,
shares[3].clone(),
shares[4].clone(),
];
full.sort_by(|a, b| a.x.cmp(&b.x));
assert_eq!(shamir::reconstruct(&f, &full[..3], 3).unwrap(), secret);
}
}