use crate::bigint::BigUint;
use crate::csprng::Csprng;
use crate::field::PrimeField;
use crate::secure::ct_eq_biguint;
#[derive(Clone, Debug)]
pub struct LinearScheme {
field: PrimeField,
rows: Vec<Vec<BigUint>>,
k: usize,
n: usize,
}
impl LinearScheme {
#[must_use]
pub fn new(field: PrimeField, rows: Vec<Vec<BigUint>>, k: usize, n: usize) -> Self {
assert!(k >= 2, "k must be at least 2");
assert!(n >= k, "n must be at least k");
assert_eq!(rows.len(), k, "rows.len() must equal k");
for r in &rows {
assert_eq!(r.len(), n, "every row must have n entries");
}
Self { field, rows, k, n }
}
#[must_use]
pub fn new_checked(
field: PrimeField,
rows: Vec<Vec<BigUint>>,
k: usize,
n: usize,
) -> Option<Self> {
let scheme = Self::new(field, rows, k, n);
let mut indices = Vec::with_capacity(k);
if scheme.spreading_holds(&mut indices, 0) {
Some(scheme)
} else {
None
}
}
fn spreading_holds(&self, indices: &mut Vec<usize>, start: usize) -> bool {
if indices.len() == self.k {
return self.submatrix_invertible(indices);
}
for c in start..self.n {
indices.push(c);
if !self.spreading_holds(indices, c + 1) {
indices.pop();
return false;
}
indices.pop();
}
true
}
#[allow(clippy::needless_range_loop)]
fn submatrix_invertible(&self, columns: &[usize]) -> bool {
let k = self.k;
let mut mat: Vec<Vec<BigUint>> = (0..k)
.map(|r| columns.iter().map(|&c| self.rows[r][c].clone()).collect())
.collect();
for col in 0..k {
let mut pivot_row = None;
for r in col..k {
if !mat[r][col].is_zero() {
pivot_row = Some(r);
break;
}
}
let pr = match pivot_row {
Some(p) => p,
None => return false,
};
if pr != col {
mat.swap(pr, col);
}
let inv = match self.field.inv(&mat[col][col]) {
Some(v) => v,
None => return false,
};
for c in col..k {
mat[col][c] = self.field.mul(&mat[col][c], &inv);
}
for r in 0..k {
if r == col || mat[r][col].is_zero() {
continue;
}
let factor = mat[r][col].clone();
for c in col..k {
let term = self.field.mul(&factor, &mat[col][c]);
mat[r][c] = self.field.sub(&mat[r][c], &term);
}
}
}
true
}
#[must_use]
pub fn k(&self) -> usize {
self.k
}
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
pub fn field(&self) -> &PrimeField {
&self.field
}
}
#[must_use]
pub fn split<R: Csprng>(scheme: &LinearScheme, rng: &mut R, secret: &BigUint) -> Vec<BigUint> {
assert!(
secret < scheme.field.modulus(),
"secret must be < field modulus; reduce explicitly if wrap-around is intended"
);
let mut u: Vec<BigUint> = Vec::with_capacity(scheme.k);
u.push(secret.clone());
for _ in 1..scheme.k {
u.push(scheme.field.random(rng));
}
(0..scheme.n)
.map(|c| {
let mut acc = BigUint::zero();
#[allow(clippy::needless_range_loop)]
for r in 0..scheme.k {
let term = scheme.field.mul(&u[r], &scheme.rows[r][c]);
acc = scheme.field.add(&acc, &term);
}
acc
})
.collect()
}
#[must_use]
pub fn reconstruct(scheme: &LinearScheme, shares: &[(usize, BigUint)]) -> Option<BigUint> {
let k = scheme.k;
if shares.len() < k {
return None;
}
for &(c, _) in shares {
if c >= scheme.n {
return None;
}
}
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].0 == shares[j].0 {
return None;
}
}
}
let used = &shares[..k];
let mut mat: Vec<Vec<BigUint>> = Vec::with_capacity(k);
for (row_in_system, (col, val)) in used.iter().enumerate() {
let mut row = Vec::with_capacity(k + 1);
for r in 0..k {
row.push(scheme.rows[r][*col].clone());
}
row.push(val.clone());
let _ = row_in_system;
mat.push(row);
}
let solution = gaussian_eliminate(&scheme.field, &mut mat)?;
for (col, val) in &shares[k..] {
let mut acc = BigUint::zero();
#[allow(clippy::needless_range_loop)]
for r in 0..k {
let term = scheme.field.mul(&solution[r], &scheme.rows[r][*col]);
acc = scheme.field.add(&acc, &term);
}
if !ct_eq_biguint(&acc, val) {
return None;
}
}
Some(solution[0].clone())
}
#[allow(clippy::needless_range_loop)]
fn gaussian_eliminate(field: &PrimeField, mat: &mut [Vec<BigUint>]) -> Option<Vec<BigUint>> {
let k = mat.len();
if k == 0 {
return Some(vec![]);
}
for col in 0..k {
let mut pivot_row = None;
for r in col..k {
if !mat[r][col].is_zero() {
pivot_row = Some(r);
break;
}
}
let pr = pivot_row?;
if pr != col {
mat.swap(pr, col);
}
let inv = field.inv(&mat[col][col])?;
let row_len = mat[col].len();
for c in col..row_len {
mat[col][c] = field.mul(&mat[col][c], &inv);
}
for r in 0..k {
if r == col || mat[r][col].is_zero() {
continue;
}
let factor = mat[r][col].clone();
for c in col..row_len {
let term = field.mul(&factor, &mat[col][c]);
mat[r][c] = field.sub(&mat[r][c], &term);
}
}
}
Some((0..k).map(|i| mat[i][k].clone()).collect())
}
#[must_use]
pub fn vandermonde(field: PrimeField, k: usize, n: usize) -> LinearScheme {
assert!(k >= 2 && n >= k, "need 2 ≤ k ≤ n");
assert!(
BigUint::from_u64(n as u64) < *field.modulus(),
"prime modulus must exceed n"
);
let mut rows: Vec<Vec<BigUint>> = (0..k).map(|_| Vec::with_capacity(n)).collect();
for c in 0..n {
let i = BigUint::from_u64((c + 1) as u64);
let mut pow = BigUint::one();
#[allow(clippy::needless_range_loop)]
for r in 0..k {
rows[r].push(pow.clone());
pow = field.mul(&pow, &i);
}
}
LinearScheme::new(field, rows, k, n)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[0x4Bu8; 32])
}
fn small() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
fn tiny() -> PrimeField {
PrimeField::new(BigUint::from_u64(65_537))
}
#[test]
fn vandermonde_round_trip() {
let scheme = vandermonde(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(0xC0FFEE);
let shares = split(&scheme, &mut r, &secret);
assert_eq!(shares.len(), 5);
let pairs: Vec<(usize, BigUint)> = (0..3).map(|c| (c, shares[c].clone())).collect();
assert_eq!(reconstruct(&scheme, &pairs), Some(secret.clone()));
let pairs: Vec<(usize, BigUint)> = vec![
(0, shares[0].clone()),
(2, shares[2].clone()),
(4, shares[4].clone()),
];
assert_eq!(reconstruct(&scheme, &pairs), Some(secret));
}
#[test]
fn extras_validated_and_tampering_rejected() {
let scheme = vandermonde(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(42);
let shares = split(&scheme, &mut r, &secret);
let all: Vec<(usize, BigUint)> =
(0..5).map(|c| (c, shares[c].clone())).collect();
assert_eq!(reconstruct(&scheme, &all), Some(secret));
let mut bad = all;
bad[4].1 = scheme.field.add(&bad[4].1, &BigUint::from_u64(1));
assert!(reconstruct(&scheme, &bad).is_none());
}
#[test]
fn below_threshold_returns_none() {
let scheme = vandermonde(small(), 4, 6);
let mut r = rng();
let secret = BigUint::from_u64(7);
let shares = split(&scheme, &mut r, &secret);
let pairs: Vec<(usize, BigUint)> = (0..3).map(|c| (c, shares[c].clone())).collect();
assert!(reconstruct(&scheme, &pairs).is_none());
}
#[test]
fn duplicate_column_rejected() {
let scheme = vandermonde(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(11);
let shares = split(&scheme, &mut r, &secret);
let pairs = vec![
(0, shares[0].clone()),
(0, shares[0].clone()),
(1, shares[1].clone()),
];
assert!(reconstruct(&scheme, &pairs).is_none());
}
#[test]
fn out_of_range_column_rejected() {
let scheme = vandermonde(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(11);
let shares = split(&scheme, &mut r, &secret);
let pairs = vec![
(0, shares[0].clone()),
(1, shares[1].clone()),
(5, shares[0].clone()),
];
assert!(reconstruct(&scheme, &pairs).is_none());
}
#[test]
fn checked_constructor_accepts_vandermonde() {
let f = tiny();
let k = 3;
let n = 5;
let mut rows: Vec<Vec<BigUint>> = (0..k).map(|_| Vec::with_capacity(n)).collect();
for c in 0..n {
let i = BigUint::from_u64((c + 1) as u64);
let mut pow = BigUint::one();
#[allow(clippy::needless_range_loop)]
for r in 0..k {
rows[r].push(pow.clone());
pow = f.mul(&pow, &i);
}
}
assert!(LinearScheme::new_checked(f, rows, k, n).is_some());
}
#[test]
fn exactly_k_shares_with_one_tamper_silently_wrong() {
let scheme = vandermonde(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(0xC0FFEE);
let shares = split(&scheme, &mut r, &secret);
let mut bad = vec![
(0, shares[0].clone()),
(1, shares[1].clone()),
(2, shares[2].clone()),
];
bad[0].1 = scheme.field.add(&bad[0].1, &BigUint::from_u64(1));
let got = reconstruct(&scheme, &bad).expect("k shares always solve");
assert_ne!(got, secret);
}
#[test]
#[should_panic(expected = "secret must be < field modulus")]
fn split_rejects_oversize_secret() {
let scheme = vandermonde(tiny(), 3, 5); let mut r = rng();
let _ = split(&scheme, &mut r, &BigUint::from_u64(70_000));
}
#[test]
fn checked_constructor_rejects_repeated_column() {
let f = tiny();
let k = 2;
let n = 3;
let rows = vec![
vec![
BigUint::from_u64(1),
BigUint::from_u64(1),
BigUint::from_u64(2),
],
vec![
BigUint::from_u64(3),
BigUint::from_u64(3),
BigUint::from_u64(7),
],
];
assert!(LinearScheme::new_checked(f, rows, k, n).is_none());
}
}