use rand::RngCore;
use serde::{Deserialize, Serialize};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::{
key::{MasterKey, KEY_SIZE},
Result, VaultError,
};
mod gf256 {
pub const fn mul(mut a: u8, mut b: u8) -> u8 {
let mut result: u8 = 0;
let mut i = 0;
while i < 8 {
if b & 1 != 0 {
result ^= a;
}
let carry = a & 0x80;
a <<= 1;
if carry != 0 {
a ^= 0x1B;
}
b >>= 1;
i += 1;
}
result
}
pub const fn inv(a: u8) -> u8 {
if a == 0 {
return 0;
}
let a2 = mul(a, a);
let a4 = mul(a2, a2);
let a8 = mul(a4, a4);
let a16 = mul(a8, a8);
let a32 = mul(a16, a16);
let a64 = mul(a32, a32);
let a128 = mul(a64, a64);
mul(mul(mul(mul(mul(mul(a128, a64), a32), a16), a8), a4), a2)
}
pub fn eval_poly(coeffs: &[u8], x: u8) -> u8 {
let mut result = 0u8;
for &coeff in coeffs.iter().rev() {
result = mul(result, x) ^ coeff;
}
result
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShamirConfig {
pub total_shares: u8,
pub threshold: u8,
}
impl ShamirConfig {
fn validate(&self) -> Result<()> {
if self.threshold < 2 {
return Err(VaultError::ShamirError(
"threshold must be at least 2".to_string(),
));
}
if self.threshold > self.total_shares {
return Err(VaultError::ShamirError(
"threshold cannot exceed total shares".to_string(),
));
}
if self.total_shares < 2 {
return Err(VaultError::ShamirError(
"total shares must be at least 2".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Zeroize, ZeroizeOnDrop)]
pub struct KeyShare {
pub index: u8,
pub data: Vec<u8>,
}
pub fn split_master_key(key: &MasterKey, config: &ShamirConfig) -> Result<Vec<KeyShare>> {
config.validate()?;
let key_bytes = key.as_bytes();
let threshold = config.threshold as usize;
let mut shares: Vec<Vec<u8>> = (1..=config.total_shares)
.map(|x| {
let mut data = vec![0u8; KEY_SIZE + 1];
data[0] = x;
data
})
.collect();
let mut coeffs = vec![0u8; threshold];
let mut rng_buf = vec![0u8; threshold - 1];
for byte_idx in 0..KEY_SIZE {
coeffs[0] = key_bytes[byte_idx];
rand::rng().fill_bytes(&mut rng_buf);
coeffs[1..].copy_from_slice(&rng_buf);
for share in &mut shares {
let x = share[0];
share[byte_idx + 1] = gf256::eval_poly(&coeffs, x);
}
}
coeffs.zeroize();
rng_buf.zeroize();
Ok(shares
.into_iter()
.map(|data| KeyShare {
index: data[0],
data,
})
.collect())
}
pub fn reconstruct_master_key(shares: &[KeyShare]) -> Result<MasterKey> {
if shares.len() < 2 {
return Err(VaultError::ShamirError(
"need at least 2 shares to reconstruct".to_string(),
));
}
let expected_len = KEY_SIZE + 1;
for share in shares {
if share.data.len() != expected_len {
return Err(VaultError::ShamirError(format!(
"invalid share length: {} (expected {expected_len})",
share.data.len()
)));
}
if share.data[0] == 0 {
return Err(VaultError::ShamirError(
"share x-coordinate cannot be zero".to_string(),
));
}
}
let mut xs: Vec<u8> = shares.iter().map(|s| s.data[0]).collect();
xs.sort_unstable();
for w in xs.windows(2) {
if w[0] == w[1] {
return Err(VaultError::ShamirError(
"duplicate share x-coordinates".to_string(),
));
}
}
let n = shares.len();
let x_coords: Vec<u8> = shares.iter().map(|s| s.data[0]).collect();
let mut basis = vec![0u8; n];
for j in 0..n {
let mut num = 1u8;
let mut den = 1u8;
for k in 0..n {
if k == j {
continue;
}
num = gf256::mul(num, x_coords[k]);
den = gf256::mul(den, x_coords[j] ^ x_coords[k]);
}
basis[j] = gf256::mul(num, gf256::inv(den));
}
let mut key_bytes = [0u8; KEY_SIZE];
for (byte_idx, key_byte) in key_bytes.iter_mut().enumerate() {
let mut value = 0u8;
for j in 0..n {
value ^= gf256::mul(shares[j].data[byte_idx + 1], basis[j]);
}
*key_byte = value;
}
Ok(MasterKey::from_bytes(key_bytes))
}
#[cfg(test)]
mod tests {
use super::*;
fn test_key() -> MasterKey {
MasterKey::from_bytes([42u8; KEY_SIZE])
}
#[test]
fn test_gf256_mul_identity() {
for a in 0..=255u8 {
assert_eq!(gf256::mul(a, 1), a);
assert_eq!(gf256::mul(1, a), a);
}
}
#[test]
fn test_gf256_mul_zero() {
for a in 0..=255u8 {
assert_eq!(gf256::mul(a, 0), 0);
assert_eq!(gf256::mul(0, a), 0);
}
}
#[test]
fn test_gf256_mul_commutative() {
for a in 0..=255u8 {
for b in 0..=255u8 {
assert_eq!(gf256::mul(a, b), gf256::mul(b, a));
}
}
}
#[test]
fn test_gf256_inv_exhaustive() {
for a in 1..=255u8 {
let ai = gf256::inv(a);
assert_ne!(ai, 0, "inverse of {a} should be nonzero");
assert_eq!(
gf256::mul(a, ai),
1,
"a={a}, inv={ai}, a*inv={}, expected 1",
gf256::mul(a, ai)
);
}
}
#[test]
fn test_gf256_inv_zero() {
assert_eq!(gf256::inv(0), 0);
}
#[test]
fn test_gf256_eval_poly_constant() {
assert_eq!(gf256::eval_poly(&[42], 0), 42);
assert_eq!(gf256::eval_poly(&[42], 1), 42);
assert_eq!(gf256::eval_poly(&[42], 255), 42);
}
#[test]
fn test_gf256_eval_poly_linear() {
let coeffs = [5u8, 3u8];
assert_eq!(gf256::eval_poly(&coeffs, 0), 5);
assert_eq!(gf256::eval_poly(&coeffs, 1), 5 ^ 3); }
#[test]
fn test_split_and_reconstruct_roundtrip() {
let key = test_key();
let config = ShamirConfig {
total_shares: 5,
threshold: 3,
};
let shares = split_master_key(&key, &config).unwrap();
assert_eq!(shares.len(), 5);
let reconstructed = reconstruct_master_key(&shares[..3]).unwrap();
assert_eq!(reconstructed.as_bytes(), key.as_bytes());
}
#[test]
fn test_reconstruct_with_all_shares() {
let key = test_key();
let config = ShamirConfig {
total_shares: 5,
threshold: 3,
};
let shares = split_master_key(&key, &config).unwrap();
let reconstructed = reconstruct_master_key(&shares).unwrap();
assert_eq!(reconstructed.as_bytes(), key.as_bytes());
}
#[test]
fn test_reconstruct_with_different_subset() {
let key = test_key();
let config = ShamirConfig {
total_shares: 5,
threshold: 3,
};
let shares = split_master_key(&key, &config).unwrap();
let subset = vec![shares[1].clone(), shares[3].clone(), shares[4].clone()];
let reconstructed = reconstruct_master_key(&subset).unwrap();
assert_eq!(reconstructed.as_bytes(), key.as_bytes());
}
#[test]
fn test_insufficient_shares_fails() {
let result = reconstruct_master_key(&[KeyShare {
index: 1,
data: vec![1; KEY_SIZE + 1],
}]);
assert!(result.is_err());
assert!(matches!(result, Err(VaultError::ShamirError(_))));
}
#[test]
fn test_threshold_too_low() {
let key = test_key();
let config = ShamirConfig {
total_shares: 5,
threshold: 1,
};
assert!(split_master_key(&key, &config).is_err());
}
#[test]
fn test_threshold_exceeds_total() {
let key = test_key();
let config = ShamirConfig {
total_shares: 3,
threshold: 5,
};
assert!(split_master_key(&key, &config).is_err());
}
#[test]
fn test_total_shares_too_low() {
let key = test_key();
let config = ShamirConfig {
total_shares: 1,
threshold: 1,
};
assert!(split_master_key(&key, &config).is_err());
}
#[test]
fn test_minimum_config() {
let key = test_key();
let config = ShamirConfig {
total_shares: 2,
threshold: 2,
};
let shares = split_master_key(&key, &config).unwrap();
assert_eq!(shares.len(), 2);
let reconstructed = reconstruct_master_key(&shares).unwrap();
assert_eq!(reconstructed.as_bytes(), key.as_bytes());
}
#[test]
fn test_shares_have_unique_indices() {
let key = test_key();
let config = ShamirConfig {
total_shares: 5,
threshold: 3,
};
let shares = split_master_key(&key, &config).unwrap();
let mut indices: Vec<u8> = shares.iter().map(|s| s.index).collect();
indices.sort_unstable();
indices.dedup();
assert_eq!(indices.len(), 5);
}
#[test]
fn test_different_keys_different_shares() {
let key1 = MasterKey::from_bytes([1u8; KEY_SIZE]);
let key2 = MasterKey::from_bytes([2u8; KEY_SIZE]);
let config = ShamirConfig {
total_shares: 3,
threshold: 2,
};
let shares1 = split_master_key(&key1, &config).unwrap();
let shares2 = split_master_key(&key2, &config).unwrap();
let r1 = reconstruct_master_key(&shares1[..2]).unwrap();
let r2 = reconstruct_master_key(&shares2[..2]).unwrap();
assert_ne!(r1.as_bytes(), r2.as_bytes());
}
#[test]
fn test_zeroize_on_drop() {
let key = test_key();
let config = ShamirConfig {
total_shares: 3,
threshold: 2,
};
let shares = split_master_key(&key, &config).unwrap();
assert!(!shares[0].data.is_empty());
}
#[test]
fn test_large_threshold() {
let key = test_key();
let config = ShamirConfig {
total_shares: 10,
threshold: 10,
};
let shares = split_master_key(&key, &config).unwrap();
assert_eq!(shares.len(), 10);
let reconstructed = reconstruct_master_key(&shares).unwrap();
assert_eq!(reconstructed.as_bytes(), key.as_bytes());
}
#[test]
fn test_empty_shares_fails() {
let result = reconstruct_master_key(&[]);
assert!(result.is_err());
}
#[test]
fn test_duplicate_shares_rejected() {
let key = test_key();
let config = ShamirConfig {
total_shares: 3,
threshold: 2,
};
let shares = split_master_key(&key, &config).unwrap();
let dupes = vec![shares[0].clone(), shares[0].clone()];
match reconstruct_master_key(&dupes) {
Err(e) => assert!(e.to_string().contains("duplicate")),
Ok(_) => panic!("expected error for duplicate shares"),
}
}
#[test]
fn test_wrong_share_length_rejected() {
let short = vec![
KeyShare {
index: 1,
data: vec![1, 2, 3],
},
KeyShare {
index: 2,
data: vec![2, 3, 4],
},
];
match reconstruct_master_key(&short) {
Err(e) => assert!(e.to_string().contains("invalid share length")),
Ok(_) => panic!("expected error for wrong length"),
}
}
#[test]
fn test_zero_x_coordinate_rejected() {
let bad = vec![
KeyShare {
index: 0,
data: vec![0; KEY_SIZE + 1],
},
KeyShare {
index: 1,
data: vec![1; KEY_SIZE + 1],
},
];
match reconstruct_master_key(&bad) {
Err(e) => assert!(e.to_string().contains("x-coordinate cannot be zero")),
Ok(_) => panic!("expected error for zero x-coordinate"),
}
}
#[test]
fn test_share_data_format() {
let key = test_key();
let config = ShamirConfig {
total_shares: 3,
threshold: 2,
};
let shares = split_master_key(&key, &config).unwrap();
for (i, share) in shares.iter().enumerate() {
assert_eq!(share.index, u8::try_from(i + 1).unwrap());
assert_eq!(share.data[0], share.index);
assert_eq!(share.data.len(), KEY_SIZE + 1);
}
}
#[test]
fn test_random_key_roundtrip() {
let mut key_bytes = [0u8; KEY_SIZE];
rand::rng().fill_bytes(&mut key_bytes);
let key = MasterKey::from_bytes(key_bytes);
let config = ShamirConfig {
total_shares: 5,
threshold: 3,
};
let shares = split_master_key(&key, &config).unwrap();
for combo in &[
vec![0, 1, 2],
vec![0, 2, 4],
vec![1, 3, 4],
vec![0, 1, 2, 3, 4],
] {
let subset: Vec<_> = combo.iter().map(|&i| shares[i].clone()).collect();
let r = reconstruct_master_key(&subset).unwrap();
assert_eq!(r.as_bytes(), key.as_bytes());
}
}
}