use crate::field::PrimeField;
use crate::poly::horner;
use crate::secure::ct_eq_biguint;
use crate::shamir::Share;
use crate::bigint::BigUint;
#[must_use]
pub fn reconstruct_with_errors(
field: &PrimeField,
shares: &[Share],
k: usize,
max_errors: usize,
) -> Option<BigUint> {
let m = shares.len();
let needed = k.checked_add(2usize.checked_mul(max_errors)?)?;
if k == 0 || m < needed {
return None;
}
for s in shares {
if s.x.is_zero() {
return None;
}
}
for i in 0..m {
for j in (i + 1)..m {
if shares[i].x == shares[j].x {
return None;
}
}
}
if max_errors == 0 {
let pts: Vec<(BigUint, BigUint)> = shares
.iter()
.take(k)
.map(|s| (s.x.clone(), s.y.clone()))
.collect();
let secret = crate::poly::lagrange_eval_unchecked(field, &pts, &BigUint::zero());
for s in &shares[k..] {
let pred = crate::poly::lagrange_eval_unchecked(field, &pts, &s.x);
if !ct_eq_biguint(&pred, &s.y) {
return None;
}
}
return Some(secret);
}
let t = max_errors;
let q_len = k + t; let e_len = t + 1; let cols = q_len + e_len;
let mut mat: Vec<Vec<BigUint>> = Vec::with_capacity(m);
for s in shares {
let mut row = Vec::with_capacity(cols);
let mut pow = BigUint::one();
let mut row_pows: Vec<BigUint> = Vec::with_capacity(q_len.max(e_len));
for _ in 0..q_len.max(e_len) {
row_pows.push(pow.clone());
pow = field.mul(&pow, &s.x);
}
for p in row_pows.iter().take(q_len) {
row.push(p.clone());
}
for p in row_pows.iter().take(e_len) {
row.push(field.neg(&field.mul(&s.y, p)));
}
mat.push(row);
}
let kernel = nullspace_basis(field, &mut mat, cols)?;
for vec in kernel {
let q_coeffs: &[BigUint] = &vec[..q_len];
let e_coeffs: &[BigUint] = &vec[q_len..];
if e_coeffs.iter().all(|c| c.is_zero()) {
continue;
}
let Some(m_coeffs) = poly_div_exact(field, q_coeffs, e_coeffs, k) else {
continue;
};
let mut agree = 0usize;
for s in shares {
let pred = horner(field, &m_coeffs, &s.x);
if ct_eq_biguint(&pred, &s.y) {
agree += 1;
}
}
if agree >= k + t {
return Some(horner(field, &m_coeffs, &BigUint::zero()));
}
}
None
}
#[allow(clippy::needless_range_loop)] fn nullspace_basis(
field: &PrimeField,
mat: &mut [Vec<BigUint>],
cols: usize,
) -> Option<Vec<Vec<BigUint>>> {
let m = mat.len();
let mut pivot_cols: Vec<usize> = Vec::new();
let mut row = 0;
for col in 0..cols {
if row >= m {
break;
}
let mut pr = None;
for r in row..m {
if !mat[r][col].is_zero() {
pr = Some(r);
break;
}
}
let Some(pr) = pr else {
continue;
};
if pr != row {
mat.swap(pr, row);
}
let inv = field.inv(&mat[row][col])?;
for c in col..cols {
mat[row][c] = field.mul(&mat[row][c], &inv);
}
for r in 0..m {
if r == row {
continue;
}
if mat[r][col].is_zero() {
continue;
}
let factor = mat[r][col].clone();
for c in col..cols {
let term = field.mul(&factor, &mat[row][c]);
mat[r][c] = field.sub(&mat[r][c], &term);
}
}
pivot_cols.push(col);
row += 1;
}
let free_cols: Vec<usize> = (0..cols).filter(|c| !pivot_cols.contains(c)).collect();
if free_cols.is_empty() {
return None;
}
let mut basis: Vec<Vec<BigUint>> = Vec::with_capacity(free_cols.len());
for &f in &free_cols {
let mut v = vec![BigUint::zero(); cols];
v[f] = BigUint::one();
for (i, &pc) in pivot_cols.iter().enumerate() {
v[pc] = field.neg(&mat[i][f]);
}
basis.push(v);
}
Some(basis)
}
fn poly_div_exact(
field: &PrimeField,
q: &[BigUint],
e: &[BigUint],
expected_quot_len: usize,
) -> Option<Vec<BigUint>> {
let strip = |v: &[BigUint]| -> Vec<BigUint> {
let mut end = v.len();
while end > 0 && v[end - 1].is_zero() {
end -= 1;
}
v[..end].to_vec()
};
let mut rem = strip(q);
let div = strip(e);
if div.is_empty() {
return None;
}
let deg_div = div.len() - 1;
let lead_inv = field.inv(&div[deg_div])?;
if rem.len() < div.len() {
if rem.is_empty() {
return Some(vec![BigUint::zero(); expected_quot_len]);
}
return None;
}
let mut quot = vec![BigUint::zero(); rem.len() - deg_div];
while rem.len() > deg_div {
let deg_rem = rem.len() - 1;
let coef = field.mul(&rem[deg_rem], &lead_inv);
let shift = deg_rem - deg_div;
quot[shift] = coef.clone();
for j in 0..=deg_div {
let term = field.mul(&coef, &div[j]);
rem[shift + j] = field.sub(&rem[shift + j], &term);
}
while !rem.is_empty() && rem.last().unwrap().is_zero() {
rem.pop();
}
}
if !rem.is_empty() {
return None;
}
if quot.len() > expected_quot_len {
for c in "[expected_quot_len..] {
if !c.is_zero() {
return None;
}
}
quot.truncate(expected_quot_len);
} else {
quot.resize(expected_quot_len, BigUint::zero());
}
Some(quot)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::shamir::split;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[1u8; 32])
}
fn small_field() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
#[test]
fn no_errors_matches_plain_reconstruct() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0x0C0D_EBAD_F00D);
let shares = split(&f, &mut r, &secret, 3, 7);
let got = reconstruct_with_errors(&f, &shares, 3, 0).unwrap();
assert_eq!(got, secret);
}
#[test]
fn corrects_one_tampered_share() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0xABCD_1234);
let mut shares = split(&f, &mut r, &secret, 3, 7);
shares[4].y = f.add(&shares[4].y, &BigUint::from_u64(1));
let got = reconstruct_with_errors(&f, &shares, 3, 1).unwrap();
assert_eq!(got, secret);
}
#[test]
fn corrects_two_tampered_shares() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0xFEED_F00D);
let mut shares = split(&f, &mut r, &secret, 3, 9);
shares[1].y = f.add(&shares[1].y, &BigUint::from_u64(7));
shares[6].y = f.add(&shares[6].y, &BigUint::from_u64(13));
let got = reconstruct_with_errors(&f, &shares, 3, 2).unwrap();
assert_eq!(got, secret);
}
#[test]
fn corrects_three_tampered_shares() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0x1234_5678_9ABC);
let mut shares = split(&f, &mut r, &secret, 4, 11);
shares[0].y = f.add(&shares[0].y, &BigUint::from_u64(1));
shares[5].y = BigUint::zero();
shares[10].y = f.add(&shares[10].y, &BigUint::from_u64(99));
let got = reconstruct_with_errors(&f, &shares, 4, 3).unwrap();
assert_eq!(got, secret);
}
#[test]
fn fails_above_decoding_radius() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0xBAAA);
let mut shares = split(&f, &mut r, &secret, 3, 7);
for s in shares.iter_mut().take(3) {
s.y = f.add(&s.y, &BigUint::from_u64(1));
}
assert!(reconstruct_with_errors(&f, &shares, 3, 2).is_none());
}
#[test]
fn handles_erasures_via_omission() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0x1357_9BDF);
let shares = split(&f, &mut r, &secret, 3, 7);
let got = reconstruct_with_errors(&f, &shares[..4], 3, 0).unwrap();
assert_eq!(got, secret);
}
#[test]
fn rejects_below_radius() {
let f = small_field();
let mut r = rng();
let secret = BigUint::from_u64(0xAA);
let shares = split(&f, &mut r, &secret, 3, 6);
assert!(reconstruct_with_errors(&f, &shares, 3, 2).is_none());
}
#[test]
fn poly_div_exact_basic() {
let f = PrimeField::new(BigUint::from_u64(257));
let q = vec![
BigUint::from_u64(2),
BigUint::from_u64(3),
BigUint::from_u64(1),
];
let e = vec![BigUint::from_u64(1), BigUint::from_u64(1)];
let m = poly_div_exact(&f, &q, &e, 2).unwrap();
assert_eq!(m, vec![BigUint::from_u64(2), BigUint::from_u64(1)]);
}
#[test]
fn poly_div_inexact_returns_none() {
let f = PrimeField::new(BigUint::from_u64(257));
let q = vec![
BigUint::from_u64(1),
BigUint::from_u64(0),
BigUint::from_u64(1),
];
let e = vec![BigUint::from_u64(1), BigUint::from_u64(1)];
assert!(poly_div_exact(&f, &q, &e, 2).is_none());
}
}