use crate::field::PrimeField;
use crate::poly::{horner, lagrange_eval};
use crate::bigint::BigUint;
use crate::csprng::Csprng;
use crate::secure::ct_eq_biguint;
const SHARE_VERSION: u8 = 0x01;
const HEADER_LEN: usize = 1 + 1 + 4;
#[must_use]
pub fn block_len(field: &PrimeField) -> usize {
let bits = field.modulus().bits();
assert!(bits >= 9, "field too small for byte-block Shamir");
(bits - 1) / 8
}
#[must_use]
fn share_elem_len(field: &PrimeField) -> usize {
field.modulus().bits().div_ceil(8)
}
#[must_use]
pub fn split<R: Csprng>(
field: &PrimeField,
rng: &mut R,
secret: &[u8],
k: usize,
n: usize,
) -> Vec<Vec<u8>> {
assert!(k >= 2, "k must be at least 2 (k = 1 would leak the secret)");
assert!(n >= k, "n must be at least k");
assert!(n <= 255, "byte-encoded shares support up to 255 trustees");
assert!(
BigUint::from_u64(n as u64) < *field.modulus(),
"prime modulus must exceed n",
);
assert!(
secret.len() <= u32::MAX as usize,
"secret length must fit in u32 (wire-format length header is 4 bytes)",
);
let bl = block_len(field);
let sl = share_elem_len(field);
let pad = (bl - (secret.len() % bl)) % bl;
let mut padded = Vec::with_capacity(secret.len() + pad);
padded.extend_from_slice(secret);
padded.resize(secret.len() + pad, 0);
let num_blocks = padded.len() / bl;
let mut shares: Vec<Vec<u8>> = (1..=n)
.map(|i| {
let mut hdr = Vec::with_capacity(HEADER_LEN + num_blocks * sl);
hdr.push(SHARE_VERSION);
hdr.push(i as u8);
hdr.extend_from_slice(&(secret.len() as u32).to_be_bytes());
hdr
})
.collect();
for block_idx in 0..num_blocks {
let block = &padded[block_idx * bl..(block_idx + 1) * bl];
let secret_elem = BigUint::from_be_bytes(block);
let mut coeffs: Vec<BigUint> = Vec::with_capacity(k);
coeffs.push(field.reduce(&secret_elem));
for _ in 1..k {
coeffs.push(field.random(rng));
}
for (i, share) in shares.iter_mut().enumerate() {
let x = BigUint::from_u64((i + 1) as u64);
let y = horner(field, &coeffs, &x);
let bytes = field_element_to_bytes(&y, sl);
share.extend_from_slice(&bytes);
}
}
shares
}
#[must_use]
pub fn reconstruct(field: &PrimeField, shares: &[&[u8]], k: usize) -> Option<Vec<u8>> {
if k == 0 || shares.len() < k {
return None;
}
let bl = block_len(field);
let sl = share_elem_len(field);
let mut parsed: Vec<(u8, &[u8])> = Vec::with_capacity(shares.len());
let mut secret_len: Option<usize> = None;
for s in shares {
if s.len() < HEADER_LEN || s[0] != SHARE_VERSION {
return None;
}
let label = s[1];
if label == 0 {
return None;
}
let len = u32::from_be_bytes([s[2], s[3], s[4], s[5]]) as usize;
if let Some(prev) = secret_len {
if prev != len {
return None;
}
} else {
secret_len = Some(len);
}
let payload = &s[HEADER_LEN..];
if payload.len() % sl != 0 {
return None;
}
parsed.push((label, payload));
}
let secret_len = secret_len?;
let pad = (bl - (secret_len % bl)) % bl;
let padded_len = secret_len + pad;
let num_blocks = padded_len / bl;
for (_, payload) in &parsed {
if payload.len() != num_blocks * sl {
return None;
}
}
for i in 0..parsed.len() {
for j in (i + 1)..parsed.len() {
if parsed[i].0 == parsed[j].0 {
return None;
}
}
}
let mut out = Vec::with_capacity(padded_len);
for block_idx in 0..num_blocks {
let mut pts: Vec<(BigUint, BigUint)> = Vec::with_capacity(k);
for (label, payload) in parsed.iter().take(k) {
let x = BigUint::from_u64(*label as u64);
let y = BigUint::from_be_bytes(&payload[block_idx * sl..(block_idx + 1) * sl]);
if y >= *field.modulus() {
return None;
}
pts.push((x, y));
}
let secret_y = lagrange_eval(field, &pts, &BigUint::zero())?;
for (label, payload) in parsed.iter().skip(k) {
let x = BigUint::from_u64(*label as u64);
let y = BigUint::from_be_bytes(&payload[block_idx * sl..(block_idx + 1) * sl]);
if y >= *field.modulus() {
return None;
}
let pred = lagrange_eval(field, &pts, &x)?;
if !ct_eq_biguint(&pred, &y) {
return None;
}
let _ = label;
}
let bytes = field_element_to_bytes_checked(&secret_y, bl)?;
out.extend_from_slice(&bytes);
}
out.truncate(secret_len);
Some(out)
}
fn field_element_to_bytes_checked(value: &BigUint, width: usize) -> Option<Vec<u8>> {
let mut be = value.to_be_bytes();
if be.len() < width {
let mut padded = vec![0u8; width - be.len()];
padded.append(&mut be);
Some(padded)
} else if be.len() == width {
Some(be)
} else {
let extra = be.len() - width;
if be[..extra].iter().all(|&b| b == 0) {
Some(be[extra..].to_vec())
} else {
None
}
}
}
fn field_element_to_bytes(value: &BigUint, width: usize) -> Vec<u8> {
let mut be = value.to_be_bytes();
if be.len() < width {
let mut padded = vec![0u8; width - be.len()];
padded.append(&mut be);
padded
} else if be.len() == width {
be
} else {
let extra = be.len() - width;
assert!(
be[..extra].iter().all(|&b| b == 0),
"field element exceeds requested encoding width",
);
be[extra..].to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::field::mersenne127;
use crate::csprng::ChaCha20Rng;
fn rng() -> ChaCha20Rng {
ChaCha20Rng::from_seed(&[9u8; 32])
}
#[test]
fn block_len_for_mersenne127_is_15() {
let f = PrimeField::new(mersenne127());
assert_eq!(block_len(&f), 15);
}
#[test]
fn round_trip_short_secret() {
let f = PrimeField::new(mersenne127());
let mut r = rng();
let secret = b"hello, secret!".to_vec();
let shares = split(&f, &mut r, &secret, 2, 4);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
let got = reconstruct(&f, &refs[..2], 2).unwrap();
assert_eq!(got, secret);
}
#[test]
fn round_trip_multi_block_secret() {
let f = PrimeField::new(mersenne127());
let mut r = rng();
let secret: Vec<u8> = (0..64u8).collect();
let shares = split(&f, &mut r, &secret, 3, 7);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
let got = reconstruct(&f, &refs[2..5], 3).unwrap();
assert_eq!(got, secret);
}
#[test]
fn round_trip_empty_secret() {
let f = PrimeField::new(mersenne127());
let mut r = rng();
let secret: Vec<u8> = Vec::new();
let shares = split(&f, &mut r, &secret, 2, 3);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
let got = reconstruct(&f, &refs[..2], 2).unwrap();
assert_eq!(got, secret);
}
#[test]
fn below_threshold_returns_none() {
let f = PrimeField::new(mersenne127());
let mut r = rng();
let secret = b"32-byte secret like an AES key!!".to_vec();
let shares = split(&f, &mut r, &secret, 3, 5);
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert!(reconstruct(&f, &refs[..2], 3).is_none());
}
#[test]
fn corrupted_extra_is_rejected() {
let f = PrimeField::new(mersenne127());
let mut r = rng();
let secret = b"do not corrupt me!".to_vec();
let mut shares = split(&f, &mut r, &secret, 3, 5);
shares[4][HEADER_LEN] ^= 0x01;
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert!(reconstruct(&f, &refs, 3).is_none());
}
#[test]
fn malformed_version_is_rejected() {
let f = PrimeField::new(mersenne127());
let mut r = rng();
let secret = b"x".to_vec();
let mut shares = split(&f, &mut r, &secret, 2, 3);
shares[0][0] = 0xFF;
let refs: Vec<&[u8]> = shares.iter().map(Vec::as_slice).collect();
assert!(reconstruct(&f, &refs[..2], 2).is_none());
}
}