use rand_core::{CryptoRng, RngCore};
use super::CryptoError;
pub const SECRET_LENGTH: usize = 32 + 32;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Share {
pub index: u8,
pub threshold: u8,
pub share_count: u8,
pub value: [u8; SECRET_LENGTH],
}
pub fn split_secret(
mut rng: impl RngCore + CryptoRng,
secret: [u8; SECRET_LENGTH],
threshold: u8,
share_count: u8,
) -> Result<Vec<Share>, CryptoError> {
if threshold < 2 {
return Err(CryptoError::ThresholdTooLow(threshold as u32));
}
if share_count == 0 {
return Err(CryptoError::ShareCountError(share_count as u32));
}
if share_count < threshold {
return Err(CryptoError::ThresholdMoreThanShareCound(
share_count as u32,
threshold as u32,
));
}
let threshold = threshold as usize;
let mut coefficients = vec![[0u8; SECRET_LENGTH]; threshold];
coefficients[0] = secret;
for coeff in coefficients.iter_mut().skip(1) {
rng.fill_bytes(coeff);
}
let mut output = Vec::with_capacity(share_count as usize);
for x in 1..=share_count {
let mut ys = [0u8; SECRET_LENGTH];
for (byte_idx, y_out) in ys.iter_mut().enumerate() {
let mut acc = coefficients[threshold - 1][byte_idx];
for k in (0..threshold - 1).rev() {
acc = gf256_add(gf256_mul(acc, x), coefficients[k][byte_idx]);
}
*y_out = acc;
}
output.push(Share {
index: x,
threshold: threshold as u8,
share_count,
value: ys,
});
}
Ok(output)
}
pub fn combine_shares(shares: &[Share]) -> Result<[u8; SECRET_LENGTH], CryptoError> {
if shares.is_empty() {
return Err(CryptoError::NotEnoughShares(0, 0));
}
let threshold = shares[0].threshold as usize;
for share in shares.iter() {
if share.threshold != shares[0].threshold
|| share.share_count != shares[0].share_count
|| share.index == 0
{
return Err(CryptoError::InconsistentShares);
}
}
if shares.len() < threshold {
return Err(CryptoError::NotEnoughShares(
shares.len() as u32,
threshold as u32,
));
}
for i in 0..shares.len() {
for j in i + 1..shares.len() {
if shares[i].index == shares[j].index {
return Err(CryptoError::DuplicateShareIndex(
shares[i].index,
shares[j].index,
));
}
}
}
let used = &shares[0..threshold];
let mut secret = [0u8; SECRET_LENGTH];
for (byte_idx, out_byte) in secret.iter_mut().enumerate() {
let mut acc = 0u8;
for (i, share) in used.iter().enumerate() {
let xi = share.index;
let yi = share.value[byte_idx];
let mut num = 1u8;
let mut den = 1u8;
for (m, other_share) in used.iter().enumerate() {
if m == i {
continue;
}
let xm = other_share.index;
num = gf256_mul(num, xm);
den = gf256_mul(den, gf256_add(xm, xi));
}
let li0 = gf256_mul(num, gf256_inv(den));
acc = gf256_add(acc, gf256_mul(yi, li0));
}
*out_byte = acc;
}
Ok(secret)
}
fn gf256_add(a: u8, b: u8) -> u8 {
a ^ b
}
fn gf256_mul(mut a: u8, mut b: u8) -> u8 {
let mut p = 0u8;
for _ in 0..8 {
if (b & 1) != 0 {
p ^= a;
}
let hi = a & 0x80;
a <<= 1;
if hi != 0 {
a ^= 0x1B;
}
b >>= 1;
}
p
}
fn gf256_pow(mut a: u8, mut e: u8) -> u8 {
let mut r = 1u8;
while e != 0 {
if (e & 1) != 0 {
r = gf256_mul(r, a);
}
a = gf256_mul(a, a);
e >>= 1;
}
r
}
fn gf256_inv(a: u8) -> u8 {
assert!(a != 0);
gf256_pow(a, 254)
}
#[cfg(test)]
mod tests {
use super::*;
use rand_core::OsRng;
fn random_secret() -> [u8; SECRET_LENGTH] {
let mut rng = OsRng;
let mut s = [0u8; SECRET_LENGTH];
rng.fill_bytes(&mut s);
s
}
#[test]
fn split_and_recombine_3_of_5() {
let mut rng = OsRng;
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 3, 5).unwrap();
let recovered = combine_shares(&shares[0..3]).unwrap();
assert_eq!(secret, recovered);
}
#[test]
fn recombinations_3_of_5() {
let mut rng = OsRng;
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 3, 5).unwrap();
for i in 0..5 {
for j in (i + 1)..5 {
for k in (j + 1)..5 {
let combo = vec![shares[i].clone(), shares[j].clone(), shares[k].clone()];
let recovered = combine_shares(&combo).unwrap();
assert_eq!(secret, recovered);
}
}
}
}
#[test]
fn rejects_not_enough_shares() {
let mut rng = OsRng;
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 3, 5).unwrap();
let result = combine_shares(&shares[0..2]);
assert!(result.is_err());
}
#[test]
fn fails_with_duplicate_indices() {
let mut rng = OsRng;
let secret = random_secret();
let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();
shares[1].index = shares[0].index;
let result = combine_shares(&shares[0..3]);
assert!(result.is_err());
}
#[test]
fn rejects_inconsistent_threshold() {
let mut rng = OsRng;
let secret = random_secret();
let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();
shares[0].threshold = 2;
let result = combine_shares(&shares[0..3]);
assert!(result.is_err());
}
#[test]
fn rejects_inconsistent_share_count() {
let mut rng = OsRng;
let secret = random_secret();
let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();
shares[0].share_count = 9;
let result = combine_shares(&shares[0..3]);
assert!(result.is_err());
}
#[test]
fn multiple_random_runs() {
let mut rng = OsRng;
for _ in 0..50 {
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 4, 7).unwrap();
let recovered = combine_shares(&shares[1..5]).unwrap();
assert_eq!(secret, recovered);
}
}
}