use crate::field::PrimeField;
use crate::poly::{horner, lagrange_eval};
use crate::bigint::BigUint;
use crate::secure::ct_eq_biguint;
const SHARE_VERSION: u8 = 0x02;
const HEADER_LEN: usize = 1 + 1 + 4;
#[must_use]
fn block_len(field: &PrimeField) -> usize {
let bits = field.modulus().bits();
assert!(bits >= 9, "field too small for byte-block IDA");
(bits - 1) / 8
}
#[must_use]
fn share_elem_len(field: &PrimeField) -> usize {
field.modulus().bits().div_ceil(8)
}
#[must_use]
pub fn split(field: &PrimeField, data: &[u8], k: usize, n: usize) -> Vec<Vec<u8>> {
assert!(k >= 2, "k must be at least 2");
assert!(n >= k, "n must be at least k");
assert!(n <= 255, "byte-encoded shares support up to 255 trustees");
assert!(
BigUint::from_u64(n as u64) < *field.modulus(),
"prime modulus must exceed n",
);
assert!(
data.len() <= u32::MAX as usize,
"data length must fit in u32 (wire-format length header is 4 bytes)",
);
let bl = block_len(field);
let sl = share_elem_len(field);
let group = k * bl;
let pad = (group - (data.len() % group)) % group;
let mut padded = Vec::with_capacity(data.len() + pad);
padded.extend_from_slice(data);
padded.resize(data.len() + pad, 0);
let num_groups = padded.len() / group;
let mut shares: Vec<Vec<u8>> = (1..=n)
.map(|i| {
let mut v = Vec::with_capacity(HEADER_LEN + num_groups * sl);
v.push(SHARE_VERSION);
v.push(i as u8);
v.extend_from_slice(&(data.len() as u32).to_be_bytes());
v
})
.collect();
for g in 0..num_groups {
let mut coeffs: Vec<BigUint> = Vec::with_capacity(k);
for j in 0..k {
let start = g * group + j * bl;
let block = &padded[start..start + bl];
coeffs.push(BigUint::from_be_bytes(block));
}
for (i, share) in shares.iter_mut().enumerate() {
let x = BigUint::from_u64((i + 1) as u64);
let y = horner(field, &coeffs, &x);
share.extend_from_slice(&field_element_to_bytes(&y, sl));
}
}
shares
}
#[must_use]
pub fn reconstruct(field: &PrimeField, shares: &[&[u8]], k: usize) -> Option<Vec<u8>> {
if k < 2 || shares.len() < k {
return None;
}
let bl = block_len(field);
let sl = share_elem_len(field);
let mut parsed: Vec<(u8, &[u8])> = Vec::with_capacity(shares.len());
let mut data_len: Option<usize> = None;
for s in shares {
if s.len() < HEADER_LEN || s[0] != SHARE_VERSION {
return None;
}
let label = s[1];
if label == 0 {
return None;
}
let len = u32::from_be_bytes([s[2], s[3], s[4], s[5]]) as usize;
if let Some(prev) = data_len {
if prev != len {
return None;
}
} else {
data_len = Some(len);
}
let payload = &s[HEADER_LEN..];
if payload.len() % sl != 0 {
return None;
}
parsed.push((label, payload));
}
let data_len = data_len?;
let group = k * bl;
let pad = (group - (data_len % group)) % group;
let padded_len = data_len + pad;
let num_groups = padded_len / group;
for (_, payload) in &parsed {
if payload.len() != num_groups * sl {
return None;
}
}
for i in 0..parsed.len() {
for j in (i + 1)..parsed.len() {
if parsed[i].0 == parsed[j].0 {
return None;
}
}
}
let mut out = Vec::with_capacity(padded_len);
for g in 0..num_groups {
let mut pts: Vec<(BigUint, BigUint)> = Vec::with_capacity(k);
for (label, payload) in parsed.iter().take(k) {
let x = BigUint::from_u64(*label as u64);
let y = BigUint::from_be_bytes(&payload[g * sl..(g + 1) * sl]);
if y >= *field.modulus() {
return None;
}
pts.push((x, y));
}
let coeffs = vandermonde_solve(field, &pts)?;
for (label, payload) in parsed.iter().skip(k) {
let x = BigUint::from_u64(*label as u64);
let y = BigUint::from_be_bytes(&payload[g * sl..(g + 1) * sl]);
if y >= *field.modulus() {
return None;
}
let pred = horner(field, &coeffs, &x);
if !ct_eq_biguint(&pred, &y) {
return None;
}
}
for c in &coeffs {
let bytes = field_element_to_bytes_checked(c, bl)?;
out.extend_from_slice(&bytes);
}
}
out.truncate(data_len);
Some(out)
}
#[allow(clippy::needless_range_loop)]
fn vandermonde_solve(
field: &PrimeField,
points: &[(BigUint, BigUint)],
) -> Option<Vec<BigUint>> {
let k = points.len();
let mut mat: Vec<Vec<BigUint>> = Vec::with_capacity(k);
for (x, y) in points {
let mut row = Vec::with_capacity(k + 1);
let mut x_pow = BigUint::one();
for _ in 0..k {
row.push(x_pow.clone());
x_pow = field.mul(&x_pow, x);
}
row.push(y.clone());
mat.push(row);
}
for col in 0..k {
let mut pivot_row = None;
for r in col..k {
if !mat[r][col].is_zero() {
pivot_row = Some(r);
break;
}
}
let pr = pivot_row?;
if pr != col {
mat.swap(pr, col);
}
let inv = field.inv(&mat[col][col])?;
for c in col..=k {
mat[col][c] = field.mul(&mat[col][c], &inv);
}
for r in 0..k {
if r == col || mat[r][col].is_zero() {
continue;
}
let factor = mat[r][col].clone();
for c in col..=k {
let term = field.mul(&factor, &mat[col][c]);
mat[r][c] = field.sub(&mat[r][c], &term);
}
}
}
Some((0..k).map(|i| mat[i][k].clone()).collect())
}
fn field_element_to_bytes(value: &BigUint, width: usize) -> Vec<u8> {
let mut be = value.to_be_bytes();
if be.len() < width {
let mut padded = vec![0u8; width - be.len()];
padded.append(&mut be);
padded
} else if be.len() == width {
be
} else {
let extra = be.len() - width;
assert!(
be[..extra].iter().all(|&b| b == 0),
"field element exceeds requested encoding width",
);
be[extra..].to_vec()
}
}
fn field_element_to_bytes_checked(value: &BigUint, width: usize) -> Option<Vec<u8>> {
let mut be = value.to_be_bytes();
if be.len() < width {
let mut padded = vec![0u8; width - be.len()];
padded.append(&mut be);
Some(padded)
} else if be.len() == width {
Some(be)
} else {
let extra = be.len() - width;
if be[..extra].iter().all(|&b| b == 0) {
Some(be[extra..].to_vec())
} else {
None
}
}
}
#[allow(dead_code)]
fn _ensure_lagrange_path_compiles(field: &PrimeField, pts: &[(BigUint, BigUint)]) -> Option<BigUint> {
lagrange_eval(field, pts, &BigUint::zero())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::mersenne127;
fn f() -> PrimeField {
PrimeField::new(mersenne127())
}
#[test]
fn round_trip_short() {
let field = f();
let data = b"information dispersal - quick brown fox".to_vec();
let shares = split(&field, &data, 3, 6);
assert_eq!(shares.len(), 6);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert_eq!(reconstruct(&field, &refs[..3], 3).unwrap(), data);
assert_eq!(reconstruct(&field, &refs[2..5], 3).unwrap(), data);
}
#[test]
fn round_trip_long() {
let field = f();
let data: Vec<u8> = (0..1024u32).map(|i| (i & 0xFF) as u8).collect();
let shares = split(&field, &data, 5, 9);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert_eq!(reconstruct(&field, &refs[..5], 5).unwrap(), data);
let picked: Vec<&[u8]> = vec![refs[0], refs[2], refs[4], refs[6], refs[8]];
assert_eq!(reconstruct(&field, &picked, 5).unwrap(), data);
}
#[test]
fn share_size_is_data_over_k() {
let field = f();
let data: Vec<u8> = (0..1500u32).map(|i| (i & 0xFF) as u8).collect();
let k = 3;
let bl = block_len(&field);
let sl = share_elem_len(&field);
let shares = split(&field, &data, k, 5);
let group = k * bl;
let num_groups = data.len().div_ceil(group);
let expected_payload = num_groups * sl;
assert_eq!(shares[0].len() - HEADER_LEN, expected_payload);
let ratio = sl as f64 / bl as f64;
let lower = (data.len() as f64) / (k as f64) * ratio;
let upper = lower + sl as f64;
assert!(
expected_payload as f64 >= lower && (expected_payload as f64) < upper,
"payload {} not in [{}, {})",
expected_payload,
lower,
upper
);
}
#[test]
fn round_trip_empty() {
let field = f();
let data: Vec<u8> = Vec::new();
let shares = split(&field, &data, 2, 3);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert_eq!(reconstruct(&field, &refs[..2], 2).unwrap(), data);
}
#[test]
fn corrupted_extra_rejected() {
let field = f();
let data = b"please don't corrupt me".to_vec();
let mut shares = split(&field, &data, 3, 5);
shares[4][HEADER_LEN] ^= 0x01;
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert!(reconstruct(&field, &refs, 3).is_none());
}
#[test]
fn below_threshold_returns_none() {
let field = f();
let data = b"need three shares".to_vec();
let shares = split(&field, &data, 3, 5);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert!(reconstruct(&field, &refs[..2], 3).is_none());
}
#[test]
fn malformed_version_rejected() {
let field = f();
let data = b"x".to_vec();
let mut shares = split(&field, &data, 2, 3);
shares[0][0] = 0xFF;
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert!(reconstruct(&field, &refs[..2], 2).is_none());
}
#[test]
fn first_k_tamper_does_not_panic() {
let field = f();
let data: Vec<u8> = (0..120u8).collect();
let shares = split(&field, &data, 3, 5);
let mut bad = shares.clone();
for offset in 0..16 {
bad[0][HEADER_LEN + offset] = bad[0][HEADER_LEN + offset].wrapping_add(0x37);
}
let refs: Vec<&[u8]> = bad.iter().map(Vec::as_slice).collect();
let _ = reconstruct(&field, &refs[..3], 3);
}
#[test]
fn duplicate_label_rejected() {
let field = f();
let data = b"hi".to_vec();
let mut shares = split(&field, &data, 2, 3);
shares[1][1] = shares[0][1];
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert!(reconstruct(&field, &refs[..2], 2).is_none());
}
}