use crate::bigint::BigUint;
use crate::csprng::Csprng;
use crate::field::PrimeField;
use crate::secure::ct_eq_biguint;
#[derive(Clone, Debug)]
pub struct SpanProgram {
field: PrimeField,
rows: Vec<Vec<BigUint>>,
labels: Vec<usize>,
n: usize,
m: usize,
}
impl SpanProgram {
#[must_use]
pub fn new(field: PrimeField, rows: Vec<Vec<BigUint>>, labels: Vec<usize>) -> Self {
assert!(!rows.is_empty(), "MSP must have at least one row");
assert_eq!(rows.len(), labels.len(), "labels.len() must match rows.len()");
let m = rows[0].len();
assert!(
m > 0,
"MSP rows must have width m ≥ 1 — m = 0 is a degenerate \
access structure with no target vector to span",
);
for r in &rows {
assert_eq!(r.len(), m, "all rows must have the same width");
}
for &lbl in &labels {
assert!(lbl != 0, "labels are 1-based; 0 is not a valid player");
}
let n = *labels.iter().max().unwrap_or(&0);
Self {
field,
rows,
labels,
n,
m,
}
}
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[must_use]
pub fn d(&self) -> usize {
self.rows.len()
}
#[must_use]
pub fn m(&self) -> usize {
self.m
}
#[must_use]
pub fn field(&self) -> &PrimeField {
&self.field
}
#[must_use]
pub fn qualifies(&self, coalition: &[usize]) -> bool {
self.recovery_coefficients(coalition).is_some()
}
#[must_use]
fn recovery_coefficients(&self, coalition: &[usize]) -> Option<Vec<(usize, BigUint)>> {
let row_indices: Vec<usize> = (0..self.d())
.filter(|&j| coalition.contains(&self.labels[j]))
.collect();
if row_indices.is_empty() {
return None;
}
let cols = row_indices.len();
let mut mat: Vec<Vec<BigUint>> = (0..self.m)
.map(|i| {
let mut row = Vec::with_capacity(cols + 1);
for &j in &row_indices {
row.push(self.rows[j][i].clone());
}
row.push(if i == 0 {
BigUint::one()
} else {
BigUint::zero()
});
row
})
.collect();
let coeffs = solve_least_constraint(&self.field, &mut mat, cols)?;
for i in 0..self.m {
let mut acc = BigUint::zero();
for (k, &j) in row_indices.iter().enumerate() {
let term = self.field.mul(&coeffs[k], &self.rows[j][i]);
acc = self.field.add(&acc, &term);
}
let want = if i == 0 {
BigUint::one()
} else {
BigUint::zero()
};
if !ct_eq_biguint(&acc, &want) {
return None;
}
}
Some(row_indices.into_iter().zip(coeffs).collect())
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct PlayerShare {
pub player: usize,
pub fragments: Vec<(usize, BigUint)>,
}
impl core::fmt::Debug for PlayerShare {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("PlayerShare(<elided>)")
}
}
#[must_use]
pub fn split<R: Csprng>(
program: &SpanProgram,
rng: &mut R,
secret: &BigUint,
) -> Vec<PlayerShare> {
assert!(
secret < program.field.modulus(),
"secret must be < field modulus"
);
let mut rho_vec: Vec<BigUint> = Vec::with_capacity(program.m);
rho_vec.push(secret.clone());
for _ in 1..program.m {
rho_vec.push(program.field.random(rng));
}
let mut by_player: std::collections::BTreeMap<usize, Vec<(usize, BigUint)>> =
std::collections::BTreeMap::new();
for (j, row) in program.rows.iter().enumerate() {
let mut acc = BigUint::zero();
for (col, val) in row.iter().enumerate() {
let term = program.field.mul(val, &rho_vec[col]);
acc = program.field.add(&acc, &term);
}
by_player
.entry(program.labels[j])
.or_default()
.push((j, acc));
}
by_player
.into_iter()
.map(|(player, fragments)| PlayerShare { player, fragments })
.collect()
}
#[must_use]
pub fn reconstruct(program: &SpanProgram, shares: &[PlayerShare]) -> Option<BigUint> {
for i in 0..shares.len() {
for j in (i + 1)..shares.len() {
if shares[i].player == shares[j].player {
return None;
}
}
}
for s in shares {
for (j, _) in &s.fragments {
if *j >= program.d() {
return None;
}
if program.labels[*j] != s.player {
return None;
}
}
}
let coalition: Vec<usize> = shares.iter().map(|s| s.player).collect();
let coeffs = program.recovery_coefficients(&coalition)?;
let mut value_by_row: std::collections::HashMap<usize, &BigUint> =
std::collections::HashMap::new();
for s in shares {
for (j, v) in &s.fragments {
if let Some(prev) = value_by_row.get(j) {
if !crate::secure::ct_eq_biguint(prev, v) {
return None;
}
continue;
}
value_by_row.insert(*j, v);
}
}
let mut secret = BigUint::zero();
for (j, c) in &coeffs {
let val = value_by_row.get(j)?;
let term = program.field.mul(c, val);
secret = program.field.add(&secret, &term);
}
Some(secret)
}
#[allow(clippy::needless_range_loop)]
fn solve_least_constraint(
field: &PrimeField,
mat: &mut [Vec<BigUint>],
cols: usize,
) -> Option<Vec<BigUint>> {
let m = mat.len();
let aug = cols; let mut pivot_col = vec![usize::MAX; m];
let mut row = 0usize;
for col in 0..cols {
if row >= m {
break;
}
let mut pivot_row = None;
for r in row..m {
if !mat[r][col].is_zero() {
pivot_row = Some(r);
break;
}
}
let Some(pr) = pivot_row else {
continue;
};
if pr != row {
mat.swap(pr, row);
}
let inv = field.inv(&mat[row][col])?;
for c in col..=aug {
mat[row][c] = field.mul(&mat[row][c], &inv);
}
for r in 0..m {
if r == row || mat[r][col].is_zero() {
continue;
}
let factor = mat[r][col].clone();
for c in col..=aug {
let term = field.mul(&factor, &mat[row][c]);
mat[r][c] = field.sub(&mat[r][c], &term);
}
}
pivot_col[row] = col;
row += 1;
}
for r in 0..m {
if (0..cols).all(|c| mat[r][c].is_zero()) && !mat[r][aug].is_zero() {
return None;
}
}
let mut solution = vec![BigUint::zero(); cols];
for (r, pcol) in pivot_col.iter().enumerate() {
if *pcol == usize::MAX {
continue;
}
if r >= m {
break;
}
solution[*pcol] = mat[r][aug].clone();
}
Some(solution)
}
#[must_use]
pub fn threshold_msp(field: PrimeField, k: usize, n: usize) -> SpanProgram {
assert!(k >= 2 && n >= k, "need 2 ≤ k ≤ n");
assert!(
BigUint::from_u64(n as u64) < *field.modulus(),
"modulus must exceed n"
);
let mut rows: Vec<Vec<BigUint>> = Vec::with_capacity(n);
let mut labels: Vec<usize> = Vec::with_capacity(n);
for i in 1..=n {
let mut row = Vec::with_capacity(k);
let mut pow = BigUint::one();
let i_val = BigUint::from_u64(i as u64);
for _ in 0..k {
row.push(pow.clone());
pow = field.mul(&pow, &i_val);
}
rows.push(row);
labels.push(i);
}
SpanProgram::new(field, rows, labels)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[0x9Cu8; 32])
}
fn small() -> PrimeField {
PrimeField::new(BigUint::from_u64((1u64 << 61) - 1))
}
fn pick(shares: &[PlayerShare], wanted: &[usize]) -> Vec<PlayerShare> {
shares
.iter()
.filter(|s| wanted.contains(&s.player))
.cloned()
.collect()
}
#[test]
fn threshold_msp_round_trip() {
let prog = threshold_msp(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(0xC0FFEE);
let shares = split(&prog, &mut r, &secret);
assert_eq!(shares.len(), 5);
for &(a, b, c) in &[(1usize, 2, 3), (1, 3, 5), (2, 4, 5), (3, 4, 5)] {
let coalition = pick(&shares, &[a, b, c]);
assert_eq!(reconstruct(&prog, &coalition), Some(secret.clone()), "subset ({a},{b},{c})");
}
for &(a, b) in &[(1usize, 2), (3, 5), (4, 5)] {
let coalition = pick(&shares, &[a, b]);
assert!(reconstruct(&prog, &coalition).is_none(), "subset ({a},{b}) must fail");
}
}
#[test]
fn qualifies_matches_reconstruct() {
let prog = threshold_msp(small(), 3, 5);
assert!(prog.qualifies(&[1, 2, 3]));
assert!(prog.qualifies(&[1, 4, 5]));
assert!(!prog.qualifies(&[1, 2]));
assert!(!prog.qualifies(&[]));
}
#[test]
fn explicit_or_msp() {
let f = small();
let rows = vec![vec![BigUint::one()], vec![BigUint::one()]];
let labels = vec![1usize, 2];
let prog = SpanProgram::new(f, rows, labels);
assert_eq!(prog.m(), 1);
assert!(prog.qualifies(&[1]));
assert!(prog.qualifies(&[2]));
assert!(prog.qualifies(&[1, 2]));
let mut r = rng();
let secret = BigUint::from_u64(7);
let shares = split(&prog, &mut r, &secret);
assert_eq!(reconstruct(&prog, &pick(&shares, &[1])), Some(secret.clone()));
assert_eq!(reconstruct(&prog, &pick(&shares, &[2])), Some(secret));
}
#[test]
fn explicit_and_msp() {
let f = small();
let neg_one = f.sub(&BigUint::zero(), &BigUint::one());
let rows = vec![
vec![BigUint::one(), BigUint::one()],
vec![BigUint::zero(), neg_one],
];
let labels = vec![1usize, 2];
let prog = SpanProgram::new(f.clone(), rows, labels);
assert!(!prog.qualifies(&[1]));
assert!(!prog.qualifies(&[2]));
assert!(prog.qualifies(&[1, 2]));
let mut r = rng();
let secret = BigUint::from_u64(100);
let shares = split(&prog, &mut r, &secret);
let both = pick(&shares, &[1, 2]);
assert_eq!(reconstruct(&prog, &both), Some(secret));
assert!(reconstruct(&prog, &pick(&shares, &[1])).is_none());
assert!(reconstruct(&prog, &pick(&shares, &[2])).is_none());
}
#[test]
fn duplicate_player_rejected() {
let prog = threshold_msp(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(11);
let shares = split(&prog, &mut r, &secret);
let dup = vec![shares[0].clone(), shares[0].clone(), shares[1].clone()];
assert!(reconstruct(&prog, &dup).is_none());
}
#[test]
fn malformed_fragment_rejected() {
let prog = threshold_msp(small(), 3, 5);
let mut r = rng();
let secret = BigUint::from_u64(22);
let mut shares = split(&prog, &mut r, &secret);
let row_idx_for_player_2 = (0..prog.d()).find(|&j| prog.labels[j] == 2).unwrap();
shares[0]
.fragments
.push((row_idx_for_player_2, BigUint::zero()));
let coalition = vec![shares[0].clone(), shares[1].clone(), shares[2].clone()];
assert!(reconstruct(&prog, &coalition).is_none());
}
#[test]
#[should_panic(expected = "secret must be < field modulus")]
fn split_rejects_oversize_secret() {
let prog = threshold_msp(PrimeField::new(BigUint::from_u64(257)), 2, 3);
let mut r = rng();
let _ = split(&prog, &mut r, &BigUint::from_u64(300));
}
#[test]
fn qualifies_consistent_with_reconstruct() {
let prog = threshold_msp(small(), 2, 4);
let mut r = rng();
let secret = BigUint::from_u64(0xDEAD);
let shares = split(&prog, &mut r, &secret);
for mask in 1u8..(1 << 4) {
let coalition: Vec<usize> = (0..4)
.filter(|i| mask & (1 << i) != 0)
.map(|i| i + 1)
.collect();
let shares_for: Vec<PlayerShare> = pick(&shares, &coalition);
let q = prog.qualifies(&coalition);
let r = reconstruct(&prog, &shares_for);
assert_eq!(
q,
r.is_some(),
"qualifies and reconstruct disagree for {coalition:?}",
);
if q {
assert_eq!(r, Some(secret.clone()), "wrong secret for {coalition:?}");
}
}
}
#[test]
fn unqualified_coalition_returns_none_in_oversized_program() {
let f = small();
let rows = vec![
vec![BigUint::one(), BigUint::zero(), BigUint::zero()],
vec![BigUint::zero(), BigUint::one(), BigUint::zero()],
vec![BigUint::zero(), BigUint::zero(), BigUint::one()],
vec![BigUint::one(), BigUint::one(), BigUint::one()],
];
let labels = vec![1, 2, 3, 4];
let prog = SpanProgram::new(f, rows, labels);
let mut r = rng();
let secret = BigUint::from_u64(99);
let shares = split(&prog, &mut r, &secret);
for unq in &[vec![2usize], vec![3], vec![4], vec![2, 3], vec![2, 4], vec![3, 4]] {
assert!(reconstruct(&prog, &pick(&shares, unq)).is_none(), "{unq:?} must fail");
}
for q in &[vec![1usize], vec![2, 3, 4], vec![1, 2, 3, 4]] {
assert_eq!(
reconstruct(&prog, &pick(&shares, q)),
Some(secret.clone()),
"{q:?} must succeed",
);
}
}
#[test]
#[should_panic(expected = "MSP rows must have width m ≥ 1")]
fn rejects_zero_width_msp() {
let f = small();
let rows: Vec<Vec<BigUint>> = vec![vec![]];
let labels = vec![1usize];
let _ = SpanProgram::new(f, rows, labels);
}
#[test]
fn fuzz_threshold_round_trip() {
for &(k, n) in &[(2usize, 3usize), (3, 5), (4, 7), (5, 9)] {
for seed in 0u8..6 {
let prog = threshold_msp(small(), k, n);
let mut r = ChaCha20Rng::from_seed(&[seed; 32]);
let secret = BigUint::from_u64(seed as u64 * 12345);
let shares = split(&prog, &mut r, &secret);
let chosen: Vec<usize> = (1..=k).collect();
let pick_first_k: Vec<PlayerShare> = pick(&shares, &chosen);
assert_eq!(reconstruct(&prog, &pick_first_k), Some(secret));
}
}
}
}