use getrandom::getrandom as os_getrandom;
#[derive(Debug, Clone, Copy)]
pub struct ThresholdConfig {
pub t: u8,
pub n: u8,
}
impl ThresholdConfig {
pub const DEFAULT: Self = Self { t: 2, n: 3 };
pub fn is_valid(&self) -> bool {
self.t >= 2 && self.t <= self.n
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Share {
pub index: u8,
pub bytes: Vec<u8>,
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum ThresholdError {
#[error("invalid threshold config: t={t} n={n}")]
InvalidConfig {
t: u8,
n: u8,
},
#[error("insufficient shares: need {need}, got {got}")]
InsufficientShares {
need: usize,
got: usize,
},
#[error("invalid share index: {index}")]
InvalidShareIndex {
index: u8,
},
#[error("duplicate share indices in combine input")]
DuplicateShares,
#[error("share byte length mismatch: expected {expected}, got {got}")]
ShareLengthMismatch {
expected: usize,
got: usize,
},
#[error("entropy source failure: {0}")]
EntropyFailure(String),
}
const GF_REDUCER: u16 = 0x11B;
#[inline]
fn gf_add(a: u8, b: u8) -> u8 {
a ^ b
}
fn gf_mul(mut a: u8, mut b: u8) -> u8 {
let mut r: u16 = 0;
while a != 0 && b != 0 {
if b & 1 == 1 {
r ^= a as u16;
}
let hi_bit = a & 0x80;
a <<= 1;
if hi_bit != 0 {
a ^= (GF_REDUCER & 0xFF) as u8;
}
b >>= 1;
}
r as u8
}
fn gf_inv(a: u8) -> u8 {
if a == 0 {
return 0;
}
let mut result: u8 = 1;
let mut base = a;
let mut exp: u8 = 254;
while exp > 0 {
if exp & 1 == 1 {
result = gf_mul(result, base);
}
base = gf_mul(base, base);
exp >>= 1;
}
result
}
fn gf_poly_eval(coeffs: &[u8], x: u8) -> u8 {
let mut acc: u8 = 0;
for &c in coeffs.iter().rev() {
acc = gf_add(gf_mul(acc, x), c);
}
acc
}
pub fn split_secret(secret: &[u8], config: ThresholdConfig) -> Result<Vec<Share>, ThresholdError> {
if !config.is_valid() {
return Err(ThresholdError::InvalidConfig {
t: config.t,
n: config.n,
});
}
let t = config.t as usize;
let n = config.n as usize;
let mut random_coeffs = vec![0u8; secret.len() * (t - 1)];
if !random_coeffs.is_empty() {
os_getrandom(&mut random_coeffs)
.map_err(|e| ThresholdError::EntropyFailure(e.to_string()))?;
}
let mut shares: Vec<Share> = Vec::with_capacity(n);
for share_idx in 1..=(n as u8) {
let mut ys = Vec::with_capacity(secret.len());
for (byte_idx, &secret_byte) in secret.iter().enumerate() {
let start = byte_idx * (t - 1);
let end = start + (t - 1);
let mut poly = Vec::with_capacity(t);
poly.push(secret_byte);
poly.extend_from_slice(&random_coeffs[start..end]);
ys.push(gf_poly_eval(&poly, share_idx));
}
shares.push(Share {
index: share_idx,
bytes: ys,
});
}
random_coeffs.fill(0);
Ok(shares)
}
pub fn combine_shares(
shares: &[Share],
config: ThresholdConfig,
) -> Result<Vec<u8>, ThresholdError> {
if shares.len() < config.t as usize {
return Err(ThresholdError::InsufficientShares {
need: config.t as usize,
got: shares.len(),
});
}
for s in shares {
if s.index == 0 {
return Err(ThresholdError::InvalidShareIndex { index: 0 });
}
}
for (i, a) in shares.iter().enumerate() {
for b in &shares[i + 1..] {
if a.index == b.index {
return Err(ThresholdError::DuplicateShares);
}
}
}
let secret_len = shares[0].bytes.len();
for s in shares {
if s.bytes.len() != secret_len {
return Err(ThresholdError::ShareLengthMismatch {
expected: secret_len,
got: s.bytes.len(),
});
}
}
let selected = &shares[..config.t as usize];
let mut secret = vec![0u8; secret_len];
for (byte_idx, secret_byte) in secret.iter_mut().enumerate() {
let mut acc: u8 = 0;
for (i, s_i) in selected.iter().enumerate() {
let mut num: u8 = 1;
let mut den: u8 = 1;
for (j, s_j) in selected.iter().enumerate() {
if i == j {
continue;
}
num = gf_mul(num, s_j.index);
den = gf_mul(den, gf_add(s_i.index, s_j.index));
}
let basis = gf_mul(num, gf_inv(den));
acc = gf_add(acc, gf_mul(s_i.bytes[byte_idx], basis));
}
*secret_byte = acc;
}
Ok(secret)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn default_config_is_2_of_3() {
let c = ThresholdConfig::DEFAULT;
assert_eq!(c.t, 2);
assert_eq!(c.n, 3);
assert!(c.is_valid());
}
#[test]
fn reject_t_greater_than_n() {
let c = ThresholdConfig { t: 5, n: 3 };
assert!(!c.is_valid());
}
#[test]
fn reject_t_below_two() {
let c = ThresholdConfig { t: 1, n: 3 };
assert!(!c.is_valid());
}
#[test]
fn split_invalid_config_errors() {
let err = split_secret(b"secret", ThresholdConfig { t: 1, n: 3 }).unwrap_err();
assert!(matches!(err, ThresholdError::InvalidConfig { .. }));
}
#[test]
fn combine_insufficient_shares_errors() {
let config = ThresholdConfig::DEFAULT;
let shares = vec![Share {
index: 1,
bytes: vec![0u8; 32],
}];
let err = combine_shares(&shares, config).unwrap_err();
assert_eq!(err, ThresholdError::InsufficientShares { need: 2, got: 1 });
}
#[test]
fn roundtrip_2_of_3() {
let secret = b"auto-promote-token-sample-bytes!";
assert_eq!(secret.len(), 32);
let shares = split_secret(secret, ThresholdConfig::DEFAULT).unwrap();
assert_eq!(shares.len(), 3);
for pair in [(0, 1), (0, 2), (1, 2)] {
let subset = vec![shares[pair.0].clone(), shares[pair.1].clone()];
let recovered = combine_shares(&subset, ThresholdConfig::DEFAULT).unwrap();
assert_eq!(&recovered[..], &secret[..]);
}
}
#[test]
fn roundtrip_3_of_5() {
let config = ThresholdConfig { t: 3, n: 5 };
let secret: Vec<u8> = (0..64).collect();
let shares = split_secret(&secret, config).unwrap();
assert_eq!(shares.len(), 5);
let subset = vec![shares[0].clone(), shares[2].clone(), shares[4].clone()];
let recovered = combine_shares(&subset, config).unwrap();
assert_eq!(recovered, secret);
}
#[test]
fn share_indices_are_one_through_n() {
let shares = split_secret(b"hi", ThresholdConfig { t: 2, n: 4 }).unwrap();
let indices: Vec<u8> = shares.iter().map(|s| s.index).collect();
assert_eq!(indices, vec![1, 2, 3, 4]);
}
#[test]
fn fewer_than_t_reveals_no_info_sampled() {
let shares = split_secret(b"topsecret", ThresholdConfig::DEFAULT).unwrap();
let one_share = vec![shares[0].clone()];
let err = combine_shares(&one_share, ThresholdConfig::DEFAULT).unwrap_err();
assert!(matches!(err, ThresholdError::InsufficientShares { .. }));
}
#[test]
fn reject_duplicate_share_indices() {
let shares = split_secret(b"hi", ThresholdConfig::DEFAULT).unwrap();
let forged = vec![shares[0].clone(), shares[0].clone()];
let err = combine_shares(&forged, ThresholdConfig::DEFAULT).unwrap_err();
assert_eq!(err, ThresholdError::DuplicateShares);
}
#[test]
fn reject_zero_index_share() {
let shares = split_secret(b"hi", ThresholdConfig::DEFAULT).unwrap();
let mut with_zero = shares.clone();
with_zero[0].index = 0;
let err = combine_shares(&with_zero, ThresholdConfig::DEFAULT).unwrap_err();
assert_eq!(err, ThresholdError::InvalidShareIndex { index: 0 });
}
#[test]
fn reject_mismatched_share_length() {
let shares = split_secret(b"hello", ThresholdConfig::DEFAULT).unwrap();
let mut bad = shares[1].clone();
bad.bytes.push(0xFF);
let combined = vec![shares[0].clone(), bad];
let err = combine_shares(&combined, ThresholdConfig::DEFAULT).unwrap_err();
assert!(matches!(
err,
ThresholdError::ShareLengthMismatch {
expected: 5,
got: 6
}
));
}
#[test]
fn tampered_share_yields_wrong_secret_silently() {
let secret = b"integrity-marker";
let mut shares = split_secret(secret, ThresholdConfig::DEFAULT).unwrap();
shares[1].bytes[0] ^= 0xFF;
let recovered = combine_shares(&shares[..2], ThresholdConfig::DEFAULT).unwrap();
assert_ne!(&recovered[..], &secret[..]);
}
#[test]
fn gf_mul_matches_aes_vectors() {
assert_eq!(gf_mul(0x02, 0xD4), 0xB3);
assert_eq!(gf_mul(0x03, 0xBF), 0xDA);
assert_eq!(gf_mul(0x00, 0xFF), 0x00);
assert_eq!(gf_mul(0x01, 0x57), 0x57);
}
#[test]
fn gf_inv_is_multiplicative_identity() {
for x in 1..=255u8 {
let inv = gf_inv(x);
assert_ne!(inv, 0, "inverse of {x} was zero");
assert_eq!(gf_mul(x, inv), 1, "x={x}");
}
}
#[test]
fn gf_inv_of_zero_is_zero() {
assert_eq!(gf_inv(0), 0);
}
}