use rand_core::{CryptoRng, RngCore};
pub const SECRET_LENGTH: usize = 32;
pub const ID_LENGTH: usize = 16;
pub const ENCODED_LENGTH: usize = 3 + SECRET_LENGTH + ID_LENGTH;
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum ShamirError {
#[error("Threshold must be more than 1, but got {0}")]
InvalidThreshold(u32),
#[error("Share count must be between 1 and 255, but got {0}")]
InvalidShareCount(u32),
#[error("Share count {0} must be at least the threshold {1}")]
LessSharesThanThreshhold(u32, u32),
#[error(
"Not enough shares to reconstruct the secret. Provided {0} shares, but threshold is {1}"
)]
NotEnoughShares(u32, u32),
#[error("Inconsistent shares provided")]
InconsistentShares,
#[error("Duplicate share index: {0} and {1}")]
DuplicateShareIndex(u8, u8),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Share {
pub index: u8,
pub threshold: u8,
pub share_count: u8,
pub value: [u8; SECRET_LENGTH],
}
impl Share {
pub fn encode(&self, id: [u8; ID_LENGTH]) -> [u8; ENCODED_LENGTH] {
let mut out = [0u8; ENCODED_LENGTH];
out[0] = self.index;
out[1] = self.threshold;
out[2] = self.share_count;
out[3..3 + SECRET_LENGTH].copy_from_slice(&self.value);
out[3 + SECRET_LENGTH..].copy_from_slice(&id);
out
}
pub fn decode(encoded: &[u8; ENCODED_LENGTH]) -> (Self, [u8; ID_LENGTH]) {
let index = encoded[0];
let threshold = encoded[1];
let share_count = encoded[2];
let mut value = [0u8; SECRET_LENGTH];
value.copy_from_slice(&encoded[3..3 + SECRET_LENGTH]);
let mut id = [0u8; ID_LENGTH];
id.copy_from_slice(&encoded[3 + SECRET_LENGTH..]);
(
Self {
index,
threshold,
share_count,
value,
},
id,
)
}
}
pub fn split_secret(
mut rng: impl RngCore + CryptoRng,
secret: [u8; SECRET_LENGTH],
threshold: u8,
share_count: u8,
) -> Result<Vec<Share>, ShamirError> {
if threshold < 1 {
return Err(ShamirError::InvalidThreshold(threshold as u32));
}
if share_count == 0 {
return Err(ShamirError::InvalidShareCount(share_count as u32));
}
if share_count < threshold {
return Err(ShamirError::LessSharesThanThreshhold(
share_count as u32,
threshold as u32,
));
}
let threshold = threshold as usize;
let mut coefficients = vec![[0u8; SECRET_LENGTH]; threshold];
coefficients[0] = secret;
for coeff in coefficients.iter_mut().skip(1) {
rng.fill_bytes(coeff);
}
let mut output = Vec::with_capacity(share_count as usize);
for x in 1..=share_count {
let mut ys = [0u8; SECRET_LENGTH];
for (byte_idx, y_out) in ys.iter_mut().enumerate() {
let mut acc = coefficients[threshold - 1][byte_idx];
for k in (0..threshold - 1).rev() {
acc = gf256::add(gf256::mul(acc, x), coefficients[k][byte_idx]);
}
*y_out = acc;
}
output.push(Share {
index: x,
threshold: threshold as u8,
share_count,
value: ys,
});
}
Ok(output)
}
pub fn combine_shares(shares: &[Share]) -> Result<[u8; SECRET_LENGTH], ShamirError> {
if shares.is_empty() {
return Err(ShamirError::NotEnoughShares(0, 0));
}
let threshold = shares[0].threshold as usize;
for share in shares.iter() {
if share.threshold != shares[0].threshold
|| share.share_count != shares[0].share_count
|| share.index == 0
{
return Err(ShamirError::InconsistentShares);
}
}
if shares.len() < threshold {
return Err(ShamirError::NotEnoughShares(
shares.len() as u32,
threshold as u32,
));
}
for i in 0..shares.len() {
for j in i + 1..shares.len() {
if shares[i].index == shares[j].index {
return Err(ShamirError::DuplicateShareIndex(
shares[i].index,
shares[j].index,
));
}
}
}
let used = &shares[0..threshold];
let mut secret = [0u8; SECRET_LENGTH];
for (byte_idx, out_byte) in secret.iter_mut().enumerate() {
let mut acc = 0u8;
for (i, share) in used.iter().enumerate() {
let xi = share.index;
let yi = share.value[byte_idx];
let mut num = 1u8;
let mut den = 1u8;
for (m, other_share) in used.iter().enumerate() {
if m == i {
continue;
}
let xm = other_share.index;
num = gf256::mul(num, xm);
den = gf256::mul(den, gf256::add(xm, xi));
}
let li0 = gf256::mul(num, gf256::inv(den));
acc = gf256::add(acc, gf256::mul(yi, li0));
}
*out_byte = acc;
}
Ok(secret)
}
pub mod gf256 {
pub fn add(a: u8, b: u8) -> u8 {
a ^ b
}
pub fn mul(mut a: u8, mut b: u8) -> u8 {
let mut p = 0u8;
for _ in 0..8 {
if (b & 1) != 0 {
p ^= a;
}
let hi = a & 0x80;
a <<= 1;
if hi != 0 {
a ^= 0x1B;
}
b >>= 1;
}
p
}
pub fn pow(mut a: u8, mut e: u8) -> u8 {
let mut r = 1u8;
while e != 0 {
if (e & 1) != 0 {
r = mul(r, a);
}
a = mul(a, a);
e >>= 1;
}
r
}
pub fn inv(a: u8) -> u8 {
assert!(a != 0);
pow(a, 254)
}
}
#[cfg(test)]
mod tests {
use crate::rng::os_rng_hkdf;
use super::*;
fn random_secret() -> [u8; SECRET_LENGTH] {
let mut s = [0u8; SECRET_LENGTH];
getrandom::fill(&mut s).unwrap();
s
}
#[test]
fn split_and_recombine_3_of_5() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 3, 5).unwrap();
let recovered = combine_shares(&shares[0..3]).unwrap();
assert_eq!(secret, recovered);
}
#[test]
fn recombinations_3_of_5() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 3, 5).unwrap();
for i in 0..5 {
for j in (i + 1)..5 {
for k in (j + 1)..5 {
let combo = vec![shares[i].clone(), shares[j].clone(), shares[k].clone()];
let recovered = combine_shares(&combo).unwrap();
assert_eq!(secret, recovered);
}
}
}
}
#[test]
fn rejects_not_enough_shares() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 3, 5).unwrap();
let result = combine_shares(&shares[0..2]);
assert!(result.is_err());
}
#[test]
fn fails_with_duplicate_indices() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
let secret = random_secret();
let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();
shares[1].index = shares[0].index;
let result = combine_shares(&shares[0..3]);
assert!(result.is_err());
}
#[test]
fn rejects_inconsistent_threshold() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
let secret = random_secret();
let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();
shares[0].threshold = 2;
let result = combine_shares(&shares[0..3]);
assert!(result.is_err());
}
#[test]
fn rejects_inconsistent_share_count() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
let secret = random_secret();
let mut shares = split_secret(&mut rng, secret, 3, 5).unwrap();
shares[0].share_count = 9;
let result = combine_shares(&shares[0..3]);
assert!(result.is_err());
}
#[test]
fn multiple_random_runs() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
for _ in 0..50 {
let secret = random_secret();
let shares = split_secret(&mut rng, secret, 4, 7).unwrap();
let recovered = combine_shares(&shares[1..5]).unwrap();
assert_eq!(secret, recovered);
}
}
#[test]
fn encode_decode() {
let mut rng = os_rng_hkdf(None, b"").unwrap();
let mut value = [0u8; SECRET_LENGTH];
rng.fill_bytes(&mut value);
let mut id = [0u8; ID_LENGTH];
rng.fill_bytes(&mut id);
let share = Share {
index: 3,
threshold: 4,
share_count: 7,
value,
};
let encoded = share.encode(id);
let (decoded_share, decoded_id) = Share::decode(&encoded);
assert_eq!(share, decoded_share);
assert_eq!(id, decoded_id);
}
}